package ac.essex.gp.problems;

import ac.essex.ooechs.imaging.commons.apps.jasmine.JasmineProject;
import ac.essex.ooechs.imaging.commons.apps.jasmine.JasmineImage;
import ac.essex.ooechs.imaging.commons.apps.jasmine.JasmineClass;
import ac.essex.ooechs.imaging.commons.apps.shapes.SegmentedShape;
import ac.essex.gp.Evolve;
import ac.essex.gp.interfaces.GraphicalInterface;
import ac.essex.gp.util.DataStack;
import ac.essex.gp.util.JavaWriter;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.nodes.Return;
import ac.essex.gp.nodes.If;
import ac.essex.gp.nodes.shape.*;
import ac.essex.gp.nodes.ercs.*;
import ac.essex.gp.nodes.logic.*;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.params.NodeParams;

import java.util.Vector;
import java.io.File;

/**
 * Aims to discover one particular class which distinguishes that class from all other classes.
 *
 * @author Olly Oechsle, University of Essex, Date: 03-Apr-2007
 * @version 1.0
 */
public class BinaryClassificationProblem extends Problem {

    // the training shapes
    public Vector<SegmentedShape> trainingData;

    // the different classes
    public Vector<JasmineClass> classes;

    // the project that we're using
    protected JasmineProject project;

    // the class we want to classify
    protected int classToClassify;

    // flag for successful executions
    protected boolean success;

    public static void main(String[] args) throws Exception {

        Vector<BinaryClassificationProblem> solutions = new Vector<BinaryClassificationProblem>(10);
        JasmineProject project = JasmineProject.load(new File("/home/ooechs/Desktop/JasmineProjects/Pasta.jasmine"));
        long start = System.currentTimeMillis();
        for (int i = 1; i <= 36; i++) {
            BinaryClassificationProblem p = new BinaryClassificationProblem(project, i);
            Evolve e = new Evolve(p, new GraphicalInterface());
            // run evolve in a non-threaded way
            e.run();
            if (p.success) {
                solutions.add(p);
            }
        }
        long end = System.currentTimeMillis();
        System.out.println("SegmentedShape shape;");
        System.out.println("public int classify(SegmentedShape shape) {");
        System.out.println("\tthis.shape = shape;");
        for (int i = 0; i < solutions.size(); i++) {
            BinaryClassificationProblem problem =  solutions.elementAt(i);
            System.out.println("\tif (find" + problem.className.toUpperCase() + "() != -1) return " + problem.classToClassify + "; // " + problem.className);
        }
        System.out.println("\treturn -1;");
        System.out.println("}");
        System.out.println("// Discovered " + solutions.size() + " classifiers in: " + (end - start) + "ms.");



    }

    public BinaryClassificationProblem(JasmineProject project, int classToClassify) {
        this.classToClassify = classToClassify;
        this.project = project;
    }

    public String getName() {
        return "Binary Classification Problem: " + project.getShapeClass(classToClassify).name;
    }

    public int getClassCount() {
        return 2;
    }

    public void initialise(Evolve e, GPParams params) {

            // initialise the training data
            trainingData = new Vector<SegmentedShape>(10);

            try {

                Vector<Integer> distinctClasses = new Vector<Integer>(100);

                // now get its images
                for (int i = 0; i < project.getImages().size(); i++) {

                    JasmineImage image = project.getImages().elementAt(i);

                    if (image.getShapes().size() > 0) {
                        for (int j = 0; j < image.getShapes().size(); j++) {
                            SegmentedShape shape = image.getShapes().elementAt(j);
                            if (shape.pixels.size() >= 50) {
                                trainingData.add(shape);
                                if (!distinctClasses.contains(shape.classID)) {
                                    distinctClasses.add(shape.classID);
                                }
                            }
                        }
                    }

                }

                if (!distinctClasses.contains(classToClassify)) {
                    JasmineClass c = project.getShapeClass(classToClassify);
                    if (c == null) {
                        e.fatal("No training data for class# " + classToClassify);
                    } else {
                        e.fatal("No training data for class: " + c.name);
                    }
                    return;
                }

                if (trainingData.size() == 0) {
                    e.fatal("No shapes defined - GP cannot proceed without training data.");
                }

                Return.classes = new int[2];
                Return.classes[0] = classToClassify;
                Return.classes[1] = -1;

            } catch (Exception err) {
                e.fatal("GP system cannot load Jasmine project.");
            }

            // register the nodes we want to use
            params.registerNode(new If());
            params.registerNode(new Return());

            // boolean functions
            params.registerNode(new More());
            params.registerNode(new Less());
            params.registerNode(new Equals());
            params.registerNode(new Between());

            // boolean logic
            params.registerNode(new AND());
            params.registerNode(new OR());
            params.registerNode(new NOT());
            params.registerNode(new BoolERC());

            // range typed ERCs
            params.registerNode(new SmallIntERC());
            params.registerNode(new SmallDoubleERC());
            params.registerNode(new TinyDoubleERC());
            params.registerNode(new PercentageERC());

            // shape attributes
            params.registerNode(new CountHollows());
            params.registerNode(new BalanceX());
            params.registerNode(new BalanceY());
            params.registerNode(new Density());
            params.registerNode(new AspectRatio());
            params.registerNode(new Joints());
            params.registerNode(new Ends());

            // set up additional parameters
            params.setReturnType(NodeParams.SUBSTATEMENT);

        }

    int totalPossibleHits = 0;

    public void customiseParameters(GPParams params) {
            params.setPopulationSize(500);
            params.setMaxTreeSize(55);
            params.setNodeChildConstraintsEnabled(true);
            params.setGenerations(500);
            params.setTournamentSizePercentage(0.50);
            params.setPointMutationProbability(0.50);
            params.setEliteCount(1);
            // we want to use numeric ERCs but not numeric functions, so ensure there are no errors
            params.setIgnoreNonTerminalWarnings(true);
    }

    public void evaluate(Individual ind, DataStack data) {

        int mistakes = 0;
        int hits = 0;
        totalPossibleHits = 0;

        Vector<Integer> distinctReturnValues = new Vector<Integer>(10);

        // make sure that cheating retainedClassifiers are noted
        data.usesImaging = false;

        for (int i = 0; i < trainingData.size(); i++) {

            // set up the image on the stack
            data.shape = trainingData.elementAt(i);

            // run the individual
            int result = (int) ind.execute(data);

            if (!distinctReturnValues.contains(result)) {
                distinctReturnValues.add(result);
            }

            if (data.shape.classID == classToClassify) {
                totalPossibleHits++;
                // the shape must return true, if not it will be penalised
                if (result == data.shape.classID) {
                    hits++;
                } else mistakes++;
            } else {
                if (result == -1) {
                    //hits++;
                } else mistakes++;
            }
        }

        if (!data.usesImaging || distinctReturnValues.size() < 2) {
            mistakes = Integer.MAX_VALUE;
        }

        double beta = 5;
        double fitness = 1 - (hits / ((double) (totalPossibleHits + (mistakes * beta))));

        ind.setKozaFitness(fitness);
        
        ind.setHits(hits);
        ind.setMistakes(mistakes);



    }

    protected String className;

    public void onFinish(Individual ind, Evolve e) {
        JasmineClass c = project.getShapeClass(classToClassify);
            className = c.name;
            System.out.println("// Discovers instances of " + c.name);
            System.out.println("// TP:" + ind.getHits() + "/"  + totalPossibleHits + " FP: " + ind.getMistakes() + ", Fitness:  " + ind.getKozaFitness());
            String java = JavaWriter.toJava(ind, "find" + c.name.toUpperCase());
            System.out.println(java);
        if (ind.getMistakes() < ind.getHits()) {
            success = true;
        } else {
            success = false;
            System.out.println("// Cannot find acceptable solution for " + c.name);
        }
    }
}
