package ac.essex.ooechs.facedetection.learners.gp;

import ac.essex.gp.problems.Problem;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.params.NodeParams;
import ac.essex.gp.Evolve;
import ac.essex.gp.interfaces.GraphicalListener;
import ac.essex.gp.interfaces.graphical.GraphicalListener;
import ac.essex.gp.nodes.haar.*;
import ac.essex.gp.nodes.ercs.LargeIntERC;
import ac.essex.gp.nodes.ercs.BoolERC;
import ac.essex.gp.nodes.imaging.haar.OneRectFeature;
import ac.essex.gp.training.haar.HaarTrainingSet;
import ac.essex.gp.util.DataStack;
import ac.essex.gp.individuals.Individual;
import ac.essex.ooechs.imaging.commons.HaarRegions;
import ac.essex.ooechs.imaging.commons.PixelLoader;
import ac.essex.ooechs.imaging.commons.util.ImageFilenameFilter;

import java.io.File;
import java.util.Vector;

/**
 * Attempts to solve the Face Detector Problem using GP.
 *
 * @author Olly Oechsle, University of Essex, Date: 15-May-2007
 * @version 1.0
 */
public class GPLearner extends Problem {

    // window has these dimensions
    public static final int WINDOW_WIDTH = 32;
    public static final int WINDOW_HEIGHT = 40;

    // window is divided into blocks
    public static final int WINDOWBLOCKSX = 16;
    public static final int WINDOWBLOCKSY = 20;

    public static final int FACE = 1;
    public static final int NOT_FACE = 2;

    public Vector<HaarTrainingSet> trainingData;

    public static void main(String[] args) {
        new Evolve(new GPLearner(), new GraphicalListener()).start();
    }

    public void initialise(Evolve e, GPParams params) {

        trainingData = new Vector<HaarTrainingSet>();
        //final File trueDirectory1 = new File("/home/ooechs/ecj-training/faces/essex/mit/test/scaled16x20/");
        final File trueDirectory2 = new File("/home/ooechs/ecj-training/faces/essex/mit/test/scaled32x40/");
        //final File trueDirectory3 = new File("/home/ooechs/ecj-training/faces/essex/mit/test/scaled48x60/");

        final File falseDirectory1 = new File("/home/ooechs/ecj-training/faces/essex/mit/false2/");
        final File falseDirectory2 = new File("/home/ooechs/ecj-training/faces/essex/false1/");
        //addImageFiles(trueDirectory1, trainingData, FACE);
        addImageFiles(trueDirectory2, trainingData, FACE);
        //addImageFiles(trueDirectory3, trainingData, FACE);

        // false training data
        addImageFiles(falseDirectory1, trainingData, NOT_FACE);
        addImageFiles(falseDirectory2, trainingData, NOT_FACE);

        // Choose which nodes to use in GP
/*        params.registerNode(new Add());
        params.registerNode(new Mul());
        params.registerNode(new Sub());*/

        params.registerNode(new OneRectFeature());
        //params.registerNode(new TwoRectangleFeature());
        //params.registerNode(new ThreeRectangleFeature());
        params.registerNode(new AdjacencyERC());
        params.registerNode(new ShapeERC());
        params.registerNode(new WidthERC());
        params.registerNode(new HeightERC());
        params.registerNode(new XERC());
        params.registerNode(new YERC());

        params.registerNode(new LargeIntERC());
        //params.registerNode(new TinyDoubleERC());
        //params.registerNode(new SmallIntERC());
        params.registerNode(new BoolERC());

        params.registerNode(new ClassifierNode());

        params.setReturnType(NodeParams.BOOLEAN);
                
        params.setIgnoreTerminalWarnings(true);

    }

    protected int addImageFiles(File directory, Vector<HaarTrainingSet> trainingData, int classID) {
        if (!directory.exists()) throw new RuntimeException("Directory does not exist: " + directory.getAbsolutePath());
        int counter = 0;
        try {
            File[] trueFiles = directory.listFiles();

            for (int i = 0; i < trueFiles.length; i++) {
                File f = trueFiles[i];
                if (ImageFilenameFilter.isImage(f)) {
                    trainingData.add(new HaarTrainingSet(new HaarRegions(new PixelLoader(f)), classID));
                    counter++;
                }
            }
            System.out.println("\nLoaded: " + counter + " images.");

        } catch (Exception e) {
            e.printStackTrace();
        }
        return counter;
    }

    public void customiseParameters(GPParams params) {
        params.setMaxTreeSize(15);
    }

    public String getName() {
        return "Face Detection Problem";
    }

    public int getClassCount() {
        return 2;
    }

    public void evaluate(Individual ind, DataStack data, Evolve e) {

        double TP = 0;
        double FP = 0;
        double TN = 0;
        double FN = 0;
        double totalTP = 0;

        int returnedTrue = 0;
        int returnedFalse = 0;

        

        for (int i = 0; i < trainingData.size(); i++) {
            HaarTrainingSet haarTrainingSet = trainingData.elementAt(i);

            data.haar = haarTrainingSet.image;

            if (haarTrainingSet.classID == FACE) {
                data.haar.makeWindowFillImage(WINDOWBLOCKSX, WINDOWBLOCKSY);
            } else {
                int x = (int) (Math.random() * (data.haar.getWidth() - WINDOW_WIDTH));
                int y = (int) (Math.random() * (data.haar.getHeight() - WINDOW_HEIGHT));
                data.haar.setWindowPosition(x, y, WINDOW_WIDTH, WINDOW_HEIGHT, WINDOWBLOCKSX, WINDOWBLOCKSY);
            }

            int classID = (int) ind.execute(data);

            if (haarTrainingSet.classID == FACE) {
                totalTP++;
                if (classID == haarTrainingSet.classID) {
                    // if it is a face and classifier was correct
                    TP++;
                    returnedTrue++;
                } else {
                    // if it is a face and classifier is incorrect
                    FN++;
                    returnedFalse++;
                }
            } else {
                if (classID == haarTrainingSet.classID) {
                    // if it isn't a face and classifier is correct
                    TN++;
                    returnedFalse++;
                } else {
                    // if it isn't a face and classifier is incorrect.
                    FP++;
                    returnedTrue++;
                }
            }
        }

        if (returnedFalse == 0 || returnedTrue == 0) {
            // it has to get at least one of each type correct.
            ind.setWorstFitness();
        } else {
            double alpha = 1;
            double beta = 3;
            double fitness = 1 / ((TP * alpha) / (totalTP + (FP * beta)));
            ind.setKozaFitness(fitness);
            ind.setHits((int) TP);
            ind.setMistakes((int) FP);
        }

    }


}
