package ac.essex.ooechs.facedetection.util;

import javax.swing.*;
import java.awt.event.*;
import java.awt.*;
import java.io.File;
import java.io.FilenameFilter;
import java.util.Vector;
import java.util.Hashtable;

import ac.essex.ooechs.imaging.commons.PixelLoader;
import ac.essex.ooechs.imaging.commons.HaarRegions;
import ac.essex.ooechs.imaging.commons.util.Region;
import ac.essex.ooechs.imaging.commons.cmu.GroundTruthReader;
import ac.essex.ooechs.imaging.commons.cmu.FaceDefinition;
import ac.essex.ooechs.imaging.commons.util.panels.ImagePanel;
import ac.essex.ooechs.imaging.commons.Pixel;
import ac.essex.ooechs.ecj.commons.util.ObjectClass;
import ac.essex.ooechs.ecj.commons.fitness.SimpleFitnessCalculator;
import ac.essex.ooechs.ecj.haar.solutions.FaceDetector;
import ac.essex.ooechs.ecj.haar.solutions.CombinedFaceDetector2;
import ac.essex.ooechs.ecj.haar.problems.FaceDetectorProblem;
import ac.essex.ooechs.facedetection.util.Combiner;
import ac.essex.ooechs.facedetection.solutions.EvolvedFaceDetector;

/**
 * <p/>
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version,
 * provided that any use properly credits the author.
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details at http://www.gnu.org
 * </p>
 *
 * @author Olly Oechsle, University of Essex, Date: 19-Sep-2006
 * @version 1.1
 */
public class WindowSearchGUI extends JFrame implements ActionListener {

    public static final int DEFAULT_WINDOW_WIDTH = FaceDetectorProblem.WINDOW_WIDTH;
    public static final int DEFAULT_WINDOW_HEIGHT = FaceDetectorProblem.WINDOW_HEIGHT;

    int windowWidth = DEFAULT_WINDOW_WIDTH;
    int windowHeight = DEFAULT_WINDOW_HEIGHT;

    int SEGMENTSX = FaceDetectorProblem.WINDOWBLOCKSX;
    int SEGMENTSY = FaceDetectorProblem.WINDOWBLOCKSY;

    JLabel result;
    JLabel position;

    File imageDirectory;

    JButton next, back;

    JButton plus, minus;

    JTextArea messages;

    JMenuItem run, runAll;

    GroundTruthReader truth;

    public static void main(String[] args) throws Exception {

        //String classname = "min_max_mean3_FaceDetector426";

        // Gets rid of false positives in some areas
        //String classname = "SAVED_WEAK_CLASSIFIERS_32x40FaceDetector005";

        // Finds a lot of true positives with only a few false positives.
        //String classname = "SAVED_WEAK_CLASSIFIERS_32x40FaceDetector016";

        // Low FP rate, and only about 75% true positives
        //String classname = "min_max_mean3_FaceDetector426";

        //String classname = "BIG_WEAK_CLASSIFIERS_32x40FaceDetector087";

        // 80% TP/FP
        //String classname = "MEDIUM_WEAK_CLASSIFIERS_32x40FaceDetector1666";

        //String classname = "HUGE_ALL_CLASSIFIERS_32x40FaceDetector373";
        String classname = "HUGE_80_WEAK_CLASSIFIERS_32x40FaceDetector1935";

        FaceDetector solution = new EvolvedFaceDetector();

        new WindowSearchGUI(new File("/home/ooechs/ecj-training/faces/mit+cmu/test"), solution);
        //new WindowSearchGUI(new File("/home/ooechs/ecj-training/faces/essex/mit/test/cropped"), solution);

    }

    PixelLoader image;

    ImagePanel ip;
    HaarRegions haar;
    FaceDetector solution;

    // list of all the files in the directory
    File[] images;

    // index of our position through the array of files
    int index = 0;

    public WindowSearchGUI(File imageDirectory, FaceDetector solution) {

        super("Evolved Face Detection using Genetic Programming");

        // load ground truth
        File f = new File("/home/ooechs/ecj-training/faces/mit+cmu/truth.txt");
        if (!f.exists()) {
            System.err.println("File does not exist: " + f.getAbsolutePath());
        } else {
            truth = new GroundTruthReader(f);
        }
        // end load ground truth

        this.imageDirectory = imageDirectory;

        loadImages();

        this.solution = solution;

        // SWING STUFF BELOW ---
        Container c = getContentPane();

        // MENU

        run = new JMenuItem("Run");
        run.addActionListener(this);

        runAll = new JMenuItem("Run All");
        runAll.addActionListener(this);

        JMenu tools = new JMenu("Tools");

        tools.add(run);
        tools.add(runAll);

        JMenuBar bar = new JMenuBar();

        bar.add(tools);

        setJMenuBar(bar);

        // UPPER PART
        JPanel buttonPanel = new JPanel(new GridLayout(1, 2));

        position = new JLabel("");
        result = new JLabel("");

        JPanel status = new JPanel(new FlowLayout(FlowLayout.LEFT));

        status.add(position);
        status.add(result);

        JPanel buttons = new JPanel(new FlowLayout(FlowLayout.LEFT));

        next = new JButton("Next");
        back = new JButton("Prev");

        next.addActionListener(this);
        back.addActionListener(this);

        plus = new JButton("+");
        minus = new JButton("-");

        plus.addActionListener(this);
        minus.addActionListener(this);

        buttons.add(back);
        buttons.add(next);

        buttons.add(minus);
        buttons.add(plus);

        buttonPanel.add(status);
        buttonPanel.add(buttons);

        c.add(buttonPanel, BorderLayout.NORTH);


        // CENTER PART
        ip = new WindowPanel();

        // load the first image in the directory
        load(true);

        c.add(ip, BorderLayout.CENTER);

        // BOTTOM PART

        messages = new JTextArea();

        JScrollPane scroller = new JScrollPane(messages);

        scroller.setPreferredSize(new Dimension(100, 100));

        c.add(scroller, BorderLayout.SOUTH);

        // EXIT JAVA WHEN WINDOW CLOSES
        addWindowListener(new WindowAdapter() {
            public void windowClosing(WindowEvent e) {
                System.exit(0);
            }
        });

        // SHOW WINDOW
        setSize(800, 600);
        setVisible(true);

    }

    public void actionPerformed(ActionEvent e) {

        if (e.getSource() == next) next(true);
        if (e.getSource() == back) back(true);
        if (e.getSource() == plus) increaseWindowSize();
        if (e.getSource() == minus) decreaseWindowSize();
        if (e.getSource() == run) run();
        if (e.getSource() == runAll) runAll();

    }

    public void runAll() {
        // get all the preliminary detectors

        File detectorDirectory = new File("/home/ooechs/ecj-common/classes/ac/essex/ooechs/ecj/haar/solutions/");

        if (!detectorDirectory.exists()) {
            throw new RuntimeException("DetectorDirectory doesn't exist");
        }

        Vector<String> preliminaryDetectors = new Vector<String>(30);

        Vector<String> secondaryDetectors = new Vector<String>(30);

        File[] detectors = detectorDirectory.listFiles();

        System.out.println("Loading detectors...");

        for (int i = 0; i < detectors.length; i++) {
            File detector = detectors[i];

            String className = detector.getName().substring(0,detector.getName().indexOf(".class"));

            if (className.startsWith("Preliminary_")) {
                preliminaryDetectors.add(className);
            }
            if (className.startsWith("Secondary_")) {
                secondaryDetectors.add(className);
            }
        }

        System.out.println("Loaded detectors.\n\n");

        Vector<CombinedFaceDetector2> combinedDetectors = new Vector<CombinedFaceDetector2>(10);

        // create the combinations
        for (int i = 0; i < preliminaryDetectors.size(); i++) {
            String name1 =  preliminaryDetectors.elementAt(i);
            for (int j = 0; j < secondaryDetectors.size(); j++) {
                String name2 = secondaryDetectors.elementAt(j);

                // instantiate each
                try {
                    FaceDetector detector1 = (FaceDetector) Class.forName("ac.essex.ooechs.ecj.haar.solutions." + name1).newInstance();
                    FaceDetector detector2 = (FaceDetector) Class.forName("ac.essex.ooechs.ecj.haar.solutions." + name2).newInstance();
                    combinedDetectors.add(new CombinedFaceDetector2(detector1, detector2));
                } catch (Exception e) {
                    System.out.print("e");
                }

            }
        }

        //run(combinedDetectors);
    }

    public void run() {
        //Vector<CombinedFaceDetector2> combinedDetectors = new Vector<CombinedFaceDetector2>(10);
        //combinedDetectors.add(new CombinedFaceDetector2());
        //run(combinedDetectors);
        Combiner c = new Combiner();
        Vector<FaceDetector> detectors = c.createCombinations();
         run(detectors);

    }

    public void run(Vector<FaceDetector> combinedDetectors) {

        float bestFitness = 0;

        CombinedFaceDetector2 best = null;

        for (int i = 0; i < combinedDetectors.size(); i++) {
            FaceDetector combinedDetector = combinedDetectors.elementAt(i);

            System.out.println(combinedDetector + ":");

            int TP = 0;
            int FP = 0;
            int totalTP = 0;

            index = 0;

            while (next(false)) {

                // wipe out hits on regions
                for (int j = 0; j < faceRegions.size(); j++) {
                    Region region = faceRegions.elementAt(j);
                    region.setHits(0);
                    totalTP++;
                }

                // runAll at three sizes in total
                windowWidth = DEFAULT_WINDOW_WIDTH;
                windowHeight = DEFAULT_WINDOW_HEIGHT;

                for (int x = 0; x < 3; x++) {

                    // runAll detector
                    haar.setWindowPosition(0, 0, windowWidth, windowHeight, SEGMENTSX, SEGMENTSY);

                    Vector<Pixel> objects = combinedDetector.getObjects(haar);
                    combinedDetector.calculateTPandFP(objects, windowWidth, windowHeight, faceRegions);

                    FP += combinedDetector.getFP();

                    // runAll at a different size
                    increaseWindowSize();

                    System.out.print(".");

                }

                 // work out TP from hits on regions
                for (int j = 0; j < faceRegions.size(); j++) {
                    Region region = faceRegions.elementAt(j);
                    if (region.getHits() > 0) TP++;
                }

                System.out.println(filename + ": TP: " + TP + ", FP: " + FP);

            }

            // now work out some kind of fitness
            float fitness = new SimpleFitnessCalculator(1, 3).getFitness(TP, FP, totalTP);

            System.out.println();

            if (fitness > bestFitness) {
                bestFitness = fitness;
                System.out.println("*** BEST DETECTOR: " + combinedDetector);
                System.out.println("Fitness: " + fitness);
                System.out.println("TP: " + TP + " / " + totalTP);
                System.out.println("FP: " + FP);
            } else {
                System.out.println("No better individual found, fitness: " + fitness);
                System.out.println("TP: " + TP + " / " + totalTP);
                System.out.println("FP: " + FP);
            }

        }

    }

    public void loadImages() {

        images = imageDirectory.listFiles(new FilenameFilter() {
            public boolean accept(File dir, String name) {
                if (name.endsWith(".bmp")) return true;
                if (name.endsWith(".gif")) return true;
                if (name.endsWith(".jpg")) return true;
                return false;
            }
        });

        index = 0;
    }

    public boolean next(boolean show) {

        if (index < (images.length - 1)) {
            index++;
            load(show);
            return true;
        } else {
            return false;
        }

    }

    public boolean back(boolean show) {

        if (index > 1) {
            index--;
            load(show);
            return true;
        } else {
            return false;
        }

    }

    public void increaseWindowSize() {
        windowWidth += SEGMENTSX;
        windowHeight += SEGMENTSY;
        if (showingMask) {
            createMask();
            ip.repaint();
        }
    }

    public void decreaseWindowSize() {
        windowWidth -= SEGMENTSX;
        windowHeight -= SEGMENTSY;
        if (showingMask) {
            createMask();
            ip.repaint();
        }
    }

    private Vector<Region> faceRegions;

    private String filename;

    private Hashtable<File, HaarRegions> cache;

    private void load(boolean show) {

        if (cache == null) {
            cache = new Hashtable<File, HaarRegions>(10);
        }

        try {

            // load image
            image = new PixelLoader(images[index]);

            // remember the filename
            filename = image.getFile().getName();

            // truth goes here
            faceRegions = new Vector<Region>(10);

            // save filename (needed by ground truth reader)
            Vector<FaceDefinition> faces = truth.getFaces(image.getFile().getName());

            if (faces != null) {
                // get the regions from each
                for (int i = 0; i < faces.size(); i++) {
                    FaceDefinition face = faces.elementAt(i);
                    faceRegions.add(face.getRegion());
                }
            }

            // stop showing mask
            showingMask = false;

            haar = cache.get(images[index]);

            if (haar == null) {
                // create the haar image
                haar = new HaarRegions(image);

                // save to cache for speedy loading next time
                cache.put(images[index], haar);
            }

            // show the image
            if (show) ip.setImage(image);

        } catch (Exception e) {
            JOptionPane.showMessageDialog(this, "Could not load image: " + e.getMessage());
            e.printStackTrace();
        }
    }

    public void eval(int x, int y) {

        haar.setWindowPosition(x, y, windowWidth, windowHeight, SEGMENTSX, SEGMENTSY);

        int decision = (int) solution.calculate(haar);

        if (decision != ObjectClass.NO_CLASS) {
            result.setText(" FOUND OBJECT");
        } else {
            result.setText("");
        }

    }

    boolean showingMask = true;

    Vector<Pixel> objects;

    public PixelLoader createMask() {

        try {

            // create a copy of the image
            final PixelLoader mask = new PixelLoader(image.getFile());

            // find all the objects
            Vector<Pixel> objects = solution.getObjects(haar);

            System.out.println("There are: " + objects.size() + " objects in the scene.");

            // draw them onto the copy
            solution.drawObjects(objects, haar.getWindowWidth(), haar.getWindowHeight(), mask.getBufferedImage().getGraphics(), faceRegions);

            solution.calculateTPandFP(objects, haar.getWindowWidth(), haar.getWindowWidth(), faceRegions);

            messages.append("TP: " + solution.getTP() + ", FP:" + solution.getFP());

            return mask;

        } catch (Exception e) {
            JOptionPane.showMessageDialog(this, e.getMessage());
            e.printStackTrace();
        }

        return null;

    }

    public void toggle() {
        if (showingMask) {
            ip.setImage(image);
            showingMask = false;
        } else {

            // display the copied image.
            SwingUtilities.invokeLater(new Thread() {
                public void run() {
                    ip.setImage(createMask());
                }
            });

            String message = "TP: " + solution.getTP() + ", FP: " + solution.getFP() + "\n";
            messages.append(message);
            showingMask = true;
        }
    }

    class WindowPanel extends ImagePanel {

        int x, y;

        public WindowPanel() {

            addMouseMotionListener(new MouseMotionAdapter() {
                public void mouseMoved(MouseEvent e) {
                    x = e.getX();
                    y = e.getY();
                    position.setText("x=" + x + ", y=" + y);
                    repaint();
                    try {
                        eval(x, y);
                    } catch (RuntimeException e1) {
                        result.setText("Out of bounds");
                    }
                }
            });
            addMouseListener(new MouseAdapter() {
                public void mousePressed(MouseEvent e) {
                    toggle();
                }
            });

        }

        public void paintComponent(Graphics g) {
            super.paintComponent(g);
            g.drawRect(x, y, windowWidth, windowHeight);
            // draw the ground truth too
            //if (faceRegions != null) {
            //    g.setColor(Color.WHITE);
            //    for (int i = 0; i < faceRegions.size(); i++) {
            //        Region r = faceRegions.elementAt(i);
           //         r.draw(g);
           //     }
           // }
        }


    }

}
