Tutorial: Neural network training

From stats++ wiki
Jump to: navigation, search

In the following tutorial, we will learn how to train a dataset using executable version of mlp in stats++. We will also be focussing upon various aspects of training that are required before training the network and also while proof checking the output of the trained network.

The general idea implemented in stats++ is of feed forward network[1]. See here[2] for the details on how to create the neural network in stats++. Once the network is created, it is ready for training. We will discuss important aspects of training, starting from choice of cost functions, followed by training method which suits the problem at hand. We will then discuss the problem of overfitting and how to resolve it. In the end we will demonstrate how to run an example problem and demonstrate the capability of the network

Cost Function

Feed forward neural network involves updation of weights of different units of neural network in the hidden layer. This updation is done based on the error in the output units. This error can be described by different normed function which we call as cost function. For problems involving linear regression we prefer to use mean squared error function[3] while cross entropy function [4] is used for problems involving binomial and multinomial classifications.

The shape of cost function determines the rate at which the network trains and if it has spherical curvature, the training becomes faster. So, at times people introduce a momentum term while evaluating the errors to make the curvature of the cost function closer to that of a n-sphere.

Training Methods

Methods for training the neural network can be broadly classified into two classes. The first class contains method such as back propagation which only utilizes first derivative information of the cost function with respect to weights of the network. Whereas the other class utilizes both first and second derivative information of the same. Quasi-Resilient Propagation function is one of them. The advantage of using the second derivative is that the training reaches a local minima much faster than the case when we use both derivatives. A clear demerit of doing that is we might get stuck in the local minima of the cost function which is not the global minima. Using single derivative becomes more advantageous in such cases as it ensures convergence of cost function to a local minima.

There is one more choice that we make apart from choosing the training method. This is the choice whether to do stochastic learning or batch learning. In former error is calculated at each iteration using a data point chosen randomly while in the latter average error is calculated using all the data points. In stochastic learning, there is noise in calculating derivatives (and hence second derivative training method fails for this type of learning) but reaches the optimal solution faster when compared with batch learning.

Fitting vs Overfitting

While calculating the errors during the training of the dataset, it might happen that, eventhough, the error decreases for the set of training data, overall the error might go up for the untrained data in the domain of interest. To resolve this issue, certain proportion of the training data is reserved and is used solely for calculating error for the network at some point. This is called validation test and is very important to prevent overfitting of data. Validation tests are not enough to determine when to stop training. It might happen that, while training, the validation error goes up slightly and then again starts decreasing. This corresponds to reaching local minima in the error and would not correspond to reaching the global minima. In that case, the criteria for stopping the training becomes tricky. There is a third segment of data which is kept solely for testing the data. This is used to determine how well the network has trained. This is the generalization data.

In stats++, the problem of finding global minima of errors during validation test is overcome by fitting a function through the validation error and then finding the global minima. There is one more way of controlling the overfitting. Usually, overfitting occurs when one of the weights of the unit becomes relatively larger than others. This can be avoided by limiting the size of the weights by putting a standard L2 weight penalty.

Example Problem

Restaurant Visitor Prediction : Data obtained from Kaggle