package ac.essex.ooechs.ecj.jasmine.problems;

import ac.essex.gp.problems.Problem;
import ac.essex.gp.problems.DataStack;
import ac.essex.gp.multiclass.BasicDRS;
import ac.essex.gp.multiclass.BetterDRS;
import ac.essex.gp.multiclass.PCM;
import ac.essex.gp.multiclass.CachedOutput;
import ac.essex.gp.multiclass.thresholding.VarianceThreshold;
import ac.essex.gp.multiclass.thresholding.EntropyThreshold;
import ac.essex.gp.Evolve;
import ac.essex.gp.nodes.*;
import ac.essex.gp.nodes.Mul;
import ac.essex.gp.nodes.Sub;
import ac.essex.gp.nodes.ercs.PercentageERC;
import ac.essex.gp.nodes.ercs.CustomRangeERC;
import ac.essex.gp.nodes.math.Sin;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.interfaces.console.ConsoleListener;
import ac.essex.ooechs.ecj.commons.functions.data.DataValueTerminal;
import ac.essex.ooechs.ecj.commons.data.DoubleData;
import ac.essex.ooechs.ecj.jasmine.nodes.classification.FeatureERC;
import ac.ooechs.classify.data.Data;
import ac.ooechs.classify.data.DataStatistics;
import ac.ooechs.classify.data.io.CSVDataReader;
import ac.ooechs.classify.classifier.gp.ProblemSettings;
import ac.ooechs.classify.classifier.gp.GPMulticlassClassificationProblem;

import java.util.Vector;
import java.io.IOException;
import java.io.File;

import ec.gp.GPProblem;
import ec.gp.GPIndividual;
import ec.gp.koza.KozaFitness;
import ec.simple.SimpleProblemForm;
import ec.util.Parameter;
import ec.EvolutionState;

public class ECJMulticlassClassificationProblem2 extends GPProblem implements SimpleProblemForm {

    protected Vector<Data> trainingData;


    protected int fold;

    protected DataStatistics s;

    public DoubleData input;

    int DRSTYPE = BetterDRS.TYPE;

    public static long start;
    public static long nodeCount;
    public static long evals;

    public static void main(String[] args) {
        System.out.println("ECJ Classification Problem is Ready");
        new ECJMulticlassClassificationProblem2();
        System.out.println("Loaded OK");
    }

    /**
     * No parameters, because ECJ can't handle this
     */
    public ECJMulticlassClassificationProblem2() {

        try {

        // first data set
        //System.out.println("Sat Image Data Set");
        //File trainingFile = new File("/home/ooechs/Desktop/jasmine-data/sat-training.ssv");
        //File testingFile = new File("/home/ooechs/Desktop/jasmine-data/sat-test.ssv");

        //System.out.println("PENDIG Data Set");
        //File trainingFile = new File("/home/ooechs/Desktop/jasmine-data/pendigits-training.csv");
        //File testingFile = new File("/home/ooechs/Desktop/jasmine-data/pendigits-test.csv");

        System.out.println("IONOSPHERE Data Set");
        File trainingFile = new File("/home/ooechs/Data/UCI/Ionosphere/training.csv");
        File testingFile = new File("/home/ooechs/Data/UCI/Ionosphere/training.csv");


        trainingData = new CSVDataReader(trainingFile).getData();
        s = new DataStatistics(trainingData);

        Vector<Data> testingData = new CSVDataReader(testingFile).getData();
        new DataStatistics(testingData);

            start = System.currentTimeMillis();

        } catch (Exception e) {
            throw new RuntimeException("Cannot load");
        }


    }


    public void setup(final EvolutionState state,
                      final Parameter base) {
        // very important, remember this
        super.setup(state, base);

        // set up the input
        input = (DoubleData) state.parameters.getInstanceForParameterEq(base.push(P_DATA), null, DoubleData.class);
        input.setup(state, base.push(P_DATA));

        //p.apply(e);
        FeatureERC.NUM_FEATURES = s.getNumFeatures(); // was 36 on original data set

    }

    public void setFold(int fold) {
        this.fold = fold;
    }

    /*
     * Builds the dynamic classification map
     */
    private PCM buildProgramClassificationMap(final EvolutionState state,
                                              final ec.Individual ind,
                                              final int subpopulation,
                                              final int threadnum) {

        PCM pcm = null;

        switch (DRSTYPE) {
            case BasicDRS.TYPE:
                pcm = new BasicDRS();
                break;
            case BetterDRS.TYPE:
                pcm = new BetterDRS();
                break;
            case VarianceThreshold.TYPE:
                pcm = new VarianceThreshold();
                break;
            case EntropyThreshold.TYPE:
                pcm = new EntropyThreshold();
                break;
        }

        for (int i = 0; i < trainingData.size(); i++) {

            // get the data
            Data d = trainingData.elementAt(i);

            // k fold cross validation addition
            if (d.fold > -1 && d.fold == fold) continue;

            // don't evaluate everything
            if (d.weight == 0) continue;

            // put the data onto the stack
            FeatureERC.values = d.values;

            if (FeatureERC.values == null)  {
                throw new RuntimeException("Values are null. Why?");
            }

            // run the individual
            ((GPIndividual) ind).trees[0].child.eval(state, threadnum, input, stack, ((GPIndividual) ind), this);

            // get the result
            double result = input.x;

            evals++;

            // add the result to the pcm
            pcm.addResult(result, d.classID);

        }

        pcm.calculateThresholds();

        return pcm;

    }

    public void evaluate(final EvolutionState state,
                         final ec.Individual ind,
                         final int subpopulation,
                         final int threadnum) {

        // execute the individual and get a program classification map
        PCM pcm = buildProgramClassificationMap(state, ind, subpopulation, threadnum);

        nodeCount += ind.size();

        

        if (pcm == null) {
            // this means that the individual doesn't use imaging functions. Assign it worst fitness
            KozaFitness f = ((KozaFitness) ind.fitness);
            f.setStandardizedFitness(state, Integer.MAX_VALUE);
            f.hits = 0;
            ind.evaluated = true;
        } else {

            float TP = 0;
            float TN = 0;
            float FP = 0;
            float FN = 0;
            float N = 0;
            int hits = 0;
            int mistakes = 0;


            //double[] classMistakes = new double[20];
            //double[] classCount = new double[20];

            // when evaluating fitness, don't bother evaluating the individual again, use cached results
            Vector<CachedOutput> cachedOutput = pcm.getCachedResults();

            for (int i = 0; i < cachedOutput.size(); i++) {

                CachedOutput output = cachedOutput.elementAt(i);

                N += output.weight;

                // get the individual's answer for this data
                int outputClass = pcm.getClassFromOutput(output.rawOutput);

                if (outputClass == 0) {
                    // returned false
                    if (output.expectedClass == 0) {
                        TN += output.weight;
                    } else {
                        mistakes++;
                        FN += output.weight;
                        //classMistakes[output.expectedClass]+=output.weight;
                    }
                } else {
                    if (outputClass == output.expectedClass) {
                        hits++;
                        TP += output.weight;
                    } else {
                        mistakes++;
                        FP += output.weight;
                        //classMistakes[output.expectedClass]+=output.weight;
                    }
                }

            }

            // compute the fitness
            double fitness = 0;

/*            for (int i = 0; i < classMistakes.length; i++) {
                if (classCount[i] == 0) continue;
                double classFitness = classMistakes[i] / classCount[i];
                fitness += classFitness;
                if (classFitness == 1) {
                    // got nothing right
                    fitness = Double.MAX_VALUE;
                    break;
                }
            }*/

            fitness = (FP + FN) / (double) N;
            // give equal weight to both classes
            //fitness = (FP / (FP + TN)) + (FN / (FN + TP));

            KozaFitness f = ((KozaFitness) ind.fitness);
            f.setStandardizedFitness(state, (float) fitness);
            f.hits = 0;
            ind.evaluated = true;

            // remember the PCM, but clear the cached results to save memory
            pcm.clearCachedResults();
            //ind.setPCM(pcm);

        }

    }
}