package ac.essex.ooechs.facedetection.evolved.gp;

import ac.essex.gp.problems.DataStack;
import ac.essex.gp.Evolve;
import ac.essex.gp.interfaces.console.ConsoleListener;
import ac.essex.gp.multiclass.PCM;
import ac.essex.gp.multiclass.BetterDRS;
import ac.essex.gp.multiclass.CachedOutput;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.nodes.ercs.CustomRangeParameterERC;
import ac.essex.gp.nodes.ercs.SmallIntERC;
import ac.essex.gp.nodes.ercs.LargeIntERC;
import ac.essex.gp.nodes.ercs.PercentageERC;
import ac.essex.gp.nodes.constants.FixedValueTerminal;
import ac.essex.gp.nodes.*;
import ac.essex.gp.nodes.math.Pow;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.params.NodeConstraints;
import ac.essex.ooechs.facedetection.util.DataSet;
import ac.essex.ooechs.facedetection.util.IntegralTrainingImage;
import ac.essex.ooechs.facedetection.evolved.gp.nodes.EvolvedTwoRectangleFeature;
import ac.essex.ooechs.facedetection.evolved.gp.nodes.Rectangle;
import ac.essex.ooechs.facedetection.util.AbstractDetectionProblem;

import java.util.Vector;

/**
 * <p/>
 * In this experiment we want to see how good GP is at solving
 * a face detection problem using a data set.
 * </p>
 *
 * @author Olly Oechsle, University of Essex, Date: 12-Jun-2008
 * @version 1.0
 */
public class GPFeatureFinderDRS extends AbstractDetectionProblem {

    protected int FPCounts[];
    protected int totalFP = 0;

    double bestFitness = Double.MAX_VALUE;

    static boolean weighting = false;
    static boolean optimisation = false;

    public static void main(String[] args) {
        DataSet training = new DataSet("/home/ooechs/Desktop/pipe-images2/foreground", "/home/ooechs/Desktop/pipe-images2/background");
        weighting = true;
        Evolve.seed =2361;
        GPFeatureFinderDRS p = new GPFeatureFinderDRS(training, null);
        new Evolve(p, new ConsoleListener(ConsoleListener.LOW_VERBOSITY)).run();
    }

    public GPFeatureFinderDRS(DataSet training, DataSet testing) {
        super(training, testing);
    }

    public String getName() {
        return "Face Detection Problem - Feature Evolution";
    }

    public void initialise(Evolve evolve, GPParams params) {

        params.registerNode(new IF_FP());

        params.registerNode(new Add());
        params.registerNode(new Sub());
        params.registerNode(new Mul());
        params.registerNode(new Div());

        params.registerNode(new Pow());
        params.registerNode(new Min());
        params.registerNode(new Max());

        params.registerNode(new SmallIntERC());
        params.registerNode(new LargeIntERC());
        params.registerNode(new PercentageERC());


        params.registerNode(new EvolvedTwoRectangleFeature());
        params.registerNode(new Rectangle());
        params.registerNode(new FixedValueTerminal(0, NodeConstraints.FEATURE));
        params.registerNode(new CustomRangeParameterERC(1,15));


        params.setReturnType(NodeConstraints.NUMBER);

    }

    public void customiseParameters(GPParams params) {
        params.setGenerations(50);
        params.setERCOptimisationEnabled(optimisation);
    }

    public void onGenerationStart() {
        super.onGenerationStart();
        if (FPCounts != null) {
            // redistribute the weights
            for (int i = 0; i < trainingData.size(); i++) {
                IntegralTrainingImage trainingImage = trainingData.elementAt(i);
                // add the one to avoid divide by zero and ensure that everything has some weight
                int mistakeCount = FPCounts[i] + 1;

                // adapt the training image weight towards the new value
                double newValue = (mistakeCount * trainingData.size()) / (double) totalFP;

                double difference = newValue - trainingImage.weight;

                difference *= 0.33;

                // learning rate
                if (weighting) trainingImage.weight += difference;

                //trainingImage.weight = (mistakeCount * trainingData.size()) / (double) totalFP;
                //System.out.println("Training Image Weight: " + trainingImage.weight);
            }
        }
        FPCounts = new int[trainingData.size()];
        totalFP = 0;
    }

    public void evaluate(Individual individual, DataStack dataStack, Evolve evolve) {

        // parameters for calculating fitness
        double totalTrueWeight = 0;
        double totalFalseWeight = 0;
        double trueWeight = 0;
        double falseWeight = 0;

        int totalTrue = 0;
        int totalFalse = 0;
        int TP = 0;
        int FP = 0;

        int hits = 0;
        int mistakes = 0;

        PCM pcm = new BetterDRS();

        dataStack.usesImaging = false;

        // go through all the training data
        for (int i = 0; i < trainingData.size(); i++) {

            IntegralTrainingImage trainingImage = trainingData.elementAt(i);

            dataStack.setIntegralImage(trainingImage.getImage());

            double result = individual.execute(dataStack);


            if (!dataStack.usesImaging ) {
                // Important - do not include the mistakes of stupid individuals
                // otherwise the definition of difficulty and therefore the weights
                // of training data are inaccurate.
                individual.setWorstFitness();
                return;
            }

            pcm.addResult(result, trainingImage.classID);

        }

        // calculate the PCM thresholds
        pcm.calculateThresholds();

        // figure out the fitness
        Vector<CachedOutput> outputcache = pcm.getCachedResults();

        int lastResult = 0;
        boolean goodClassifier = false;
        for (int i = 0; i < outputcache.size(); i++) {
            CachedOutput cachedOutput = outputcache.elementAt(i);
            int result = pcm.getClassFromOutput(cachedOutput.rawOutput);
            if (lastResult != result) {
                goodClassifier = true;
                break;
            }
        }

        if (!goodClassifier) {
            individual.setWorstFitness();
            return;
        }

        for (int i = 0; i < outputcache.size(); i++) {

            CachedOutput cachedOutput = outputcache.elementAt(i);

            int result = pcm.getClassFromOutput(cachedOutput.rawOutput);

            IntegralTrainingImage trainingImage = trainingData.elementAt(i);

            boolean returnedTrue = result == 1;

            if (result == cachedOutput.expectedClass) {
                hits++;
            } else {
                mistakes++;
                    FPCounts[i]++;
                    totalFP++;
            }

           if (cachedOutput.expectedClass == 1) {
                totalTrueWeight+=trainingImage.weight;
                totalTrue++;
                if (returnedTrue) {
                    trueWeight += trainingImage.weight;
                    TP++;
                }
            } else {
                totalFalseWeight+= trainingImage.weight;
                totalFalse++;
                if (returnedTrue) {
                    falseWeight += trainingImage.weight;
                    FP++;
                }
            }


        }

        // make sure the individual doesn't hog too much memory
        pcm.clearCachedResults();

        double x = 1;

        double truePercentage = trueWeight / totalTrueWeight;
        double falsePercentage = falseWeight / totalFalseWeight;
        double weightedFitness = x * (1 - truePercentage) + falsePercentage;
        //double weightedFitness = falseWeight / totalFalseWeight;

 /*       double unweightedTruePercentage = TP / (double) totalTrue;
        double unweightedFalsePercentage = FP / (double) totalFalse;
        double unweightedFitness = x * (1 - unweightedTruePercentage) + unweightedFalsePercentage;
*/
        individual.setKozaFitness(weightedFitness);
        individual.setAlternativeFitness(mistakes);
/*
        if (individual.getKozaFitness() < bestFitness) {
            bestFitness = individual.getKozaFitness();
            bestIndividual = individual;
        }*/

        individual.setHits(hits);
        individual.setMistakes(mistakes);
        individual.setPCM(pcm);

    }

}
