Friday, October 26, 2018

A Java Based QLearning Algorithm For Getting Buy and Sell Indicators

Updated program at https://usacoder.blogspot.com/2018/11/updated-qlearning-algorithm-in-java-to.html

package QLearning;

import java.util.Date;
import java.util.Random;
import java.util.TreeMap;

/**
 *
 * @author joe mcverry
 * usacoder@gmail.com
 * @since 10-10-2018
 * @version 0.0.10
 *
 *  I posted a program using a double moving algorithm that calls this code
 *  see: https://usacoder.blogspot.com/2018/11/a-double-moving-average-qlearning.html
 */

public class QLearningAlgorithm {

    enum State {
        Cash, Stocks
    };

    enum CashAction {
        HoldCash, BuyStocks
    };

    enum StockAction {
        HoldStocks, GoToCash
    };

    TreeMap stateActionAndResultsTable = new TreeMap<>();

    public TreeMap getSARTable() {
        return stateActionAndResultsTable;
    }

    final static double ALPHA = .618; // learning rate

    final static double GAMMA = .9; // if 0 then the immediate state replaces
                                    // the current
    // state value;
    // if 1 then it's the average of all time steps.

    final static double EPISILON = .1; // % of times the action taken is random.

    public final static int ITERATIONS = 144; // how many times do we run the
                                                // learning algorithm to fill
                                                // the table (theMap)

    final static int LOOKAHEAD = 5; // how far into the future are we testing

    public void learning(String symbol, String intervals[], int maxPosition, QLearning qla) {

        for (int iterationCount = ITERATIONS; iterationCount > 0; iterationCount--) {
            Random randoms = new Random(new Date().getTime());
            int position = 0;
            while (qla.dataHere(symbol, intervals, position) == false)
                position++;

            State currentState = State.Cash;
            CashAction nextCashAction;
            StockAction nextStockAction;

            for (; position < maxPosition - LOOKAHEAD; position++) {
                String attributeKey = qla.getAttributeValue(symbol, intervals, position);
                double thisPrice = qla.getPrice(symbol, position);
                double nextPrice = qla.getTrend(symbol, intervals, position + LOOKAHEAD);

//getTrend returns a X-days moving average value; I use a Simple Moving Value based on X = lookahead
                double percentChange = nextPrice / thisPrice;
                if (stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name()) == null) {
                    stateActionAndResultsTable.put(attributeKey + "_" + CashAction.HoldCash.name(), 0.);
                    stateActionAndResultsTable.put(attributeKey + "_" + CashAction.BuyStocks.name(), 0.);
                    stateActionAndResultsTable.put(attributeKey + "_" + StockAction.GoToCash.name(), 0.);
                    stateActionAndResultsTable.put(attributeKey + "_" + StockAction.HoldStocks.name(), 0.);
                }
                double rand = randoms.nextDouble();
                if (currentState == State.Cash) {
                    double holdCash = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name());
                    double buyStocks = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name());
                    if (holdCash == buyStocks) {
                        if (rand < .5) {
                            nextCashAction = CashAction.HoldCash;
                        } else {
                            nextCashAction = CashAction.BuyStocks;
                        }
                    } else {
                        if (rand < EPISILON) {
                            rand = randoms.nextDouble();
                            if (rand < .5) {
                                nextCashAction = CashAction.HoldCash;
                            } else {
                                nextCashAction = CashAction.BuyStocks;
                            }
                        } else {
                            if (holdCash > buyStocks) {
                                nextCashAction = CashAction.HoldCash;
                            } else {
                                nextCashAction = CashAction.BuyStocks;
                            }
                        }
                    }
                    if (nextCashAction == CashAction.HoldCash) {
                        double was = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name());
                        double reward = (1 - percentChange);
                        double max = (Math.max(stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name())));
                        was += ALPHA * (reward + (GAMMA * max) - was);
                        stateActionAndResultsTable.put(attributeKey + "_" + CashAction.HoldCash.name(), was);
                        currentState = State.Cash;
                    } else {
                        double was = stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name());
                        double reward = (percentChange - 1);
                        double max = (Math.max(stateActionAndResultsTable.get(attributeKey + "_" + CashAction.HoldCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + CashAction.BuyStocks.name())));
                        was += ALPHA * (reward + (GAMMA * max) - was);
                        stateActionAndResultsTable.put(attributeKey + "_" + CashAction.BuyStocks.name(), was);
                        currentState = State.Stocks;
                    }
                }
                else {
                    double HoldStocks = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.HoldStocks.name());
                    double GoToCash = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name());
                    if (HoldStocks == GoToCash) {
                        if (rand < .5) {
                            nextStockAction = StockAction.HoldStocks;
                        } else {
                            nextStockAction = StockAction.GoToCash;
                        }
                    } else {
                        if (rand < EPISILON) {
                            rand = randoms.nextDouble();
                            if (rand < .5) {
                                nextStockAction = StockAction.HoldStocks;
                            } else {
                                nextStockAction = StockAction.GoToCash;
                            }
                        } else {
                            if (HoldStocks > GoToCash) {
                                nextStockAction = StockAction.HoldStocks;
                            } else {
                                nextStockAction = StockAction.GoToCash;
                            }
                        }
                    }
                    if (nextStockAction == StockAction.HoldStocks) {
                        double was = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name());
                        // if price goes up good, if price goes down bad
                        double reward = percentChange - 1;
                        double max = Math.max(stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()));
                        was += ALPHA * (reward + (GAMMA * max) - was);
                        stateActionAndResultsTable.put(attributeKey + "_" + StockAction.HoldStocks.name(), was);
                        currentState = State.Stocks;
                    } else { // go to cash
                        double was = stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name());
                        double reward = 1 - percentChange;
                        double max = Math.max(stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()), stateActionAndResultsTable.get(attributeKey + "_" + StockAction.GoToCash.name()));
                        was += ALPHA * (reward + (GAMMA * max) - was);
                        stateActionAndResultsTable.put(attributeKey + "_" + StockAction.GoToCash.name(), was);
                        currentState = State.Cash;
                    }
                }
            }
        }
    }
 }

No comments:

Post a Comment