package ac.ooechs.classify.experiments;

import ac.essex.gp.Evolve;
import ac.essex.gp.interfaces.console.ConsoleListener;
import ac.ooechs.classify.classifier.GPClassifier;
import ac.ooechs.classify.classifier.fusion.ClassifierFusion;
import ac.ooechs.classify.classifier.gp.GPMulticlassClassificationProblem;
import ac.ooechs.classify.data.Data;
import ac.ooechs.classify.data.DataStatistics;
import java.util.Random;
import java.util.Vector;

/* loaded from: input_file:ac/ooechs/classify/experiments/KFoldCrossValidation.class */
public class KFoldCrossValidation {
    protected int numFolds;
    protected Random r;

    /* loaded from: input_file:ac/ooechs/classify/experiments/KFoldCrossValidation$DataBin.class */
    class DataBin {
        int classID;
        int size = 0;
        Vector<Data> data = new Vector<>(50);

        public DataBin(int i) {
            this.classID = i;
        }

        public void add(Data data) {
            if (data.getLabel() != this.classID) {
                throw new RuntimeException("Wrong data in bin " + this.classID);
            }
            this.data.add(data);
            if (this.data.size() > this.size) {
                this.size = this.data.size();
            }
        }

        public int size() {
            return this.size;
        }

        public Data getData() {
            if (this.data.size() == 0) {
                return null;
            }
            Data elementAt = this.data.elementAt((int) (this.data.size() * KFoldCrossValidation.this.getRandomNumber()));
            this.data.remove(elementAt);
            return elementAt;
        }
    }

    public double getRandomNumber() {
        return this.r.nextDouble();
    }

    public KFoldCrossValidation(int i, int i2) {
        this.r = null;
        this.numFolds = i;
        if (i2 == -1) {
            this.r = new Random();
        } else {
            this.r = new Random(i2);
        }
    }

    public void evolveSolution(int i, Vector<Data> vector, DataStatistics dataStatistics, GPMulticlassClassificationProblem gPMulticlassClassificationProblem) {
        int i2 = 0;
        for (int i3 = 0; i3 < this.numFolds; i3++) {
            System.out.println("Fold: " + i3);
            gPMulticlassClassificationProblem.setFold(i3);
            int i4 = 0;
            ClassifierFusion classifierFusion = new ClassifierFusion(dataStatistics.getClassIDs());
            for (int i5 = 0; i5 < 5; i5++) {
                gPMulticlassClassificationProblem.getProblemSettings().seed = i + i5;
                Evolve evolve = new Evolve(gPMulticlassClassificationProblem, new ConsoleListener(ConsoleListener.SILENT));
                evolve.run();
                GPClassifier gPClassifier = new GPClassifier(evolve.getBestIndividual());
                int i6 = gPClassifier.getHits(vector, i3)[0];
                if (i6 > i4) {
                    i4 = i6;
                }
                System.out.println("c[" + i5 + "] : " + i6);
                classifierFusion.add(gPClassifier);
            }
            classifierFusion.setMode(1);
            int tryHits = classifierFusion.tryHits(vector, i3);
            classifierFusion.setMode(2);
            int tryHits2 = classifierFusion.tryHits(vector, i3);
            System.out.println("Best individual hits: " + i4);
            System.out.println("Majority hits: " + tryHits);
            System.out.println("Committee hits: " + tryHits2);
            if (tryHits > i4) {
                i4 = tryHits;
            }
            if (tryHits2 > i4) {
                i4 = tryHits2;
            }
            i2 += i4;
        }
        System.out.println("Total =========");
        System.out.println("Hits: " + i2 + " / " + vector.size());
        System.out.println("===============");
    }

    public void splitData(Vector<Data> vector, DataStatistics dataStatistics) {
        DataBin[] dataBinArr = new DataBin[dataStatistics.getClassCount()];
        for (int i = 0; i < dataStatistics.getClassIDs().size(); i++) {
            int intValue = dataStatistics.getClassIDs().elementAt(i).intValue();
            dataBinArr[intValue - 1] = new DataBin(intValue);
        }
        for (int i2 = 0; i2 < vector.size(); i2++) {
            Data elementAt = vector.elementAt(i2);
            dataBinArr[elementAt.getLabel() - 1].add(elementAt);
        }
        int[] iArr = new int[this.numFolds];
        int i3 = 0;
        for (DataBin dataBin : dataBinArr) {
            while (true) {
                Data data = dataBin.getData();
                if (data == null) {
                    break;
                }
                data.fold = i3;
                int i4 = i3;
                iArr[i4] = iArr[i4] + 1;
                i3++;
                if (i3 >= this.numFolds) {
                    i3 = 0;
                }
            }
        }
        for (int i5 = 0; i5 < iArr.length; i5++) {
            System.out.println("fold[" + i5 + "] size=" + iArr[i5]);
        }
    }
}
