package QLearning;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Date;
import java.util.Random;
import java.util.TreeMap;
/**
*
* QLearningAlgorithm
*
* @author joe mcverry
*
* write to me - usacoder on the gmail server
*
*
* @version 0.0.20
*
* 1. simplified by removing some arguments pass to learning step.
*
* 2. broke learning steps into three methods with learning now calling
* learnAStep and takeAStep.
*
* 3. added divergence to allow breaking out of learning if the state
* table values don't change (much).
*
* 4. state table (aka StateActionAndResultsTable) can be stored and
* read from a file
*
*/
public class QLearningAlgorithm {
enum State {
Initial, Cash, Stocks
};
/**
* state action table,
*
* key is the state plus the action taken
*
* double[2]:
*
* 1 - total reward
*
* 2 - # of times we've done this, used for debugging.
*
*/
TreeMap stateActionAndResultsTable = new TreeMap<>();
private File objectFile;
public void storeSARTableToFile() throws Exception {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(objectFile));
oos.writeObject(stateActionAndResultsTable);
oos.flush();
oos.close();
}
@SuppressWarnings("unchecked")
public void loadSARTableFromFile(File file) throws Exception {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(file));
stateActionAndResultsTable = (TreeMap) ois.readObject();
ois.close();
objectFile = file;
}
public TreeMap getSARTable() {
return stateActionAndResultsTable;
}
/**
* ALPHA is learning rate.
*
* I set alpha high at the begining and diverge it after every
* iteration, logic has it going to .66666666
*
* If you don't want to do this set ALPHA to your wanted value
* and set CONVERGE_ALPHA to false.
*
*/
double ALPHA = 1.; // learning rate
public void setALPHA(double aLPHA) {
ALPHA = aLPHA;
}
public void setCONVERGE_ALPHA(boolean cONVERGE_ALPHA) {
CONVERGE_ALPHA = cOVERGE_ALPHA;
}
boolean CONVERGE_ALPHA = true;
/**
* GAMMA is replacement rate. if 0 then the immediate state replaces the
* current state value; if 1 then it's the average of all time steps.
*/
double GAMMA = .32;
public void setGAMMA(double gAMMA) {
GAMMA = gAMMA;
}
/**
* EPSILON % of times the action taken is random.
*/
double EPSILON = .1; //
public void setEPSILON(double ePSILON) {
EPSILON = ePSILON;
}
/**
* ITERATIONS is how many times do we run the learning algorithm to fill the
* table
*/
static int ITERATIONS = 2500; // how many times do we run the
// learning algorithm to fill
// the table (theMap)
public static void setITERATIONS(int iTERATIONS) {
ITERATIONS = iTERATIONS;
}
/**
* do we take iterations down to zero or can be break when map values don't
* change. see DIVERGENCE_LIMIT;
*/
static boolean TEST_DIVERGENCE = true;
public void setTestDivergence(boolean tEST_DIVERGENCE)
{
TEST_DIVERGENCE = tEST_DIVERGENCE;
}
static double DIVERGENCE_LIMIT = 1e-7;
public void setDivergenceLimit(double dIVERGENCE_LIMIT)
{
DIVERGENCE_LIMIT = dIVERGENCE_LIMIT;
}
/**
* TRENDLENGTH is what i used to look into the future to get results from
* the trendMap (just a simple moving average).
*/
static int TRENDLENGTH = 5;
Random randoms;
public static void setTrendLength(int inTrendLength) {
TRENDLENGTH = inTrendLength;
}
public QLearningAlgorithm() {
randoms = new Random(0);
}
public QLearningAlgorithm(String SARTableFileName) throws Exception {
if (SARTableFileName == null) {
stateActionAndResultsTable = new TreeMap<>();
} else if (SARTableFileName.length() == 0) {
stateActionAndResultsTable = new TreeMap<>();
} else {
File file = new File(SARTableFileName);
if (file.exists() == false) {
System.err.println("sar table file (" + SARTableFileName + ")does not exists, starting from scratch ");
stateActionAndResultsTable = new TreeMap<>();
}
loadSARTableFromFile(file);
}
randoms = new Random(0);
}
/**
* @param symbol
* @param maxPosition
* - where do we stop
* @param qL
* - QLearner object, which should be the object calling this
* method.
* @return a double of the final divergent value divergence is to test the
* state table and see if anything changed
*/
public double learning(String symbol, int maxPosition, QLearner qL) {
double divergentSum = 0;
TreeMap testTable = new TreeMap<>();
for (int iterationCount = ITERATIONS; iterationCount > 0; iterationCount--) {
randoms = new Random(new Date().getTime());
int position = 0;
while (qL.dataHere(symbol, position) == false)
/**
* why test for data? - because, for example, if using a 15 day
* moving average the first 14 entries in the ma table would be
* 0 and therefore useless.
*/
position++;
State currentState = State.Cash;
State previousState = State.Initial;
for (; position < maxPosition - TRENDLENGTH; position++) {
if (previousState != State.Initial)
learnAStep(symbol, position - 1, qL, currentState);
currentState = takeAStep(symbol, position, qL, currentState);
previousState = currentState;
}
if (TEST_DIVERGENCE) {
if (iterationCount < ITERATIONS / 2) {
divergentSum = 0;
for (String key : stateActionAndResultsTable.keySet()) {
divergentSum += Math.abs(stateActionAndResultsTable.get(key)[0] - testTable.get(key));
}
if (divergentSum < DIVERGENCE_LIMIT)
return divergentSum;
}
for (String key : stateActionAndResultsTable.keySet()) {
double d = stateActionAndResultsTable.get(key)[0];
testTable.put(key, new Double(d));
}
}
/**
* IF ALPHA is 1 then ALPHA will converge to wards.66666, feel free to remove this.
*/
if (CONVERGE_ALPHA)
ALPHA = 1 - ALPHA * .5; // converges towards .6666666666666
}
return divergentSum;
}
/**
* @param symbol
* @param position
* - where is the data
* @param qL
* - QLearner object
* @param thisState
* (buy or sell)
*/
private void learnAStep(String symbol, int position, QLearner qL, State thisState) {
String attributeKey = qL.getAttributeValue(symbol, position);
double nextPrice = qL.getTrend(symbol, position + 1);
double lastPrice = qL.getPrice(symbol, position);
/*
* getTrend returns a moving average value, I use a simple moving
* average with the m.a. interval set to LOOKAHEAD
*/
double percentChange = (nextPrice / lastPrice);
if (stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name()) == null) {
stateActionAndResultsTable.put(attributeKey + "_" + State.Cash.name(), new double[2]);
stateActionAndResultsTable.put(attributeKey + "_" + State.Stocks.name(), new double[2]);
}
double reward;
if (thisState == State.Cash)
/**
* if percentage < 1 in cash that's good; flip to get good reward,
* if in cash & percentage > 1 then ding reward.
*/
reward = 1 / percentChange;
else
/**
* if reward > 1 then that's good for in stock position otherwiese
* reward shouldn't be so good.
*/
reward = percentChange;
/**
* how'd we do for the action taken prior to this step.
*/
double valueForThisAction = stateActionAndResultsTable.get(attributeKey + "_" + thisState.name())[0];
/**
* max is something like the strength value for this particular
* attribute
*/
double max = (Math.max(stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name())[0],
stateActionAndResultsTable.get(attributeKey + "_" + State.Stocks.name())[0]));
/**
* apply the learning rate
*/
valueForThisAction = valueForThisAction * (1 - ALPHA);
/**
* and apply the reward with replacement rate with its strength
*/
valueForThisAction += ALPHA * (reward + (GAMMA * max) - valueForThisAction);
stateActionAndResultsTable.get(attributeKey + "_" + thisState.name())[0] = valueForThisAction;
stateActionAndResultsTable.get(attributeKey + "_" + thisState.name())[1] += 1;
}
/**
* @param symbol
* @param position
* - where is the data
* @param qL
* - QLearner object
* @param thisState
* (buy or sell)
* @return newState
*/
private State takeAStep(String symbol, int position, QLearner qL, State currentState) {
String attributeKey = qL.getAttributeValue(symbol, position);
/*
* getTrend returns a moving average value, I use a simple moving
* average with the m.a. interval set to LOOKAHEAD
*/
if (stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name()) == null) {
stateActionAndResultsTable.put(attributeKey + "_" + State.Cash.name(), new double[2]);
stateActionAndResultsTable.put(attributeKey + "_" + State.Stocks.name(), new double[2]);
}
double rand = randoms.nextDouble();
double holdCash = stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name())[0];
double buyStocks = stateActionAndResultsTable.get(attributeKey + "_" + State.Stocks.name())[0];
if (holdCash == buyStocks) {
/** either way the same so take a random step */
if (rand < .5) {
return State.Cash;
} else {
return State.Stocks;
}
} else {
if (rand < EPSILON) {
/** randomly take a random step */
rand = randoms.nextDouble();
if (rand < .5) {
return State.Cash;
} else {
return State.Stocks;
}
} else { /** take the stronger step */
if (holdCash > buyStocks) {
return State.Cash;
} else {
return State.Stocks;
}
}
}
// unreachable
}
}
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Date;
import java.util.Random;
import java.util.TreeMap;
/**
*
* QLearningAlgorithm
*
* @author joe mcverry
*
* write to me - usacoder on the gmail server
*
*
* @version 0.0.20
*
* 1. simplified by removing some arguments pass to learning step.
*
* 2. broke learning steps into three methods with learning now calling
* learnAStep and takeAStep.
*
* 3. added divergence to allow breaking out of learning if the state
* table values don't change (much).
*
* 4. state table (aka StateActionAndResultsTable) can be stored and
* read from a file
*
*/
public class QLearningAlgorithm {
enum State {
Initial, Cash, Stocks
};
/**
* state action table,
*
* key is the state plus the action taken
*
* double[2]:
*
* 1 - total reward
*
* 2 - # of times we've done this, used for debugging.
*
*/
TreeMap
private File objectFile;
public void storeSARTableToFile() throws Exception {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(objectFile));
oos.writeObject(stateActionAndResultsTable);
oos.flush();
oos.close();
}
@SuppressWarnings("unchecked")
public void loadSARTableFromFile(File file) throws Exception {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(file));
stateActionAndResultsTable = (TreeMap
ois.close();
objectFile = file;
}
public TreeMap
return stateActionAndResultsTable;
}
/**
* ALPHA is learning rate.
*
* I set alpha high at the begining and diverge it after every
* iteration, logic has it going to .66666666
*
* If you don't want to do this set ALPHA to your wanted value
* and set CONVERGE_ALPHA to false.
*
*/
double ALPHA = 1.; // learning rate
public void setALPHA(double aLPHA) {
ALPHA = aLPHA;
}
public void setCONVERGE_ALPHA(boolean cONVERGE_ALPHA) {
CONVERGE_ALPHA = cOVERGE_ALPHA;
}
boolean CONVERGE_ALPHA = true;
/**
* GAMMA is replacement rate. if 0 then the immediate state replaces the
* current state value; if 1 then it's the average of all time steps.
*/
double GAMMA = .32;
public void setGAMMA(double gAMMA) {
GAMMA = gAMMA;
}
/**
* EPSILON % of times the action taken is random.
*/
double EPSILON = .1; //
public void setEPSILON(double ePSILON) {
EPSILON = ePSILON;
}
/**
* ITERATIONS is how many times do we run the learning algorithm to fill the
* table
*/
static int ITERATIONS = 2500; // how many times do we run the
// learning algorithm to fill
// the table (theMap)
public static void setITERATIONS(int iTERATIONS) {
ITERATIONS = iTERATIONS;
}
/**
* do we take iterations down to zero or can be break when map values don't
* change. see DIVERGENCE_LIMIT;
*/
static boolean TEST_DIVERGENCE = true;
public void setTestDivergence(boolean tEST_DIVERGENCE)
{
TEST_DIVERGENCE = tEST_DIVERGENCE;
}
static double DIVERGENCE_LIMIT = 1e-7;
public void setDivergenceLimit(double dIVERGENCE_LIMIT)
{
DIVERGENCE_LIMIT = dIVERGENCE_LIMIT;
}
/**
* TRENDLENGTH is what i used to look into the future to get results from
* the trendMap (just a simple moving average).
*/
static int TRENDLENGTH = 5;
Random randoms;
public static void setTrendLength(int inTrendLength) {
TRENDLENGTH = inTrendLength;
}
public QLearningAlgorithm() {
randoms = new Random(0);
}
public QLearningAlgorithm(String SARTableFileName) throws Exception {
if (SARTableFileName == null) {
stateActionAndResultsTable = new TreeMap<>();
} else if (SARTableFileName.length() == 0) {
stateActionAndResultsTable = new TreeMap<>();
} else {
File file = new File(SARTableFileName);
if (file.exists() == false) {
System.err.println("sar table file (" + SARTableFileName + ")does not exists, starting from scratch ");
stateActionAndResultsTable = new TreeMap<>();
}
loadSARTableFromFile(file);
}
randoms = new Random(0);
}
/**
* @param symbol
* @param maxPosition
* - where do we stop
* @param qL
* - QLearner object, which should be the object calling this
* method.
* @return a double of the final divergent value divergence is to test the
* state table and see if anything changed
*/
public double learning(String symbol, int maxPosition, QLearner qL) {
double divergentSum = 0;
TreeMap
for (int iterationCount = ITERATIONS; iterationCount > 0; iterationCount--) {
randoms = new Random(new Date().getTime());
int position = 0;
while (qL.dataHere(symbol, position) == false)
/**
* why test for data? - because, for example, if using a 15 day
* moving average the first 14 entries in the ma table would be
* 0 and therefore useless.
*/
position++;
State currentState = State.Cash;
State previousState = State.Initial;
for (; position < maxPosition - TRENDLENGTH; position++) {
if (previousState != State.Initial)
learnAStep(symbol, position - 1, qL, currentState);
currentState = takeAStep(symbol, position, qL, currentState);
previousState = currentState;
}
if (TEST_DIVERGENCE) {
if (iterationCount < ITERATIONS / 2) {
divergentSum = 0;
for (String key : stateActionAndResultsTable.keySet()) {
divergentSum += Math.abs(stateActionAndResultsTable.get(key)[0] - testTable.get(key));
}
if (divergentSum < DIVERGENCE_LIMIT)
return divergentSum;
}
for (String key : stateActionAndResultsTable.keySet()) {
double d = stateActionAndResultsTable.get(key)[0];
testTable.put(key, new Double(d));
}
}
/**
* IF ALPHA is 1 then ALPHA will converge to wards.66666, feel free to remove this.
*/
if (CONVERGE_ALPHA)
ALPHA = 1 - ALPHA * .5; // converges towards .6666666666666
}
return divergentSum;
}
/**
* @param symbol
* @param position
* - where is the data
* @param qL
* - QLearner object
* @param thisState
* (buy or sell)
*/
private void learnAStep(String symbol, int position, QLearner qL, State thisState) {
String attributeKey = qL.getAttributeValue(symbol, position);
double nextPrice = qL.getTrend(symbol, position + 1);
double lastPrice = qL.getPrice(symbol, position);
/*
* getTrend returns a moving average value, I use a simple moving
* average with the m.a. interval set to LOOKAHEAD
*/
double percentChange = (nextPrice / lastPrice);
if (stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name()) == null) {
stateActionAndResultsTable.put(attributeKey + "_" + State.Cash.name(), new double[2]);
stateActionAndResultsTable.put(attributeKey + "_" + State.Stocks.name(), new double[2]);
}
double reward;
if (thisState == State.Cash)
/**
* if percentage < 1 in cash that's good; flip to get good reward,
* if in cash & percentage > 1 then ding reward.
*/
reward = 1 / percentChange;
else
/**
* if reward > 1 then that's good for in stock position otherwiese
* reward shouldn't be so good.
*/
reward = percentChange;
/**
* how'd we do for the action taken prior to this step.
*/
double valueForThisAction = stateActionAndResultsTable.get(attributeKey + "_" + thisState.name())[0];
/**
* max is something like the strength value for this particular
* attribute
*/
double max = (Math.max(stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name())[0],
stateActionAndResultsTable.get(attributeKey + "_" + State.Stocks.name())[0]));
/**
* apply the learning rate
*/
valueForThisAction = valueForThisAction * (1 - ALPHA);
/**
* and apply the reward with replacement rate with its strength
*/
valueForThisAction += ALPHA * (reward + (GAMMA * max) - valueForThisAction);
stateActionAndResultsTable.get(attributeKey + "_" + thisState.name())[0] = valueForThisAction;
stateActionAndResultsTable.get(attributeKey + "_" + thisState.name())[1] += 1;
}
/**
* @param symbol
* @param position
* - where is the data
* @param qL
* - QLearner object
* @param thisState
* (buy or sell)
* @return newState
*/
private State takeAStep(String symbol, int position, QLearner qL, State currentState) {
String attributeKey = qL.getAttributeValue(symbol, position);
/*
* getTrend returns a moving average value, I use a simple moving
* average with the m.a. interval set to LOOKAHEAD
*/
if (stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name()) == null) {
stateActionAndResultsTable.put(attributeKey + "_" + State.Cash.name(), new double[2]);
stateActionAndResultsTable.put(attributeKey + "_" + State.Stocks.name(), new double[2]);
}
double rand = randoms.nextDouble();
double holdCash = stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name())[0];
double buyStocks = stateActionAndResultsTable.get(attributeKey + "_" + State.Stocks.name())[0];
if (holdCash == buyStocks) {
/** either way the same so take a random step */
if (rand < .5) {
return State.Cash;
} else {
return State.Stocks;
}
} else {
if (rand < EPSILON) {
/** randomly take a random step */
rand = randoms.nextDouble();
if (rand < .5) {
return State.Cash;
} else {
return State.Stocks;
}
} else { /** take the stronger step */
if (holdCash > buyStocks) {
return State.Cash;
} else {
return State.Stocks;
}
}
}
// unreachable
}
}
No comments:
Post a Comment