package ac.essex.ooechs.problems;

import ac.essex.gp.nodes.generic.CSVFeature;
import ac.essex.ooechs.imaging.commons.StatisticsSolver;
import ac.essex.ooechs.kmeans.ClusterClass;
import ac.essex.ooechs.kmeans.DataPoint;
import ac.essex.ooechs.kmeans.KMeansClusterer;
import ac.essex.ooechs.kmeans.KMeansSolution;
import java.io.File;
import java.text.DecimalFormat;
import java.util.Vector;

/* loaded from: input_file:ac/essex/ooechs/problems/GenericClassificationProblem.class */
public class GenericClassificationProblem {
    double testResult = 0.0d;

    public static void main(String[] strArr) throws Exception {
        File file = new File("/home/ooechs/Desktop/sat-training.ssv");
        File file2 = new File("/home/ooechs/Desktop/sat-test.ssv");
        CSVFeature.loadTrainingData(file);
        CSVFeature.loadTestData(file2);
        System.out.println("Class count: " + CSVFeature.getDistinctClasses().size());
        for (int size = CSVFeature.getDistinctClasses().size(); size < 30; size++) {
            StatisticsSolver statisticsSolver = new StatisticsSolver();
            for (int i = 0; i < 10; i++) {
                statisticsSolver.addData(new GenericClassificationProblem(size, file, file2).testResult);
            }
            DecimalFormat decimalFormat = new DecimalFormat("0.0");
            System.out.println("K = " + size + ", mean=" + decimalFormat.format(statisticsSolver.getMean()) + ", max=" + decimalFormat.format(statisticsSolver.getMax()));
        }
    }

    public GenericClassificationProblem(int i, File file, File file2) throws Exception {
        Vector vector = new Vector();
        for (int i2 = 0; i2 < CSVFeature.getDistinctClasses().size(); i2++) {
            vector.add(new ClusterClass(((Integer) CSVFeature.getDistinctClasses().elementAt(i2)).intValue(), "class " + i2));
        }
        KMeansClusterer kMeansClusterer = new KMeansClusterer(i);
        kMeansClusterer.verbose = false;
        for (int i3 = 0; i3 < CSVFeature.getTrainingDataSize(); i3++) {
            kMeansClusterer.add(new DataPoint(CSVFeature.getData(1, i3), CSVFeature.getTrainingClassID(i3)));
        }
        kMeansClusterer.run();
        KMeansSolution solution = kMeansClusterer.getSolution();
        test(solution, 1);
        test(solution, 2);
    }

    protected void test(KMeansSolution kMeansSolution, int i) {
        int trainingDataSize = i == 1 ? CSVFeature.getTrainingDataSize() : CSVFeature.getTestDataSize();
        int i2 = 0;
        for (int i3 = 0; i3 < trainingDataSize; i3++) {
            if (kMeansSolution.test(CSVFeature.getData(i, i3)) == (i == 1 ? CSVFeature.getTrainingClassID(i3) : CSVFeature.getTestClassID(i3))) {
                i2++;
            }
        }
        double d = (i2 / trainingDataSize) * 100.0d;
        new DecimalFormat("0.000");
        if (i == 2) {
            this.testResult = d;
        }
    }
}
