Wednesday, November 28, 2018

Updated QLearning Algorithm in Java to Generate Buy/Sell Signals

package QLearning;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Date;
import java.util.Random;
import java.util.TreeMap;

/**
 *
 * QLearningAlgorithm
 *
 * @author joe mcverry
 *
 * write to me - usacoder on the gmail server
 *
 *
 * @version 0.0.20
 *
 *          1. simplified by removing some arguments pass to learning step.
 *
 *          2. broke learning steps into three methods with learning now calling
 *          learnAStep and takeAStep.
 *
 *          3. added divergence to allow breaking out of learning if the state
 *          table values don't change (much).
 *
 *          4. state table (aka StateActionAndResultsTable) can be stored and
 *          read from a file
 *
 */

public class QLearningAlgorithm {

    enum State {
        Initial, Cash, Stocks
    };

    /**
     * state action table,
     *
     * key is the state plus the action taken
     *
     * double[2]:
     *
     * 1 - total reward
     *
     * 2 - # of times we've done this, used for debugging.
     *
     */
    TreeMap stateActionAndResultsTable = new TreeMap<>();

    private File objectFile;

    public void storeSARTableToFile() throws Exception {
        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(objectFile));
        oos.writeObject(stateActionAndResultsTable);
        oos.flush();
        oos.close();
    }

    @SuppressWarnings("unchecked")
    public void loadSARTableFromFile(File file) throws Exception {
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(file));
        stateActionAndResultsTable = (TreeMap) ois.readObject();
        ois.close();
        objectFile = file;
    }

    public TreeMap getSARTable() {
        return stateActionAndResultsTable;
    }

    /**
     * ALPHA is learning rate.
     *
     * I set alpha high at the begining and diverge it after every
     * iteration,  logic has it going to .66666666
     *
     * If you don't want to do this set ALPHA to your wanted value
     * and set CONVERGE_ALPHA to false.
     *
     */
    double ALPHA = 1.; // learning rate
   
    public void setALPHA(double aLPHA) {
        ALPHA = aLPHA;
    }

    public void setCONVERGE_ALPHA(boolean cONVERGE_ALPHA) {
        CONVERGE_ALPHA = cOVERGE_ALPHA;
    }

    boolean CONVERGE_ALPHA = true;

    /**
     * GAMMA is replacement rate. if 0 then the immediate state replaces the
     * current state value; if 1 then it's the average of all time steps.
     */
    double GAMMA = .32;
   
    public void setGAMMA(double gAMMA) {
        GAMMA = gAMMA;
    }


    /**
     * EPSILON % of times the action taken is random.
     */
    double EPSILON = .1; //

    public void setEPSILON(double ePSILON) {
        EPSILON = ePSILON;
    }

    /**
     * ITERATIONS is how many times do we run the learning algorithm to fill the
     * table
     */

    static int ITERATIONS = 2500; // how many times do we run the
                                    // learning algorithm to fill
                                    // the table (theMap)

    public static void setITERATIONS(int iTERATIONS) {
        ITERATIONS = iTERATIONS;
    }


    /**
     * do we take iterations down to zero or can be break when map values don't
     * change. see DIVERGENCE_LIMIT;
     */
    static boolean TEST_DIVERGENCE = true;
   
    public void setTestDivergence(boolean tEST_DIVERGENCE)
    {
        TEST_DIVERGENCE = tEST_DIVERGENCE;
    }

    static double DIVERGENCE_LIMIT = 1e-7;

    public void setDivergenceLimit(double dIVERGENCE_LIMIT)
    {
        DIVERGENCE_LIMIT = dIVERGENCE_LIMIT;
    }




    /**
     * TRENDLENGTH is what i used to look into the future to get results from
     * the trendMap (just a simple moving average).
     */
    static int TRENDLENGTH = 5;

    Random randoms;

    public static void setTrendLength(int inTrendLength) {
        TRENDLENGTH = inTrendLength;
    }

    public QLearningAlgorithm() {
        randoms = new Random(0);
    }

    public QLearningAlgorithm(String SARTableFileName) throws Exception {
        if (SARTableFileName == null) {
            stateActionAndResultsTable = new TreeMap<>();
        } else if (SARTableFileName.length() == 0) {
            stateActionAndResultsTable = new TreeMap<>();
        } else {
            File file = new File(SARTableFileName);
            if (file.exists() == false) {
                System.err.println("sar table file (" + SARTableFileName + ")does not exists, starting from scratch ");
                stateActionAndResultsTable = new TreeMap<>();
            }
            loadSARTableFromFile(file);
        }
        randoms = new Random(0);
    }

    /**
     * @param symbol
     * @param maxPosition
     *            - where do we stop
     * @param qL
     *            - QLearner object, which should be the object calling this
     *            method.
     * @return a double of the final divergent value divergence is to test the
     *         state table and see if anything changed
     */
    public double learning(String symbol, int maxPosition, QLearner qL) {

        double divergentSum = 0;
        TreeMap testTable = new TreeMap<>();
        for (int iterationCount = ITERATIONS; iterationCount > 0; iterationCount--) {
            randoms = new Random(new Date().getTime());
            int position = 0;
            while (qL.dataHere(symbol, position) == false)
                /**
                 * why test for data? - because, for example, if using a 15 day
                 * moving average the first 14 entries in the ma table would be
                 * 0 and therefore useless.
                 */
                position++;

            State currentState = State.Cash;
            State previousState = State.Initial;

            for (; position < maxPosition - TRENDLENGTH; position++) {
                if (previousState != State.Initial)
                    learnAStep(symbol, position - 1, qL, currentState);
                currentState = takeAStep(symbol, position, qL, currentState);
                previousState = currentState;
            }

            if (TEST_DIVERGENCE) {
                if (iterationCount < ITERATIONS / 2) {
                    divergentSum = 0;
                    for (String key : stateActionAndResultsTable.keySet()) {
                        divergentSum += Math.abs(stateActionAndResultsTable.get(key)[0] - testTable.get(key));
                    }
                    if (divergentSum < DIVERGENCE_LIMIT)
                        return divergentSum;
                }

                for (String key : stateActionAndResultsTable.keySet()) {
                    double d = stateActionAndResultsTable.get(key)[0];
                    testTable.put(key, new Double(d));
                }
            }
            /**
             * IF ALPHA is 1 then ALPHA will converge to wards.66666,  feel free to remove this.
             */
            if (CONVERGE_ALPHA)
                ALPHA = 1 - ALPHA * .5;  // converges towards .6666666666666 

        }

        return divergentSum;
    }

    /**
     * @param symbol
     * @param position
     *            - where is the data
     * @param qL
     *            - QLearner object
     * @param thisState
     *            (buy or sell)
     */
    private void learnAStep(String symbol, int position, QLearner qL, State thisState) {

        String attributeKey = qL.getAttributeValue(symbol, position);
        double nextPrice = qL.getTrend(symbol, position + 1);
        double lastPrice = qL.getPrice(symbol, position);
        /*
         * getTrend returns a moving average value, I use a simple moving
         * average with the m.a. interval set to LOOKAHEAD
         */
        double percentChange = (nextPrice / lastPrice);
        if (stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name()) == null) {
            stateActionAndResultsTable.put(attributeKey + "_" + State.Cash.name(), new double[2]);
            stateActionAndResultsTable.put(attributeKey + "_" + State.Stocks.name(), new double[2]);
        }

        double reward;
        if (thisState == State.Cash)
            /**
             * if percentage < 1 in cash that's good; flip to get good reward,
             * if in cash & percentage > 1 then ding reward.
             */
            reward = 1 / percentChange;
        else
            /**
             * if reward > 1 then that's good for in stock position otherwiese
             * reward shouldn't be so good.
             */
            reward = percentChange;

        /**
         * how'd we do for the action taken prior to this step.
         */
        double valueForThisAction = stateActionAndResultsTable.get(attributeKey + "_" + thisState.name())[0];

        /**
         * max is something like the strength value for this particular
         * attribute
         */

        double max = (Math.max(stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name())[0],
                stateActionAndResultsTable.get(attributeKey + "_" + State.Stocks.name())[0]));

        /**
         * apply the learning rate
         */
        valueForThisAction = valueForThisAction * (1 - ALPHA);
        /**
         * and apply the reward with replacement rate with its strength
         */
        valueForThisAction += ALPHA * (reward + (GAMMA * max) - valueForThisAction);

        stateActionAndResultsTable.get(attributeKey + "_" + thisState.name())[0] = valueForThisAction;
        stateActionAndResultsTable.get(attributeKey + "_" + thisState.name())[1] += 1;

    }

    /**
     * @param symbol
     * @param position
     *            - where is the data
     * @param qL
     *            - QLearner object
     * @param thisState
     *            (buy or sell)
     * @return newState
     */

    private State takeAStep(String symbol, int position, QLearner qL, State currentState) {

        String attributeKey = qL.getAttributeValue(symbol, position);
        /*
         * getTrend returns a moving average value, I use a simple moving
         * average with the m.a. interval set to LOOKAHEAD
         */
        if (stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name()) == null) {
            stateActionAndResultsTable.put(attributeKey + "_" + State.Cash.name(), new double[2]);
            stateActionAndResultsTable.put(attributeKey + "_" + State.Stocks.name(), new double[2]);

        }
        double rand = randoms.nextDouble();

        double holdCash = stateActionAndResultsTable.get(attributeKey + "_" + State.Cash.name())[0];
        double buyStocks = stateActionAndResultsTable.get(attributeKey + "_" + State.Stocks.name())[0];
        if (holdCash == buyStocks) {
            /** either way the same so take a random step */
            if (rand < .5) {
                return State.Cash;
            } else {
                return State.Stocks;
            }
        } else {
            if (rand < EPSILON) {
                /** randomly take a random step */
                rand = randoms.nextDouble();
                if (rand < .5) {
                    return State.Cash;
                } else {
                    return State.Stocks;
                }
            } else { /** take the stronger step */
                if (holdCash > buyStocks) {
                    return State.Cash;
                } else {
                    return State.Stocks;
                }
            }
        }

        // unreachable
    }

}

No comments:

Post a Comment