package ac.essex.gp.problems.picking;

import ac.essex.gp.problems.Problem;
import ac.essex.gp.params.ADFNodeParams;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.params.NodeParams;
import ac.essex.gp.Evolve;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.util.DataStack;
import ac.essex.gp.util.JavaWriter;
import ac.essex.gp.tree.Terminal;
import ac.essex.gp.nodes.Return;
import ac.essex.gp.nodes.ADFNode;
import ac.essex.gp.nodes.shape.*;
import ac.essex.gp.nodes.ercs.*;
import ac.essex.gp.nodes.logic.*;
import ac.essex.gp.interfaces.GraphicalInterface;
import ac.essex.ooechs.imaging.commons.apps.shapes.SegmentedShape;
import ac.essex.ooechs.imaging.commons.apps.jasmine.JasmineProject;
import ac.essex.ooechs.imaging.commons.apps.jasmine.JasmineImage;
import ac.essex.ooechs.imaging.commons.apps.jasmine.JasmineClass;

import java.util.Vector;
import java.util.Hashtable;
import java.util.Enumeration;
import java.io.File;
import java.io.PrintStream;
import java.io.FileOutputStream;
import java.io.FileNotFoundException;

/**
 * The idea of this problem is not to generate a single solution, but to find a number of
 * unique, good solutions (weak retainedClassifiers)- the building blocks to solve the larger problem.
 * <p/>
 * Good retainedClassifiers are stored as they are discovered, with additional fitness criteria
 * being based on the uniqueness of the remaining retainedClassifiers. Thus a high performing classifier, once
 * chosen, will have its fitness decrease immediately afterward, allowing the GP problem
 * to discover multiple retainedClassifiers in a single GP simulation.
 * <p/>
 * These retainedClassifiers can be saved as ADFs and then used in a subsequent GP problem.
 *
 * @author Olly Oechsle, University of Essex, Date: 28-Mar-2007
 * @version 1.0
 */
public class ClassifierPickingProblem2 extends Problem {

    /**
     * Store retainedClassifiers here as they are evolved. They will then affect the fitness of the
     * population as no two retainedClassifiers which return the same result are allowed.
     */
    protected Vector<ADFNodeParams> retainedClassifiers;

    /**
     * Keeps track of which exact instances of training data have been solved.
     */
    protected Vector<TrainingClass> trainingClasses;

    /**
     * For every classifier stored, also store its results on the training set. This allows
     * us to compare if a classifier returns the same results or not. Clearly we only
     * want unique retainedClassifiers.
     */
    protected Vector<String> savedClassifierResults;


    /**
     * Classes for which there exists a classifier that can solve this particular class.
     */
    protected Vector<Integer> classesFullySolved;


    /**
     * This particular problem concentrates on shape classification, so the
     * training data consists of shapes. SegmentedShape is a class in the imaging library.
     */
    public Vector<SegmentedShape> trainingData;

    /**
     * Keep a vector of all the different shape classes
     * and how many instances of each exist
     */
    public Hashtable<Integer, Integer> classes;

    /**
     * Allows us to assign a unique ID to each ADF node
     */
    public long idCounter = 0;

    /**
     * Should this individual be written to java file?
     */
    protected boolean writeToFile = false;

    /**
     * The training data is created using Jasmine, and we can use the project file to
     * get all that information out easily.
     */
    JasmineProject project;

    public PrintingBuffer buffer;
    StringBuffer executionOrderBuffer = new StringBuffer();

    protected String className;
    File outputFile;

    protected GraphicalInterface gi;

    public static void main(String[] args) throws Exception {
        //while (true)  {
        // run the GP problem - this is done using the Evolve class.
        // you'll want to change this path
        //JasmineProject project = JasmineProject.load(new File("/home/ooechs/Desktop/JasmineProjects/Alphabet.jasmine"));
        JasmineProject project = JasmineProject.load(new File("/home/ooechs/Desktop/JasmineProjects/ANPR.jasmine"));
        //JasmineProject project = JasmineProject.load(new File("/home/ooechs/Desktop/JasmineProjects/Pasta.jasmine"));
        //JasmineProject project = JasmineProject.load(new File("/home/ooechs/Desktop/JasmineProjects/Skin Lesions.jasmine"));
        //JasmineProject project = JasmineProject.load(new File("/home/ooechs/Desktop/JasmineProjects/OCR.jasmine"));
        GraphicalInterface i = new GraphicalInterface();
        new Evolve(new ClassifierPickingProblem2(project, i , true), i).run();
        //}
    }

    /**
     * Creates a new classifier picking problem. Training data and class information
     * comes from a Jasmine Project file. (Jasmine is part of the imaging library)
     */
    public ClassifierPickingProblem2(JasmineProject project, GraphicalInterface gi, boolean writeToFile) {
        this.writeToFile = writeToFile;
        this.project = project;
        this.gi = gi;
        buffer = new PrintingBuffer(writeToFile);
    }

    /**
     * Provides a name for the problem so that it can be identified via the
     * user interface.
     */
    public String getName() {
        return "Classifier Picking Problem";
    }

    private int distinctClassCount;

    /**
     * Returns how many classes the problem must solve
     */
    public int getClassCount() {
        return distinctClassCount;
    }

    GPParams params;
    Evolve e;

    /**
     * Initialises the problem. This is where the training data is loaded
     * and the GP params object initialised with Nodes to use. The return
     * object should also be set up.
     */
    public void initialise(Evolve e, GPParams params) {

        this.params = params;
        this.e = e;

        // the retainedClassifiers we want to keep
        retainedClassifiers = new Vector<ADFNodeParams>(100);

        // keeps a record of which classes have been partly solved and which are fully solved
        trainingClasses = new Vector<TrainingClass>(100);

        // and the classifier results vectors (easiest way to ensure uniqueness)
        savedClassifierResults = new Vector<String>(100);

        // and which classes have been solved
        classesFullySolved = new Vector<Integer>(100);

        // initialise the training data
        trainingData = new Vector<SegmentedShape>(10);

        // the number of distinct classes there are
        distinctClassCount = 0;

        try {

            // keep a record of every distinct classID and record how many instances of
            // each classID there are.
            classes = new Hashtable<Integer, Integer>(100);

            // look through the images in the project. Each image may be segmented, and some
            // of the segmented shapes may be marked up and accessible - we'll use them as training data.
            for (int i = 0; i < project.getImages().size(); i++) {

                JasmineImage image = project.getImages().elementAt(i);

                if (image.getShapes().size() > 0) {
                    for (int j = 0; j < image.getShapes().size(); j++) {
                        SegmentedShape shape = image.getShapes().elementAt(j);
                        if (shape.pixels.size() >= 50) {

                            trainingData.add(shape);

                            // add a reference to each classID only once
                            Integer classCount = classes.get(shape.classID);

                            if (classCount == null) {
                                // class isn't in the hashtable yet - create a new entry
                                classes.put(shape.classID, 1);
                                distinctClassCount++;
                            } else {
                                // class is already in the hashtable - increment the class count
                                classes.put(shape.classID, classCount + 1);
                            }

                            TrainingClass trainingClass = getTrainingClass(shape.classID);

                            if (trainingClass == null) {
                                trainingClass = new TrainingClass(shape.classID);
                                trainingClasses.add(trainingClass);
                            }

                            trainingClass.addInstance(trainingData.size() - 1);

                        }
                    }
                }

            }

            // halt if there isn't sufficient data
            if (trainingData.size() == 0) {
                e.fatal("No shapes defined - GP cannot proceed without training data.");
            }

            // use the distinct classes to setup the Return ERC. The return ERC is a special
            // ERC that returns classIDs. So that the GP works effectively, the ReturnERC needs
            // to be given all the classIDs so it doesn't return meaningless classes.
            Return.classes = new int[distinctClassCount];
            Enumeration<Integer> classIDs = classes.keys();
            // allow return the "I don't know" option.
            int i = 0;
            while (classIDs.hasMoreElements()) {
                Integer classID = classIDs.nextElement();
                Return.classes[i] = classID;
                i++;
            }

        } catch (Exception err) {
            e.fatal("GP system cannot load Jasmine project.");
            err.printStackTrace();
        }

        // after loading training data, we need to specify which Nodes the classifier is permitted to use.

        // boolean logic
        params.registerNode(new More());
        params.registerNode(new Less());
        params.registerNode(new Equals());
        params.registerNode(new Between());

        params.registerNode(new AND());
        params.registerNode(new OR());
        params.registerNode(new NOT());
        params.registerNode(new BoolERC());

        //params.setAutomaticRangeTypingEnabled(true);

        // range typed ERCs
        if (!params.isAutomaticRangeTypingEnabled()) {
            params.registerNode(new SmallIntERC());
            params.registerNode(new SmallDoubleERC());
            params.registerNode(new TinyDoubleERC());
            params.registerNode(new PercentageERC());
            params.registerNode(new LargeIntERC());
        }

        // shape attributes
        registerNode(new Corners());
        registerNode(new CountHollows());
        registerNode(new BalanceX());
        registerNode(new BalanceY());
        registerNode(new Density());
        registerNode(new AspectRatio());
        registerNode(new Joints());
        registerNode(new Ends());
        registerNode(new Roundness());
        registerNode(new BalanceXEnds());
        registerNode(new BalanceYEnds());
        registerNode(new ClosestEndToCog());
        registerNode(new ClosestPixelToCog());
        registerNode(new HorizontalSymmetry());
        registerNode(new VerticalSymmetry());
        registerNode(new InverseHorizontalSymmetry());
        registerNode(new InverseVerticalSymmetry());//*/

        // as each of these is a binary classifier, we need to specify that boolean values are to be returned
        params.setReturnType(NodeParams.BOOLEAN);



        className = project.getName().replaceAll(" ", "_") + System.currentTimeMillis();

        if (writeToFile) {

            try {

                outputFile = new File("/home/ooechs/ecj-imaging/src/ac/essex/ooechs/imaging/commons/apps/jasmine/results/" + className + ".sxgp");

                System.setOut(new PrintStream(new FileOutputStream(outputFile, true)));

            } catch (FileNotFoundException err) {

                err.printStackTrace();
                System.exit(1);

            }

        }

        buffer.append("package ac.essex.ooechs.imaging.commons.apps.jasmine.results;\n" +
                "\n" +
                "import ac.essex.ooechs.imaging.commons.apps.shapes.SegmentedShape;\n" +
                "\n" +
                "/**\n" +
                " * ANPR Classifier. This program was evolved automatically using SXGP.\n" +
                " *\n" +
                " * @author Olly Oechsle, University of Essex, Date: 16-Feb-2007\n" +
                " * @version 1.0\n" +
                " */\n" +
                "public class " + className + " extends ShapeClassifier {");

    }

    public void customiseParameters(GPParams params) {
        // this enables the range typing so that different functions are tied with certain ERCs
        // this is only applicable to functions which implement the constraints mechanism, which less(), more(), between() and equals() do.
        params.setNodeChildConstraintsEnabled(true);

        // turn off mutation and crossover as they are not compatible with range checking
        params.setPointMutationProbability(0.10);
        params.setCrossoverProbability(0.10);

        // make sure that jittering and erc mutation are likely
        params.setERCjitterProbability(0.25);
        params.setERCmutateProbability(0.25);

        // create low selection pressure so the variability doesn't diminish
        params.setTournamentSizePercentage(0.5);

        params.setBreedSizePercentage(0.5);

        // elites count for nothing in this scenario
        params.setEliteCount(0);

        // keep the retainedClassifiers small
        params.setMaxTreeSize(20);

        // we want to use numeric ERCs but not numeric functions, so ensure there are no errors
        params.setIgnoreNonTerminalWarnings(true);

        // don't take too long
        params.setGenerations(10);

        // use a large initial population size so all different structures may be evaluated.
        params.setPopulationSize(7500);
    }

    /**
     * Returns a training class data structure associated with a given classID.
     */
    public TrainingClass getTrainingClass(int classID) {
        for (int i = 0; i < trainingClasses.size(); i++) {
            TrainingClass c = trainingClasses.elementAt(i);
            if (c.classID == classID) return c;
        }
        return null;
    }


    private int uniqueRangeID = 100;

    public void registerNode(Terminal n) {
        //buffer.append("RANGING FOR TERMINAL: " + n);
        if (params.isAutomaticRangeTypingEnabled()) {
            AutorangeERC erc = new AutorangeERC(params, uniqueRangeID);
            // now go through the training data and initialise the ERC
            DataStack data = new DataStack();
            for (int j = 0; j < trainingClasses.size(); j++) {
                TrainingClass trainingClass = trainingClasses.elementAt(j);
                int previousValue = Integer.MAX_VALUE;
                for (int i = 0; i < trainingData.size(); i++) {
                    data.shape = trainingData.elementAt(i);
                    if (data.shape.classID == trainingClass.classID) {
                        double value = n.execute(data);
                        //buffer.append(i + ", " + project.getShapeClass(data.shape.classID) + ", " + value);
                        erc.addData(n.execute(data));
                    }
                }

                buffer.append("---");
            }

            uniqueRangeID++;
            params.registerNode(erc);
            n.setRangeID(erc.getRangeID());
        }
        params.registerNode(n);
        //System.exit(0);
    }



    /**
     * Evaluates a single individual, fitness should be assigned using the
     * setKozaFitness() method on the individual.
     */
    public void evaluate(Individual ind, DataStack data) {

        boolean classifierWasAdded = false;

        totalIndividuals++;

        // The Evolve class passes the problem a single individual to be evaluated.
        // The individual must be evaluated against each item in the training data.
        // We'll store the results of the evaluation (which is a sequence of true/false decisions)
        // We'll encode this in a string
        String results = "";

        int returnedTrue = 0;
        int returnedFalse = 0;

        // note exactly which classes were returned true, and how many times this occurred per class.
        Hashtable<Integer, Integer> classesReturnedTrue = new Hashtable<Integer, Integer>(trainingData.size());
        Hashtable<Integer, Integer> classesReturnedFalse = new Hashtable<Integer, Integer>(trainingData.size());

        Vector<Integer> indexesReturnedTrue = new Vector<Integer>(50);
        Vector<Integer> indexesReturnedFalse = new Vector<Integer>(50);

        // iterate through each item of training data
        for (int i = 0; i < trainingData.size(); i++) {

            // the data stack allows the individual access to data, so we'll put the shape onto the stack
            data.shape = trainingData.elementAt(i);

            // execute the individual
            boolean result = ind.execute(data) == 1;

            // record the number of TPs;FPs
            if (result) {
                returnedTrue++;
                indexesReturnedTrue.add(i);
                Integer timesReturnedThisClassTrue = classesReturnedTrue.get(data.shape.classID);
                if (timesReturnedThisClassTrue == null) {
                    classesReturnedTrue.put(data.shape.classID, 1);
                } else {
                    classesReturnedTrue.put(data.shape.classID, timesReturnedThisClassTrue + 1);
                }
            } else {
                returnedFalse++;
                indexesReturnedFalse.add(i);
                Integer timesReturnedThisClassFalse = classesReturnedFalse.get(data.shape.classID);
                if (timesReturnedThisClassFalse == null) {
                    classesReturnedFalse.put(data.shape.classID, 1);
                } else {
                    classesReturnedFalse.put(data.shape.classID, timesReturnedThisClassFalse + 1);
                }
            }

            // record the result
            results += result ? "1" : "0";

        }

        // now we've evaluated the individual it is time to look at the results more closely.
        // our first criterion is that the individual be unique - that is it produces a different set of results
        // to any individual that we have already got.

        boolean isUnique = !savedClassifierResults.contains(results);

        if (!isUnique) {
            totalNonUniqueClassifiers++;
        }

        // clearly it needs to be able to discriminate, otherwise it is worthless
        boolean cantDiscriminate = returnedTrue == 0 || returnedFalse == 0;

        if (cantDiscriminate) {
            totalNonDiscriminatingClassifiers++;
        }

        // as well as understanding the classifer's uniqueness, we'd also like to see how
        // many different classes the classifier is able to find.

        Enumeration<Integer> classEnumeration;

        boolean invert = false;
        if (returnedTrue < returnedFalse) {
            classEnumeration = classesReturnedTrue.keys();
        } else {
            classEnumeration = classesReturnedFalse.keys();
            invert = true;
            indexesReturnedTrue = indexesReturnedFalse;
        }

        // since enumerations are difficult to manage (not easy to iterate through them multiple times)
        // we'll convert the data into a vector, called classesIdentified
        Vector<Integer> classesIdentified = new Vector<Integer>(trainingData.size());
        while (classEnumeration.hasMoreElements()) {
            classesIdentified.add(classEnumeration.nextElement());
        }

        double bestFitness = Integer.MAX_VALUE;

        // Iterate through each class, and for each one
        // calculate its fitness as a classifier for that particular class.
        // Assign the individual whatever fitness is the lowest

        if (isUnique && !cantDiscriminate) {

            for (int i = 0; i < classesIdentified.size(); i++) {
                double TP = 0;
                double FP = 0;
                double totalTP;

                int currentClassID = classesIdentified.elementAt(i);

                // Get some data about this class
                TrainingClass c = getTrainingClass(currentClassID);

                // If the class has already been solved we don't need a classifier
                // to solve it again.
                if (c.isFullySolved()) {
                    continue;
                }

                // The total TP is the unsolved instances that this classifier needs
                totalTP = c.getUnsolvedCount();

                // Now iterate through every instance for which this individual returned TRUE (we are finding TP/FP only)
                for (int j = 0; j < indexesReturnedTrue.size(); j++) {

                    // the training dataID that was returned true
                    int trainingDataID = indexesReturnedTrue.elementAt(j);

                    // get the associated classID of this training data
                    int dataClassID = trainingData.elementAt(trainingDataID).classID;

                    // if this is an instance of the current class
                    if (dataClassID == currentClassID) {

                        // this data is an instance of the current class

                        // if it has NOT been solved yet, award a TP
                        if (!c.isSolved(trainingDataID)) {
                            TP++;
                        }

                    } else {

                        // this data is NOT an instance of the current class

                        // if this class has NOT been fully solved, this is a false positive
                        if (!getTrainingClass(dataClassID).fullySolved) {
                            FP++;
                        }

                    }


                }

                // make sure TP is at least one, otherwise the fitness will be zero which is the ideal fitness.
                if (TP > 0) {

                    // calculate the fitness
                    double fitness = TP / (totalTP + FP);

                    // keep the lowest (best) fitness
                    if (fitness < bestFitness) {
                        bestFitness = fitness;
                    }

                }

                // here's the nice part. If this classifier works a little bit, then add it to the classifier set
                // We define that as having at last one TP and no FPs.
                if (FP == 0 && TP > 0) {
                    addClassifier(TP, ind, invert, results, c, indexesReturnedTrue);
                    classifierWasAdded = true;
                    break;
                }

            }

        }

        ind.setKozaFitness(bestFitness);
        ind.setHits(totalInstancesSolved);

        if (classesFullySolved.size() == getClassCount()) {
            e.stopFlag = true;
        }

        if (classifierWasAdded && gi != null)
        gi.onGoodIndividual(ind);


    }

    /**
     * Adds a successful individual to the classifier list.
     */
    public void addClassifier(double TP, Individual ind, boolean invert, String results, TrainingClass c, Vector<Integer> classesReturnedTrue) {

        e.requestFreshPopulation();

        // definitely add this classifier - its a gem
        idCounter++;
        ADFNode node = new ADFNode(idCounter, ind.getTree(), NodeParams.BOOLEAN);
        ADFNodeParams n = node.createNodeParamsObject();

        // save this classifier
        retainedClassifiers.add(n);
        totalClassifiersAdded++;

        // remember this classifier's results
        savedClassifierResults.add(results);

        // register which instances are solved by this classifier
        for (int j = 0; j < classesReturnedTrue.size(); j++) {
            Integer trainingDataIndex = classesReturnedTrue.elementAt(j);
            c.registerInstanceSolved(trainingDataIndex);
        }

        JasmineClass jc = project.getShapeClass(c.classID);

        if (c.fullySolved) {
            // all instances of this class have been solved, so mark the class as completely solved.
            classesFullySolved.add(c.classID);
            buffer.append("// Class " + jc.name + " is fully solved. (" + (getClassCount() - classesFullySolved.size()) + " remaining.)");
        }

        // Makes the code easier to understand / debug
        String comment = "Returns " + (invert ? "false" : "true") + " for classes: ";

        // And state what the objective of the classifier is
        if (c.fullySolved) {
            comment += ", fully identifies: " + jc.name;
        } else {
            comment += ", partly identifies " + TP + " / " + c.getTotalInstances() + " instances of " + jc.name + "  ( now " + c.getPercentageSolved() + "% solved)";
        }

        comment += "\n * " + results;

        // print out the individual, as a java method
        buffer.append(JavaWriter.toJava(node, comment));

        if (!invert) {
            executionOrderBuffer.append("if (method" + idCounter + "()) return " + c.classID + "; // " + jc.name + "\n");
        } else {
            executionOrderBuffer.append("if (!method" + idCounter + "()) return " + c.classID + "; // " + jc.name + "\n");
        }

        calculateTotalInstancesSolved();



    }

    int totalInstancesSolved = 0;
    int totalInstances = 0;

    public void calculateTotalInstancesSolved() {
        totalInstances = 0;
        totalInstancesSolved = 0;
        for (int i = 0; i < trainingClasses.size(); i++) {
            TrainingClass c = trainingClasses.elementAt(i);
            totalInstancesSolved += c.totalSolved;
            totalInstances += c.totalInstances;
        }        
    }

    public void onFinish(Individual bestOfGeneration, Evolve e) {

        buffer.append("    SegmentedShape shape;\n    public int classify(SegmentedShape shape) {\n" +
                "        this.shape = shape;");
        buffer.append("if (new LetterDetector().classify(shape) == LetterDetector.NOT_LETTER) return -1;");
        buffer.append(executionOrderBuffer.toString());
        buffer.append("\treturn -1;\n");
        buffer.append("\t}");

        calculateTotalInstancesSolved();

        buffer.append("// TOTAL INSTANCES SOLVED: " + totalInstancesSolved);
        buffer.append("// of " + totalInstances);

        buffer.append("}");

        if (writeToFile)
            outputFile.renameTo(new File("/home/ooechs/ecj-imaging/src/ac/essex/ooechs/imaging/commons/apps/jasmine/results/" + className + ".java"));

    }

    private int totalIndividuals;
    private int totalClassifiersAdded;
    private int totalNonUniqueClassifiers;
    private int totalNonDiscriminatingClassifiers;

    public Object describe(Individual ind, DataStack data, int index) {
        buffer.append(" // End of generation");
        buffer.append(" // Total individuals: " + totalIndividuals);
        buffer.append(" // Total classifiers added: " + totalClassifiersAdded);
        buffer.append(" // Total non-unique classifiers discarded: " + totalNonUniqueClassifiers);
        buffer.append(" // Total non-discriminating classifiers discarded: " + totalNonDiscriminatingClassifiers);
        totalIndividuals = 0;
        totalClassifiersAdded = 0;
        totalNonUniqueClassifiers = 0;
        totalNonDiscriminatingClassifiers = 0;
        return null;
    }


}
