package ac.ooechs.classify.classifier.gp;

import ac.essex.gp.Evolve;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.multiclass.BasicDRS;
import ac.essex.gp.multiclass.BetterDRS;
import ac.essex.gp.multiclass.CachedOutput;
import ac.essex.gp.multiclass.PCM;
import ac.essex.gp.nodes.Add;
import ac.essex.gp.nodes.Div;
import ac.essex.gp.nodes.Max;
import ac.essex.gp.nodes.Mean;
import ac.essex.gp.nodes.Min;
import ac.essex.gp.nodes.Mul;
import ac.essex.gp.nodes.PercentDiff;
import ac.essex.gp.nodes.Sub;
import ac.essex.gp.nodes.ercs.PercentageERC;
import ac.essex.gp.nodes.ercs.SmallDoubleERC;
import ac.essex.gp.nodes.ercs.SmallIntERC;
import ac.essex.gp.nodes.ercs.TinyDoubleERC;
import ac.essex.gp.nodes.math.Exp;
import ac.essex.gp.nodes.math.Log;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.problems.DataStack;
import ac.essex.gp.problems.Problem;
import ac.ooechs.classify.data.Data;
import ac.ooechs.classify.data.DataStatistics;
import ac.ooechs.classify.data.io.CSVDataReader;
import java.io.File;
import java.io.IOException;
import java.util.Vector;

/* loaded from: input_file:ac/ooechs/classify/classifier/gp/GPMulticlassClassificationProblem.class */
public class GPMulticlassClassificationProblem extends Problem {
    protected Vector<Data> trainingData;
    protected ProblemSettings p;

    public static void main(String[] strArr) throws IOException {
        Vector<Data> data = new CSVDataReader(new File("/home/ooechs/Desktop/jasmine-data/sat-training.ssv")).getData();
        new DataStatistics(data);
        new Evolve(new GPMulticlassClassificationProblem(new ProblemSettings(60, 2357, 2), data)).start();
    }

    public GPMulticlassClassificationProblem(ProblemSettings problemSettings, Vector<Data> vector) {
        this.p = problemSettings;
        this.trainingData = vector;
    }

    public ProblemSettings getProblemSettings() {
        return this.p;
    }

    public String getName() {
        return "Classification " + this.p.toString();
    }

    public void initialise(Evolve evolve, GPParams gPParams) {
        this.p.apply(evolve);
        gPParams.registerNode(new Add());
        gPParams.registerNode(new Mul());
        gPParams.registerNode(new Sub());
        gPParams.registerNode(new Div());
        gPParams.registerNode(new Mean());
        gPParams.registerNode(new PercentDiff());
        gPParams.registerNode(new Log());
        gPParams.registerNode(new Exp());
        gPParams.registerNode(new Max());
        gPParams.registerNode(new Min());
        gPParams.registerNode(new SmallIntERC());
        gPParams.registerNode(new SmallDoubleERC());
        gPParams.registerNode(new TinyDoubleERC());
        gPParams.registerNode(new PercentageERC());
        int columnCount = this.trainingData.elementAt(0).getColumnCount();
        for (int i = 0; i < columnCount; i++) {
            gPParams.registerNode(new DataValueTerminal(i));
        }
    }

    public void customiseParameters(GPParams gPParams) {
        this.p.apply(gPParams);
    }

    private PCM buildProgramClassificationMap(DataStack dataStack, Individual individual) {
        BasicDRS basicDRS = this.p.DRSMethod == 1 ? new BasicDRS() : new BetterDRS();
        for (int i = 0; i < this.trainingData.size(); i++) {
            Data elementAt = this.trainingData.elementAt(i);
            if (elementAt.weight != 0.0f) {
                DataValueTerminal.currentValues = elementAt.values;
                double execute = individual.execute(dataStack);
                if (!dataStack.usesImaging) {
                    return null;
                }
                basicDRS.addResult(execute, elementAt.classID);
            }
        }
        basicDRS.calculateThresholds();
        return basicDRS;
    }

    public void evaluate(Individual individual, DataStack dataStack, Evolve evolve) {
        PCM buildProgramClassificationMap = buildProgramClassificationMap(dataStack, individual);
        if (buildProgramClassificationMap == null) {
            individual.setWorstFitness();
            return;
        }
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = 0.0f;
        float f4 = 0.0f;
        float f5 = 0.0f;
        int i = 0;
        int i2 = 0;
        Vector cachedResults = buildProgramClassificationMap.getCachedResults();
        for (int i3 = 0; i3 < cachedResults.size(); i3++) {
            CachedOutput cachedOutput = (CachedOutput) cachedResults.elementAt(i3);
            f5 = (float) (f5 + cachedOutput.weight);
            int classFromOutput = buildProgramClassificationMap.getClassFromOutput(cachedOutput.rawOutput);
            if (classFromOutput == 0) {
                if (cachedOutput.expectedClass == 0) {
                    f2 = (float) (f2 + cachedOutput.weight);
                } else {
                    i2++;
                    f4 = (float) (f4 + cachedOutput.weight);
                }
            } else if (classFromOutput == cachedOutput.expectedClass) {
                i++;
                f = (float) (f + cachedOutput.weight);
            } else {
                i2++;
                f3 = (float) (f3 + cachedOutput.weight);
            }
        }
        individual.setKozaFitness((f3 + f4) / f5);
        individual.setHits(i);
        individual.setMistakes(i2);
        individual.setAlternativeFitness(f4);
        buildProgramClassificationMap.clearCachedResults();
        individual.setPCM(buildProgramClassificationMap);
    }
}
