package ac.ooechs.classify;

import ac.essex.gp.cluster.GPClient;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.util.DeepCopy;
import ac.ooechs.classify.classifier.Classifier;
import ac.ooechs.classify.classifier.GPClassifier;
import ac.ooechs.classify.classifier.MulticlassGPClassifier;
import ac.ooechs.classify.classifier.gp.FNRefinementProblem;
import ac.ooechs.classify.classifier.gp.GPOneClassClassificationProblem;
import ac.ooechs.classify.classifier.gp.ProblemSettings;
import ac.ooechs.classify.classifier.knn.NearestNeighbourClassifier;
import ac.ooechs.classify.data.Data;
import ac.ooechs.classify.data.DataNormaliser;
import ac.ooechs.classify.data.DataStatistics;
import ac.ooechs.classify.data.TrainingDataStore;
import ac.ooechs.classify.evaluation.ClassifierTestResults;
import ac.ooechs.classify.evaluation.ClassifierTester;
import ac.ooechs.classify.tasks.DistributedGPClientListener;
import ac.ooechs.classify.tasks.DistributedTask;
import ac.ooechs.classify.ui.ClassColours;
import ac.ooechs.classify.ui.GPStatusBox;
import java.io.File;
import java.io.IOException;
import java.util.Collections;
import java.util.Vector;

/* loaded from: input_file:ac/ooechs/classify/DistributedSuperClassifier2.class */
public class DistributedSuperClassifier2 extends DistributedExperimentRunner {
    public static final String VERSION = "1.27";
    protected Vector<Data> trainingData;
    protected Vector<Data> testingData;
    protected DataStatistics trainingStatistics;
    protected DataStatistics testingStatistics;
    protected DataNormaliser normaliser;
    protected int maxTraining;
    protected int maxTesting;
    public volatile MulticlassGPClassifier multiclassClassifier;
    TrainingDataStore store = new TrainingDataStore();

    /* loaded from: input_file:ac/ooechs/classify/DistributedSuperClassifier2$LearnClassTask.class */
    class LearnClassTask extends DistributedTask {
        protected int classID;
        protected GPClassifier c;
        protected Vector<Data> data;
        protected int classifierIndex;

        public LearnClassTask(int i, int i2) {
            super(DistributedSuperClassifier2.this, ClassColours.getColour(i));
            this.classifierIndex = i2;
            this.classID = i;
            this.data = DistributedSuperClassifier2.this.store.get(i);
        }

        public String toString() {
            return "Learning Class " + this.classID;
        }

        @Override // ac.ooechs.classify.tasks.DistributedTask
        public void startClient(GPClient gPClient, ProblemSettings problemSettings) {
            DistributedSuperClassifier2.this.listener.onClassUpdated(this.classID, "Learning");
            gPClient.setProblem(new GPOneClassClassificationProblem(this.classID, problemSettings, this.data));
            DistributedGPClientListener distributedGPClientListener = new DistributedGPClientListener(this);
            GPStatusBox statusBox = DistributedSuperClassifier2.this.listener.getStatusBox(gPClient.getID());
            statusBox.setColor(ClassColours.getColour(this.classID));
            statusBox.setName(toString());
            distributedGPClientListener.setProgressBar(statusBox);
            gPClient.setListener(distributedGPClientListener);
            gPClient.start();
        }

        @Override // ac.ooechs.classify.tasks.DistributedTask
        public void taskFinished(Vector<GPClient> vector) {
            DistributedSuperClassifier2.this.listener.onClassUpdated(this.classID, "Learned");
            System.out.println("--- Task: " + this + " finishing.");
            Individual processClients = DistributedSuperClassifier2.this.processClients(vector, this.classID);
            this.c = new GPClassifier(this.classID, processClients);
            float test = DistributedSuperClassifier2.this.test(DistributedSuperClassifier2.this.multiclassClassifier, DistributedSuperClassifier2.this.trainingData);
            DistributedSuperClassifier2.this.multiclassClassifier.set(this.c, this.classifierIndex);
            float test2 = DistributedSuperClassifier2.this.test(DistributedSuperClassifier2.this.multiclassClassifier, DistributedSuperClassifier2.this.trainingData);
            System.out.println("Finished task, adding class: " + this.classID);
            System.out.println("Fitness before: " + test + ", fitness after: " + test2);
            if (test > test2) {
                System.out.println("Classifier is worse after adding class: " + this.classID);
                System.out.println("Individual fitness was: " + processClients.getKozaFitness());
            }
            System.out.println("----");
            int[] hits = DistributedSuperClassifier2.this.getHits(this.c, this.data);
            int i = hits[0];
            int i2 = hits[1];
            DistributedSuperClassifier2.this.getHits(this.c, this.data);
            DistributedSuperClassifier2.this.listener.onClassUpdated(this.classID, i, i2);
            DistributedSuperClassifier2.this.onClassifierUpdated(DistributedSuperClassifier2.this.multiclassClassifier);
        }

        public GPClassifier getClassifier() {
            return this.c;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ac/ooechs/classify/DistributedSuperClassifier2$RefineClassTask.class */
    public class RefineClassTask extends DistributedTask implements Comparable {
        public static final int FALSE_NEGATIVES = 1;
        public static final int FALSE_POSITIVES = 2;
        protected int classID;
        protected GPClassifier toRefine;
        protected Vector<Data> data;
        protected Vector<Data> refinedTrainingData;
        protected ClassifierTestResults testResults;
        protected int errors;
        protected float refinementPotential;
        protected int mode;

        public RefineClassTask(GPClassifier gPClassifier, int i, int i2) throws RuntimeException {
            super(DistributedSuperClassifier2.this, ClassColours.getColour(i));
            this.toRefine = gPClassifier;
            this.mode = i2;
            this.classID = i;
            this.data = DistributedSuperClassifier2.this.store.get(i);
        }

        public boolean isWorthwhile() {
            this.testResults = ClassifierTester.testBinarySingleClass(this.toRefine, this.data, this.classID);
            if (this.mode == 1) {
                this.errors = this.testResults.getFalseNegatives().size();
            } else {
                this.errors = this.testResults.getFalsePositives().size();
            }
            if (this.errors == 0) {
                return false;
            }
            this.refinementPotential = this.errors / DistributedSuperClassifier2.this.trainingStatistics.getClassCount(this.classID);
            return true;
        }

        public String toString() {
            return this.mode == 1 ? "Refining Class " + this.classID + " FN=" + this.errors : "Refining Class " + this.classID + " FP=" + this.errors;
        }

        @Override // ac.ooechs.classify.tasks.DistributedTask
        public boolean init() {
            if (this.mode == 1) {
                if (this.testResults.getFalseNegatives().size() == 0) {
                    return false;
                }
                this.refinedTrainingData = new Vector<>(1000);
                this.refinedTrainingData.addAll(this.testResults.getFalseNegatives());
                for (int i = 0; i < this.data.size(); i++) {
                    Data elementAt = this.data.elementAt(i);
                    if (elementAt.classID != this.classID) {
                        this.refinedTrainingData.add(elementAt);
                    }
                }
                return true;
            }
            if (this.testResults.getFalsePositives().size() == 0) {
                return false;
            }
            this.refinedTrainingData = new Vector<>(1000);
            this.refinedTrainingData.addAll(this.testResults.getFalsePositives());
            for (int i2 = 0; i2 < this.data.size(); i2++) {
                Data elementAt2 = this.data.elementAt(i2);
                if (elementAt2.classID == this.classID) {
                    this.refinedTrainingData.add(elementAt2);
                }
            }
            return true;
        }

        @Override // ac.ooechs.classify.tasks.DistributedTask
        public void startClient(GPClient gPClient, ProblemSettings problemSettings) {
            DistributedSuperClassifier2.this.listener.onClassUpdated(this.classID, "Refining");
            gPClient.setProblem(new FNRefinementProblem(this.classID, this.toRefine, problemSettings, this.refinedTrainingData));
            DistributedGPClientListener distributedGPClientListener = new DistributedGPClientListener(this);
            GPStatusBox statusBox = DistributedSuperClassifier2.this.listener.getStatusBox(gPClient.getID());
            statusBox.setColor(ClassColours.getColour(this.classID));
            statusBox.setName(toString());
            distributedGPClientListener.setProgressBar(statusBox);
            gPClient.setListener(distributedGPClientListener);
            gPClient.start();
        }

        @Override // java.lang.Comparable
        public int compareTo(Object obj) {
            RefineClassTask refineClassTask = (RefineClassTask) obj;
            if (refineClassTask.refinementPotential > this.refinementPotential) {
                return 1;
            }
            return refineClassTask.refinementPotential < this.refinementPotential ? -1 : 0;
        }

        @Override // ac.ooechs.classify.tasks.DistributedTask
        public void taskFinished(Vector<GPClient> vector) {
            DistributedSuperClassifier2.this.listener.onClassUpdated(this.classID, "Refined");
            System.out.println("--- Task: " + this + " finishing.");
            synchronized (DistributedSuperClassifier2.this.multiclassClassifier) {
                float test = DistributedSuperClassifier2.this.test(DistributedSuperClassifier2.this.multiclassClassifier, DistributedSuperClassifier2.this.trainingData);
                Individual processClients = DistributedSuperClassifier2.this.processClients(vector, this.classID);
                if (this.mode == 1) {
                    this.toRefine.FNRefinements.add(processClients);
                } else {
                    this.toRefine.FPRefinements.add(processClients);
                }
                if (DistributedSuperClassifier2.this.test(DistributedSuperClassifier2.this.multiclassClassifier, DistributedSuperClassifier2.this.trainingData) > test) {
                    System.out.println("Hurrah! It has reduced the error!");
                    int[] hits = DistributedSuperClassifier2.this.getHits(this.toRefine, this.data);
                    DistributedSuperClassifier2.this.listener.onClassUpdated(this.classID, hits[0], hits[1]);
                    DistributedSuperClassifier2.this.onClassifierUpdated(DistributedSuperClassifier2.this.multiclassClassifier);
                } else {
                    System.out.println("Boo! It didn't reduce the error at all");
                    if (this.mode == 1) {
                        this.toRefine.FNRefinements.remove(processClients);
                        this.toRefine.lastFNRefinementFailed = true;
                    } else {
                        this.toRefine.FPRefinements.remove(processClients);
                        this.toRefine.lastFPRefinementFailed = true;
                    }
                }
            }
        }
    }

    @Override // ac.ooechs.classify.DistributedExperimentRunner
    public void saveClassifier(File file) {
        this.multiclassClassifier.save(file);
    }

    @Override // ac.ooechs.classify.DistributedExperimentRunner
    public String getVersion() {
        return VERSION;
    }

    @Override // ac.ooechs.classify.DistributedExperimentRunner
    public void generateClassifier() throws IOException {
        if (this.training == null || this.testing == null) {
            this.listener.onError("Training/Testing data not set up!");
            throw new RuntimeException("Training/Testing data not set up");
        }
        this.listener.onStatusUpdate("Loading the data");
        this.trainingData = this.training.getData();
        this.testingData = this.testing.getData();
        this.trainingStatistics = new DataStatistics(this.trainingData);
        this.testingStatistics = new DataStatistics(this.testingData);
        this.listener.onStatusUpdate("Assessing class difficulty");
        NearestNeighbourClassifier nearestNeighbourClassifier = new NearestNeighbourClassifier(this.trainingData);
        Vector vector = new Vector(this.trainingStatistics.getClassCount());
        for (int i = 0; i < this.trainingStatistics.getClassIDs().size(); i++) {
            int intValue = this.trainingStatistics.getClassIDs().elementAt(i).intValue();
            System.out.println("Running k-nearest neighbour on class: " + intValue);
            vector.add(ClassifierTester.testBinarySingleClass(nearestNeighbourClassifier, this.testingData, intValue, false));
        }
        System.out.println("Getting Confusion Matrix");
        ClassifierTester.testMulticlass(this.trainingStatistics.getClassIDs(), nearestNeighbourClassifier, this.testingData).printConfusionMatrix();
        this.listener.onStatusUpdate("Normalising data");
        this.normaliser = new DataNormaliser(this.trainingData);
        this.normaliser.normalise(this.trainingData);
        this.normaliser.normalise(this.testingData);
        Collections.sort(vector);
        this.multiclassClassifier = new MulticlassGPClassifier(this.trainingStatistics.getClassCount());
        for (int i2 = 0; i2 < vector.size(); i2++) {
            ClassifierTestResults classifierTestResults = (ClassifierTestResults) vector.elementAt(i2);
            int classID = classifierTestResults.getClassID();
            if (i2 < vector.size() - 1) {
                System.out.println("Classify Class " + classID + ", e=" + ((int) classifierTestResults.getErrorEstimate()));
                this.store.put(classID, (Vector) new DeepCopy().copy(this.trainingData));
                addTask(new LearnClassTask(classID, i2));
                removeClasses(classID);
            } else {
                this.multiclassClassifier.defaultClass = classifierTestResults.getClassID();
                this.listener.onClassUpdated(classifierTestResults.getClassID(), this.trainingStatistics.getClassCount(classifierTestResults.getClassID()), 0);
                this.listener.onClassUpdated(classifierTestResults.getClassID(), "Default");
                onClassifierUpdated(this.multiclassClassifier);
                System.out.println("Default Class " + classID);
            }
        }
        runTasks();
    }

    @Override // ac.ooechs.classify.DistributedExperimentRunner
    public synchronized Vector<DistributedTask> getReadyTasks(Vector<DistributedTask> vector) {
        Vector<DistributedTask> vector2 = new Vector<>();
        for (int i = 0; i < vector.size(); i++) {
            DistributedTask elementAt = vector.elementAt(i);
            if (elementAt.readyToExecute()) {
                vector2.add(elementAt);
            }
        }
        if (vector2.size() != 0) {
            return vector2;
        }
        if (this.stop) {
            this.listener.onFinished(this.multiclassClassifier);
            return null;
        }
        vector.addAll(getRefinementTasks());
        if (vector == null || vector.size() <= 0) {
            this.listener.onFinished(this.multiclassClassifier);
            return null;
        }
        runTasks();
        return null;
    }

    public Vector<RefineClassTask> getRefinementTasks() {
        Vector<RefineClassTask> vector = new Vector<>();
        for (int i = 0; i < this.multiclassClassifier.classifiers.length; i++) {
            GPClassifier gPClassifier = this.multiclassClassifier.classifiers[i];
            if (gPClassifier != null) {
                boolean z = true;
                if (!gPClassifier.lastFNRefinementFailed) {
                    RefineClassTask refineClassTask = new RefineClassTask(gPClassifier, gPClassifier.classID, 1);
                    if (refineClassTask.isWorthwhile()) {
                        vector.add(refineClassTask);
                        z = false;
                    }
                }
                if (!gPClassifier.lastFPRefinementFailed) {
                    RefineClassTask refineClassTask2 = new RefineClassTask(gPClassifier, gPClassifier.classID, 2);
                    if (refineClassTask2.isWorthwhile()) {
                        vector.add(refineClassTask2);
                        z = false;
                    }
                }
                if (z) {
                    this.listener.onClassUpdated(gPClassifier.classID, "Finished");
                } else {
                    this.listener.onClassUpdated(gPClassifier.classID, "Scheduled");
                }
            }
        }
        Collections.sort(vector);
        for (int i2 = 0; i2 < vector.size(); i2++) {
            System.out.println(i2 + ". " + vector.elementAt(i2));
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("--- Adding new tasks:\n");
        for (int i3 = 0; i3 < vector.size(); i3++) {
            RefineClassTask elementAt = vector.elementAt(i3);
            stringBuffer.append(elementAt + " " + elementAt.refinementPotential);
            System.out.println("\n");
        }
        this.listener.appendToConsole(stringBuffer.toString());
        return vector;
    }

    @Override // ac.ooechs.classify.DistributedExperimentRunner
    public int estimateRuntime(int i, int i2, int i3) {
        int i4 = (i - 1) * 2;
        int size = this.clients.size();
        double d = 0.0d;
        while (size >= i2) {
            d += 1.0d;
            size -= i2;
        }
        if (d == 0.0d) {
            return -1;
        }
        int i5 = (int) (i4 / d);
        int i6 = 0;
        while (true) {
            int i7 = i6;
            if (i5 <= 0) {
                return i7 + 1 + 1;
            }
            i5 = (int) (i5 - d);
            i6 = i7 + i3;
        }
    }

    public void onClassifierUpdated(Classifier classifier) {
        System.out.println("--- Classifier updated:");
        GPClassifier.REFINEMENT_ON = false;
        System.out.println("Without Refinement: ");
        System.out.println(test(classifier, this.trainingData) + ", " + test(classifier, this.testingData));
        GPClassifier.REFINEMENT_ON = true;
        System.out.println("With refinement:");
        float test = test(classifier, this.trainingData);
        float test2 = test(classifier, this.testingData);
        System.out.println(test + ", " + test2);
        this.listener.onClassifierUpdated(test, test2, this.maxTraining / this.trainingStatistics.getDataCount(), this.maxTesting / this.testingStatistics.getDataCount());
    }

    public synchronized float test(Classifier classifier, Vector<Data> vector) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < vector.size(); i3++) {
            Data elementAt = vector.elementAt(i3);
            i++;
            if (classifier.classify(elementAt) == elementAt.classID) {
                i2++;
            }
        }
        return i2 / i;
    }

    public synchronized int[] getHits(Classifier classifier, Vector<Data> vector) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < vector.size(); i3++) {
            Data elementAt = vector.elementAt(i3);
            int classify = classifier.classify(elementAt);
            if (classify == elementAt.classID) {
                if (classify > 0) {
                    i++;
                }
            } else if (classify > 0) {
                i2 = (int) (i2 + elementAt.weight);
            }
        }
        return new int[]{i, i2};
    }

    public void removeClasses(int i) {
        for (int i2 = 0; i2 < this.trainingData.size(); i2++) {
            Data elementAt = this.trainingData.elementAt(i2);
            if (elementAt.classID == i) {
                elementAt.weight = 0.0f;
            }
        }
    }

    public float getEffectiveTrainingSize() {
        float f = 0.0f;
        for (int i = 0; i < this.trainingData.size(); i++) {
            f += this.trainingData.elementAt(i).weight;
        }
        return f;
    }
}
