Class NeuralNetwork

  • All Implemented Interfaces:
    java.io.Serializable

    public class NeuralNetwork
    extends java.lang.Object
    implements java.io.Serializable
    a class for a Neural Network that will take in inputs and send them through difference layers and generate outputs
    Author:
    Bhagat
    See Also:
    Serialized Form
    • Field Detail

      • shape

        private int[] shape
        an array defining the shape of the network
      • numOfInputs

        private int numOfInputs
        the number of inputs
      • numOfOutputs

        private int numOfOutputs
        the number of outputs
      • numsOfHiddens

        private int[] numsOfHiddens
        the number of hidden nodes in each hidden layer
      • weights

        private Matrix[] weights
        the weights for each layer
      • bias

        private Matrix[] bias
        the bias for each layer
      • learingRate

        private double learingRate
        the learning rate
      • activationFunction

        private SerializableFunction<java.lang.Double,​java.lang.Double> activationFunction
        the activation function
      • defaultActivationFunction

        public static SerializableFunction<java.lang.Double,​java.lang.Double> defaultActivationFunction
        the default activation function the sigmoid function takes any numbers and puts it between 0 and 1
    • Constructor Detail

      • NeuralNetwork

        public NeuralNetwork​(SerializableFunction<java.lang.Double,​java.lang.Double> activationFunction,
                             int... shape)
        Creates a NeuralNetwork with a specified shape
        Parameters:
        shape - an array defining the shape of the NeuralNetwork
        activationFunction - the activation function
      • NeuralNetwork

        public NeuralNetwork​(int... shape)
        Creates a NeuralNetwork with a specified shape
        Parameters:
        shape - an array defining the shape of the NeuralNetwork
    • Method Detail

      • feedForward

        public void feedForward​(DataPoint dataPoint)
        sets the output of feedForward into the dataPoint
        Parameters:
        dataPoint - the data point
      • feedForward

        public double[] feedForward​(double... inputs)
        the feed forward algorithm for making a guess based on the inputs
        Parameters:
        inputs - an array of inputs
        Returns:
        an array for the outputs
      • feedForward

        public Vector feedForward​(Vector inputs)
        the feed forward algorithm for making a guess based on the inputs
        Parameters:
        inputs - a Vector that hold the inputs
        Returns:
        a Vector for the outputs
      • feedForward

        public Matrix feedForward​(Matrix inputs)
        the feed forward algorithm for making a guess based on the inputs
        Parameters:
        inputs - a Matrix that hold the inputs
        Returns:
        a Matrix for the outputs
      • train

        public void train​(DataSet dataSet)
        trains the network using all the data points in the data set
        Parameters:
        dataSet - the data set to train the network with
      • train

        public void train​(DataSet dataSet,
                          int epoch)
        trains the network using all the data points in the data set
        Parameters:
        dataSet - the data set to train the network with
        epoch - how many times to train each data point
      • train

        public void train​(DataPoint dataPoint)
        trains the network using a data point
        Parameters:
        dataPoint - the data point
      • train

        public void train​(double[] inputs,
                          double[] targets)
        trains the network based on certain inputs and the known targets for those inputs
        Parameters:
        inputs - the inputs
        targets - the targets
      • train

        public void train​(Vector inputs,
                          Vector targets)
        trains the network based on certain inputs and the known targets for those inputs
        Parameters:
        inputs - the inputs
        targets - the targets
      • train

        public void train​(Matrix inputs,
                          Matrix targets)
        trains the network based on certain inputs and the known targets for those inputs
        Parameters:
        inputs - the inputs
        targets - the targets
      • test

        public double test​(DataSet dataSet)
        tests the data
        Parameters:
        dataSet - the testing data set
        Returns:
        the accuracy
      • test

        public double test​(DataSet dataSet,
                           boolean log)
        tests the data
        Parameters:
        dataSet - the testing data set
        log - whether or not to log out the tests
        Returns:
        the accuracy
      • readObject

        private void readObject​(java.io.ObjectInputStream in)
                         throws java.io.IOException,
                                java.lang.ClassNotFoundException
        Throws:
        java.io.IOException
        java.lang.ClassNotFoundException
      • serialize

        public void serialize​(java.lang.String filename)
        Serializes the neural network into a file
        Parameters:
        filename - the file name
      • getWeights

        public Matrix[] getWeights()
        Returns:
        the weights
      • setWeights

        public void setWeights​(Matrix[] weights)
        Parameters:
        weights - the weights to set
      • getBias

        public Matrix[] getBias()
        Returns:
        the bias
      • setBias

        public void setBias​(Matrix[] bias)
        Parameters:
        bias - the bias to set
      • getNumOfInputs

        public int getNumOfInputs()
        Returns:
        the numOfInputs
      • getNumOfOutputs

        public int getNumOfOutputs()
        Returns:
        the numOfOutputs
      • getNumsOfHiddens

        public int[] getNumsOfHiddens()
        Returns:
        the numsOfHiddens
      • getActivationFunction

        public SerializableFunction<java.lang.Double,​java.lang.Double> getActivationFunction()
        Returns:
        the activationFunction
      • setActivationFunction

        public void setActivationFunction​(SerializableFunction<java.lang.Double,​java.lang.Double> activationFunction)
        Parameters:
        activationFunction - the activationFunction to set
      • getLearingRate

        public double getLearingRate()
        Returns:
        the learingRate
      • setLearingRate

        public void setLearingRate​(double learingRate)
        Parameters:
        learingRate - the learingRate to set