package ac.essex.ooechs.imaging.gp.problems.classification.boosting;

import ac.essex.gp.Evolve;
import ac.essex.gp.interfaces.console.ConsoleListener;
import ac.essex.gp.params.GPParams;
import ac.essex.ooechs.adaboost.AdaBoost;
import ac.essex.ooechs.adaboost.AdaBoostLearner;
import ac.essex.ooechs.adaboost.AdaBoostSample;
import ac.essex.ooechs.imaging.gp.problems.classification.BasicClassificationProblem;
import ac.essex.ooechs.imaging.gp.problems.classification.rangeselection.DRSProblem;
import java.io.File;

/* loaded from: input_file:ac/essex/ooechs/imaging/gp/problems/classification/boosting/Adaboost_GeneticProgramming.class */
public class Adaboost_GeneticProgramming extends AdaBoost {
    protected BasicClassificationProblem problem;

    public static void main(String[] strArr) throws Exception {
        DRSProblem dRSProblem = new DRSProblem(new File("/home/ooechs/Desktop/sat-training.ssv"), new File("/home/ooechs/Desktop/sat-test.ssv"));
        dRSProblem.verbose = false;
        dRSProblem.loadData(null);
        dRSProblem.pruneTrainingSize(400);
        long currentTimeMillis = System.currentTimeMillis();
        new Adaboost_GeneticProgramming(dRSProblem).boost(20);
        System.out.println("Time: " + (System.currentTimeMillis() - currentTimeMillis));
    }

    public void onIterationEnd() {
        System.out.println("TESTING:");
        GPAdaBoostM1Learner.TRAINING = false;
        test(getTestSamples());
        GPAdaBoostM1Learner.TRAINING = true;
    }

    public Adaboost_GeneticProgramming(BasicClassificationProblem basicClassificationProblem) {
        this.problem = basicClassificationProblem;
    }

    public AdaBoostSample[] getSamples() {
        AdaBoostSample[] adaBoostSampleArr = new AdaBoostSample[this.problem.getTrainingCount()];
        for (int i = 0; i < this.problem.getTrainingCount(); i++) {
            adaBoostSampleArr[i] = new AdaBoostSample(Integer.valueOf(i), this.problem.getTrainingClassID(i));
        }
        return adaBoostSampleArr;
    }

    public AdaBoostSample[] getTestSamples() {
        AdaBoostSample[] adaBoostSampleArr = new AdaBoostSample[this.problem.getTestCount()];
        for (int i = 0; i < this.problem.getTestCount(); i++) {
            adaBoostSampleArr[i] = new AdaBoostSample(Integer.valueOf(i), this.problem.getTestClassID(i));
        }
        return adaBoostSampleArr;
    }

    protected AdaBoostLearner weakLearn(AdaBoostSample[] adaBoostSampleArr, double[] dArr) {
        GPParams gPParams = new GPParams();
        gPParams.setGenerations(100);
        gPParams.setMaxTime(200);
        this.problem.setWeights(dArr);
        Evolve evolve = new Evolve(this.problem, new ConsoleListener(1), gPParams);
        evolve.run();
        System.out.println(evolve.getBestIndividual().getHits() + " hits");
        return new GPAdaBoostM1Learner(this.problem, evolve.getBestIndividual());
    }
}
