package ac.essex.gp.markov;

import java.util.Vector;
import java.util.Hashtable;

/**
 * Decoding
 * Finding the hidden states athat generated the observed
 * output. The most probable sequence of hidden states given
 * a sequence of observations is calculated using the Viterbi
 * Algorithm.
 * <p/>
 * Viterbi is also used in Natural Langugae Processing to tag
 * words with their syntactic class (noun, verb etc). Words are the
 * observable states and syntactic classes are hidden states.
 *
 * @author Olly Oechsle, University of Essex, Date: 05-Mar-2007
 * @version 1.0
 */
public class ViterbiAlgorithm {

    public static void main(String[] args) {

        Vector<ObservedState>observedStates = new Vector<ObservedState>(10);

        ObservedState dry = new ObservedState(0, "Dry");
        ObservedState dryish = new ObservedState(1, "Dryish");
        ObservedState damp = new ObservedState(2, "Damp");
        ObservedState soggy = new ObservedState(3, "Soggy");
        observedStates.add(dry);
        observedStates.add(dryish);
        observedStates.add(damp);
        observedStates.add(soggy);

        Vector<HiddenState>hiddenStates = new Vector<HiddenState>(10);

        // 1. create the states
        HiddenState sun = new HiddenState(0, "Sun");
        HiddenState cloud = new HiddenState(1, "Cloudy");
        HiddenState rain = new HiddenState(2, "Rainy");
        hiddenStates.add(sun);
        hiddenStates.add(cloud);
        hiddenStates.add(rain);

        // 2. set the state at time 0 in the pi vector
        PiVector initialProbabilities = new PiVector(hiddenStates);
        initialProbabilities.setProbability(sun, 0.63);
        initialProbabilities.setProbability(cloud, 0.17);
        initialProbabilities.setProbability(rain, 0.20);

        // 3. define the state transition probabilities
        StateTransitionMatrix stm = new StateTransitionMatrix(hiddenStates);
        stm.setProbability(sun, sun, 0.5);
        stm.setProbability(sun, cloud, 0.25);
        stm.setProbability(sun, rain, 0.25);

        stm.setProbability(cloud, sun, 0.375);
        stm.setProbability(cloud, cloud, 0.125);
        stm.setProbability(cloud, rain, 0.375);

        stm.setProbability(rain, sun, 0.125);
        stm.setProbability(rain, cloud, 0.625);
        stm.setProbability(rain, rain, 0.375);

        // 4. Set up the confusion matrix
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(hiddenStates, observedStates);
        confusionMatrix.setProbability(sun, dry, 0.60);
        confusionMatrix.setProbability(sun, dryish, 0.20);
        confusionMatrix.setProbability(sun, damp, 0.15);
        confusionMatrix.setProbability(sun, soggy, 0.05);

        confusionMatrix.setProbability(cloud, dry, 0.25);
        confusionMatrix.setProbability(cloud, dryish, 0.25);
        confusionMatrix.setProbability(cloud, damp, 0.25);
        confusionMatrix.setProbability(cloud, soggy, 0.25);

        confusionMatrix.setProbability(rain, dry, 0.05);
        confusionMatrix.setProbability(rain, dryish, 0.10);
        confusionMatrix.setProbability(rain, damp, 0.35);
        confusionMatrix.setProbability(rain, soggy, 0.50);

        // All this data is stored in the HiddenMarkovModel class
        HiddenMarkovModel hmm = new HiddenMarkovModel(hiddenStates, observedStates, initialProbabilities, stm, confusionMatrix);
        
        // What do we observe?
        Vector<ObservedState> observations = new Vector<ObservedState>(10);
        observations.add(soggy);
        observations.add(damp);
        observations.add(dry);
        
        ViterbiAlgorithm va = new ViterbiAlgorithm();
        ViterbiPath t = va.forward_viterbi(observations, hmm);

        System.out.println("Done>");
        System.out.println("t.totalProbability: " + t.totalProbability);
        System.out.println("t.viterbiPathProbabililty: " + t.viterbiPathProbabililty);
        System.out.print("t.path: ");
        for (int i = 0; i < t.path.size(); i++) {
            System.out.print(t.path.elementAt(i) + " > ");
        }


    }

    public ViterbiPath forward_viterbi(Vector<ObservedState> observedStates, HiddenMarkovModel hmm) {
        return forward_viterbi(observedStates, hmm.hiddenStates, hmm.initialProbabilities, hmm.stateTransitionMatrix, hmm.confusionMatrix);
    }


    /**
     * Finds the most likely sequence of hidden states that would explain a set of observed states.
     * @param observations  The sequence of observations
     * @param hiddenStates  The set of hidden states
     * @param initialProbabilities Start Probabilities
     * @param stateTransitionMatrix Transition Probabilities
     * @param confusionMatrix Emission Probabilities
     */
    public ViterbiPath forward_viterbi(Vector<ObservedState> observations, Vector<HiddenState> hiddenStates, PiVector initialProbabilities, StateTransitionMatrix stateTransitionMatrix, ConfusionMatrix confusionMatrix) {


        // Initialise the possible paths by looking at all possible hidden states from which to start
        Hashtable<HiddenState, ViterbiPath> paths = new Hashtable<HiddenState, ViterbiPath>();
        for (HiddenState state : hiddenStates) {

            // Each hidden state has an initial probability
            double prob = initialProbabilities.getProbability(state);

            // The path is initialised consisting of just this state
            Vector<HiddenState> v_path = new Vector<HiddenState>();
            v_path.add(state);

            // Save it to a mapping where we can access this information again later
            paths.put(state, new ViterbiPath(prob, v_path, prob));
            
        }

        // the main loop considers the each observation in sequence. We're looking for the sequence of states
        // that is most likely to reflect these observations.
        for (ObservedState observation : observations) {

            // Initialise a second mapping which will store the new paths
            Hashtable<HiddenState, ViterbiPath> newPaths = new Hashtable<HiddenState, ViterbiPath>();

            // consider each possible state to which we could move.
            for (HiddenState nextState : hiddenStates) {

                double totalProbability = 0;
                Vector<HiddenState> path = null;
                double highestViterbiPathProbability = 0;

                // consider every transition from currentState -> nextState (as this also has a given probability)
                for (HiddenState currentState : hiddenStates) {

                    // look up the details of the path to this state
                    ViterbiPath currentPath = paths.get(currentState);
                    double prob = currentPath.totalProbability;
                    double v_prob = currentPath.viterbiPathProbabililty;

                    // the total probability =
                    // the probability that this state will transition to the next state
                    // multiplied by
                    // the probability that the current state is responsible for the observation.
                    // (This is the value that needs to be maximised)
                    double p = stateTransitionMatrix.getProbability(currentState, nextState) * confusionMatrix.getProbability(currentState, observation);


                    prob *= p;

                    v_prob *= p;

                    totalProbability += prob;

                    // if this viterbi path has the highest probability, then remember this path
                    if (v_prob > highestViterbiPathProbability) {
                        // copy the old path
                        path = (Vector<HiddenState>) currentPath.path.clone();
                        // add the next state onto it
                        path.add(nextState);
                        highestViterbiPathProbability = v_prob;
                    }

                }

                // the loop above returns the best way of getting to the next state
                newPaths.put(nextState, new ViterbiPath(totalProbability, path, highestViterbiPathProbability));

            }

            // swap over the lookup tables - we only need to see one step behind
            paths = newPaths;

        }

        // We now have a set of possible paths. Choose the most likely one
        double totalProbability = 0;
        Vector<HiddenState> path = null;
        double highestViterbiPathProbability = 0;

        // there is one path per state - go through each one
        for (HiddenState state : hiddenStates) {
            ViterbiPath t = paths.get(state);
            totalProbability += t.totalProbability;

            // find the highest probability
            if (t.viterbiPathProbabililty > highestViterbiPathProbability) {
                path = t.path;
                highestViterbiPathProbability = t.viterbiPathProbabililty;
            }
        }

        // return the ideal path
        return new ViterbiPath(totalProbability, path, highestViterbiPathProbability);

    }

    private class ViterbiPath {

        /**
         * Total probability of all paths from the starting state to the current state
         */
        protected double totalProbability;

        /**
         * The Viterbi path up to the current state
         */
        protected Vector<HiddenState> path;

        /**
         * The probability of the Viterbi path up to the current state
         */
        protected double viterbiPathProbabililty;

        public ViterbiPath(double prob, Vector<HiddenState> path, double viterbiPathProbability) {
            this.totalProbability = prob;
            this.path = path;
            this.viterbiPathProbabililty = viterbiPathProbability;
        }
        
    }

}
