package ac.essex.ooechs.ecj.jasmine.problems;

import ac.essex.ooechs.ecj.jasmine.nodes.classification.TrainingData;
import ac.essex.ooechs.ecj.jasmine.nodes.classification.FeatureERC;
import ac.essex.ooechs.ecj.jasmine.nodes.classification.Loader;
import ac.essex.ooechs.ecj.commons.data.DoubleData;
import ac.essex.gp.multiclass.ProgramClassificationMap;

import java.util.Vector;
import java.io.File;
import java.io.IOException;

import ec.EvolutionState;
import ec.Individual;
import ec.simple.SimpleProblemForm;
import ec.gp.GPIndividual;
import ec.gp.GPProblem;
import ec.gp.koza.KozaFitness;
import ec.util.Parameter;

/**
 * <p/>
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version,
 * provided that any use properly credits the author.
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details at http://www.gnu.org
 * </p>
 *
 * @author Olly Oechsle, University of Essex, Date: 25-Oct-2007
 * @version 1.0
 */
public class JasmineClassificationProblemDRS extends GPProblem implements SimpleProblemForm {

    public static String SEPARATOR = ",";
    public static boolean usesImaging = false;

    // run for 10 minutes
    final int MAX_TIME = 600 * 1000;
    
    public Vector<TrainingData> trainingData;
    public Vector<TrainingData> unseenTrainingData;

    // the data object used by all individuals
    public DoubleData input;

    float numberOfClasses;

    long startTime;

    static int evaluations = 0;

    public void setup(EvolutionState state, Parameter base) {

        super.setup(state, base);

        // set up the input
        input = (DoubleData) state.parameters.getInstanceForParameterEq(base.push(P_DATA), null, DoubleData.class);
        input.setup(state, base.push(P_DATA));

        // Experiment 2 - Postures, 18 Oct 2007
        //File dataFolder = new File("/home/ooechs/Desktop/jasmine-data/");
        //File training = new File(dataFolder, "postures_training.csv");
        //File testing = new File(dataFolder, "postures_testing.csv");

        // Experiment 1 - Colour Recognition, 6th February 2008
        File training = new File("/home/ooechs/Desktop/jasmine-data/car_colours_training.csv");
        File testing = new File("/home/ooechs/Desktop/jasmine-data/car_colours_testing.csv");
        numberOfClasses = 7;
        FeatureERC.NUM_FEATURES = 10;

        //File training = new File("/home/ooechs/Desktop/jasmine-data/sat-training.ssv");
        //File testing = new File("/home/ooechs/Desktop/jasmine-data/sat-test.ssv");
        //numberOfClasses = 6;
        //FeatureERC.NUM_FEATURES = 36;// // was 36 on original data set
        //JasmineClassificationProblem.SEPARATOR = " ";
        
        // end setup

        try {

            Loader l = new Loader();
            trainingData = l.getTrainingData(training, Loader.TRAINING);
            unseenTrainingData = l.getTrainingData(testing, Loader.TESTING);

        } catch (IOException e) {

            e.printStackTrace();
            System.exit(0);

        }

        startTime = System.currentTimeMillis();

    }

    public void evaluate(final EvolutionState state,
                         final Individual ind,
                         final int subpopulation,
                         final int threadnum)
        {

        evaluations++;

        int hits = 0;
        float mistakes = 0;

        JasmineClassificationProblem.usesImaging = false;

        ProgramClassificationMap map = new ProgramClassificationMap();

        for (int j = 0; j < trainingData.size(); j++) {

            TrainingData trainingShape = trainingData.elementAt(j);

            FeatureERC.values = trainingShape.getFeatures();

            // run the individual
            ((GPIndividual) ind).trees[0].child.eval(state, threadnum, input, stack, ((GPIndividual) ind), this);

            // get the result
            double result = input.x;

            //System.out.println(result + ", " + trainingShape.getClassID());

            map.addResult(result, trainingShape.getClassID());

        }

        map.calculateThresholds();

        hits = map.getHits();

        mistakes = trainingData.size() - hits;

        if (!JasmineClassificationProblem.usesImaging) {
            mistakes = Integer.MAX_VALUE;
        }

        KozaFitness f = ((KozaFitness) ind.fitness);
        f.setStandardizedFitness(state, mistakes);
        f.hits = hits;
        ind.evaluated = true;

        long timeElapsed = System.currentTimeMillis() - startTime;
        if (timeElapsed > MAX_TIME) {
            System.out.println(csv);
            System.exit(0);
        }

    }

    static double lastTestingTP = -1;
    static String csv = null;

    public void describe(Individual ind, EvolutionState state, int threadnum, int i1, int i2) {


        try {
        //KozaFitness f = ((KozaFitness) ind.fitness);

        ProgramClassificationMap map = new ProgramClassificationMap();

        for (int j = 0; j < trainingData.size(); j++) {

            TrainingData trainingShape = trainingData.elementAt(j);

            FeatureERC.values = trainingShape.getFeatures();

            // run the individual
            ((GPIndividual) ind).trees[0].child.eval(state, threadnum, input, stack, ((GPIndividual) ind), this);

            // get the result
            double result = input.x;

            map.addResult(result, trainingShape.getClassID());

        }

        map.calculateThresholds();

        double trainingTP = 0;

        if (unseenTrainingData != null && unseenTrainingData.size() > 0) {

            for (int j = 0; j < trainingData.size(); j++) {

                TrainingData trainingShape = trainingData.elementAt(j);

                FeatureERC.values = trainingShape.getFeatures();

                // run the individual
                ((GPIndividual) ind).trees[0].child.eval(state, threadnum, input, stack, ((GPIndividual) ind), this);

                // get the result
                int result = map.getClassFromOutput(input.x);

                if (result == trainingShape.getClassID()) {
                    trainingTP++;
                }

            }

        }

        double testingTP = 0;

        if (unseenTrainingData != null && unseenTrainingData.size() > 0) {

            for (int j = 0; j < unseenTrainingData.size(); j++) {

                TrainingData testingShape = unseenTrainingData.elementAt(j);

                FeatureERC.values = testingShape.getFeatures();

                // run the individual
                ((GPIndividual) ind).trees[0].child.eval(state, threadnum, input, stack, ((GPIndividual) ind), this);

                // get the result
                int result = map.getClassFromOutput(input.x);

                if (result == testingShape.getClassID()) {
                    testingTP++;
                }

            }

        }


        //if (testingTP != lastTestingTP) {
            double trainingPercentage = (trainingTP / (double) trainingData.size()) * 100;
            //System.out.println("Training %: " + new DecimalFormat("0.0").format(trainingPercentage) + "%");
            double testingPercentage = (testingTP / unseenTrainingData.size()) * 100;
            //System.out.println("Testing %: " + new DecimalFormat("0.0").format(testingPercentage) + "%");
            //System.out.println("Fitness: " + f.fitness());
            //System.out.println("Size: " + ind.size());
            long time = System.currentTimeMillis() - startTime;
            //System.out.println("Time: " + time);

            //String header = "Run, Training %, Testing %, Time (ms), Classifiers, Total Nodes, Evaluations";
            csv = "," + trainingPercentage + "," + testingPercentage + "," + time + ",1," + ind.size() + "," + evaluations;

            lastTestingTP = testingTP;
        //}

        } catch (NullPointerException e) {
            e.printStackTrace();
        }

    }
        
}
