package ac.essex.gp.problems.examples.noisereduction;

import ac.essex.gp.problems.Problem;
import ac.essex.gp.problems.DataStack;
import ac.essex.gp.problems.ImagingProblem;
import ac.essex.gp.Evolve;
import ac.essex.gp.treebuilders.TreeBuilder;
import ac.essex.gp.individuals.Individual;
import ac.essex.gp.nodes.*;
import ac.essex.gp.nodes.logic.More;
import ac.essex.gp.nodes.logic.Less;
import ac.essex.gp.nodes.generic.BasicFeature;
import ac.essex.gp.nodes.ercs.CustomRangeIntegerERC;
import ac.essex.gp.nodes.ercs.PercentageERC;
import ac.essex.gp.nodes.ercs.BoolERC;
import ac.essex.gp.nodes.math.*;
import ac.essex.gp.params.GPParams;
import ac.essex.gp.params.NodeConstraints;
import ac.essex.gp.interfaces.console.ConsoleListener;
import ac.essex.gp.interfaces.GPActionListener;
import ac.essex.gp.interfaces.graphical.GraphicalListener;
import ac.essex.ooechs.imaging.commons.PixelLoader;
import ac.essex.ooechs.imaging.commons.util.ImageFilenameFilter;
import ac.essex.ooechs.imaging.commons.util.panels.ImageFrame;

import java.awt.image.BufferedImage;
import java.awt.*;
import java.io.File;
import java.util.Vector;

/**
 * Attempts to reduce the noise in a noisy image
 */
public class NoiseReductionProblem extends ImagingProblem {

    Vector<TrainingPair> pairs;

    public static void main(String[] args) throws Exception {

        NoiseReductionProblem p = new NoiseReductionProblem();

        File dir = new File("M:\\pc\\desktop\\noise\\noise");
        Vector<File> imageFiles = ImageFilenameFilter.getImages(dir);
        File noisyDir = new File(dir, "noise5");
        for (int i = 0; i < imageFiles.size(); i++) {
            File file = imageFiles.elementAt(i);
            if (file.getName().startsWith("key")) {
            PixelLoader clean = new PixelLoader(file);
            PixelLoader noisy = new PixelLoader(new File(noisyDir, file.getName()));
            p.addPair(clean, noisy);
            break;
            }
        }

        Evolve.seed = 2360;

        long start = System.currentTimeMillis();
        Evolve e = new Evolve(p, new GraphicalListener());
        e.run();
        long end = System.currentTimeMillis();
        System.out.println("Time: " + (end - start));
        System.out.println(e.getBestIndividual().getTree(0).toLisp());
    }

    public NoiseReductionProblem() {
        pairs = new Vector<TrainingPair>();
    }

    public void addPair(PixelLoader clean, PixelLoader noisy) {
        pairs.add(new TrainingPair(clean, noisy));
    }

    public String getName() {
        return "Noise Reduction Problem";
    }

    public String getTrainingImageName(int index) {
        return pairs.elementAt(index).cleanImage.getFilename();
    }

    public int getTrainingCount() {
        return pairs.size();
    }

    public void initialise(Evolve e, GPParams params) {

        params.registerNode(new If_FP());
        params.registerNode(new More());
        params.registerNode(new Less());

        // register the nodes we want to use
        params.registerNode(new Add());
        params.registerNode(new Mul());
        params.registerNode(new Sub());
        params.registerNode(new Div());
        params.registerNode(new Squared());
        params.registerNode(new Sqrt());
        params.registerNode(new Hypot());
        params.registerNode(new Abs());
        params.registerNode(new Cos());
        params.registerNode(new CustomRangeIntegerERC(0, 255));
        params.registerNode(new CustomRangeIntegerERC(0, 8));
        params.registerNode(new PercentageERC());
        params.registerNode(new BoolERC());
        params.registerNode(new Mean());

        for (int i = 0; i < EvolvedNoiseReducer.NUM_FEATURES; i++) {
            params.registerNode(new BasicFeature(i));
        }

        // set up additional parameters
        params.setReturnType(NodeConstraints.NUMBER);

    }


    public void customiseParameters(GPParams params) {
        params.setPopulationSize(500);
        params.setGenerations(100);
        params.setMutationProbability(0.3);
        params.setCrossoverProbability(0.7);
        params.setTreeBuilder(GPParams.FULL);
    }

    public void evaluate(Individual ind, DataStack data, Evolve e) {

        double rmsError = 0;
        int hits = 0;
        int mistakes = 0;
        double n = 0;

        for (int i = 0; i < pairs.size(); i++) {

            TrainingPair trainingPair = pairs.elementAt(i);
            PixelLoader noisyImage = trainingPair.noisyImage;
            PixelLoader cleanImage = trainingPair.cleanImage;

            int size = 1;
            for (int ny = size; ny < noisyImage.getHeight() - size; ny++) {
                for (int nx = size; nx < noisyImage.getWidth() - size; nx++) {

                    data.features = EvolvedNoiseReducer.getFeatures(noisyImage, nx, ny);

                    int expected = cleanImage.getGreyValue(nx, ny);

                    int result = (int) ind.execute(data);
                    if (result > 255) result = 255;
                    if (result < 0) result = 0;

                    int diff = Math.abs(result - expected);
                    //double percentDiff = diff + 1 / (double) expected + 1;
                    //double error = percentDiff - 1;

                    rmsError += diff;//*diff;

                    n++;

                    if (diff < 2) {
                        hits++;
                    } else {
                        mistakes++;
                    }

                }
            }

        } // for every pair

        //rmsError /= n;
        //rmsError = Math.sqrt(rmsError);

        ind.setKozaFitness(rmsError);
        ind.setHits(hits);
        ind.setMistakes(mistakes);


    }

    private BufferedImage out, original;
    private int[] rgb;

    ImageFrame originalFrame;

    public BufferedImage describe(GPActionListener listener, Individual ind, DataStack data, int index) {

        PixelLoader noisyImage = pairs.elementAt(index).noisyImage;
        PixelLoader cleanImage = pairs.elementAt(index).cleanImage;

        if (listener instanceof GraphicalListener) {

            double rmsError = 0;

            if (out == null) {
                out = new BufferedImage(noisyImage.getWidth(), noisyImage.getHeight(), BufferedImage.TYPE_INT_RGB);
                original = new BufferedImage(noisyImage.getWidth(), noisyImage.getHeight(), BufferedImage.TYPE_INT_RGB);
                rgb = new int[256];
                for (int i = 0; i < rgb.length; i++) {
                    rgb[i] = new Color(i, i, i).getRGB();
                }
            }

            int size = 1;
            for (int ny = size; ny < noisyImage.getHeight() - size; ny++) {
                for (int nx = size; nx < noisyImage.getWidth() - size; nx++) {

                    data.features = EvolvedNoiseReducer.getFeatures(noisyImage, nx, ny);

                    int result = (int) ind.execute(data);
                    if (result > 255) result = 255;
                    if (result < 0) result = 0;

                    int diff = Math.abs(noisyImage.getGreyValue(nx, ny) - cleanImage.getGreyValue(nx, ny));
                    rmsError += diff;

                    out.setRGB(nx, ny, rgb[result]);
                    original.setRGB(nx, ny, rgb[noisyImage.getGreyValue(nx, ny)]);

                }
            }

            if (originalFrame ==  null) {
                System.out.println("Noisy Image Error: " + rmsError);
                originalFrame = new ImageFrame(original, "Original"); 
            }


            return out;

        }

        return null;

    }

    class TrainingPair {

        public PixelLoader noisyImage, cleanImage;

        TrainingPair(PixelLoader cleanImage, PixelLoader noisyImage) {
            this.noisyImage = noisyImage;
            this.cleanImage = cleanImage;
        }
    }


}

