package ac.essex.ooechs.lcs.test;

import ac.essex.ooechs.lcs.*;
import ac.essex.ooechs.lcs.representation.Condition;
import ac.essex.ooechs.lcs.util.MovingAverage;
import ac.essex.ooechs.lcs.util.Average;
import ac.essex.ooechs.lcs.util.ProportionalSelection;
import ac.essex.ooechs.problems.woods.Woods2Environment;

import java.util.Vector;
import java.util.Hashtable;

/**
 * <p/>
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version,
 * provided that any use properly credits the author.
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details at http://www.gnu.org
 * </p>
 *
 * @author Olly Oechsle, University of Essex, Date: 14-Jan-2008
 * @version 1.0
 */
public class XCS_OLD extends Thread {

    public static final int EXPLORE = 1;
    public static final int EXPLOIT = 2;

    protected int time = 0;

    protected Environment environment;

    protected Vector<Classifier> classifiers;

    protected XCSParams_OLD settings;

    protected Vector<Classifier> previousActionSet;

    protected Payoff previousPayoff;

    protected GeneticAlgorithm_OLD ga;

    public XCS_OLD(Environment e, XCSParams_OLD settings) {
        this.environment = e;
        this.settings = settings;
    }

    /**
     * Initialises the rules vector and the genetic algorithm.
     */
    public void initialise() {

        System.out.println("Initialising");

        // start with an empty rules vector
        classifiers = new Vector<Classifier>();

        // initialise the GA
        ga = new GeneticAlgorithm_OLD(this, settings);
    }

    /**
     * Starts the main learning process
     */
    public void learn() {

        System.out.println("XCS in Java\nby Olly Oechsle");

        initialise();

        int maxSteps = 50;

        MovingAverage runs = new MovingAverage(50);
        MovingAverage error = new MovingAverage(50);

        for (int i = 0; i < settings.problemCount; i++)  {

            // try the problem first in explore mode
            solveProblemOnce(EXPLORE, maxSteps);

            // then get a value in exploit mode
            int run = solveProblemOnce(EXPLOIT, maxSteps);

            // addClassifier the run to the moving average
            runs.add(run);
            error.add(getMeanError());

            // and print out the average
            if (i % 50 == 0)
            System.out.println(i + ", " + error.getMovingAverage() + ", " + runs.getMovingAverage() + ", " + ga.executionCount);

            // increment the experiment number
            time++;            

        }

        System.out.println("Done.");

    }

    protected double getMeanFitness() {
        // TODO: Why is fitness getting smaller?
        Average a = new Average();
        for (int i = 0; i < classifiers.size(); i++) {
            Classifier classifier = classifiers.elementAt(i);
            a.add(classifier.fitness);
        }
        return a.getMean();
    }

    protected double getMeanError() {
        // TODO: Why is fitness getting smaller?
        Average a = new Average();
        for (int i = 0; i < classifiers.size(); i++) {
            Classifier classifier = classifiers.elementAt(i);
            a.add(classifier.e);
        }
        return a.getMean();
    }

    protected int solveProblemOnce(int mode, int maxsteps) {

        // initialise the environment
        environment.initialise();

        // clear the previous action set and payoff
        previousActionSet = null;
        previousPayoff = null;

        // record how many steps were required
        int run;

        // run for as many steps as allowed
        for (run = 1; run <= maxsteps; run++) {

            // get the payoff from one iteration
            Payoff p = iterateOnce(mode);

            // if we have found food, finish
            if (p.isFinished()) break;

        }

        // return the number of steps required
        return run;

    }


    /**
     * Runs one learning iteration either in explore or in exploit mode.
     */
    protected Payoff iterateOnce(int mode) {

        // get a stimulus from the environment
        InputVector i = environment.getInput();

        // create a match set
        Vector<Classifier> matchSet = new Vector<Classifier>();

        // keep a note of the total prediction of M
        double totalPredictionOfM = 0;

        // the size of m
        int matchSize = 0;

        // also keep a note of the average prediction of the population
        Average populationPrediction = new Average();

        // populate the match set
        for (int j = 0; j < classifiers.size(); j++) {
            Classifier classifier = classifiers.elementAt(j);
            if (classifier.condition.matches(i)) {
                matchSet.add(classifier);
                totalPredictionOfM += classifier.p * classifier.numerosity;
                matchSize += classifier.numerosity;
            }
            populationPrediction.add(classifier.p);
        }

        // covering operation, instantiated if:
        // 1. The match set is empty
        // 2. The total prediction of M is less than PHI times the mean prediction of [P]
        if (matchSet.size() == 0/* || totalPredictionOfM < (settings.phi * populationPrediction.getMean())*/) {
            Condition condition = environment.getConditionToCover(i);
            // get all the available actions
            Vector<Action> actions = environment.getActions();
            Action action = actions.elementAt((int) (Math.random() * actions.size()));
            Classifier classifier = new Classifier(condition, action, time);
            matchSet.add(classifier);
            add(classifier);
        }

        // now create a prediction array
        Hashtable<Action, Prediction_OLD> predictionArray = new Hashtable<Action, Prediction_OLD>();

        // populate the prediction array
        for (int j = 0; j < matchSet.size(); j++) {
            Classifier classifier = matchSet.elementAt(j);

            Prediction_OLD p = predictionArray.get(classifier.action);
            if (p == null)  {
                // create a new entry in the prediction array
                predictionArray.put(classifier.action, new Prediction_OLD(classifier));
            } else {
                // takeAction the prediction for this action
                p.addRule(classifier);
            }

            // also, update the match size estimages (needed for deletion scheme (1))
            classifier.updateMatchSizeEstimate(settings.beta, matchSize);

        }

        // now choose the action with the largest prediction (deterministic action selection)
        Vector<Prediction_OLD> predictions = new Vector<Prediction_OLD>(predictionArray.values());

        if (predictions.size() > 1) {
            //System.out.println("More than one prediction");
        }

        // Choose the best prediction
        Prediction_OLD chosenPrediction = null;

        if (mode == EXPLOIT)  {

            // exploit mode - choose the prediction with highest score            
            for (int j = 0; j < predictions.size(); j++) {
                Prediction_OLD prediction =  predictions.elementAt(j);
                if (chosenPrediction == null || prediction.getPrediction() > chosenPrediction.getPrediction()) chosenPrediction = prediction;
            }

        } else {

            // explore mode - choose randomly from the prediction set.
            chosenPrediction = predictions.elementAt((int) (Math.random() * predictions.size()));

        }

        // send the recommended action to the environment and receive the payoff
        Payoff payoff = environment.takeAction(chosenPrediction.getAction());

        // The action set [A] consists of all the rules which contributed to the chosen prediction
        Vector<Classifier> actionSet = chosenPrediction.getRules();

        if (mode == EXPLORE) {

            // do the reinforcement here, on the previous action set
            if (previousActionSet != null) {
                updateActionSet(previousActionSet, chosenPrediction, payoff);
            }

            // if it only took one step, update the current action set instead.
            if (payoff.isFinished() && previousActionSet == null) {
                updateActionSet(actionSet, chosenPrediction, payoff);
            }

            // set the previous action set (for future reinforcements)
            previousActionSet = actionSet;

            // and the previous payoff
            previousPayoff = payoff;

            // run the GA
            ga.runIfAverageAgeIsHighEnough(actionSet, time);

        }

        return payoff;

    }

    /**
     * Adds a rule to the classifier
     */
    protected void add(Classifier classifier) {

        if (classifiers.size() >= settings.N) {
            delete();
        }

        classifiers.add(classifier);

    }

    /**
     * Deletes a rule from the classifier
     */
    protected void delete() {
        
        if (classifiers.size() >= settings.N) {

            Classifier r = deletionScheme1();

            if (r.numerosity > 1) {
                r.numerosity--;
            } else {
                classifiers.remove(r);
            }
            
        }

    }

    /**
     * Returns a rule for deletion using Wilson's first deletion scheme, where rules are
     * deleted with probability proportional to their match sizes. This is intended to keep
     * all the match sets (which may be regarded as equivalent to niches) roughly the same
     * size.
     */
    public Classifier deletionScheme1() {

        ProportionalSelection s = new ProportionalSelection(classifiers.size());

        for (int i = 0; i < classifiers.size(); i++) {
            Classifier classifier = classifiers.elementAt(i);
            s.addWeight(classifier.matchSizeEstimate * classifier.numerosity);
        }

        int index = s.chooseIndex();

        return classifiers.elementAt(index);
        
    }

    protected void updateActionSet(Vector<Classifier> actionSet, Prediction_OLD bestPrediction, Payoff payoff) {

        // recalculate accuracy and relative accuracy, based on the prediction error
        recalculateAccuracy(actionSet);

        // adjust the fitness to tend toward the accuracy
        adjustFitness(actionSet);

        // calculate P (the reward)
        double P = (bestPrediction.getPrediction() * settings.discountFactor) + payoff.getAmount();

        // adjust the prediction error using P
        adjustPredictionAndError(actionSet, P);

        // takeAction experience
        updateExperience(actionSet);

    }

    /**
     * Accuracy is recalculated based on the current prediction error.
     * Accuracy is a fixed function based on the prediction error. Bad error
     * rates get exponentially poorer accuracies.
     */
    protected void recalculateAccuracy(Vector<Classifier> actionSet) {

        double totalAccuracy = 0;

        for (int i = 0; i < actionSet.size(); i++) {

            Classifier classifier = actionSet.elementAt(i);

            double ej = classifier.e;

            double e0 = settings.e0;

            if (ej > e0) {

                classifier.accuracy = settings.alpha * Math.pow( ej / e0 , -5);

            } else {

                classifier.accuracy = Math.pow(e0, -5);

            }

            totalAccuracy += (classifier.accuracy * classifier.numerosity);

        }

        for (int i = 0; i < actionSet.size(); i++) {

            Classifier classifier = actionSet.elementAt(i);

            classifier.setRelativeAccuracy(classifier.accuracy / totalAccuracy);

        }

    }

    /**
     * Adjusts the fitness based of each rule in the set based on its relative accuracy.
     * Fitness tends towards the value of the relative accuracy, the speed at which it converges
     * is dependent upon the learning rate.
     */
    protected void adjustFitness(Vector<Classifier> actionSet) {

        double oneOverBeta = 1d / settings.beta;

        for (int i = 0; i < actionSet.size(); i++) {

            Classifier classifier =  actionSet.elementAt(i);

            // Has fitness been adjusted at least 1 / BETA times?
            if (classifier.experience >= oneOverBeta) {

                // have the fitness tend toward the accuracy, dependent on the strength of beta
                classifier.fitness += settings.beta * ((classifier.getRelativeAccuracy() * classifier.numerosity ) - classifier.fitness);

            } else {

                // otherwise, have the fitness be the average of the current and previous relative accuracy values
                classifier.fitness = classifier.relativeAccuracyValues.getMean();

            }


        }

    }

    /**
     * Adjusts the prediction error of the action set using a calculated value of P
     * which incorporates the payoff and the confidence of the the prediction array.
     */
    protected void adjustPredictionAndError(Vector<Classifier> actionSet, double P) {

        double oneOverBeta = 1d / settings.beta;

        for (int i = 0; i < actionSet.size(); i++) {

            Classifier classifier = actionSet.elementAt(i);

            // Has fitness been adjusted at least 1 / BETA times?
            if (classifier.experience >= oneOverBeta) {

                // Adjust the prediction error toward the absolute difference between P and the actual condition.
                classifier.e += settings.beta * ((Math.abs(P - classifier.p)) - classifier.e);

                // Adjust the prediction toward P itself
                classifier.p += settings.beta * (P - classifier.p);

            } else {

                // otherwise, have the fitness be the average of the current and previous P values
                classifier.previousPValues.add(P);
                // TODO: This can't be right!
                classifier.e = P - classifier.previousPValues.getMean();
                classifier.p = classifier.previousPValues.getMean();

            }

        }

    }

    /**
     * Updates the "experience" of each rule in the action set, that is how many times
     * each of them has had their parameters updated.
     */
    public void updateExperience(Vector<Classifier> actionSet) {

        for (int i = 0; i < actionSet.size(); i++) {
            actionSet.elementAt(i).experience++;
        }

    }

    public static void main(String[] args) {
        Environment e = new Woods2Environment();
        XCS_OLD learner = new XCS_OLD(e, new XCSParams_OLD());
        learner.learn();
    }

}
