package ac.essex.gp.problems;

import ac.essex.gp.individuals.Individual;
import ac.essex.gp.nodes.*;
import ac.essex.gp.nodes.logic.*;
import ac.essex.gp.nodes.ercs.*;
import ac.essex.gp.nodes.shape.*;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.params.NodeParams;
import ac.essex.gp.params.ADFNodeParams;
import ac.essex.gp.util.*;
import ac.essex.gp.util.ClassResults;
import ac.essex.gp.Evolve;
import ac.essex.gp.interfaces.GraphicalInterface;
import ac.essex.gp.problems.coevolve.TestResults;
import ac.essex.ooechs.imaging.commons.apps.jasmine.JasmineImage;
import ac.essex.ooechs.imaging.commons.apps.jasmine.JasmineProject;
import ac.essex.ooechs.imaging.commons.apps.jasmine.JasmineClass;
import ac.essex.ooechs.imaging.commons.apps.shapes.SegmentedShape;

import java.io.File;
import java.util.Vector;

/**
 * This problem solves the classification of shapes according to their
 * properties and is a natural follow on from a successful segmentation
 * algorithm. Jasmine may be used to run a segmenter and generate shapes
 * which can subsequently be marked up to be run by this method.
 *
 * @author Olly Oechsle, University of Essex, Date: 06-Feb-2007
 * @version 1.0
 */
public class ShapeClassificationProblem extends Problem {

    final boolean coevolutionProblem = true;

    public static void main(String[] args) throws Exception {
        JasmineProject project = JasmineProject.load(new File("/home/ooechs/Desktop/JasmineProjects/Numberplates.jasmine"));
        new Evolve(new ShapeClassificationProblem(project), new GraphicalInterface()).start();        
    }

    public String getName() {
        return "Shape Classification Problem";
    }

    // the training shapes
    public Vector<SegmentedShape> trainingData;

    // the different classes
    public Vector<Integer> distinctClasses;

    // the project that we're using
    JasmineProject project;

    public ShapeClassificationProblem(File jasmineFile) throws Exception {
        // load the jasmine project
        project = JasmineProject.load(jasmineFile);
    }

    public ShapeClassificationProblem(JasmineProject p) {
        project = p;
    }

    public void initialise(Evolve e, GPParams params) {

        // initialise the training data
        trainingData = new Vector<SegmentedShape>(10);

        try {

            distinctClasses = new Vector<Integer>(100);

            // now get its images
            for (int i = 0; i < project.getImages().size(); i++) {

                JasmineImage image = project.getImages().elementAt(i);

                if (image.getShapes().size() > 0) {
                    for (int j = 0; j < image.getShapes().size(); j++) {
                        SegmentedShape shape = image.getShapes().elementAt(j);
                        if (shape.pixels.size() >= 50) {
                            trainingData.add(shape);
                            if (!distinctClasses.contains(shape.classID)) {
                                distinctClasses.add(shape.classID);
                            }
                        }
                    }
                }

            }

            if (trainingData.size() == 0) {
                e.fatal("No shapes defined - GP cannot proceed without training data.");
            }

            // set up the classes that return is allowed to ... return
            /*Return.classes = new int[project.getShapeClasses().size()];
            for (int i = 0; i < project.getShapeClasses().size(); i++) {
                Return.classes[i] = project.getShapeClasses().elementAt(i).classID;
            }*/

            Return.classes = new int[distinctClasses.size()];
            for (int i = 0; i < distinctClasses.size(); i++) {
                Return.classes[i] = distinctClasses.elementAt(i);
            }

            //System.out.println("DISTINCT CLASSES: " + distinctClasses.size());
            //System.out.println("RETURN CLASSES: " + Return.classes.length);

        } catch (Exception err) {
            e.fatal("GP system cannot load Jasmine project.");
        }

        // register the nodes we want to use
        params.registerNode(new If());
        params.registerNode(new Return());

        // boolean functions
        params.registerNode(new More());
        params.registerNode(new Less());
        params.registerNode(new Equals());
        params.registerNode(new Between());

        // boolean logic
        params.registerNode(new AND());
        params.registerNode(new OR());
        params.registerNode(new NOT());          

        // range typed ERCs
        params.registerNode(new SmallIntERC());
        params.registerNode(new SmallDoubleERC());
        params.registerNode(new TinyDoubleERC());
        params.registerNode(new PercentageERC());
        params.registerNode(new BoolERC());

        // shape attributes
        params.registerNode(new Corners());
        params.registerNode(new CountHollows());
        params.registerNode(new BalanceX());
        params.registerNode(new BalanceY());
        params.registerNode(new Density());
        params.registerNode(new AspectRatio());
        params.registerNode(new Joints());
        params.registerNode(new Ends());
        params.registerNode(new Roundness());
        params.registerNode(new BalanceXEnds());
        params.registerNode(new BalanceYEnds());
        params.registerNode(new ClosestEndToCog());
        params.registerNode(new ClosestPixelToCog());
        params.registerNode(new HorizontalSymmetry());
        params.registerNode(new VerticalSymmetry());
        params.registerNode(new InverseHorizontalSymmetry());
        params.registerNode(new InverseVerticalSymmetry());

        // we have numeric terminals but no numeric functions, so we need to enable this setting
        // otherwise the treebuilder will get upset.
        params.setIgnoreNonTerminalWarnings(true);

        // set up additional parameters
        params.setReturnType(NodeParams.SUBSTATEMENT);

    }

    public void customiseParameters(GPParams params) {
        params.setGenerations(100000);
        params.setPopulationSize(500);
        params.setMaxTreeSize(30);
        params.setTournamentSizePercentage(0.15);
    }

    public int getClassCount() {
        return Return.classes.length;
    }

    public TestResults testClassifier(ADFNodeParams classifier) {

        // instantitate the node
        ADFNode c = classifier.getInstance();

        // create a test results object
        TestResults results = new TestResults(trainingData.size());

        // loop through the training data
        for (int i = 0; i < trainingData.size(); i++) {
            DataStack data = new DataStack();
            data.shape = trainingData.elementAt(i);
            int result = (int) c.execute(data);
            results.setResult(i, result == 1, data.shape.classID);
        }

        // bind results to classifier
        classifier.setTestResults(results);

        // return the results
        return results;

    }

    public void evaluate(Individual ind, DataStack data) {

        int totalClassesHit = 0;

        int mistakes = 0;
        int hits = 0;

        boolean[] classHits = new boolean[project.getPixelClasses().size() + project.getShapeClasses().size() + 2];

        for (int i = 0; i < trainingData.size(); i++) {

            // set up the image on the stack
            data.shape = trainingData.elementAt(i);

            // runt the individual
            int result = (int) ind.execute(data);

            // wrong answers are penalised
            if (result != data.shape.classID) {
                mistakes++;
            } else {
                // record that this class was hit
                if (result > 0 && result < classHits.length) {
                    if (!classHits[result]) totalClassesHit++;
                    classHits[result] = true;
                }
                hits++;
            }
        }

        // find out how many classes were NOT hit
        int classesNotHit = distinctClasses.size() - totalClassesHit;

        // for each class not hit give a penalty of 100. This number is higher than
        // all the shapes in the set, so hopefully removes the problem of local maxima
        // impeding further progress
        mistakes += (classesNotHit * 100);

        if (!data.usesImaging) {
            mistakes = Integer.MAX_VALUE;
        }

        ind.setKozaFitness(mistakes);
        ind.setHits(hits);

    }


    public ClassResults describe(Individual ind, DataStack data, int index) {

        ClassResults results = new ac.essex.gp.util.ClassResults();

        int hits = 0;

        for (int i = 0; i < trainingData.size(); i++) {

            // set up the image on the stack
            data.shape = trainingData.elementAt(i);

            if (results.getClassResult(data.shape.classID) == null) {
                JasmineClass c = project.getShapeClass(data.shape.classID);
                results.addClass(c.name, c.classID);
            }

            // run the individual
            int result = (int) ind.execute(data);

            if (result != data.shape.classID) {
                results.addMiss(data.shape.classID);
            } else {
                results.addHit(data.shape.classID);
                hits++;
            }

        }

        if (hits != ind.getHits()) {
            System.err.println("// Wrong hits value: Should be " + hits + " but is " + ind.getHits());
        }

        return results;

    }


}
