Study notes: Well-behaved training in Deep Neural Networks
Study notes for The lazy (NTK) and rich (µP) regimes: A gentle tutorial by Dhruva Karkada and A Spectral Condition for Feature Learning by Greg Yang, James B. Simon and Jeremy Bernstein.
Core Idea
For a deep neural network, it is crucial to ensure well-behaved training, where updates in response to the training data are relevant, efficient, and balanced across neurons and layers. One effective approach to achieving this is by carefully designing the initialization strategy for each layer and assigning layer-specific learning rates. In this framework, the training dynamics are controlled by a single degree of freedom, known as richness, which governs the extent of feature learning versus kernel-like (lazy) behavior.
Mathematical Notation
The scaling notation \(O, \Omega\) and \(\Theta\) will be commonly used in this note. Given functions \(f\) and \(g\). We say \(f=O(g)\) if \(f\) scales no faster than \(g\), i.e., there exists a constant \(C>0\) such that \(f(n)\leq C\cdot g(n)\) for sufficiently large \(n\); we say \(f=\Omega(g)\) if \(f\) scales at least as fast as \(g\), i.e., there exists a constant \(c>0\) such that \(f(n)\geq c\cdot g(n)\) for sufficiently large \(n\); we say \(f=\Theta(g)\) if \(f=O(g)\) and \(f=\Omega(g)\).
Regarding vector and matrix operations, I use \(\mathbf{a} \otimes \mathbf{b}\) to denote outer-product of two vectors \(\mathbf{a}\) and \(\mathbf{b}\). \(\Delta\) is used to denote changes, especially changes across a single optimization step.
Regarding deep learning, I primarily focus on multilayer perceptron (MLP). In this note, I use \(W_l\) to denote the weight matrix of the \(l\)th layer, \(g_l\) to denote the scalar gradient multiplier at layer \(l\), and \(h_l(\mathbf{x})\in \mathbb{R}^{n_l}\) to denote the features of input \(\mathbf{x}\) at layer \(l\) with size \(n_l\). Formally, \(h_l(\mathbf{x})\) is recursively defined as
\[h_l(\mathbf{x})=\left\{ \begin{array}{ll} g_lW_lh_{l-1}(\mathbf{x}) & \text{if } l > 0 \\ \mathbf{x} & \text{if } l = 0 \end{array} \right.\]In addition, \(\mathcal{L}\) is used to denote the scalar loss.
Desideratum
What are the criteria to say a training as well-behaved? Let’s first consider how a signal from a single input \(\mathbb{x}\) changes after a gradient descend step. At layer \(l, l>0\), after updating the weights by \(\Delta W_l\), the new representation becomes
\[h_l+\Delta h_l = g_l(W_l+\Delta W_l)(h_{l-1}+\Delta h_{l_1})\]By substituting the expression for \(h_l\), we obtain an expression for the representation update as the sum of three terms:
\[\Delta h_l = \underbrace{g_l\Delta W_l h_{l-1}}[layer] + \underbrace{g_lW_l\Delta h_{l_1}}[passthrough] + \underbrace{g_l\Delta W_l \Delta h_{l-1}}[interaction]\]The first term layer is induced by the update of current layer’s weight \(\Delta W_l\); the second term passthrough is induced by the update of previous layer \(\Delta h_{l_1}\) passing through the old weight \(W_l\); the third term interaction captures the interaction between the layer update and the previous representation update.
For a well-behaved training, we expect that each representation update \(h_l\) is controlled. Specifically, a well-behaved training must satisfy the following three criteria:
- Useful Update Criterion (UUC): each representation update should contribute to optimizing the loss. The scale of that contribution should be the same for each representation update. Mathematically speaking, for \(l\geq 1\),
- Maximality Criterion (MAX): A layer’s weight update \(\Delta W_l\) should contribute non-negligibly to the following representation update \(\Delta h_l\). Any layer’s weight should not be frozen in training. Mathematically speaking, for \(l\geq 1\),
- Nontriviality Criterion(NTC): the updates to the outputs shouldn’t scale with the width of hidden layers. This ensures that the loss decreases at a width-independent rate. Mathematically speaking, for the last layer \(k\),
Assumptions
- Model Simplicity: we temporarily ignore activation functions between each layer. Moreover, we particularly focus on 3-layer MLP, \(h_3(\mathbf{x}):= g_3W_3g_2W_2g_1W_1\mathbf{x}=g_3W_3h_2(\mathbf{x})\).
- Wide Hidden Layers: we assume both of the two hidden layers have the same scale \(n\) that is far larger than either the input or output size:
Gaussian Initialization: we assume each layer’s weight matrix \(W_l\) is initialized with Gaussian initialization with scale \(\sigma_l\), i.e., every entry of \(W_l\) is drawn from \(\mathcal{N}(0, \sigma_l^2)\).
Bounded Input: we assume the input \(\mathbf{x}\) is \(\Theta(1)\), and each entry of \(\mathbf{x}\) is independent of \(n_0\), i.e., \(\Vert h_0(\mathbf{x})\Vert^2 =\Vert \mathbf{x}\Vert^2=\Theta(n_0)\).
Bounded Representations: we assume the initial entries of hidden representations, \(h_1^{i}\) and \(h_2^{j}\), are \(\Theta(1)\). With the bounded input assumption, this can be enforced through adjusting values of \(g_l\) and \(\sigma_l\).
Richness Scale Derivation
Although there are several hyper-parameters, including \(g_l\) and \(\sigma_l\), it turns out that they are governed by only one degree of freedom called richness, in order to satisfy those criteria and assumptions. I will try to present the derivation in an intuitive way. For convenience, I will use \(a \sim b\) to denote \(a =\Theta(b)\).
Claim 1: \(g^2_l \sigma^2_l n_{l-1}\sim 1\) for \(l=1,2\), and \(g^2_3 \sigma^2_3 n_{2}=O(1)\).
- Intuition of Proof: Consider the recursive formulae for hidden representations for \(1\leq l\leq 3\): \(h_l(\mathbf{x})=g_lW_lh_{l-1}(\mathbf{x})\) Take the squared norm of both sides, we get
By the assumption of bounded representations, we have \(\Vert h_l(\mathbf{x})\Vert^2\sim n_l\), and thus
\[n_l \sim g_l^2 \Vert W_lh_{l-1}(\mathbf{x})\Vert^2 \\\]And therefore, we can justify the Claim 1.
Claim 2: Changes of representation in both hidden layers are of the same scale, i.e., \(\Vert\Delta h_1\Vert\sim \Vert\Delta h_2\Vert\)
We call the scale as richness, \(\Vert\Delta h\Vert\), i.e., \(\Vert\Delta h\Vert\sim \Vert\Delta h_1\Vert\sim \Vert\Delta h_2\Vert\)
Claim 3: All \(\sigma_l\) must be chosen with respect to the richness \(\Vert\Delta h\Vert\); the scale of \(\sigma_l\) is reciprocal of \(\Vert\Delta h\Vert\), i.e., \(\sigma_l \sim \frac{1}{\Vert\Delta h\Vert}\)
With those three claims, we can summarize the scales of all relevant quantities in the table below:
\(g_l\) | \(\sigma_l\) | \(\Vert h_l\Vert\) | \(g_l\Delta W_l h_{l-1}\) | \(g_l W_l \Delta h_{l-1}\) | |
---|---|---|---|---|---|
\(l=1\) | \(\frac{\Vert \Delta h\Vert}{n_0}\) | \(\frac{1}{\Vert \Delta h\Vert}\) | \(\sqrt{n}\) | \(\Vert \Delta h\Vert\) | 0 |
\(l=2\) | \(\frac{\Vert \Delta h\Vert}{n}\) | \(\frac{1}{\Vert \Delta h\Vert}\) | \(\sqrt{n}\) | \(\Vert \Delta h\Vert\) | \(\Vert \Delta h\Vert\) |
\(l=3\) | \(\frac{1}{n}\) | \(\frac{1}{\Vert \Delta h\Vert}\) | \(\frac{\sqrt{n_3}}{\Vert \Delta h \Vert}\) | 1 | 1 |
Significance of Richness Scale
U