Updated program at https://usacoder.blogspot.com/2018/11/updated-qlearning-algorithm-in-java-to.html
package QLearning;
import java.util.Date;
import java.util.Random;
import java.util.TreeMap;
/**
*
* @author joe mcverry
* usacoder@gmail.com
* @since 10-10-2018
* @version 0.0.10
*
* I posted a program using a double moving algorithm that calls this code
* see: https://usacoder.blogspot.com/2018/11/a-double-moving-average-qlearning.html
*/
public class QLearningAlgorithm {
enum State {
Cash, Stocks
};
enum CashAction {
HoldCash, BuyStocks
};
enum StockAction {
HoldStocks, GoToCash
};
TreeMap stateActionAndResultsTable = new TreeMap<>();
public TreeMap getSARTable() {
return stateActionAndResultsTable;
}
final static double ALPHA = .618; // learning rate
final static double GAMMA = .9; // if 0 then the immediate state replaces
// the current
// state value;
// if 1 then it's the average of all time steps.
final static double EPISILON = .1; // % of times the action taken is random.
public final static int ITERATIONS = 144; // how many times do we run the
// learning algorithm to fill
// the table (theMap)
final static int LOOKAHEAD = 5; // how far into the future are we testing
public void learning(String symbol, String intervals[], int maxPosition, QLearning qla) {
for (int iterationCount = ITERATIONS; iterationCount > 0; iterationCount--) {
Random randoms = new Random(new Date().getTime());
int position = 0;
while (qla.dataHere(symbol, intervals, position) == false)
position++;
State currentState = State.Cash;
CashAction nextCashAction;
StockAction nextStockAction;
for (; position < maxPosition - LOOKAHEAD; position++) {
String attributeKey = qla.getAttributeValue(symbol, intervals, position);
double thisPrice = qla.getPrice(symbol, position);
double nextPrice = qla.getTrend(symbol, intervals, position + LOOKAHEAD);
//getTrend returns a X-days moving average value; I use a Simple Moving Value based on X = lookahead
double percentChange = nextPrice / thisPrice;
if (stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name()) == null) {
stateActionAndResultsTable.put(attributeKey + "_" + CashAction.HoldCash.name(), 0.);
stateActionAndResultsTable.put(attributeKey + "_" + CashAction.BuyStocks.name(), 0.);
stateActionAndResultsTable.put(attributeKey + "_" + StockAction.GoToCash.name(), 0.);
stateActionAndResultsTable.put(attributeKey + "_" + StockAction.HoldStocks.name(), 0.);
}
double rand = randoms.nextDouble();
if (currentState == State.Cash) {
double holdCash = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name());
double buyStocks = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name());
if (holdCash == buyStocks) {
if (rand < .5) {
nextCashAction = CashAction.HoldCash;
} else {
nextCashAction = CashAction.BuyStocks;
}
} else {
if (rand < EPISILON) {
rand = randoms.nextDouble();
if (rand < .5) {
nextCashAction = CashAction.HoldCash;
} else {
nextCashAction = CashAction.BuyStocks;
}
} else {
if (holdCash > buyStocks) {
nextCashAction = CashAction.HoldCash;
} else {
nextCashAction = CashAction.BuyStocks;
}
}
}
if (nextCashAction == CashAction.HoldCash) {
double was = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name());
double reward = (1 - percentChange);
double max = (Math.max(stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name())));
was += ALPHA * (reward + (GAMMA * max) - was);
stateActionAndResultsTable.put(attributeKey + "_" + CashAction.HoldCash.name(), was);
currentState = State.Cash;
} else {
double was = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name());
double reward = (percentChange - 1);
double max = (Math.max(stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name())));
was += ALPHA * (reward + (GAMMA * max) - was);
stateActionAndResultsTable.put(attributeKey + "_" + CashAction.BuyStocks.name(), was);
currentState = State.Stocks;
}
}
else {
double HoldStocks = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.HoldStocks.name());
double GoToCash = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name());
if (HoldStocks == GoToCash) {
if (rand < .5) {
nextStockAction = StockAction.HoldStocks;
} else {
nextStockAction = StockAction.GoToCash;
}
} else {
if (rand < EPISILON) {
rand = randoms.nextDouble();
if (rand < .5) {
nextStockAction = StockAction.HoldStocks;
} else {
nextStockAction = StockAction.GoToCash;
}
} else {
if (HoldStocks > GoToCash) {
nextStockAction = StockAction.HoldStocks;
} else {
nextStockAction = StockAction.GoToCash;
}
}
}
if (nextStockAction == StockAction.HoldStocks) {
double was = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name());
// if price goes up good, if price goes down bad
double reward = percentChange - 1;
double max = Math.max(stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()));
was += ALPHA * (reward + (GAMMA * max) - was);
stateActionAndResultsTable.put(attributeKey + "_" + StockAction.HoldStocks.name(), was);
currentState = State.Stocks;
} else { // go to cash
double was = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name());
double reward = 1 - percentChange;
double max = Math.max(stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()));
was += ALPHA * (reward + (GAMMA * max) - was);
stateActionAndResultsTable.put(attributeKey + "_" + StockAction.GoToCash.name(), was);
currentState = State.Cash;
}
}
}
}
}
}
public TreeMap
return stateActionAndResultsTable;
}
final static double ALPHA = .618; // learning rate
final static double GAMMA = .9; // if 0 then the immediate state replaces
// the current
// state value;
// if 1 then it's the average of all time steps.
final static double EPISILON = .1; // % of times the action taken is random.
public final static int ITERATIONS = 144; // how many times do we run the
// learning algorithm to fill
// the table (theMap)
final static int LOOKAHEAD = 5; // how far into the future are we testing
public void learning(String symbol, String intervals[], int maxPosition, QLearning qla) {
for (int iterationCount = ITERATIONS; iterationCount > 0; iterationCount--) {
Random randoms = new Random(new Date().getTime());
int position = 0;
while (qla.dataHere(symbol, intervals, position) == false)
position++;
State currentState = State.Cash;
CashAction nextCashAction;
StockAction nextStockAction;
for (; position < maxPosition - LOOKAHEAD; position++) {
String attributeKey = qla.getAttributeValue(symbol, intervals, position);
double thisPrice = qla.getPrice(symbol, position);
double nextPrice = qla.getTrend(symbol, intervals, position + LOOKAHEAD);
double percentChange = nextPrice / thisPrice;
if (stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name()) == null) {
stateActionAndResultsTable.put(attributeKey + "_" + CashAction.HoldCash.name(), 0.);
stateActionAndResultsTable.put(attributeKey + "_" + CashAction.BuyStocks.name(), 0.);
stateActionAndResultsTable.put(attributeKey + "_" + StockAction.GoToCash.name(), 0.);
stateActionAndResultsTable.put(attributeKey + "_" + StockAction.HoldStocks.name(), 0.);
}
double rand = randoms.nextDouble();
if (currentState == State.Cash) {
double holdCash = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name());
double buyStocks = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name());
if (holdCash == buyStocks) {
if (rand < .5) {
nextCashAction = CashAction.HoldCash;
} else {
nextCashAction = CashAction.BuyStocks;
}
} else {
if (rand < EPISILON) {
rand = randoms.nextDouble();
if (rand < .5) {
nextCashAction = CashAction.HoldCash;
} else {
nextCashAction = CashAction.BuyStocks;
}
} else {
if (holdCash > buyStocks) {
nextCashAction = CashAction.HoldCash;
} else {
nextCashAction = CashAction.BuyStocks;
}
}
}
if (nextCashAction == CashAction.HoldCash) {
double was = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name());
double reward = (1 - percentChange);
double max = (Math.max(stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name())));
was += ALPHA * (reward + (GAMMA * max) - was);
stateActionAndResultsTable.put(attributeKey + "_" + CashAction.HoldCash.name(), was);
currentState = State.Cash;
} else {
double was = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name());
double reward = (percentChange - 1);
double max = (Math.max(stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name())));
was += ALPHA * (reward + (GAMMA * max) - was);
stateActionAndResultsTable.put(attributeKey + "_" + CashAction.BuyStocks.name(), was);
currentState = State.Stocks;
}
}
else {
double HoldStocks = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.HoldStocks.name());
double GoToCash = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name());
if (HoldStocks == GoToCash) {
if (rand < .5) {
nextStockAction = StockAction.HoldStocks;
} else {
nextStockAction = StockAction.GoToCash;
}
} else {
if (rand < EPISILON) {
rand = randoms.nextDouble();
if (rand < .5) {
nextStockAction = StockAction.HoldStocks;
} else {
nextStockAction = StockAction.GoToCash;
}
} else {
if (HoldStocks > GoToCash) {
nextStockAction = StockAction.HoldStocks;
} else {
nextStockAction = StockAction.GoToCash;
}
}
}
if (nextStockAction == StockAction.HoldStocks) {
double was = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name());
// if price goes up good, if price goes down bad
double reward = percentChange - 1;
double max = Math.max(stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()));
was += ALPHA * (reward + (GAMMA * max) - was);
stateActionAndResultsTable.put(attributeKey + "_" + StockAction.HoldStocks.name(), was);
currentState = State.Stocks;
} else { // go to cash
double was = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name());
double reward = 1 - percentChange;
double max = Math.max(stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()));
was += ALPHA * (reward + (GAMMA * max) - was);
stateActionAndResultsTable.put(attributeKey + "_" + StockAction.GoToCash.name(), was);
currentState = State.Cash;
}
}
}
}
}
}
No comments:
Post a Comment