package ac.essex.ooechs.imaging.jasmine.util;

import java.util.Hashtable;
import java.util.Vector;

/* loaded from: input_file:ac/essex/ooechs/imaging/jasmine/util/LDA.class */
public class LDA {
    protected int numFeatures;
    protected double[] globalMean;
    double[][] inverted;
    double[] priorProbabilities;
    protected int totalSamples = 0;
    protected Vector<Integer> classes = new Vector<>(5);
    protected Hashtable<Integer, Group> groups = new Hashtable<>(5);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ac/essex/ooechs/imaging/jasmine/util/LDA$Group.class */
    public final class Group {
        double[] adjustedGroupMean;
        double[] unadjustedMean;
        Vector<double[]> data = new Vector<>();
        int classID;

        public Group(int i) {
            this.classID = i;
            this.unadjustedMean = new double[LDA.this.numFeatures];
        }

        public void add(double[] dArr) {
            this.data.add(dArr);
            for (int i = 0; i < dArr.length; i++) {
                double[] dArr2 = LDA.this.globalMean;
                int i2 = i;
                dArr2[i2] = dArr2[i2] + dArr[i];
                double[] dArr3 = this.unadjustedMean;
                int i3 = i;
                dArr3[i3] = dArr3[i3] + dArr[i];
            }
            LDA.this.totalSamples++;
        }

        public void adjustDataByGlobalMean() {
            this.adjustedGroupMean = new double[LDA.this.numFeatures];
            for (int i = 0; i < this.data.size(); i++) {
                double[] elementAt = this.data.elementAt(i);
                for (int i2 = 0; i2 < elementAt.length; i2++) {
                    double[] dArr = this.adjustedGroupMean;
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + (elementAt[i2] - LDA.this.globalMean[i2]);
                }
            }
            for (int i4 = 0; i4 < LDA.this.numFeatures; i4++) {
                double[] dArr2 = this.adjustedGroupMean;
                int i5 = i4;
                dArr2[i5] = dArr2[i5] / this.data.size();
                double[] dArr3 = this.unadjustedMean;
                int i6 = i4;
                dArr3[i6] = dArr3[i6] / this.data.size();
            }
        }

        public double[][] getCovarianceMatrix() {
            double[][] dArr = new double[LDA.this.numFeatures][LDA.this.numFeatures];
            double size = this.data.size();
            for (int i = 0; i < LDA.this.numFeatures; i++) {
                for (int i2 = 0; i2 < LDA.this.numFeatures; i2++) {
                    double d = 0.0d;
                    for (int i3 = 0; i3 < this.data.size(); i3++) {
                        d += (this.data.elementAt(i3)[i] - LDA.this.globalMean[i]) * (this.data.elementAt(i3)[i2] - LDA.this.globalMean[i2]);
                    }
                    dArr[i][i2] = d / size;
                }
            }
            return dArr;
        }
    }

    public LDA(int i) {
        this.numFeatures = i;
        this.globalMean = new double[i];
    }

    public void add(double[] dArr, int i) {
        if (!this.classes.contains(Integer.valueOf(i))) {
            this.classes.add(Integer.valueOf(i));
            this.groups.put(Integer.valueOf(i), new Group(i));
        }
        this.groups.get(Integer.valueOf(i)).add(dArr);
    }

    public void compute() {
        for (int i = 0; i < this.globalMean.length; i++) {
            double[] dArr = this.globalMean;
            int i2 = i;
            dArr[i2] = dArr[i2] / this.totalSamples;
        }
        this.priorProbabilities = new double[this.classes.size()];
        double[][] dArr2 = new double[this.numFeatures][this.numFeatures];
        for (int i3 = 0; i3 < this.classes.size(); i3++) {
            Group group = this.groups.get(this.classes.elementAt(i3));
            group.adjustDataByGlobalMean();
            double[][] covarianceMatrix = group.getCovarianceMatrix();
            this.priorProbabilities[i3] = group.data.size() / this.totalSamples;
            for (int i4 = 0; i4 < this.numFeatures; i4++) {
                for (int i5 = 0; i5 < this.numFeatures; i5++) {
                    double[] dArr3 = dArr2[i4];
                    int i6 = i5;
                    dArr3[i6] = dArr3[i6] + (this.priorProbabilities[i3] * covarianceMatrix[i4][i5]);
                }
            }
        }
        this.inverted = Matrix.invert(dArr2);
    }

    public int classify(double[] dArr) {
        double d = 0.0d;
        int i = -1;
        for (int i2 = 0; i2 < this.classes.size(); i2++) {
            Integer elementAt = this.classes.elementAt(i2);
            double d2 = 0.0d;
            double d3 = 0.0d;
            Group group = this.groups.get(elementAt);
            for (int i3 = 0; i3 < this.numFeatures; i3++) {
                double d4 = 0.0d;
                for (int i4 = 0; i4 < this.numFeatures; i4++) {
                    d4 += this.inverted[i3][i4] * group.unadjustedMean[i4];
                }
                d2 += d4 * dArr[i3];
                d3 += d4 * group.unadjustedMean[i3];
            }
            double log = (d2 - (0.5d * d3)) + Math.log(this.priorProbabilities[i2]);
            if (i == -1 || log > d) {
                d = log;
                i = elementAt.intValue();
            }
        }
        return i;
    }

    public boolean test(double[] dArr, int i) {
        boolean z = classify(dArr) == i;
        if (z) {
            System.out.println("Correct");
        } else {
            System.out.println("Wrong");
        }
        return z;
    }

    public double test() {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.classes.size(); i++) {
            Group group = this.groups.get(this.classes.elementAt(i));
            for (int i2 = 0; i2 < group.data.size(); i2++) {
                d2 += 1.0d;
                if (classify(group.data.elementAt(i2)) == group.classID) {
                    d += 1.0d;
                }
            }
        }
        return d / d2;
    }

    public static void main(String[] strArr) {
        LDA lda = new LDA(2);
        lda.add(new double[]{2.95d, 6.63d}, 1);
        lda.add(new double[]{2.53d, 7.79d}, 1);
        lda.add(new double[]{3.57d, 5.65d}, 1);
        lda.add(new double[]{3.16d, 5.47d}, 1);
        lda.add(new double[]{2.58d, 4.46d}, 2);
        lda.add(new double[]{2.16d, 6.22d}, 2);
        lda.add(new double[]{3.27d, 3.52d}, 2);
        lda.compute();
        System.out.println(lda.test());
    }
}
