package ac.ooechs.classify;

import ac.ooechs.classify.data.io.DataReader;
import ac.ooechs.classify.data.io.CSVDataReader;
import ac.ooechs.classify.data.Data;
import ac.ooechs.classify.data.DataStatistics;
import ac.ooechs.classify.data.DataNormaliser;
import ac.ooechs.classify.classifier.knn.NearestNeighbourClassifier;
import ac.ooechs.classify.evaluation.ClassifierTestResults;
import ac.ooechs.classify.evaluation.ClassifierTester;
import ac.ooechs.classify.evaluation.ClassConfusion;
import ac.ooechs.classify.classifier.gp.GPClassificationProblem;
import ac.ooechs.classify.classifier.gp.ProblemSettings;
import ac.ooechs.classify.classifier.gp.BoostedDetectionProblem;
import ac.ooechs.classify.classifier.gp.GPOneClassClassificationProblem;
import ac.ooechs.classify.classifier.GPClassifier;
import ac.ooechs.classify.classifier.MulticlassGPClassifier;
import ac.ooechs.classify.classifier.Classifier;
import ac.ooechs.classify.classifier.AdaboostClassifier;
import ac.essex.gp.Evolve;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.interfaces.console.ConsoleListener;

import java.util.Vector;
import java.util.Collections;
import java.io.IOException;
import java.io.File;

/**
 * Intelligent classification system
 *
 * @author Olly Oechsle, University of Essex, Date: 01-Aug-2008
 * @version 1.0
 */
public class SuperClassifier2 {

    // give the GP programs nearly an hour to complete their dirty business
    int GP_RUNTIME = 150;
    // have 10 tries
    int GP_RUNS = 1;

    protected Vector<Data> trainingData, testingData;

    protected DataStatistics trainingStatistics, testingStatistics;

    protected DataNormaliser normaliser;

    public void generateClassifier(DataReader training, DataReader testing) throws IOException {

        DataStatistics.classIDMapping.reset();

        // Load the training and testBinary data.
        trainingData = training.getData();
        testingData = testing.getData();

        // Establish the number of classes and class frequencies
        trainingStatistics = new DataStatistics(trainingData);
        testingStatistics = new DataStatistics(testingData);

        // Search for additional features in the data using Principle Components Analysis
        // Not implemented yet

        // Generate a confusion matrix of the data using K-means clustering
        NearestNeighbourClassifier knn = new NearestNeighbourClassifier(trainingData);
        Vector<ClassifierTestResults> testResults = new Vector<ClassifierTestResults>(trainingStatistics.getClassCount());

        // Go through each class and get the results
        for (int i = 0; i < trainingStatistics.getClassIDs().size(); i++) {
            int classID = trainingStatistics.getClassIDs().elementAt(i);
            testResults.add(ClassifierTester.testBinarySingleClass(knn, testingData, classID));
        }

        // Normalise the data, do this after KNN because it makes it slightly worse in some cases.
        normaliser = new DataNormaliser(trainingData);
        normaliser.normalise(trainingData);
        normaliser.normalise(testingData);

        // Now, decide which classes to evolve first, and how to evolve them
        // Start by ordering them by how easy they are to solve by k means
        Collections.sort(testResults);

        // Start the classifier
        MulticlassGPClassifier multiclassClassifier = new MulticlassGPClassifier(trainingStatistics.getClassCount());

        Classifier classifier = learnInSingleRun(4);
        ClassifierTestResults r = ClassifierTester.testBinary(trainingStatistics.getClassIDs(), classifier, trainingData, 4);
        r.printConfusionMatrix();


/*        for (int i = 0; i < testResults.size(); i++) {
            ClassifierTestResults classifierTestResults = testResults.elementAt(i);
            System.out.println("KNN Testing Results: " + classifierTestResults);
            if (i < testResults.size() - 1) {
                // all but the hardest class to solve are learned by GP
                GPClassifier classifier = learnInSingleRun(classifierTestResults.getClassID());
                System.out.println("GP Training Results: " + ClassifierTester.testBinary(classifier, trainingData, classifierTestResults.getClassID()));
                System.out.println("GP Testing Results: " + ClassifierTester.testBinary(classifier, testingData, classifierTestResults.getClassID()));
                multiclassClassifier.set(classifier, i);

            } else {
                // the hardest class, just set this as the default
                multiclassClassifier.defaultClass = classifierTestResults.getClassID();
            }
            System.out.println("Super Classifier Training: " + test(multiclassClassifier, trainingData));
            System.out.println("Super Classifier Testing: " + test(multiclassClassifier, testingData));
            System.out.println("----");
        }
        multiclassClassifier.save(new File("/home/ooechs/Desktop/superlclassifier.solution"));
        System.out.println("Finished at: " + new Date().toString());
        System.out.println("Done.");*/
    }

    public float test(Classifier c, Vector<Data> data) {
        int N = 0;
        int TP = 0;
        for (int i = 0; i < data.size(); i++) {
            Data d = data.elementAt(i);
            N++;
            if (c.classify(d) == d.classID) {
                TP++;
            }
        }
        return TP / (float) N;
    }


    public GPClassifier learnInSingleRun(int classID) {
        Individual best = null;
        for (int i = 0; i < GP_RUNS; i++) {
            System.out.println("RUN: " + i);
            int seed = 2357 + i;
            ProblemSettings settings = new ProblemSettings(GP_RUNTIME, seed, 7);
            GPOneClassClassificationProblem p = new GPOneClassClassificationProblem(classID, settings, trainingData);
            //GPClassificationProblem p = new GPClassificationProblem(classID, -1, settings, trainingData);
            //GraphicalListener listener = new GraphicalListener();
            ConsoleListener listener = new ConsoleListener(ConsoleListener.LOW_VERBOSITY);
            Evolve e = new Evolve(p, listener);
            e.run();
            Individual ind = e.getBestIndividual();
            if (best == null || ind.getKozaFitness() < best.getKozaFitness()) {
                best = ind;
            }

            GPClassifier classifier = new GPClassifier(classID, ind);
            System.out.println("GP Training Results: " + ClassifierTester.testBinarySingleClass(classifier, trainingData, classID));
            System.out.println("GP Testing Results: " + ClassifierTester.testBinarySingleClass(classifier, testingData, classID));

        }
        GPClassifier c = new GPClassifier(classID, best);
        // check that this classifier is any good
        float result = c.ind.getHits() / (float) getEffectiveTrainingSize();
        // total possible mistakes
        float potentialFP = (1-result) * trainingStatistics.getClassCount(classID);
        float twoPercent = trainingData.size() * 0.02f;
        if (result > 0.9) {
            if (potentialFP < twoPercent)  {
                // this is so good that we can discard this class from the training data in future
                removeClasses(classID);
            }
            return c;
        } else {
            // the classifier isn't up to scratch - try to improve it.
            return learnInSeveralRuns(c, classID);
        }
    }
    

    public Classifier learnInSingleRunAdaBoost(int classID) {

        int seed = 2357;

        ProblemSettings settings = new ProblemSettings(GP_RUNTIME, seed, 7);
        GPClassificationProblem p = new GPClassificationProblem(classID, -1, settings, trainingData);

        BoostedDetectionProblem bp = new BoostedDetectionProblem(p);

        bp.boost(10);

        Classifier classifier = new AdaboostClassifier(bp.getBestSolution());
        System.out.println("GP Training Results: " + ClassifierTester.testBinarySingleClass(classifier, trainingData, classID));
        System.out.println("GP Testing Results: " + ClassifierTester.testBinarySingleClass(classifier, testingData, classID));

        return classifier;
    }

    public void removeClasses(int classID) {
        for (int i = 0; i < trainingData.size(); i++) {
            Data data = trainingData.elementAt(i);
            if (data.classID == classID)  {
                data.weight = 0;
            }
        }
    }

    public float getEffectiveTrainingSize() {
        float total = 0;
        for (int i = 0; i < trainingData.size(); i++) {
            Data data = trainingData.elementAt(i);
            total += data.weight;
        }
        return total;
    }

    public GPClassifier learnInSeveralRuns(GPClassifier c, int classID) {
        // This classifier isn't performing as well as we hoped. Try to evolve something decent by futher subdividing the problem
        // First subject it to more rigourous tests
        ClassifierTestResults results = ClassifierTester.testBinarySingleClass(c, trainingData, classID);
        // Find out which classes this classifier has most problems with
        Collections.sort(results.getClassConfusions());
        for (int i = 0; i < results.getClassConfusions().size(); i++) {
            ClassConfusion classConfusion = results.getClassConfusions().elementAt(i);
            System.out.println(classConfusion);
        }
        return c;
    }

    public static void main(String[] args) throws IOException {
        CSVDataReader training = new CSVDataReader(new File("/home/ooechs/Desktop/jasmine-data/sat-training.ssv"));
        CSVDataReader testing = new CSVDataReader(new File("/home/ooechs/Desktop/jasmine-data/sat-test.ssv"));
        SuperClassifier2 s = new SuperClassifier2();
        s.generateClassifier(training, testing);
    }

}