package ac.ooechs.oil;

import ac.ooechs.oil.pipeclassification.NewPipelineClassifier;
import ac.ooechs.oil.pipeclassification.PipelineClassifier;
import ac.ooechs.oil.edgedetection.NewEdgeSegmenter;
import ac.ooechs.oil.segmentation.PipelineSegmenter;
import ac.ooechs.oil.util.MedianFilter;
import ac.essex.ooechs.imaging.commons.segmentation.Segmenter;
import ac.essex.ooechs.imaging.commons.util.panels.FileTree;
import ac.essex.ooechs.imaging.commons.util.ImageFilenameFilter;
import ac.essex.ooechs.imaging.commons.edge.hough.local.LocalHoughTransform;
import ac.essex.ooechs.imaging.commons.Pixel;
import ac.essex.ooechs.imaging.commons.PixelLoader;
import ac.essex.ooechs.imaging.commons.fast.IntegralImage;
import ac.essex.ooechs.imaging.commons.subpixel.LineExtractor;
import ac.essex.ooechs.imaging.commons.window.util.WindowFeatures;
import ac.essex.ooechs.imaging.shapes.SegmentedShape;
import ac.essex.ooechs.imaging.shapes.Grouper;
import ac.essex.ooechs.imaging.shapes.ExtraShapeData;
import ac.essex.ooechs.imaging.shapes.ShapePixel;
import ac.essex.ooechs.imaging.gp.problems.classification.distance.DistanceClassifier;

import javax.swing.*;
import java.awt.event.ActionListener;
import java.awt.event.ActionEvent;
import java.awt.image.BufferedImage;
import java.awt.*;
import java.io.File;
import java.util.Vector;

/**
 * Second pipelines GUI. You can move your mouse around to see the hough transform
 * working on local areas of the image.
 */
public class PipelinesGUI_AutomaticLocalHough3 extends JFrame implements ActionListener {

    protected ImagePanelWindow p;
    protected JButton play, save;

    protected NewPipelineClassifier classifier = new NewPipelineClassifier();
    protected Segmenter segmenter;

    protected JLabel progress;
    protected BufferedImage newImage;

    protected FileTree f;
    protected static File directory;

    protected int[][] results;

    int length = 15;
    int errors = 0;

    public static void main(String[] args) {
        String file = "/home/ooechs/Desktop/Documents/Papers/Pipelines/data";
        directory = new File(file);
        if (!directory.exists()) {
            JFileChooser c = new JFileChooser();
            if (c.showOpenDialog(null) == JFileChooser.APPROVE_OPTION) {
                directory = c.getSelectedFile();
                if (!directory.isDirectory()) {
                    directory = directory.getParentFile();
                }
            } else {
                System.exit(0);
            }
        }
        new PipelinesGUI_AutomaticLocalHough3(directory, new NewEdgeSegmenter());
    }

    public PipelinesGUI_AutomaticLocalHough3(File file, Segmenter segmenter) {
        super("Oil Pipelines Solution version 0.3");


        try {
            UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
        } catch (Exception ex) {
            System.err.println("Unable to load native look and feel");
        }

        this.segmenter = segmenter;

        p = new ImagePanelWindow(false) {

            public void onNewWindow(int x, int y) {

                BufferedImage overlay = getImage();

                if (overlay != null && results != null) {

                    if (results[x][y] == 1) {

                        LocalHoughTransform h = getLine(results, x, y);

                        double angle = h.getAngle();

                        int dy = (int) (length * Math.sin(angle));
                        int dx = (int) (length * Math.cos(angle));

                        // reject "weak" lines
                        if (h.getStrength() < 0.10/* || h.getPoints() < 5*/) {
                            return;
                        }

                        Pixel lineStart = new Pixel(x + dx, y + dy);
                        Pixel lineEnd = new Pixel(x - dx, y - dy);

                        Graphics g = overlay.getGraphics();
                        g.setColor(Color.RED);
                        //g.drawLine(lineStart.x, lineStart.y, lineEnd.x, lineEnd.y);

                        if (Math.abs(angle + 0.78539) < 0.001) {
                            System.out.println("Naughty line: " + angle);
                            PixelLoader subimage = new PixelLoader(overlay.getSubimage(x - houghSize - 1, y - houghSize - 1, (houghSize * 2) + 1, (houghSize * 2) + 1));
                            try {
                                subimage.saveAs("/home/ooechs/Desktop/naughtyline.png");
                            } catch (Exception e) {
                                e.printStackTrace();
                            }
                        }

                    }

                }

            }

        };

        f = new FileTree(file, new ImageFilenameFilter()) {

            /**
             * Called when a file is selected in the tree
             */
            public void onSelectFile(File f) {
                loadImage(f);
            }

        };

        play = new JButton("Generate Movie");
        play.addActionListener(this);

        save = new JButton("Save");
        save.addActionListener(this);

        JToolBar toolbar = new JToolBar();
        toolbar.add(play);
        toolbar.add(save);

        JScrollPane scrollPane = new JScrollPane(f);
        scrollPane.setPreferredSize(new Dimension(200, -1));

        JPanel main = new JPanel(new BorderLayout());
        main.add(toolbar, BorderLayout.NORTH);
        main.add(new JScrollPane(p));
        progress = new JLabel();
        main.add(progress, BorderLayout.SOUTH);

        JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT, true, scrollPane, main);
        getContentPane().add(splitPane, BorderLayout.CENTER);
        setDefaultCloseOperation(EXIT_ON_CLOSE);
        setExtendedState(MAXIMIZED_BOTH);
        setSize(900, 668);
        setLocationRelativeTo(null);
        setVisible(true);

    }

    public void actionPerformed(ActionEvent e) {
        if (e.getSource() == play) {
            try {
                File movieDir = new File("/home/ooechs/Desktop/movieframes");
                movieDir.mkdir();

                String[] children = f.getChildren();
                for (int i = 0; i < children.length; i++) {
                    String child = children[i];
                    System.out.println(child);
                    PixelLoader image = new PixelLoader(new File(directory, child));
                    p.setImage(image);
                    segment();
                    p.repaint();
                    String name;
                    if (i < 10) {
                        name = "000" + i;
                    } else {
                        if (i < 100) {
                            name = "00" + i;
                        } else {
                            if (i < 1000) {
                                name = "0" + i;
                            } else {
                                name = String.valueOf(i);
                            }
                        }
                    }
                    System.out.println(name);
                    PixelLoader img = new PixelLoader(p.getImage());
                    img.saveAs(new File(movieDir, "frame" + name + ".png"));
                }
            } catch (Exception e4) {
                e4.printStackTrace();
            }
        }
        if (e.getSource() == save) {
            try {
                File f = new File("/home/ooechs/Desktop/pipes.png");
                new PixelLoader(p.getImage()).saveAs(f);
            } catch (Exception er) {
                er.printStackTrace();
            }
        }
    }

    protected void loadImage(File f) {
        if (f != null) {
            try {
                PixelLoader image = new PixelLoader(f);
                p.setImage(image);
                new Thread() {
                    public void run() {
                        segment();
                        p.repaint();
                    }
                }.start();
            } catch (Exception err) {
                err.printStackTrace();
            }
        }
    }

    int n = 0;

    PipelineSegmenter s = new PipelineSegmenter();

    int counter = 0;

    protected void segment() {
        if (segmenter == null) return;

        BufferedImage bufferedImage = p.getImage();

        Graphics2D g = null;

        if (newImage == null || newImage.getWidth() != bufferedImage.getWidth() || newImage.getHeight() != bufferedImage.getHeight()) {
            newImage = new BufferedImage(bufferedImage.getWidth(), bufferedImage.getHeight(), BufferedImage.TYPE_INT_ARGB);
            g = (Graphics2D) newImage.getGraphics();
        }

        if (g == null) {
            g = (Graphics2D) newImage.getGraphics();
            // clear the overlay
            g.setColor(Color.BLACK);
            //g.fillRect(0, 0, bufferedImage.getWidth(), bufferedImage.getHeight());

        }

        g.drawImage(bufferedImage, 0, 0, null);

        long start = System.currentTimeMillis();

        PixelLoader pl = new PixelLoader(bufferedImage);

        results = new int[bufferedImage.getWidth()][bufferedImage.getHeight()];

        for (int y = 1; y < bufferedImage.getHeight() - 1; y++) {
            for (int x = 1; x < bufferedImage.getWidth() - 1; x++) {
                if (segmenter.segment(pl, x, y) == 2) {
                    // only find the brightest part of the pipe - or the central region
                    if (pl.getGreyValue(x, y) < 220) continue;

                    results[x][y] = 1;

                }
            }
        }

        // now find the individual pipelines by skeletonising the shapes from the segmenter
        Vector<SegmentedShape> shapes = new Grouper().getShapes(results);

        // clear the results array
        results = new int[bufferedImage.getWidth()][bufferedImage.getHeight()];

        for (int i = 0; i < shapes.size(); i++) {
            SegmentedShape shape = shapes.elementAt(i);
            if (shape.getMass() <= 50) continue;
            if (shape.originalValue == 0) continue;

            ExtraShapeData esd = new ExtraShapeData(shape);
            esd.skeletonise();

            // draw the shape on the overlay
            for (int y = 0; y < esd.boundingHeight; y++) {
                for (int x = 0; x < esd.boundingWidth; x++) {
                    ShapePixel p = esd.array[x][y];
                    if (p != null) {
                        results[p.x][p.y] = 1;
                        //newImage.setRGB(p.x, p.y, Color.WHITE.getRGB());
                    }
                }
            }
        }

         double[][] angles = new double[bufferedImage.getWidth()][bufferedImage.getHeight()];

        // go through the results array
        for (int y = length + houghSize; y < bufferedImage.getHeight() - length - houghSize; y++) {
            for (int x = length + houghSize; x < bufferedImage.getWidth() - length - houghSize; x++) {

                angles[x][y]  = -10;

                if (results[x][y] == 1) {

                    LocalHoughTransform h = getLine(results, x, y);

                    angles[x][y] = h.getAngle();

                    // reject "weak" lines
                    if (h.getStrength() < 0.10/* || h.getPoints() < 5*/) {
                        //results[x][y] = 0;
                        //newImage.setRGB(x,y,Color.RED.getRGB());
                        errors++;
                        continue;
                    }

                }
            }
        }


        MedianFilter f = new MedianFilter();

            for (int y = length + houghSize; y < bufferedImage.getHeight() - length - houghSize; y++) {
                for (int x = length + houghSize; x < bufferedImage.getWidth() - length - houghSize; x++) {

                    double angle = angles[x][y];

                    if (angle < -5) continue;

                    int neighbourSize = 2;

                    for (int dy = -neighbourSize; dy <= neighbourSize; dy++) {
                        for (int dx = -neighbourSize; dx < neighbourSize; dx++) {

                            if (angles[x + dx][y + dy] >= -5) {
                                f.addData(angles[x + dx][y + dy]);
                            }

                        }
                    }

                    //angle = f.getMedian();
                    f.clear();

                    int dy = (int) (length * Math.sin(angle));
                    int dx = (int) (length * Math.cos(angle));

                    Pixel lineStart = new Pixel(x + dx, y + dy);
                    Pixel lineEnd = new Pixel(x - dx, y - dy);

                    //g.setColor(Color.RED);
                    //g.drawLine(lineStart.x, lineStart.y,  lineEnd.x, lineEnd.y);

                    PixelLoader subImage = LineExtractor.extract(pl, length*2, lineStart, lineEnd);

                    IntegralImage image = subImage.getIntegralImage();

                    try {

                        if (classifier.eval(image) == 1) {

                            // this second level detector removes errors
                            if (classifier.eval2(image) != 1) {
                                results[x][y] = 0;
                            }

                            // this third level detector removes any other errors
                   /*         if (classifier.eval3(image) != 1) {
                                results[x][y] = 0;
                            }*/


                        } else {
                            results[x][y] = 0;
                        }

                        String folder = results[x][y] == 1? "foreground" : "background";

                        if (Math.random() < 0.1) {
                            //subImage.saveAs(new File("/home/ooechs/Desktop/images/"  + folder + "/np_" + counter + ".bmp"));
                            counter++;
                        }

                    } catch (Exception e) {
                        System.err.println(e.toString());
                    }

                }

            }
        
        BufferedImage overlay = new BufferedImage(bufferedImage.getWidth(), bufferedImage.getHeight(), BufferedImage.TYPE_INT_ARGB);
        Graphics og = overlay.getGraphics();

        Composite oldcomp = g.getComposite();
        g.setComposite(AlphaComposite.getInstance(AlphaComposite.SRC_OVER, 0.5f));


        for (int y = length + houghSize; y < bufferedImage.getHeight() - length - houghSize; y++) {
            for (int x = length + houghSize; x < bufferedImage.getWidth() - length - houghSize; x++) {

                if (results[x][y] == 1) {

                    int neighbours = 0;
                    int neighbourSize = 2;
                    for (int dy = -neighbourSize; dy <= neighbourSize; dy++) {
                        for (int dx = -neighbourSize; dx < neighbourSize; dx++) {

                            if (results[x + dx][y + dy] == 1) {
                                neighbours++;
                            }

                        }
                    }

                    if (neighbours < 3) {
                        //g.setColor(Color.BLUE);
                        //g.drawOval(x - 5, y - 5, 10, 10);
                    } else {
 
                        og.setColor(Color.GREEN);
                        og.fillOval(x - 5, y - 5, 10, 10);
                        //overlay.setRGB(x, y, Color.GREEN.getRGB());
                    }

                }

            }
        }



        int borderSize = 20;
        g.setColor(Color.WHITE);
        g.drawRect(borderSize, borderSize, pl.getWidth() - (borderSize * 2), pl.getHeight() - (borderSize * 2));
        g.setColor(Color.BLACK);
        // top
        g.fillRect(0,0,pl.getWidth(), borderSize);
        // bottom
        g.fillRect(0,pl.getHeight() - borderSize,pl.getWidth(), borderSize);
        // left
        g.fillRect(0,borderSize, borderSize, pl.getHeight() - (borderSize * 2));
        // right
        g.fillRect(pl.getWidth() - borderSize, borderSize, borderSize, pl.getHeight() - (borderSize * 2));

        g.drawImage(overlay, borderSize, borderSize, pl.getWidth() - borderSize, pl.getHeight() - borderSize, borderSize, borderSize, pl.getWidth() - borderSize, pl.getHeight() - borderSize, null);

        g.setComposite(oldcomp);

        // display the work we've done
        p.setImage(newImage);

        long end = System.currentTimeMillis();
        progress.setText("Time: " + (end - start));

        System.out.println("Errors: " + errors);
    }

    protected int houghSize = 10;

    public LocalHoughTransform getLine(int[][] results, int x, int y) {

        // do a local hough transform here
        LocalHoughTransform h = new LocalHoughTransform(36);

        for (int dy = -houghSize; dy <= houghSize; dy++) {
            for (int dx = -houghSize; dx < houghSize; dx++) {

                if (results[x + dx][y + dy] == 1) {
                    h.addPoint(dx, dy);
                }

            }
        }

        return h;

    }

    protected int bad = 0;

    public void evaluatePipeline(Graphics g, Pixel start, Pixel end, PixelLoader subImage, boolean swap) {

        PipelineClassifier classifier = new PipelineClassifier();

        DistanceClassifier c = new DistanceClassifier();

        try {

            int windowWidth = subImage.getWidth();
            int windowHeight = 5;

            int dx = end.x - start.x;
            int dy = end.y - start.y;

            int maxHeight = subImage.getHeight() - windowHeight;

            // create the windows we want to test
            for (int y = 0; y < maxHeight; y++) {

                ac.essex.ooechs.imaging.commons.window.data.Window window;
                if (swap) {
                    window = new ac.essex.ooechs.imaging.commons.window.data.Window(windowWidth, windowHeight, 0, maxHeight - y - 1, null);
                } else {
                    window = new ac.essex.ooechs.imaging.commons.window.data.Window(windowWidth, windowHeight, 0, y, null);
                }


                double[] features = WindowFeatures.getFeatures(subImage, window);

                if (classify(features) == -1) {
                    if (y % 5 == 0) {
                        //PixelLoader subsubImage = subImage.getSubImage(new ImageWindow(window));
                        //subsubImage.saveAs("/home/ooechs/Desktop/bad/bad" + bad + ".bmp");
                        bad++;
                    }
                }

                double mean = -17.83;
                double maxDist = 71.17;

                double dist = Math.abs(eval(features) - mean);

                // set a colour
                int red = (int) ((dist / maxDist) * 255);
                if (red > 255) red = 255;
                int green = 255 - red;

                g.setColor(new Color(red, green, 0));

                // draw the point, but translate it onto the original image
                double py = y / (double) subImage.getHeight();

                int nx = start.x + (int) (dx * py);
                int ny = start.y + (int) (dy * py);

                g.drawOval(nx - 2, ny - 2, 4, 4);

            }

        } catch (Exception e) {
            e.printStackTrace();
        }

    }

    public double eval(double[] feature) {
        double node8 = feature[3] * 0.45445273909203254;
        double node6 = feature[16] * node8;
        double node5 = node6 / -0.6493517511388426;
        double node3 = 0.0 + node5;
        double node2 = feature[26] != 0 ? node3 / feature[26] : 0;
        double node1 = node2 - feature[18];
        return (node1 + feature[25]) / 2;
    }

    final int correctClassID = 1;
    final double threshold = 10.0;
    final double mean = -17.834495544433594;

    public int classify(double[] feature) {
        if (Math.abs(eval(feature) - mean) < threshold) return 1;
        else return -1;
    }


    public void evaluatePipeline2(Graphics g, Pixel start, Pixel end, PixelLoader subImage) {

        PipelineClassifier classifier = new PipelineClassifier();

        try {

            int windowWidth = subImage.getWidth();
            int windowHeight = 10;

            int dx = end.x - start.x;
            int dy = end.y - start.y;

            // create the windows we want to test
            for (int y = 0; y < subImage.getHeight() - windowHeight; y++) {
                ac.essex.ooechs.imaging.commons.window.data.Window window = new ac.essex.ooechs.imaging.commons.window.data.Window(windowWidth, windowHeight, 0, y, null);

                int classID = classifier.classify(subImage, window);

                // draw the point, but translate it onto the original image
                double py = y / (double) subImage.getHeight();

                int nx = start.x + (int) (dx * py);
                int ny = start.y + (int) (dy * py);

                // draw the point on the screen
                if (classID == 0) {
                    g.setColor(Color.RED);
                } else {
                    g.setColor(Color.GREEN);
                    g.drawOval(nx - 2, ny - 2, 4, 4);
                }


            }

        } catch (Exception e) {
            e.printStackTrace();
        }

    }
}
