package ac.ooechs.classify;

import ac.essex.gp.cluster.GPClient;
import ac.essex.gp.cluster.GPClientListener;
import ac.essex.gp.individuals.Individual;
import ac.essex.ooechs.imaging.commons.fast.FastStatistics;
import ac.ooechs.classify.classifier.Classifier;
import ac.ooechs.classify.classifier.GPClassifier;
import ac.ooechs.classify.classifier.MulticlassGPClassifier;
import ac.ooechs.classify.classifier.gp.FNRefinementProblem;
import ac.ooechs.classify.classifier.gp.GPClassificationProblem;
import ac.ooechs.classify.classifier.gp.ProblemSettings;
import ac.ooechs.classify.classifier.knn.NearestNeighbourClassifier;
import ac.ooechs.classify.data.Data;
import ac.ooechs.classify.data.DataNormaliser;
import ac.ooechs.classify.data.DataStatistics;
import ac.ooechs.classify.data.io.DataReader;
import ac.ooechs.classify.evaluation.ClassifierTestResults;
import ac.ooechs.classify.evaluation.ClassifierTester;
import ac.ooechs.classify.experiments.Experiment18_VarianceThresholding;
import ac.ooechs.classify.ui.GPStatusBox;
import java.awt.Component;
import java.io.IOException;
import java.util.Collections;
import java.util.Hashtable;
import java.util.Vector;
import javax.swing.JOptionPane;

/* loaded from: input_file:ac/ooechs/classify/DistributedSuperClassifier.class */
public class DistributedSuperClassifier extends Thread {
    public static final String VERSION = "1.1";
    public static int gpRuntime = Experiment18_VarianceThresholding.runtime;
    protected Vector<Data> trainingData;
    protected Vector<Data> testingData;
    protected DataStatistics trainingStatistics;
    protected DataStatistics testingStatistics;
    protected DataNormaliser normaliser;
    public Vector<GPClient> clients;
    protected SuperClassifierListener listener;
    protected int maxTraining;
    protected int maxTesting;
    protected MulticlassGPClassifier multiclassClassifier;
    protected DataReader training;
    protected DataReader testing;
    private int activeRuns;

    /* loaded from: input_file:ac/ooechs/classify/DistributedSuperClassifier$Listener.class */
    public class Listener implements GPClientListener {
        protected int activeRuns = 0;
        protected GPStatusBox bar = null;

        public Listener() {
            DistributedSuperClassifier.this.registerThread();
        }

        public void setProgressBar(GPStatusBox gPStatusBox) {
            this.bar = gPStatusBox;
            gPStatusBox.onStart();
        }

        public void onServerError(String str) {
            JOptionPane.showMessageDialog((Component) null, "Error in one of the server threads: " + str);
            DistributedSuperClassifier.this.deregisterThread();
        }

        public void onStatusUpdate(int i, double d, int i2, int i3) {
            if (this.bar != null) {
                this.bar.update(i, d, i2, i3, i3 / DistributedSuperClassifier.gpRuntime);
            }
        }

        public void onFinish(Individual individual) {
            this.bar.onFinish();
            DistributedSuperClassifier.this.deregisterThread();
        }
    }

    public DistributedSuperClassifier(int i, Vector<GPClient> vector) {
        this.clients = vector;
        gpRuntime = i * 60;
    }

    public void setListener(SuperClassifierListener superClassifierListener) {
        this.listener = superClassifierListener;
    }

    public void setData(DataReader dataReader, DataReader dataReader2) {
        this.training = dataReader;
        this.testing = dataReader2;
    }

    @Override // java.lang.Thread, java.lang.Runnable
    public void run() {
        try {
            generateClassifier();
        } catch (Exception e) {
            e.printStackTrace();
            this.listener.onError(e.toString());
        }
    }

    public void generateClassifier() throws IOException {
        if (this.training == null || this.testing == null) {
            this.listener.onError("Training/Testing data not set up!");
            throw new RuntimeException("Training/Testing data not set up");
        }
        this.listener.onStatusUpdate("Loading the data");
        this.trainingData = this.training.getData();
        this.testingData = this.testing.getData();
        this.trainingStatistics = new DataStatistics(this.trainingData);
        this.testingStatistics = new DataStatistics(this.testingData);
        this.listener.onStatusUpdate("Running k-Nearest Neighbour");
        NearestNeighbourClassifier nearestNeighbourClassifier = new NearestNeighbourClassifier(this.trainingData);
        Vector vector = new Vector(this.trainingStatistics.getClassCount());
        for (int i = 0; i < this.trainingStatistics.getClassIDs().size(); i++) {
            vector.add(ClassifierTester.testBinarySingleClass(nearestNeighbourClassifier, this.testingData, this.trainingStatistics.getClassIDs().elementAt(i).intValue()));
        }
        this.listener.onStatusUpdate("Normalising data");
        this.normaliser = new DataNormaliser(this.trainingData);
        this.normaliser.normalise(this.trainingData);
        this.normaliser.normalise(this.testingData);
        Collections.sort(vector);
        this.multiclassClassifier = new MulticlassGPClassifier(this.trainingStatistics.getClassCount());
        for (int i2 = 0; i2 < vector.size(); i2++) {
            ClassifierTestResults classifierTestResults = (ClassifierTestResults) vector.elementAt(i2);
            this.maxTraining += this.trainingStatistics.getClassCount(classifierTestResults.getClassID());
            this.maxTesting += this.testingStatistics.getClassCount(classifierTestResults.getClassID());
            if (i2 < vector.size() - 1) {
                GPClassifier learnInSingleRun = learnInSingleRun(classifierTestResults.getClassID(), i2);
                this.listener.onGPRunsComplete(classifierTestResults.getClassID(), ClassifierTester.testBinarySingleClass(learnInSingleRun, this.trainingData, classifierTestResults.getClassID()), ClassifierTester.testBinarySingleClass(learnInSingleRun, this.testingData, classifierTestResults.getClassID()));
            } else {
                this.multiclassClassifier.defaultClass = classifierTestResults.getClassID();
                onClassifierUpdated();
            }
        }
        this.listener.onFinished(this.multiclassClassifier);
    }

    public void onClassifierUpdated() {
        GPClassifier.REFINEMENT_ON = false;
        System.out.println("Without Refinement: ");
        System.out.println(test(this.multiclassClassifier, this.trainingData) + ", " + test(this.multiclassClassifier, this.testingData));
        GPClassifier.REFINEMENT_ON = true;
        System.out.println("With refinement:");
        float test = test(this.multiclassClassifier, this.trainingData);
        float test2 = test(this.multiclassClassifier, this.testingData);
        System.out.println(test + ", " + test2);
        this.listener.onClassifierUpdated(test, test2, this.maxTraining / this.trainingStatistics.getDataCount(), this.maxTesting / this.testingStatistics.getDataCount());
    }

    public float test(Classifier classifier, Vector<Data> vector) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < vector.size(); i3++) {
            Data elementAt = vector.elementAt(i3);
            i++;
            if (classifier.classify(elementAt) == elementAt.classID) {
                i2++;
            }
        }
        return i2 / i;
    }

    public synchronized void registerThread() {
        this.activeRuns++;
        System.out.println("Registered thread: " + this.activeRuns);
    }

    public synchronized void deregisterThread() {
        this.activeRuns--;
        System.out.println("Deregistered thread: " + this.activeRuns);
        if (this.activeRuns == 0) {
            System.out.println("Notifying");
            notifyAll();
        }
    }

    public synchronized GPClassifier learnInSingleRun(int i, int i2) {
        Vector<Listener> vector = new Vector<>();
        int i3 = 2357;
        for (int i4 = 0; i4 < this.clients.size(); i4++) {
            GPClient elementAt = this.clients.elementAt(i4);
            elementAt.setProblem(new GPClassificationProblem(i, -1, new ProblemSettings(gpRuntime, i3, i4 % 2 == 0 ? 2 : 7, i2 % 2 == 0 ? 0 : 1), this.trainingData));
            Listener listener = new Listener();
            vector.add(listener);
            elementAt.setListener(listener);
            elementAt.start();
            if (i4 % 2 == 0) {
                i3++;
            }
        }
        this.listener.onStatusUpdate("Learning Class " + i);
        this.listener.onListenersUpdated(vector);
        try {
            System.out.println("Started waiting");
            wait();
            System.out.println("Finished waiting");
        } catch (InterruptedException e) {
        }
        GPClassifier gPClassifier = new GPClassifier(i, processClients(i));
        this.multiclassClassifier.set(gPClassifier, i2);
        onClassifierUpdated();
        refineClassifier(gPClassifier, i);
        if ((1.0f - test(gPClassifier, this.trainingData)) * this.trainingStatistics.getClassCount(i) < this.trainingData.size() * 0.02f) {
            removeClasses(i);
        }
        return gPClassifier;
    }

    public Individual processClients(int i) {
        FastStatistics fastStatistics = new FastStatistics();
        Hashtable hashtable = new Hashtable();
        Individual individual = null;
        for (int i2 = 0; i2 < this.clients.size(); i2++) {
            GPClient elementAt = this.clients.elementAt(i2);
            Individual solution = elementAt.getSolution();
            if (solution != null) {
                fastStatistics.addData((float) solution.getKozaFitness());
                if (individual == null || solution.getKozaFitness() < individual.getKozaFitness()) {
                    individual = solution;
                }
                String problemName = elementAt.getProblemName();
                FastStatistics fastStatistics2 = (FastStatistics) hashtable.get(problemName);
                if (fastStatistics2 == null) {
                    fastStatistics2 = new FastStatistics();
                    hashtable.put(problemName, fastStatistics2);
                }
                fastStatistics2.addData((float) solution.getKozaFitness());
            }
        }
        this.listener.onGPFitnessStatistics(i, fastStatistics, null);
        return individual;
    }

    public synchronized void refineClassifier(GPClassifier gPClassifier, int i) {
        ClassifierTestResults testBinarySingleClass = ClassifierTester.testBinarySingleClass(gPClassifier, this.trainingData, i);
        if (testBinarySingleClass.getFalseNegatives().size() == 0) {
            return;
        }
        this.listener.onStatusUpdate("Refining class " + i + " e= " + testBinarySingleClass.getFalseNegatives().size());
        Vector vector = new Vector(1000);
        vector.addAll(testBinarySingleClass.getFalseNegatives());
        for (int i2 = 0; i2 < this.trainingData.size(); i2++) {
            Data elementAt = this.trainingData.elementAt(i2);
            if (elementAt.classID != i) {
                vector.add(elementAt);
            }
        }
        Vector<Listener> vector2 = new Vector<>();
        int i3 = 2357;
        for (int i4 = 0; i4 < this.clients.size(); i4++) {
            GPClient elementAt2 = this.clients.elementAt(i4);
            elementAt2.setProblem(new FNRefinementProblem(i, gPClassifier, new ProblemSettings(gpRuntime, i3, i4 % 2 == 0 ? 2 : 7), vector));
            Listener listener = new Listener();
            vector2.add(listener);
            elementAt2.setListener(listener);
            elementAt2.start();
            if (i4 % 2 == 0) {
                i3++;
            }
        }
        this.listener.onListenersUpdated(vector2);
        try {
            System.out.println("Started waiting");
            wait();
            System.out.println("Finished waiting");
        } catch (InterruptedException e) {
        }
        System.out.println(processClients(i).toJava());
        ClassifierTestResults testBinarySingleClass2 = ClassifierTester.testBinarySingleClass(gPClassifier, this.trainingData, i);
        if (testBinarySingleClass2.getErrorEstimate() >= testBinarySingleClass.getErrorEstimate()) {
            System.out.println("Boo! It didn't reduce the error at all");
            return;
        }
        System.out.println("Hurrah! It has reduced the error!");
        this.listener.onClassifierRefined(testBinarySingleClass.getPercentageCorrect(), testBinarySingleClass2.getPercentageCorrect());
        onClassifierUpdated();
    }

    public synchronized void refineMulticlassClassifier() {
    }

    public void removeClasses(int i) {
        for (int i2 = 0; i2 < this.trainingData.size(); i2++) {
            Data elementAt = this.trainingData.elementAt(i2);
            if (elementAt.classID == i) {
                elementAt.weight = 0.0f;
            }
        }
    }

    public float getEffectiveTrainingSize() {
        float f = 0.0f;
        for (int i = 0; i < this.trainingData.size(); i++) {
            f += this.trainingData.elementAt(i).weight;
        }
        return f;
    }
}
