package ac.essex.ooechs.problems;

import ac.essex.ooechs.kmeans.ClusterClass;
import ac.essex.ooechs.kmeans.KMeansClusterer;
import ac.essex.ooechs.kmeans.DataPoint;
import ac.essex.ooechs.kmeans.KMeansSolution;
import ac.essex.ooechs.imaging.commons.StatisticsSolver;

import java.io.File;
import java.util.Vector;
import java.text.DecimalFormat;

import ac.essex.gp.nodes.generic.CSVFeature;

/**
 * <p/>
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version,
 * provided that any use properly credits the author.
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details at http://www.gnu.org
 * </p>
 *
 * @author Olly Oechsle, University of Essex, Date: 25-Oct-2007
 * @version 1.0
 */
public class GenericClassificationProblem {

    public static void main(String[] args) throws Exception {

        // Colour Segmentation Data, 25th Oct 2007
/*        File training = new File("/home/ooechs/Desktop/Colour-Segmentation_pixel_features_training.csv");
        File testing = new File("/home/ooechs/Desktop/Colour-Segmentation_pixel_features_testing.csv");*/

        File training = new File("/home/ooechs/Desktop/sat-training.ssv");
        File testing = new File("/home/ooechs/Desktop/sat-test.ssv");        

        CSVFeature.loadTrainingData(training);
        CSVFeature.loadTestData(testing);

        System.out.println("Class count: " + CSVFeature.getDistinctClasses().size());

        for (int k = CSVFeature.getDistinctClasses().size(); k < 30; k++) {

            StatisticsSolver s = new StatisticsSolver();

            for (int i = 0; i < 10; i++) {

                GenericClassificationProblem gcp = new GenericClassificationProblem(k,training, testing);

                s.addData(gcp.testResult);

            }

            DecimalFormat f = new DecimalFormat("0.0");

            System.out.println("K = " + k + ", mean=" + f.format(s.getMean()) + ", max=" + f.format(s.getMax()));

        }

    }

    public GenericClassificationProblem(int k, File training, File testing) throws Exception {

        Vector<ClusterClass> classes = new Vector<ClusterClass>();
        for (int i = 0; i < CSVFeature.getDistinctClasses().size(); i++) {
            Integer classID = CSVFeature.getDistinctClasses().elementAt(i);
            classes.add(new ClusterClass(classID, "class " + i));
        }

        // create a k-means clusterer
        KMeansClusterer clusterer = new KMeansClusterer(k);
        clusterer.verbose = false;

        // add the data to the clusterer
        for (int i = 0; i < CSVFeature.getTrainingDataSize(); i++) {
            clusterer.add(new DataPoint(CSVFeature.getData(CSVFeature.TRAINING, i), CSVFeature.getTrainingClassID(i)));
        }

        //System.err.println("Training on " + CSVFeature.getTrainingDataSize() + " samples(s).");

        // run the clusterer
        clusterer.run();

        // get the solution
        KMeansSolution s = clusterer.getSolution();

        // results
        //System.out.println("TRAINING RESULTS:");
        test(s, CSVFeature.TRAINING);
        //System.out.println("TESTING RESULTS:");
        test(s, CSVFeature.TESTING);

    }

    double testResult = 0;

    protected void test(KMeansSolution s, int mode) {

        int size = mode == CSVFeature.TRAINING ? CSVFeature.getTrainingDataSize() : CSVFeature.getTestDataSize();

        int TP = 0;

        for (int i = 0; i < size; i++) {

            int output = s.test(CSVFeature.getData(mode, i));

            int classID = mode == CSVFeature.TRAINING ? CSVFeature.getTrainingClassID(i) : CSVFeature.getTestClassID(i);

            if (output == classID) {
                TP++;
            }

        }

        double percentage = (TP / (double) size) * 100;
        DecimalFormat f = new DecimalFormat("0.000");
        //System.out.println(TP + " / " + size + " (" + f.format(percentage) + "%)");


        if (mode == CSVFeature.TESTING) {
            testResult = percentage;
        }


    }


}
