package ac.essex.ooechs.imaging.gp.problems.classification.distance;

import ac.essex.gp.Evolve;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.interfaces.graphical.GraphicalListener;
import ac.essex.gp.multiclass.CachedOutput;
import ac.essex.gp.nodes.Add;
import ac.essex.gp.nodes.Div;
import ac.essex.gp.nodes.Mean;
import ac.essex.gp.nodes.Mul;
import ac.essex.gp.nodes.PercentDiff;
import ac.essex.gp.nodes.Sub;
import ac.essex.gp.nodes.ercs.LargeIntERC;
import ac.essex.gp.nodes.ercs.PercentageERC;
import ac.essex.gp.nodes.ercs.SmallDoubleERC;
import ac.essex.gp.nodes.ercs.SmallIntERC;
import ac.essex.gp.nodes.ercs.TinyDoubleERC;
import ac.essex.gp.nodes.generic.CSVFeature;
import ac.essex.gp.nodes.math.Exp;
import ac.essex.gp.nodes.math.Log;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.problems.DataStack;
import ac.essex.gp.tree.Terminal;
import ac.essex.ooechs.imaging.commons.PixelLoader;
import ac.essex.ooechs.imaging.commons.fast.FastStatistics;
import ac.essex.ooechs.imaging.commons.util.CSVWriter;
import ac.essex.ooechs.imaging.commons.util.ImageFilenameFilter;
import ac.essex.ooechs.imaging.commons.window.data.Window;
import ac.essex.ooechs.imaging.commons.window.data.WindowClass;
import ac.essex.ooechs.imaging.commons.window.util.WindowFeatures;
import ac.essex.ooechs.imaging.gp.problems.classification.BasicClassificationProblem;
import ac.essex.ooechs.imaging.gp.problems.classification.icvs08.experiments.ICVSExperiments;
import java.awt.Color;
import java.io.File;
import java.io.IOException;
import java.util.Vector;

/* loaded from: input_file:ac/essex/ooechs/imaging/gp/problems/classification/distance/DistanceFunctionProblem.class */
public class DistanceFunctionProblem extends BasicClassificationProblem {
    protected int correctClassID;

    public static void main(String[] strArr) throws IOException {
        generateCSVFile(new File("/home/ooechs/Desktop/pipe-images3/foreground/"), new File("/home/ooechs/Desktop/pipe-images3/background/"));
        GraphicalListener graphicalListener = new GraphicalListener();
        File file = new File("/home/ooechs/Desktop/newpipes.csv");
        new Evolve(new DistanceFunctionProblem(file, file, 3), graphicalListener).start();
    }

    public static void generateCSVFile(File file, File file2) throws IOException {
        WindowClass windowClass = new WindowClass("Background", 1, Color.RED);
        WindowClass windowClass2 = new WindowClass("Foreground", 1, Color.GREEN);
        CSVWriter cSVWriter = new CSVWriter();
        int i = 0;
        try {
            for (File file3 : file.listFiles()) {
                if (ImageFilenameFilter.isImage(file3)) {
                    PixelLoader pixelLoader = new PixelLoader(file3);
                    cSVWriter.addData(WindowFeatures.getFeatures(pixelLoader, new Window(pixelLoader.getWidth(), pixelLoader.getHeight(), 0, 0, windowClass2)));
                    cSVWriter.addData(windowClass2.getClassID());
                    cSVWriter.newLine();
                    i++;
                }
            }
            System.out.println("\nSaved: " + i + " images.");
        } catch (Exception e) {
            e.printStackTrace();
        }
        int i2 = 0;
        try {
            for (File file4 : file2.listFiles()) {
                if (ImageFilenameFilter.isImage(file4)) {
                    PixelLoader pixelLoader2 = new PixelLoader(file4);
                    cSVWriter.addData(WindowFeatures.getFeatures(pixelLoader2, new Window(pixelLoader2.getWidth(), pixelLoader2.getHeight(), 0, 0, windowClass)));
                    cSVWriter.addData(windowClass.getClassID());
                    cSVWriter.newLine();
                    i2++;
                }
            }
            System.out.println("\nSaved: " + i2 + " images.");
        } catch (Exception e2) {
            e2.printStackTrace();
        }
        cSVWriter.save(new File("/home/ooechs/Desktop/newpipes.csv"));
    }

    public DistanceFunctionProblem(File file, File file2, int i) {
        super(file, file2);
        this.correctClassID = i;
    }

    @Override // ac.essex.ooechs.imaging.gp.problems.classification.BasicClassificationProblem
    public int execute(Individual individual, DataStack dataStack) {
        if (Math.abs(individual.execute(dataStack) - individual.getCustomValue2()) < individual.getCustomValue1()) {
            return this.correctClassID;
        }
        return -1;
    }

    public void initialise(Evolve evolve, GPParams gPParams) {
        CSVFeature.NORMALISING = ICVSExperiments.NORMALISING;
        loadData(evolve);
        gPParams.registerNode(new Add());
        gPParams.registerNode(new Mul());
        gPParams.registerNode(new Sub());
        gPParams.registerNode(new Div());
        gPParams.registerNode(new Mean());
        gPParams.registerNode(new PercentDiff());
        gPParams.registerNode(new Log());
        gPParams.registerNode(new Exp());
        gPParams.registerNode(new SmallIntERC());
        gPParams.registerNode(new SmallDoubleERC());
        gPParams.registerNode(new TinyDoubleERC());
        gPParams.registerNode(new PercentageERC());
        gPParams.registerNode(new LargeIntERC());
        Vector<Terminal> features = getFeatures();
        for (int i = 0; i < features.size(); i++) {
            registerTerminal(gPParams, features.elementAt(i));
        }
        gPParams.setIgnoreNonTerminalWarnings(true);
        gPParams.setReturnType(2);
    }

    public void customiseParameters(GPParams gPParams) {
    }

    public void evaluate(Individual individual, DataStack dataStack, Evolve evolve) {
        Vector vector = new Vector();
        FastStatistics fastStatistics = new FastStatistics();
        int i = 0;
        for (int i2 = 0; i2 < getTrainingCount(); i2++) {
            setupDataStackForTraining(dataStack, i2);
            int trainingClassID = getTrainingClassID(i2);
            double execute = individual.execute(dataStack);
            if (trainingClassID == this.correctClassID) {
                fastStatistics.addData((float) execute);
                i++;
            }
            vector.add(new CachedOutput(execute, trainingClassID));
        }
        double mean = fastStatistics.getMean();
        double d = 0.0d;
        double d2 = 1.0d;
        double d3 = 0.0d;
        int i3 = 0;
        while (d2 > 0.95d && d < 10.0d) {
            i3 = 0;
            d += 1.0d;
            double d4 = 0.0d;
            d3 = 0.0d;
            for (int i4 = 0; i4 < vector.size(); i4++) {
                CachedOutput cachedOutput = (CachedOutput) vector.elementAt(i4);
                if (Math.abs(cachedOutput.rawOutput - mean) < d) {
                    if (cachedOutput.expectedClass == this.correctClassID) {
                        d3 += 1.0d;
                    } else {
                        i3++;
                    }
                    d4 += 1.0d;
                }
            }
            if (d4 > 0.0d) {
                d2 = d3 / d4;
            }
        }
        individual.setKozaFitness(i - (d3 * d2));
        individual.setMistakes(i3);
        individual.setCustomValue1(d);
        individual.setCustomValue2(mean);
    }

    public String getMethodSignature(Individual individual) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("final int correctClassID = " + this.correctClassID);
        stringBuffer.append(";\nfinal double threshold = " + individual.getCustomValue1());
        stringBuffer.append(";\n");
        stringBuffer.append("final double mean = " + individual.getCustomValue2());
        stringBuffer.append(";\n\n");
        stringBuffer.append("public int classify(double[] feature) {\n");
        stringBuffer.append("  if (Math.abs(eval(feature) - mean) < threshold) return ");
        stringBuffer.append(this.correctClassID);
        stringBuffer.append(";\n  else return -1;\n}\n\n");
        stringBuffer.append("protected double eval(double[] feature)");
        return stringBuffer.toString();
    }
}
