package ac.essex.gp.neural;

import java.util.ArrayList;
import java.util.Vector;

/**
 * <p/>
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version,
 * provided that any use properly credits the author.
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details at http://www.gnu.org
 * </p>
 *
 * @author Olly Oechsle, University of Essex, Date: 06-Mar-2007
 * @version 1.0
 */
public class NeuralNet {

    public ArrayList<NeuronLayer> layers;

    NeuronLayer inputLayer;
    NeuronLayer outputLayer;

    double learningRate = 1.00;

    public static void main(String[] args) {

        // two inputLayer - 0 and 1.
        // one output 0/1

        // aim to make an AND gate
        // input | output
        // 0 1   | 0
        // 1 0   | 0
        // 1 1   | 1
        Vector<TrainingData> trainingData = new Vector<TrainingData>(10);
        trainingData.add(new TrainingData(new double[]{0, 1}, new double[]{0}));
        trainingData.add(new TrainingData(new double[]{1, 0}, new double[]{0}));
        trainingData.add(new TrainingData(new double[]{1, 1}, new double[]{1}));


        NeuralNet nn = new NeuralNet(2, 2, 1);
        nn.trainNetwork(trainingData);

    }

    public NeuralNet(int numInputs, int numHidden, int numOutputs) {

        layers = new ArrayList<NeuronLayer>(3);

        inputLayer = new NeuronLayer(numInputs);
        // input layer is not included in the list of layers


        NeuronLayer previous = inputLayer;

        for (int i = 0; i < 1; i++) {
            NeuronLayer hidden = new NeuronLayer(numHidden);
            layers.add(hidden);

            previous.right = hidden;
            hidden.left = previous;

            previous = hidden;
        }

        outputLayer = new NeuronLayer(numOutputs);
        outputLayer.left = previous;
        previous.right = outputLayer;
        layers.add(outputLayer);

        inputLayer.connectNeurons();

        System.out.println("Created NET");

    }

    public void trainNetwork(Vector<TrainingData> trainingData) {
        for (int i = 0; i < 1000; i++) {
            double error = 0;
            for (int j = 0; j < trainingData.size(); j++) {
                TrainingData t = trainingData.elementAt(j);
                error += trainNetwork(t);
                System.out.println("#" + j + " " + outputLayer.neurons[0].output);
            }
            System.out.println(i + ": " + error);

        }
    }

    public double trainNetwork(TrainingData t) {

        double totalError = 0;

        // Give our input to the first layer. t is an object of TrainingData class
        for (int i = 0; i < inputLayer.neurons.length; i++) {
            Neuron neuron = inputLayer.neurons[i];
            neuron.output = t.inputs[i];
        }

        // Step 1 : Find the output of hidden layer neurons and output layer neurons
        for (NeuronLayer layer : layers) {
            for (Neuron n : layer.neurons) {
                n.updateOutput();
            }
        }

        // Step 2 : Finding Delta

        // 2.1 - Find the delta (error rate) of the output layer
        for (int i = 0; i < outputLayer.neurons.length; i++) {
            Neuron neuron = outputLayer.neurons[i];
            double delta = t.outputs[i] - neuron.getOutputValue();
            neuron.setDeltaValue(t.outputs[i] - neuron.getOutputValue());
            totalError += Math.abs(delta);
        }

        // 2.2 - Calculate the delta of hidden neurons, backwards
        for (int i = layers.size() - 2; i > 0; i--) {
            NeuronLayer currentLayer = layers.get(i);

            for (Neuron someNeuron : currentLayer.neurons) {
                double errorFactor = 0;
                for (Connection connection : someNeuron.forwardConnections) {
                    Neuron connectedNeuron = connection.to;
                    errorFactor += (connectedNeuron.deltaValue * connection.weight);
                }
                someNeuron.setDeltaValue(errorFactor);
            }

        }

        // Step 3 - Update Free Params
        for (NeuronLayer layer : layers) {
            for (Neuron n : layer.neurons) {
                n.updateFreeParams();
            }
        }

        return totalError;

    }

}
