Deep Dips #4: Training neural networks
In previous posts in this series, I’ve covered the basics of neural network models, talked about embedding models, and explained how transformers work. So, if you’ve read these, then you should already have an idea of what neural networks look like and how they’re used.
Next I want to cover the training and configuration of neural networks. I’m going to start off with backpropagation and gradient descent, which are the bread-and-butter of most neural network training. Then I’ll say a bit about reinforcement learning, gradient-free methods, and neural architecture search. Finally, a few words on approaches which are biologically more plausible.
Measuring performance
When a neural network is first created, its weights and biases (also known as its parameters) are set to random values. So at this point it’s unlikely to do anything useful. Training a neural network involves working out more appropriate values for these parameters, so that the neural network does do something useful.
You can tell how useful a neural network is by doing a forward pass. That is, you provide input, you propagate this through the network, and you get an output. Assuming you know what the output should be for a particular input, you can then work out how close it got to the right answer. Doing this for a whole bunch of input-output examples provides a measure of how well the neural network is doing. The accumulated differences between the correct outputs and the outputs the neural network actually gave is know as the loss, and the specific method used to determine this is known as the loss function.
In most cases, training a neural network amounts to attempting to minimise this loss, getting it as close to zero as possible. To use an unnecessarily complicated and opaque term — which, alas, people often do — this is known as empirical risk minimisation. There are various ways of doing this, but the most popular is something commonly referred to as backpropagation, or just backprop. However, it really consists of two different things: backpropagation and gradient descent, so I’ll cover these separately.
Backpropagation
Backpropagation is a method to determine how you need to tweak each parameter of a neural network in order to reduce the loss. It achieves this through the magic of differentiation, which is a mathematical procedure that takes a function and tells you how its output will change if you change one of its inputs.
Basically, backpropagation moves backwards through a neural network (hence the name), starting at the outputs1 and moving towards the inputs. As it does this, it uses differentiation2 to calculate values called gradients, and attaches one of these to each component of the neural network. This procedure rests on something called the chain rule, which allows you to work out the gradient for a particular component based on gradients you’ve already calculated further downstream — hence why backpropagation moves from the outputs to the inputs.
Gradients are just numbers that capture how much the loss of the neural network will change if you tweak the value at that point. And we’re really just interested in the gradients that are associated with weights and biases. So, if a weight has a high gradient associated with it, then we know a small change to that weight would have a large effect on the loss. Conversely, if a weight has a low gradient, then you’d need a larger tweak in order to have the same effect. Gradients can be positive or negative. For weights with positive gradients, increasing the weight would increase the loss — so this means we want to reduce these weights in order to improve the behaviour of the neural network. Likewise, we want to increase the values of weights for those with negative gradients. And similarly for biases.
Gradient descent
And this is where gradient descent comes in. Having completed backpropagation, you’ll now have a gradient value for each parameter in the neural network, which tells you how you need to tweak each weight and bias in order to reduce the overall loss. Gradient descent is just an iterative process for doing this tweaking.
An iteration in gradient descent is known as an epoch. During each epoch, you carry out a forward pass to calculate the loss, you then do backpropagation (also referred to as the backward pass in this context) to calculate the gradients, and you then tweak every parameter. The size of the tweak applied to each parameter is determined by its gradient multiplied by a (typically constant) negative value called the learning rate. Over a series of epochs, gradient descent will move the parameters from their initial random values to values which minimise the loss — eventually leading to a trained neural network.
Or at least that’s the aim. In practice, you need to find an appropriate learning rate for the problem you’re trying to solve. Too high, and you’ll find it overshoots the optimal parameter values. Too low, and it’ll take ages to converge. And this is assuming that it does converge, since gradient descent is not guaranteed to find optimal values.
Tweaking gradient descent
There are a bunch of methods which have been developed to improve the behaviour of gradient descent in one way or another.
Some of these are concerned with how much training data you pass through the neural network at each epoch. In the original formulation, forward passes were made for every sample in the training data, and then gradient values were calculated across all of these before any tweaking was done. But this was very time consuming if you had a large training set. So one simple innovation was to use less data during each epoch, and this generally leads to faster learning. If you use only one item of data each time, it’s known as stochastic gradient descent. If you use more than one, but not all the data, then it’s called mini-batching. The amount you use is known as the batch size.
Another issue with vanilla gradient descent is that the learning rate is a constant. But as you approach the optimum, it can actually be useful to perform smaller tweaks. Conversely, at the beginning of training, larger tweaks can be better to speed up convergence. This is addressed through the use of adaptive learning rates. The simplest approach is to gradually decrease the learning rate using a learning schedule.
But a more successful approach is used by Adam, which is short for adaptive moment estimation, and is currently the most popular form of gradient descent used for neural networks. Its main point of novelty is that it adds a fraction of a gradient’s value from the previous epoch onto its value in the current epoch. This is known as momentum, and helps to make gradient values more stable across epochs, particularly when using small batch sizes. However, Adam also uses this information to individually adapt the learning rate for each parameter.
Another common extension is the catchily-named L2 regularisation. Like most things in neural networks, this is not as complicated as it sounds, and basically involves adding a penalty term to the loss function in order to minimise the number of parameters that have non-zero values. So, if lots of weights and biases have non-zero values, then the loss is artificially increased to apply pressure towards bringing some of these to zero. This helps to prevent overfitting by reducing the effective number of parameters. A related concept is weight decay, which is used in a variant of Adam called AdamW.
Reinforcement learning
This all assumes that after each forward pass, you can measure how well the neural network did, and use this information to tweak its parameters. Whilst this is true for many situations, there are situations where a measure of performance is not available until much later, or is in some way incomplete. In these circumstances it’s common to use some kind of reinforcement learning.
A typical example is playing a computer game. The goal may be to complete a level, but to do so may involve carrying out a lot of other actions that contribute in some way towards whether and how quickly the level is completed. These are actions like avoiding monsters, not falling into pits, collecting power-ups. How does each of these actions contribute towards the eventual score? Or, at an even lower level, how does each move made by the player contribute towards these? Well, this is unclear. But if you’re using a neural network’s forward pass to determine the next move, then you need to know this information in order to correctly tweak its parameters during the subsequent backwards pass.
Reinforcement learning, in a nutshell, involves remembering all the actions that contributed to the outcome, and then apportioning reward to all those actions that led to it3. And this reward may be positive or negative depending on whether the goal was achieved. So, if a neural network — through a series of moves generated by forward passes — completed a level and received points as a consequence, these points will be apportioned back to all the individual moves that contributed to receiving them. And these rewards can then be used to drive the backward passes and update the neural network’s parameters4.
Optimisation without gradients
Backpropagation isn’t the only fish in the sea. It’s a pretty big one, and could feed you for much of your neural network training life, but there are also some pretty tasty minnows. An important group of minnows is the gradient-free optimisers — which, as the name suggests, do not calculate or use gradients when learning neural network parameters.
Optimisation is a pretty meaty subject in itself, so I’m not going to say much about how it works. But in a nutshell, it treats the parameters of a neural network as a list of numbers that have to be correctly determined in order to minimise some loss function. Any optimiser that works with numbers can be applied to this task, but a common approach is to use an evolutionary algorithm. This is based upon an analogue of natural selection — start with a population of random solutions (i.e. a bunch of neural networks with random weights and biases) and then iteratively kill off the weak ones and breed new solutions using the best ones. Breeding is done using mutation (which randomly change a small number of parameter values) and crossover (which splices together parts of two existing solutions).
Evolutionary algorithms tend to be better at avoiding local optima than gradient descent, which means they can in principle train better neural networks. However, gradient-free optimisers don’t scale anywhere near as well as backpropagation, which means in practice they can only optimise the parameters of relatively small neural networks. Anything above the thousands of parameters would be a push, and this is quite a limitation given that nowadays we’re routinely dealing with billions of parameters. But they do still have niches, and one of these is neural architecture search.
Neural architecture search
Before training a neural network, you have to choose an architecture. This includes things like how many layers there are, how many neurons there are in each layer, how the layers are connected, whether there are any residual connections, and which activation functions are used in each layer. Wouldn’t it be great if there were a way of choosing these things for you?
This is where neural architecture search (or NAS) comes in. It’s a way of learning the neural network architecture that is optimal for a particular task. There are various ways of doing this, including flavours of reinforcement learning, but the most successful approach is arguably evolutionary algorithms. This again rests on their ability to find good solutions and not get stuck in local optima. Scalability is much less of an issue here, because they’re only being used to learn a relatively small number of architectural parameters, rather than a huge number of weights and biases.
If you’re interested in doing a bit of NAS, it’s worth mentioning that it can be a very expensive procedure, since it basically involves generating a large number of candidate architectures and then training all their weights and biases each time (using backpropagation) in order to find out how good they are. So, it’s only really advisable for smaller neural networks, or for people with a huge pile of GPUs5.
Biologically more plausible approaches
And that’s pretty much it for training neural networks. However, I don’t want to leave without mentioning that although neural networks are modelled upon the structure of the brain, backpropagation takes pretty much no inspiration from biology. Yes, it works well enough, but I wouldn’t assume this is the final destination. In the longer term, it seems plausible that we could improve neural network training by learning more about how biological brains actually learn.
But currently we don’t really know how biological brains learn. There are theories, and these theories underlie some of the more biologically-plausible approaches to training neural networks. Leading amongst these is the idea of Hebbian learning, often summarised as “neurons that fire together, wire together”. This can be seen within a method called spike-timing dependent plasticity which is used to train spiking neural networks, which are themselves a more biologically-plausible model of brains. However, in its current form, Hebbian learning seems to be in no danger of challenging backpropagation as the go-to neural network trainer.
Too long; didn’t read
Neural networks are typically trained using backpropagation and gradient descent. Backpropagation is a way of working out how you need to tweak the network’s parameters in order to reduce the loss — which is a measure of how close a neural network is to its target behaviour. Gradient descent is an iterative procedure for carrying out these tweaks, with the aim of minimising the loss. There’s no guarantee it will lead to an optimal neural network, but more recent innovations like mini-batching, momentum and weight decay improve its behaviour, and are central to modern optimisers like Adam. Evolutionary algorithms are also sometimes used to train neural networks, and are particularly effective at optimising the architecture of a neural network. In the future, we might expect further improvements by looking more closely at how brains learn — that is, once we work out how they learn!
Well, actually at the output of the loss function, which also gets involved in the whole differentiation thing.
Specifically, partial differentiation and the chain rule. For an accessible introduction to this process, see this video by Andrej Karpathy.
Typically with some kind of strategy that assigns more points to more recent moves. This whole area is known as credit assignment.
Although exactly how this is done varies significantly. If you want to know more, you could look into deep Q-learning. This is the approach that DeepMind famously used to solve a bunch of Atari games problems.
So you won’t be surprised to learn that early work in this area was done by Google.