KL-Divergence Part 2: Gradient Descent
The previous tutorial was a non-technical introduction to KL-divergence that allowed you to perform your own variational inference to match distributions using KL-divergence to guide you. This post continues this, showing how the process can be automated using gradient descent.
Gradient Descent is so prevalent in computer science that even its wikipedia page contains a nice intuitive introduction (a rare honour for anything even slightly technical).
The basic idea is that you improve your current best guess by making small steps it what locally looks like the best direction. In the example shown in the cartoon, if you want to go down, it makes sense to keep heading in that general direction. Unfortunately, that strategy doesn't always work out how you might have hoped, sometimes leading you to only the closest low point and not the overall lowest point.
The gradient of a function at a point represents the slope of the function at that point. By following the gradient you can make sure you are heading in the (locally) correct direction. This is why it is called gradient descent. Again, I'm keeping this post as non-technical as possible. If you're interested in more details then this post is an excellent next step.
Below shows how this works on our simple setup from before. We have a known distribution \(q\). a hidden distribution \(p\), and we want to match \(q\) to \(p\). We know that \(KL(q,p)\) provides a value representing the similarity between \(p\) and \(q\) with \(KL(q,p) = 0\) being an exact match. So, if we have the KL-Divergence, we can use gradient descent along the KL values until (hopefully) we get \(KL(q,p) = 0\)
As before, we can using the KL measure in two different directions yielding slightly different values. Even with this slight differences the behaviour of gradient descent is radically different (as can be seen by using the buttons below).
Gradient descent is fairly effective in this simple example because the KL function only has one minimum. The algorithm (with a sensible step size) will eventually get there, though it may take an infuriatingly long time to converge. The step size is incredibly important to gradient descent. It says how far to go at each point. If it's too small, it'll take too long to get to the destination. Even worse, you will get stuck in every little minimum along the way. For example, an ant performing gradient descent on the cartoon might end up in the starred location. With the goal of always heading downwards they wouldn't be able to make it past the first divot in the ground. On the other hand, a giant may step over the entire mountain, never able to reach the destination.
You can play around with the step size in the example above. But, be warned, the algorithm is very sensitive to this. Bonus, set the step size to a negative value (-0.02 works well) for a demonstration of 'machine running away'.
A simple way to improve gradient descent is to add momentum. This works just like it sounds. Instead of stopping in between each step, some of the previous momentum is carried over each time. Just like a cheese rolling down a hill, a coin spinning round a circular charity cyclone donation box or a human in one of these.
You can play around with the momentum in the example above. The starting value of 0 means that no momentum is preserved. A value of 1 would mean that all of it is. Sensible values to try are 0.9 and 0.99.
In the next installment in this series I will explore how gradient descent performs in more complex domains, before finally turning to a gentle introduction to some of the maths going on in the background. In the meantime, as it's good enough to mention twice, this is the place to go for more on the topic.