package ac.essex.gp.multiclass;

import java.io.Serializable;
import java.util.Vector;

/**
 * My version of DRS
 *
 * @author Olly Oechsle, University of Essex, Date: 27-Feb-2008
 */
public class BetterDRS extends PCM implements Serializable {

    public static final int TYPE = 2;

    private int defaultClass = -1;
    private int numSlots = 50;
    private int[] slots;
    private double MIN = Double.MAX_VALUE, MAX = Double.MIN_VALUE;

    public BetterDRS() {
        this(50);
    }

    public BetterDRS(int numSlots) {
        this(numSlots, -1);
    }

    public BetterDRS(int numSlots, int defaultClass) {
        this.defaultClass = defaultClass;
        this.numSlots = numSlots;
        slots = new int[numSlots + 1];
    }

    public static void main(String[] args) {
        BetterDRS drs = new BetterDRS(3);
        drs.addResult(1, 5);
        drs.addResult(1, 5);
        drs.addResult(2, 5);
        drs.addResult(2, 7);
        drs.addResult(2, 5);
        drs.addResult(2, 8);
        drs.addResult(3, 8);

        drs.calculateThresholds();
        System.out.println(drs.getHits());
        for (int i = 1; i <= 3; i++) {
            int classID = drs.getClassFromOutput(i);
            float confidence1 = drs.getConfidence(5, i);
            float confidence2 = drs.getConfidence(8, i);
            float confidence3 = drs.getConfidence(7, i);
            System.out.println(i + ", " + classID + ", " + confidence1 + ", " + confidence2 + ", " + confidence3);
        }

    }

    public BetterDRS(double MIN, double MAX, int[] slots) {
        this.MIN = MIN;
        this.MAX = MAX;
        this.slots = slots;
        this.numSlots = slots.length - 1;
    }


    /**
     * Gives a result to the program classification map. This information can be used
     * to calculate good values for the threshold using the calculateThresholds method.
     *
     * @param output  The raw output from the program
     * @param classID The class that we expect to see
     */
    public void addResult(double output, int classID) {
        addResult(output, classID, 1);
    }

    /**
     * Gives a result to the program classification map. This information can be used
     * to calculate good values for the threshold using the calculateThresholds method.
     *
     * @param output  The raw output from the program
     * @param classID The class that we expect to see
     * @param weight  The weight associated with this piece of data
     */
    public void addResult(double output, int classID, double weight) {
        super.addResult(output, classID, weight);
        if (output > MAX) MAX = output;
        if (output < MIN) MIN = output;
    }

    protected int slotCount[][];

    /**
     * Calculates the thresholds in some way so as to enable classification.
     */
    public void calculateThresholds() {

        //Vector<Integer> classes = discoverClasses();

        int maxClassID = 0;
        for (int i = 0; i < classes.size(); i++) {
            int classID = classes.elementAt(i);
            if (classID > maxClassID) maxClassID = classID;
        }

        slotCount = new int[slots.length][maxClassID + 1];

        for (int i = 0; i < cachedResults.size(); i++) {
            CachedOutput cachedOutput = cachedResults.elementAt(i);
            // which slot does the output fit into?
            int slot = getSlotIndex(cachedOutput.rawOutput);
            slotCount[slot][cachedOutput.expectedClass]++;
        }

        for (int SLOT = 0; SLOT < slots.length; SLOT++) {
            int CL = defaultClass;
            int highest = 0;
            for (int classID = 0; classID < slotCount[SLOT].length; classID++) {
                if (slotCount[SLOT][classID] > highest) {
                    CL = classID;
                    highest = slotCount[SLOT][classID];
                }
            }
            slots[SLOT] = CL;
        }

    }

    public int getSlotIndex(double raw) {
        if (raw > MAX) raw = MAX;
        else if (raw < MIN) raw = MIN;
        double RANGE = MAX - MIN;
        double adjusted = ((raw - MIN) / RANGE) * numSlots;
        return (int) adjusted;
    }

    /**
     * Given an output, returns the classification
     */
    public int getClassFromOutput(double raw) {
        return slots[getSlotIndex(raw)];
    }

    /**
     * Returns the DRS classifier's confidence that it relates to
     * a particular classID.
     * Warning - Will not work on individuals initialised using the toJava
     * method code.
     */
    public float getConfidence(int classID, double raw) {
        // first get the slot index
        int SLOT = getSlotIndex(raw);
        float qty = 0;
        int total = 0;
        for (int j = 0; j < slotCount[SLOT].length; j++) {
            if (j == classID) {
                qty = slotCount[SLOT][j];
            }
            total += slotCount[SLOT][j];
        }
        return qty / total;
    }

    /**
     * Turns the map into java that allows it to be reinstantiated.
     */
    public String toJava() {
        StringBuffer buffer = new StringBuffer();
        buffer.append("new BetterDRS(");
        buffer.append(MIN);
        buffer.append(",");
        buffer.append(MAX);
        buffer.append(",new int[]{");
        for (int i = 0; i < slots.length; i++) {
            buffer.append(slots[i]);
            if (i < slots.length - 1) buffer.append(",");
        }
        buffer.append("});");
        return buffer.toString();
    }

}
