package ac.ooechs.classify.classifier.gp.multiclass;

import ac.essex.gp.Evolve;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.interfaces.console.ConsoleListener;
import ac.essex.gp.multiclass.BasicDRS;
import ac.essex.gp.multiclass.BetterDRS;
import ac.essex.gp.multiclass.CachedOutput;
import ac.essex.gp.multiclass.PCM;
import ac.essex.gp.nodes.Add;
import ac.essex.gp.nodes.Div;
import ac.essex.gp.nodes.Max;
import ac.essex.gp.nodes.Mean;
import ac.essex.gp.nodes.Min;
import ac.essex.gp.nodes.Mul;
import ac.essex.gp.nodes.PercentDiff;
import ac.essex.gp.nodes.Sub;
import ac.essex.gp.nodes.ercs.CustomRangeERC;
import ac.essex.gp.nodes.ercs.PercentageERC;
import ac.essex.gp.nodes.math.Cubed;
import ac.essex.gp.nodes.math.Ln;
import ac.essex.gp.nodes.math.Sin;
import ac.essex.gp.nodes.math.Squared;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.problems.DataStack;
import ac.essex.gp.problems.Problem;
import ac.essex.ooechs.imaging.commons.fast.FastStatistics;
import ac.ooechs.classify.classifier.gp.DataValueTerminal;
import ac.ooechs.classify.classifier.gp.ProblemSettings;
import ac.ooechs.classify.data.Data;
import ac.ooechs.classify.data.DataStatistics;
import ac.ooechs.classify.data.io.CSVDataReader;
import java.io.File;
import java.io.IOException;
import java.util.Vector;

/* loaded from: input_file:ac/ooechs/classify/classifier/gp/multiclass/GPMulticlassClassificationProblem.class */
public class GPMulticlassClassificationProblem extends Problem {
    protected Vector<Data> trainingData;
    protected ProblemSettings p;

    public static void main(String[] strArr) throws IOException {
        CSVDataReader cSVDataReader = new CSVDataReader(new File("/home/ooechs/Desktop/jasmine-data/sat-training.ssv"));
        CSVDataReader cSVDataReader2 = new CSVDataReader(new File("/home/ooechs/Desktop/jasmine-data/sat-test.ssv"));
        Vector<Data> data = cSVDataReader.getData();
        Vector<Data> data2 = cSVDataReader2.getData();
        new DataStatistics(data);
        new DataStatistics(data2);
        FastStatistics fastStatistics = new FastStatistics();
        FastStatistics fastStatistics2 = new FastStatistics();
        FastStatistics fastStatistics3 = new FastStatistics();
        FastStatistics fastStatistics4 = new FastStatistics();
        for (int i = 2357; i < 2362; i++) {
            for (int i2 = 2; i2 <= 7; i2 += 5) {
                System.out.println("seed= " + i + ", t=" + i2 + ", drs: 1");
                ProblemSettings problemSettings = new ProblemSettings(-1, i, i2);
                problemSettings.DRSMethod = 1 == 0 ? 1 : 2;
                GPMulticlassClassificationProblem gPMulticlassClassificationProblem = new GPMulticlassClassificationProblem(problemSettings, data);
                Evolve evolve = new Evolve(gPMulticlassClassificationProblem, new ConsoleListener(ConsoleListener.SILENT));
                evolve.run();
                Individual bestIndividual = evolve.getBestIndividual();
                FastStatistics fastStatistics5 = problemSettings.DRSMethod == 1 ? fastStatistics3 : fastStatistics;
                FastStatistics fastStatistics6 = problemSettings.DRSMethod == 1 ? fastStatistics4 : fastStatistics2;
                fastStatistics5.addData((float) bestIndividual.getKozaFitness());
                System.out.print(bestIndividual.getKozaFitness());
                System.out.println("," + ((float) gPMulticlassClassificationProblem.test(bestIndividual)));
                fastStatistics6.addData((float) gPMulticlassClassificationProblem.test(bestIndividual));
            }
        }
        System.out.println("Basic DRS (Training): " + fastStatistics3);
        System.out.println("Basic DRS (Testing): " + fastStatistics4);
        System.out.println("---");
        System.out.println("Better DRS (Training): " + fastStatistics);
        System.out.println("Better DRS (Testing): " + fastStatistics2);
        System.out.println("---");
    }

    public GPMulticlassClassificationProblem(ProblemSettings problemSettings, Vector<Data> vector) {
        this.p = problemSettings;
        this.trainingData = vector;
    }

    public ProblemSettings getProblemSettings() {
        return this.p;
    }

    public String getName() {
        return "Multiclass Classification " + this.p.toString();
    }

    public void initialise(Evolve evolve, GPParams gPParams) {
        this.p.apply(evolve);
        gPParams.registerNode(new Add());
        gPParams.registerNode(new Mul());
        gPParams.registerNode(new Sub());
        gPParams.registerNode(new Div());
        gPParams.registerNode(new Mean());
        gPParams.registerNode(new PercentDiff());
        gPParams.registerNode(new Ln());
        gPParams.registerNode(new Squared());
        gPParams.registerNode(new Cubed());
        gPParams.registerNode(new Sin());
        gPParams.registerNode(new Max());
        gPParams.registerNode(new Min());
        gPParams.registerNode(new PercentageERC());
        gPParams.registerNode(new CustomRangeERC(0.0d, 255.0d));
        gPParams.registerNode(new CustomRangeERC(0.0d, 25.0d));
        int columnCount = this.trainingData.elementAt(0).getColumnCount();
        for (int i = 0; i < columnCount; i++) {
            gPParams.registerNode(new DataValueTerminal(i));
        }
    }

    public void customiseParameters(GPParams gPParams) {
        this.p.apply(gPParams);
    }

    private PCM buildProgramClassificationMap(DataStack dataStack, Individual individual) {
        BasicDRS basicDRS = this.p.DRSMethod == 1 ? new BasicDRS() : new BetterDRS();
        for (int i = 0; i < this.trainingData.size(); i++) {
            Data elementAt = this.trainingData.elementAt(i);
            if (elementAt.weight != 0.0f) {
                DataValueTerminal.currentValues = elementAt.values;
                double execute = individual.execute(dataStack);
                if (!dataStack.usesImaging) {
                    return null;
                }
                basicDRS.addResult(execute, elementAt.classID);
            }
        }
        basicDRS.calculateThresholds();
        return basicDRS;
    }

    public void evaluate(Individual individual, DataStack dataStack, Evolve evolve) {
        PCM buildProgramClassificationMap = buildProgramClassificationMap(dataStack, individual);
        if (buildProgramClassificationMap == null) {
            individual.setWorstFitness();
            return;
        }
        int i = 0;
        int i2 = 0;
        double[] dArr = new double[20];
        double[] dArr2 = new double[20];
        Vector cachedResults = buildProgramClassificationMap.getCachedResults();
        for (int i3 = 0; i3 < cachedResults.size(); i3++) {
            CachedOutput cachedOutput = (CachedOutput) cachedResults.elementAt(i3);
            int classFromOutput = buildProgramClassificationMap.getClassFromOutput(cachedOutput.rawOutput);
            int i4 = cachedOutput.expectedClass;
            dArr2[i4] = dArr2[i4] + cachedOutput.weight;
            if (classFromOutput != cachedOutput.expectedClass) {
                int i5 = cachedOutput.expectedClass;
                dArr[i5] = dArr[i5] + cachedOutput.weight;
                i2++;
            } else {
                i++;
            }
        }
        double d = 0.0d;
        for (int i6 = 0; i6 < dArr.length; i6++) {
            if (dArr2[i6] != 0.0d) {
                d += dArr[i6] / dArr2[i6];
            }
        }
        individual.setKozaFitness(d);
        individual.setHits(i);
        individual.setMistakes(i2);
        buildProgramClassificationMap.clearCachedResults();
        individual.setPCM(buildProgramClassificationMap);
    }

    public double test(Individual individual) {
        PCM pcm = individual.getPCM();
        if (pcm == null) {
            return Double.MAX_VALUE;
        }
        int i = 0;
        int i2 = 0;
        double[] dArr = new double[20];
        double[] dArr2 = new double[20];
        pcm.getCachedResults();
        for (int i3 = 0; i3 < this.trainingData.size(); i3++) {
            Data elementAt = this.trainingData.elementAt(i3);
            if (elementAt.weight != 0.0f) {
                DataValueTerminal.currentValues = elementAt.values;
                CachedOutput cachedOutput = new CachedOutput(individual.execute(new DataStack()), elementAt.classID);
                int classFromOutput = pcm.getClassFromOutput(cachedOutput.rawOutput);
                int i4 = cachedOutput.expectedClass;
                dArr2[i4] = dArr2[i4] + cachedOutput.weight;
                if (classFromOutput != cachedOutput.expectedClass) {
                    int i5 = cachedOutput.expectedClass;
                    dArr[i5] = dArr[i5] + cachedOutput.weight;
                    i2++;
                } else {
                    i++;
                }
            }
        }
        double d = 0.0d;
        for (int i6 = 0; i6 < dArr.length; i6++) {
            if (dArr2[i6] != 0.0d) {
                d += dArr[i6] / dArr2[i6];
            }
        }
        return d;
    }
}
