Lab41 Reading Group: Swapout: Learning an Ensemble of Deep Architectures
This article originally appeared on Lab41's blog: Gab41. It is reposted here with permission.
Next up for the reading group is a paper about a new stochastic training method written by Saurabh Singh, Derek Hoiem, and David Forsyth of the University of Illinois at Urbana–Champaign. Their new training method is like dropout, stochastic depth, and ResNets but with its own special twist. I recommend picking up the paper after going through this post, it is very readable and includes an excellent section on performing inference with a stochastically trained network that I will only touch on.
As you may recall, dropout works by randomly setting individual neuron outputs in a network to zero, essentially dropping those neurons from training and hence forcing the network to use a variety of signals instead of over-training on one. Stochastic depth (covered in a previous post) is similar, but instead of dropping neurons it bypasses whole layers! We can think of these operations a little more mathematically, but first I’ll have to define some notation.
I’ll use block to mean a set of layers in some specific configuration (for example, a convolution followed by a ReLU), and a unit to be one of the computational nodes within the block (basically a neuron). X
will be the input from the previous block, and F(X)
will be the output from a unit within the current block.
Using this notation then, we can think about ResNets as consisting of blocks where ever unit in the block always reports X + F(X)
. A standard, feed-forward layer can be viewed in this framework as well, with each unit always reporting F(X)
. The paper includes a figure, which I’ve edited and included below, showing feed-forward and ResNets in this scheme:
Things become more interesting when we start thinking about stochastic training methods in this manner. Dropout can be thought of as randomly selecting the output for each unit from the following set of possible outcomes: {0, F(X)}
. Likewise, stochastic depth can be thought of as randomly selecting between the outcomes {X, F(X)}
for each block, so that every unit in the block returns X
or F(X)
together. Both of these training methods are shown in the figure below, which is again has been modified from the paper:
So now that I’ve laid the groundwork, what does swapout add? Well, add isn’t really the right word, swapout combines! It randomly selects from the four possible outcomes mentioned above: feed-forward, ResNet, dropout, and stochastic depth. They do this by allowing each unit to randomly select from the following outcomes: {0, X, F(X), X + F(X)}
. Therefore, swapout samples from every possible stochastic depth and ResNet architecture, both including and not include dropout!
In addition to swapout, the authors define a simpler version called skipforward. Skipforward only allows units to select from the outcomes {X, F(X)}
, that is limiting the choice to only stochastic depth and feed-forward. Both of these architectures are shown in the figure below, which is again from the paper with modification:
One of the dilemmas when using stochastic training methods is: how do I use the network at inference time? When training the network is constantly mutating as units pick different ways of behaving, but at inference time that network needs to be roughly static so that the same input will always yield the same prediction. We can make the network static in two ways:
- Deterministic inference: All values are replaced by their expectation value. That is, the unit that was dropped half the time is set to 50% weight.
- Stochastic inference: Several versions of the network are randomly generated at the end of training and their results averaged. A unit that is dropped half the time would (by chance) appear in half randomly generated versions of the network.
Although it seems like deterministic inference should be faster (because it does not require running multiple networks) it has several drawbacks. The first drawback is that you can not actually calculate the true expectation value for a swapout network, only approximate it. The second is the the fact that batch normalization—one of the most powerful training methodologies—does not work with deterministic inference. The authors conclude (through testing) that stochastic inference works best.
The authors test swapout and skipforward networks against networks trained with stochastic depth, dropout, and various ResNet architectures. They conclude:
- Swapout improves results as compared with ResNet
- Stochastic inference beats deterministic, even with only a few averaged results
- Increasing the width of a network greatly improves performance
One final note on the paper: the way they define the various operations as random selections from a set of possible outcomes is, for me, a very intuitive way to think about them. I would love to see other papers use a similar framework for describing their network modifications!