package ac.essex.gp.ooechs.novelty2;

import ac.essex.gp.problems.Problem;
import ac.essex.gp.problems.DataStack;
import ac.essex.gp.training.TrainingImage;
import ac.essex.gp.Evolve;
import ac.essex.gp.multiclass.BetterDRS;
import ac.essex.gp.multiclass.PCM;
import ac.essex.gp.multiclass.CachedOutput;
import ac.essex.gp.multiclass.BasicDRS;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.nodes.Add;
import ac.essex.gp.nodes.Sub;
import ac.essex.gp.nodes.imaging.texture.Mean;
import ac.essex.gp.nodes.imaging.texture.Range;
import ac.essex.gp.nodes.imaging.texture.Variance;
import ac.essex.gp.nodes.imaging.features.PerimeterMean;
import ac.essex.gp.nodes.imaging.features.PerimeterStdDev;
import ac.essex.gp.nodes.logic.Less;
import ac.essex.gp.nodes.logic.More;
import ac.essex.gp.nodes.logic.Equals;
import ac.essex.gp.nodes.ercs.BoolERC;
import ac.essex.gp.nodes.ercs.PercentageERC;
import ac.essex.gp.nodes.ercs.SmallIntERC;
import ac.essex.gp.nodes.math.Abs;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.params.NodeConstraints;
import ac.essex.gp.interfaces.graphical.GraphicalListener;
import ac.essex.ooechs.imaging.commons.PixelLoader;

import java.util.Vector;
import java.io.File;

/**
 * <p/>
 * Searches for features that define the true class.
 * </p>
 *
 * @author Olly Oechsle, University of Essex, Date: 16-Apr-2008
 * @version 1.0
 */
public class FeatureFindingProblemDRS extends Problem {

    public static final int TRUE = 1;
    public static final int FALSE = 2;

    protected int x, y;

    protected Vector<TrainingImage> data;

    public static void main(String[] args) {
        FeatureFindingProblemDRS problem = new FeatureFindingProblemDRS();
        problem.load("/home/ooechs/Desktop/birds/faces/mandarin", TRUE);
        problem.load("/home/ooechs/ecj-training/faces/essex/false1/", FALSE);
        Evolve e = new Evolve(problem, new GraphicalListener());
        e.start();
    }

    public FeatureFindingProblemDRS() {
        data = new Vector<TrainingImage>();
    }

    public void load(String directoryPath, int classID) {

        File directory = new File(directoryPath);
        if (directory.exists() && directory.isDirectory()) {

            File[] files = directory.listFiles();

            for (int i = 0; i < files.length; i++) {
                File imageFile = files[i];
                if (imageFile.isDirectory()) continue;
                try {
                    PixelLoader image = new PixelLoader(imageFile);

                    data.add(new TrainingImage(image.getBufferedImage(), classID));

                } catch (Exception err) {
                    // don't worry too much
                }
            }

        }

        System.out.println("Training size: " + data.size());

    }


    public String getName() {
        return "Novel GP Feature Problem (DRS)";
    }

    public void initialise(Evolve evolve, GPParams params) {

        params.registerNode(new Add());
        params.registerNode(new Sub());
        params.registerNode(new Abs());

        params.registerNode(new Mean());
        params.registerNode(new Range());
        params.registerNode(new Variance());

        params.registerNode(new PerimeterMean());
        params.registerNode(new PerimeterStdDev());

        params.registerNode(new PercentageERC(NodeConstraints.PARAMETER));

        // erc for any comparison or mathematical functions
        params.registerNode(new SmallIntERC());

        params.setReturnType(NodeConstraints.NUMBER);

    }

    public void customiseParameters(GPParams gpParams) {
        // Do nothing for the moment
    }

    public void evaluate(Individual individual, DataStack dataStack, Evolve evolve) {

        int totalFalseData = 0;
        int totalTrueData = 0;

        int trueHits = 0;
        int falseHits = 0;

        PCM pcm = new BetterDRS();

        for (int i = 0; i < data.size(); i++) {
            TrainingImage signal = data.elementAt(i);

            dataStack.setImage(signal);
            dataStack.setX(x);
            dataStack.setY(y);

            double result = individual.execute(dataStack);

            pcm.addResult(result, signal.classID);


        }

        pcm.calculateThresholds();

        Vector<CachedOutput> outputcache = pcm.getCachedResults();

        for (int i = 0; i < outputcache.size(); i++) {

            CachedOutput cachedOutput = outputcache.elementAt(i);

            int result = pcm.getClassFromOutput(cachedOutput.rawOutput);

            boolean returnedTrue = result == TRUE;

            if (cachedOutput.expectedClass == TRUE) {
                totalTrueData++;
                if (returnedTrue) trueHits++;
            } else {
                totalFalseData++;
                if (returnedTrue) falseHits++;
            }

        }

        double truePercentage = trueHits / (double) totalTrueData;
        double falsePercentage = falseHits / (double) totalFalseData;
        double x = 1;

        double fitness = x * (1 - truePercentage) + falsePercentage;

        individual.setKozaFitness(fitness);
        individual.setHits(trueHits);
        individual.setMistakes(falseHits);
        individual.setPCM(pcm);

    }

    public void setX(int x) {
        this.x = x;
    }

    public void setY(int y) {
        this.y = y;
    }

}
