package ac.essex.ooechs.knn;

import ac.essex.gp.nodes.generic.CSVFeature;
import ac.essex.ooechs.kmeans.ClusterClass;
import ac.essex.ooechs.kmeans.KMeansSolution;

import java.io.File;
import java.text.DecimalFormat;

/**
 * <p/>
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version,
 * provided that any use properly credits the author.
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details at http://www.gnu.org
 * </p>
 *
 * @author Olly Oechsle, University of Essex, Date: 25-Oct-2007
 * @version 1.0
 */
public class KNNTester {

    public static void main(String[] args) throws Exception {

     //File dataFolder = new File("/home/ooechs/Desktop/jasmine-data");
     //File ORIGINAL_TRAINING = new File(dataFolder, "sat-training.ssv");
     //File ORIGINAL_TESTING = new File(dataFolder, "sat-test.ssv");

/*
        File ORIGINAL_TRAINING = new File("/home/ooechs/Desktop/pipes/pipes.csv");
        File ORIGINAL_TESTING = new File("/home/ooechs/Desktop/pipes/pipes.csv");
*/
        File ORIGINAL_TRAINING = new File("/home/ooechs/Desktop/toy-classification.csv");
        File ORIGINAL_TESTING = ORIGINAL_TRAINING;


        new KNNTester(ORIGINAL_TESTING, ORIGINAL_TRAINING);

    }

    public KNNTester(File training, File testing) throws Exception {

        CSVFeature.loadTrainingData(training);

        CSVFeature.loadTestData(testing);

        // create a k-means clusterer
        KNearestNeighbour knn = new KNearestNeighbour(2);

        // add the data to the clusterer
        for (int i = 0; i < CSVFeature.getTrainingDataSize(); i++) {
            if (i == 0) {
                double[] d = CSVFeature.getData(CSVFeature.TRAINING, i);
                int classID = CSVFeature.getTrainingClassID(i);
                System.out.println("d, classID");
            }
            knn.add(CSVFeature.getData(CSVFeature.TRAINING, i), CSVFeature.getTrainingClassID(i));
        }

        System.err.println("Training on " + CSVFeature.getTrainingDataSize() + " samples(s).");

        // results
        for (int classID = 1; classID <= 7; classID++) {
            System.out.print(classID + ", ");
            //System.out.println(test(knn, classID, CSVFeature.TRAINING));
            System.out.println(test(knn, classID, CSVFeature.TESTING));
        }

    }

    protected String test(KNearestNeighbour knn, int classID, int mode) {

        int size = mode == CSVFeature.TRAINING ? CSVFeature.getTrainingDataSize() : CSVFeature.getTestDataSize();

        int N = 0;
        int TP = 0;


        int[] outputs = new int[10];

        for (int i = 0; i < size; i++) {

            int expected = mode == CSVFeature.TRAINING ? CSVFeature.getTrainingClassID(i) : CSVFeature.getTestClassID(i);

            if (expected != classID) continue;

            N++;

            int output = knn.classify(CSVFeature.getData(mode, i));

            outputs[output]++;

            if (output == expected) {
                TP++;
            }

        }

        StringBuffer buffer = new StringBuffer();
        double percentage = (TP / (double) N) * 100;
        DecimalFormat f = new DecimalFormat("0.000");
        buffer.append(TP + " / " + N + " (" + f.format(percentage) + "%),");
        /*for (int i = 1; i < outputs.length; i++) {
            int output = outputs[i];
            buffer.append(f.format((output / (double) N) * 100) + ", ");
        }*/
        return buffer.toString();

    }


}

