package ac.essex.ooechs.problems;

import ac.essex.ooechs.imaging.shapes.ExtraShapeData;
import ac.essex.ooechs.imaging.shapes.SegmentedShape;
import ac.essex.ooechs.imaging.jasmine.JasmineProject;
import ac.essex.ooechs.imaging.jasmine.JasmineImage;
import ac.essex.ooechs.imaging.jasmine.JasmineClass;
import ac.essex.ooechs.kmeans.KMeansClusterer;
import ac.essex.ooechs.kmeans.ClusterClass;
import ac.essex.ooechs.kmeans.KMeansSolution;
import ac.essex.ooechs.kmeans.DataPoint;

import java.util.Vector;
import java.io.File;
import java.text.DecimalFormat;

/**
 * Attempts to cluster the shapes found in a Jasmine Project.
 *
 * @author Olly Oechsle, University of Essex, Date: 27-Apr-2007
 * @version 1.0
 */
public class ShapeClusteringProblem {

    public static void main(String[] args) throws Exception {

        String TRAINING_DATA = "ANPR-224.jasmine";
        String UNSEEN_DATA = "ANPR-Unseen.jasmine";
        JasmineProject trainingProject = JasmineProject.load(new File("/home/ooechs/Desktop/JasmineProjects/" + TRAINING_DATA));
        JasmineProject unseenProject = JasmineProject.load(new File("/home/ooechs/Desktop/JasmineProjects/" + UNSEEN_DATA));
        ShapeClusteringProblem p = new ShapeClusteringProblem();

        for (int i = 0; i < 5; i++) {
            p.run(trainingProject, unseenProject);
        }

    }

    public void run(JasmineProject trainingProject, JasmineProject unseenProject) {


        Vector<ExtraShapeData> shapes = new Vector<ExtraShapeData>(100);
        Vector<Integer> distinctClassIDs = new Vector<Integer>(10);
        Vector<ClusterClass> classes = new Vector<ClusterClass>(10);

        // now get shapes
        for (int i = 0; i < trainingProject.getImages().size(); i++) {
            JasmineImage image = trainingProject.getImages().elementAt(i);
            if (image.getShapes().size() > 0) {
                for (int j = 0; j < image.getShapes().size(); j++) {
                    SegmentedShape shape = image.getShapes().elementAt(j);

                    if (shape.pixels.size() >= 50) {

                        shapes.add(new ExtraShapeData(shape));

                        if (!distinctClassIDs.contains(shape.classID)) {
                            distinctClassIDs.add(shape.classID);
                            JasmineClass c = trainingProject.getShapeClass(shape.classID);
                            classes.add(new ClusterClass(c.classID, c.name));
                        }

                    }
                }
            }
        }

        // create a k-means clusterer
        KMeansClusterer clusterer = new KMeansClusterer(classes.size());

        // add the data to the clusterer
        for (int i = 0; i < shapes.size(); i++) {
            ExtraShapeData shape = shapes.elementAt(i);
            clusterer.add(new DataPoint(makePositionFromShape(shape), shape.getClassID()));
        }

        System.err.println("Training on " + shapes.size() + " shape(s).");

        // run the clusterer
        clusterer.run();

        // get the solution
        KMeansSolution s = clusterer.getSolution();

        // test
        test(s, trainingProject);
        test(s, unseenProject);


    }

    public void test(KMeansSolution s, JasmineProject project) {

        Vector<ExtraShapeData> shapes = new Vector<ExtraShapeData>(10);

        System.out.println("Results for: " + project.getName());
        
        // now get shapes
        for (int i = 0; i < project.getImages().size(); i++) {
            JasmineImage image = project.getImages().elementAt(i);
            if (image.getShapes().size() > 0) {
                for (int j = 0; j < image.getShapes().size(); j++) {
                    SegmentedShape shape = image.getShapes().elementAt(j);

                    if (shape.pixels.size() >= 50) {

                        shapes.add(new ExtraShapeData(shape));

                    }
                }
            }
        }

        int TP = 0;

        for (int i = 0; i < shapes.size(); i++) {
            ExtraShapeData shape = shapes.elementAt(i);
            DataPoint object = new DataPoint(makePositionFromShape(shape), shape.getClassID());
            int c = s.test(object);
            if (c == shape.getClassID()) {
                TP++;
            }
        }

        double percentage = (TP / (double) shapes.size()) * 100;

        System.out.println("TP: " + TP);
        System.out.println("OF: " + shapes.size());
        System.out.println("  = " + new DecimalFormat("0.0").format(percentage) + "%");

    }


    public double[] makePositionFromShape(ExtraShapeData shape) {

        double[] values = new double[19];
        values[0] = shape.countCorners();
        values[1] = shape.countHollows();
        values[2] = shape.getBalanceX();
        values[3] = shape.getBalanceY();
        values[4] = shape.getDensity();
        values[5] = shape.getAspectRatio();
        values[6] = shape.getJoints();
        values[7] = shape.getEnds();
        values[8] = shape.getRoundness();
        values[9] = shape.getRoughness(4);
        values[10] = shape.getRoughness(8);
        values[11] = shape.getEndBalanceX();
        values[12] = shape.getEndBalanceY();
        values[13] = shape.getClosestEndToCog();
        values[14] = shape.getClosestPixelToCog();
        values[15] = shape.getHorizontalSymmetry();
        values[16] = shape.getVerticalSymmetry();
        values[17] = shape.getInverseHorizontalSymmetry();
        values[18] = shape.getInverseVerticalSymmetry();

        return values;

    }

}
