package ac.essex.ooechs.facedetection.evolved.gp;

import ac.essex.gp.problems.DataStack;
import ac.essex.gp.Evolve;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.nodes.constants.FixedValueTerminal;
import ac.essex.gp.nodes.ercs.CustomRangeParameterERC;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.params.NodeConstraints;
import ac.essex.gp.interfaces.console.ConsoleListener;
import ac.essex.ooechs.facedetection.util.DataSet;
import ac.essex.ooechs.facedetection.util.IntegralTrainingImage;
import ac.essex.ooechs.facedetection.evolved.gp.nodes.EvolvedTwoRectangleFeature;
import ac.essex.ooechs.facedetection.evolved.gp.nodes.Rectangle;
import ac.essex.ooechs.facedetection.util.AbstractDetectionProblem;

/**
 * <p/>
 * In this experiment we want to see how good GP is at solving
 * a face detection problem using a data set.
 * </p>
 *
 * @author Olly Oechsle, University of Essex, Date: 12-Jun-2008
 * @version 1.0
 */
public class GPFeatureFinderBoolean extends AbstractDetectionProblem {

    protected int FPCounts[];
    protected int totalFP = 0;

    double bestFitness = Double.MAX_VALUE;

    static boolean weighting = false;
    static boolean optimisation = false;

    public static void main(String[] args) {

        DataSet training = new DataSet("/home/ooechs/Data/faces/train/face/", "/home/ooechs/Data/faces/train/non-face");
        DataSet testing = new DataSet("/home/ooechs/Data/faces/test/face/", "/home/ooechs/Data/faces/test/non-face");

        int[] seeds = new int[5];
        for (int i = 0; i < seeds.length; i++) {
            seeds[i] = 2357+i;
        }

        System.out.println("ERC OPTIMISATION IS OFF");

        optimisation = false;

        System.out.println("NO WEIGHTING --------------------------------");
        weighting = false;

        for (int i = 0; i < seeds.length; i++) {
            int seed = seeds[i];
            System.out.println("Running: " + i + ", seed=" + seed + " --------------------------");
            Evolve.seed =seed;
            GPFeatureFinderBoolean p = new GPFeatureFinderBoolean(training, testing);
            new Evolve(p, new ConsoleListener(ConsoleListener.SILENT)).run();
            System.gc();
        }

        System.out.println("WEIGHTING STARTED -------------------------");
        weighting = true;

        for (int i = 0; i < seeds.length; i++) {
            int seed = seeds[i];
            System.out.println("Running: " + i + ", seed=" + seed + " --------------------------");
            Evolve.seed =seed;
            GPFeatureFinderBoolean p = new GPFeatureFinderBoolean(training, testing);
            new Evolve(p, new ConsoleListener(ConsoleListener.SILENT)).run();
            System.gc();
        }

        System.out.println("NOW TURNING ON ERC OPTIMISATION");

        optimisation = true;

        System.out.println("NO WEIGHTING --------------------------------");
        weighting = false;

        for (int i = 0; i < seeds.length; i++) {
            int seed = seeds[i];
            System.out.println("Running: " + i + ", seed=" + seed + " --------------------------");
            Evolve.seed =seed;
            GPFeatureFinderBoolean p = new GPFeatureFinderBoolean(training, testing);
            new Evolve(p, new ConsoleListener(ConsoleListener.SILENT)).run();
            System.gc();
        }

        System.out.println("WEIGHTING STARTED -------------------------");
        weighting = true;

        for (int i = 0; i < seeds.length; i++) {
            int seed = seeds[i];
            System.out.println("Running: " + i + ", seed=" + seed + " --------------------------");
            Evolve.seed =seed;
            GPFeatureFinderBoolean p = new GPFeatureFinderBoolean(training, testing);
            new Evolve(p, new ConsoleListener(ConsoleListener.SILENT)).run();
            System.gc();
        }


            /*
            weighting = true;
            Evolve.seed =2361;
            FeatureEvolutionProblem p = new FeatureEvolutionProblem(training, testing);
            new Evolve(p, new ConsoleListener(ConsoleListener.HIGH_VERBOSITY)).run();
            */


    }

    public GPFeatureFinderBoolean(DataSet training, DataSet testing) {
        super(training, testing);
    }

    public String getName() {
        return "Face Detection Problem - Feature Evolution";
    }

    public void initialise(Evolve evolve, GPParams params) {

        params.registerNode(new EvolvedTwoRectangleFeature());
        params.registerNode(new Rectangle());
        params.registerNode(new FixedValueTerminal(0, NodeConstraints.FEATURE));
        params.registerNode(new CustomRangeParameterERC(1,15));

    }

    public void customiseParameters(GPParams params) {
        params.setGenerations(50);
        params.setERCOptimisationEnabled(optimisation);
    }

    public void onGenerationStart() {
        super.onGenerationStart();
        if (FPCounts != null) {
            // redistribute the weights
            for (int i = 0; i < trainingData.size(); i++) {
                IntegralTrainingImage trainingImage = trainingData.elementAt(i);
                // add the one to avoid divide by zero and ensure that everything has some weight
                int mistakeCount = FPCounts[i] + 1;

                // adapt the training image weight towards the new value
                double newValue = (mistakeCount * trainingData.size()) / (double) totalFP;

                double difference = newValue - trainingImage.weight;

                difference *= 0.33;

                // learning rate
                if (weighting) trainingImage.weight += difference;

                //trainingImage.weight = (mistakeCount * trainingData.size()) / (double) totalFP;
                //System.out.println("Training Image Weight: " + trainingImage.weight);
            }
        }
        FPCounts = new int[trainingData.size()];
        totalFP = 0;
    }

    public void evaluate(Individual individual, DataStack dataStack, Evolve evolve) {

        // parameters for calculating fitness
        double totalTrueWeight = 0;
        double totalFalseWeight = 0;
        double trueWeight = 0;
        double falseWeight = 0;

        int totalTrue = 0;
        int totalFalse = 0;
        int TP = 0, TN = 0;
        int FP = 0;

        int hits = 0;
        int mistakes = 0;

        dataStack.usesImaging = false;

        int lastResult = -1;

        boolean goodClassifier = true;

        int[] cachedResults = new int[trainingData.size()];
        for (int i = 0; i < trainingData.size(); i++) {

            IntegralTrainingImage trainingImage = trainingData.elementAt(i);

            dataStack.setIntegralImage(trainingImage.getImage());

            double raw = individual.execute(dataStack);

            if (!dataStack.usesImaging ) {
                // Important - do not include the mistakes of stupid individuals
                // otherwise the definition of difficulty and therefore the weights
                // of training data are inaccurate.
                individual.setWorstFitness();
                return;
            }


            int result = raw > 0 ? DataSet.TRUE: DataSet.FALSE;

            cachedResults[i] = result;

            if (i > 0) {
                if (result != lastResult) {
                    goodClassifier = true;
                }
            }

            lastResult = result;

        }

        if (!goodClassifier) {
            individual.setWorstFitness();
            return;
        }

        // go through all the training data
        for (int i = 0; i < trainingData.size(); i++) {

            IntegralTrainingImage trainingImage = trainingData.elementAt(i);



            int result = cachedResults[i];

            boolean returnedTrue = result == 1;

            if (result == trainingImage.classID) {
                hits++;
            } else {
                mistakes++;
                FPCounts[i]++;
                totalFP++;
            }

           if (trainingImage.classID == 1) {
                totalTrueWeight+=trainingImage.weight;
                totalTrue++;
                if (returnedTrue) {
                    trueWeight += trainingImage.weight;
                    TP++;
                }
            } else {
                totalFalseWeight+= trainingImage.weight;
                totalFalse++;
                if (returnedTrue) {
                    falseWeight += trainingImage.weight;
                    FP++;
                } else {
                    TN++;
                }
            }


        }



        double x = 1;

        double truePercentage = trueWeight / totalTrueWeight;
        double falsePercentage = falseWeight / totalFalseWeight;
        double weightedFitness = x * (1 - truePercentage) + falsePercentage;
        //double weightedFitness = falseWeight / totalFalseWeight;

 /*       double unweightedTruePercentage = TP / (double) totalTrue;
        double unweightedFalsePercentage = FP / (double) totalFalse;
        double unweightedFitness = x * (1 - unweightedTruePercentage) + unweightedFalsePercentage;
*/
        individual.setKozaFitness(weightedFitness);
        individual.setAlternativeFitness(mistakes);

/*        if (individual.getKozaFitness() < bestFitness) {
            bestFitness = individual.getKozaFitness();
            bestIndividual = individual;
        }*/

        individual.setHits(hits);
        individual.setMistakes(mistakes);

    }

}
