package ac.essex.ooechs.adaboost;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.Vector;

/* loaded from: input_file:ac/essex/ooechs/adaboost/AdaBoost.class */
public abstract class AdaBoost implements Serializable {
    protected AdaBoostLearner[] h;
    protected double[] b;
    protected Vector<AdaBoostSolution> solutions;
    protected Vector<Integer> scores;

    public abstract AdaBoostSample[] getSamples();

    protected abstract AdaBoostLearner weakLearn(AdaBoostSample[] adaBoostSampleArr, double[] dArr);

    public void boost(int i) {
        this.solutions = new Vector<>(i);
        this.scores = new Vector<>(i);
        AdaBoostSample[] samples = getSamples();
        int length = samples.length;
        double[] dArr = new double[length];
        new DecimalFormat("0.00000000");
        double d = 1.0d / length;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = d;
        }
        this.h = new AdaBoostLearner[i];
        this.b = new double[i];
        int i3 = 0;
        while (true) {
            if (i3 >= i) {
                break;
            }
            message("ADABOOST: t=" + i3);
            this.h[i3] = weakLearn(samples, dArr);
            double d2 = 0.0d;
            boolean[] zArr = new boolean[length];
            for (int i4 = 0; i4 < samples.length; i4++) {
                zArr[i4] = this.h[i3].classify(samples[i4], null) == samples[i4].getLabel();
                if (zArr[i4]) {
                    this.h[i3].individualTP++;
                } else {
                    d2 += dArr[i4];
                }
            }
            if (d2 > 0.5d) {
                message("ADABOOST: Rejecting individual: e=" + d2 + ", TP=" + this.h[i3].individualTP + " / " + samples.length);
                this.h[i3] = null;
                break;
            }
            if (d2 == 0.0d) {
                message("ADABOOST: Ideal Individual");
                break;
            }
            this.b[i3] = d2 / (1.0d - d2);
            double d3 = 0.0d;
            for (int i5 = 0; i5 < samples.length; i5++) {
                dArr[i5] = dArr[i5] * (zArr[i5] ? this.b[i3] : 1.0d);
                d3 += dArr[i5];
            }
            for (int i6 = 0; i6 < samples.length; i6++) {
                int i7 = i6;
                dArr[i7] = dArr[i7] / d3;
            }
            int test = test(samples);
            AdaBoostSolution solution = getSolution();
            if (test != test(samples, solution)) {
                message("Solution is WRONG");
            }
            this.solutions.add(solution);
            this.scores.add(Integer.valueOf(test));
            if (test == samples.length) {
                message("Found ideal classifier.");
                break;
            } else {
                onIterationEnd();
                i3++;
            }
        }
        finish();
    }

    public void onIterationEnd() {
        message("Best so far: " + getBestSolution());
    }

    public void finish() {
        message("AdaBoost complete");
    }

    public void message(String str) {
        System.out.println(str);
    }

    public int classify(AdaBoostSample adaBoostSample) {
        double[] dArr = new double[100];
        int i = -1;
        double d = Double.MIN_VALUE;
        for (int i2 = 0; i2 < this.h.length; i2++) {
            if (this.h[i2] != null) {
                int classify = this.h[i2].classify(adaBoostSample, null);
                dArr[classify] = dArr[classify] + Math.log10(1.0d / this.b[i2]);
                if (dArr[classify] > d) {
                    d = dArr[classify];
                    i = classify;
                }
            }
        }
        return i;
    }

    public int test(AdaBoostSample[] adaBoostSampleArr) {
        int i = 0;
        for (AdaBoostSample adaBoostSample : adaBoostSampleArr) {
            if (classify(adaBoostSample) == adaBoostSample.getLabel()) {
                i++;
            }
        }
        double length = i / adaBoostSampleArr.length;
        DecimalFormat decimalFormat = new DecimalFormat("0.000");
        message("Adaboost TP: " + i + " / " + adaBoostSampleArr.length + " (" + decimalFormat.format(length * 100.0d) + "%, e=" + decimalFormat.format(1.0d - length) + ")");
        return i;
    }

    public int test(AdaBoostSample[] adaBoostSampleArr, AdaBoostSolution adaBoostSolution) {
        int i = 0;
        for (AdaBoostSample adaBoostSample : adaBoostSampleArr) {
            if (adaBoostSolution.classify(adaBoostSample, null) == adaBoostSample.getLabel()) {
                i++;
            }
        }
        double length = i / adaBoostSampleArr.length;
        DecimalFormat decimalFormat = new DecimalFormat("0.000");
        message("Solution TP: " + i + " / " + adaBoostSampleArr.length + " (" + decimalFormat.format(length * 100.0d) + "%, e=" + decimalFormat.format(1.0d - length) + ")");
        return i;
    }

    protected AdaBoostSolution getSolution() {
        AdaBoostSolution adaBoostSolution = new AdaBoostSolution(this.h.length);
        for (int i = 0; i < this.h.length; i++) {
            adaBoostSolution.addHypothesis(this.h[i], this.b[i]);
        }
        return adaBoostSolution;
    }

    public AdaBoostSolution getBestSolution() {
        if (this.solutions == null) {
            return null;
        }
        int i = -1;
        AdaBoostSolution adaBoostSolution = null;
        for (int i2 = 0; i2 < this.solutions.size(); i2++) {
            AdaBoostSolution elementAt = this.solutions.elementAt(i2);
            int intValue = this.scores.elementAt(i2).intValue();
            if (adaBoostSolution == null || intValue > i || (intValue == i && elementAt.h.length < adaBoostSolution.h.length)) {
                adaBoostSolution = elementAt;
                i = intValue;
            }
        }
        return adaBoostSolution;
    }
}
