Study notes: Well-behaved training in Deep Neural Networks

Study notes for The lazy (NTK) and rich (µP) regimes: A gentle tutorial by Dhruva Karkada and A Spectral Condition for Feature Learning by Greg Yang, James B. Simon and Jeremy Bernstein.

Core Idea

For a deep neural network, it is crucial to ensure well-behaved training, where updates in response to the training data are relevant, efficient, and balanced across neurons and layers. One effective approach to achieving this is by carefully designing the initialization strategy for each layer and assigning layer-specific learning rates. In this framework, the training dynamics are controlled by a single degree of freedom, known as richness, which governs the extent of feature learning versus kernel-like (lazy) behavior.

Mathematical Notation

The scaling notation \(O, \Omega\) and \(\Theta\) will be commonly used in this note. Given functions \(f\) and \(g\). We say \(f=O(g)\) if \(f\) scales no faster than \(g\), i.e., there exists a constant \(C>0\) such that \(f(n)\leq C\cdot g(n)\) for sufficiently large \(n\); we say \(f=\Omega(g)\) if \(f\) scales at least as fast as \(g\), i.e., there exists a constant \(c>0\) such that \(f(n)\geq c\cdot g(n)\) for sufficiently large \(n\); we say \(f=\Theta(g)\) if \(f=O(g)\) and \(f=\Omega(g)\).

Regarding vector and matrix operations, I use \(\mathbf{a} \otimes \mathbf{b}\) to denote outer-product of two vectors \(\mathbf{a}\) and \(\mathbf{b}\). \(\Delta\) is used to denote changes, especially changes across a single optimization step.

Regarding deep learning, I primarily focus on multilayer perceptron (MLP). In this note, I use \(W_l\) to denote the weight matrix of the \(l\)th layer, \(g_l\) to denote the scalar gradient multiplier at layer \(l\), and \(h_l(\mathbf{x})\in \mathbb{R}^{n_l}\) to denote the features of input \(\mathbf{x}\) at layer \(l\) with size \(n_l\). Formally, \(h_l(\mathbf{x})\) is recursively defined as

\[h_l(\mathbf{x})=\left\{ \begin{array}{ll} g_lW_lh_{l-1}(\mathbf{x}) & \text{if } l > 0 \\ \mathbf{x} & \text{if } l = 0 \end{array} \right.\]

In addition, \(\mathcal{L}\) is used to denote the scalar loss.

Desideratum

What are the criteria to say a training as well-behaved? Let’s first consider how a signal from a single input \(\mathbb{x}\) changes after a gradient descend step. At layer \(l, l>0\), after updating the weights by \(\Delta W_l\), the new representation becomes

\[h_l+\Delta h_l = g_l(W_l+\Delta W_l)(h_{l-1}+\Delta h_{l_1})\]

By substituting the expression for \(h_l\), we obtain an expression for the representation update as the sum of three terms:

\[\Delta h_l = \underbrace{g_l\Delta W_l h_{l-1}}[layer] + \underbrace{g_lW_l\Delta h_{l_1}}[passthrough] + \underbrace{g_l\Delta W_l \Delta h_{l-1}}[interaction]\]

The first term layer is induced by the update of current layer’s weight \(\Delta W_l\); the second term passthrough is induced by the update of previous layer \(\Delta h_{l_1}\) passing through the old weight \(W_l\); the third term interaction captures the interaction between the layer update and the previous representation update.

For a well-behaved training, we expect that each representation update \(h_l\) is controlled. Specifically, a well-behaved training must satisfy the following three criteria:

\[\left| \frac{\partial \mathcal{L}}{\partial h_l}^T \Delta h_l \right| = \Theta(1)\]\[\left\Vert g_l \Delta W_l h_{l-1} \right\Vert = \Theta(\Delta h_l)\]\[\left\Vert \Delta h_k \right\Vert = \Theta(1)\]

Assumptions

\[n = \Theta(n_1) = \Theta(n_2) >> 1 = \Theta(n_3) = \Theta(n_0)\]

Richness Scale Derivation

Although there are several hyper-parameters, including \(g_l\) and \(\sigma_l\), it turns out that they are governed by only one degree of freedom called richness, in order to satisfy those criteria and assumptions. I will try to present the derivation in an intuitive way. For convenience, I will use \(a \sim b\) to denote \(a =\Theta(b)\).

Claim 1: \(g^2_l \sigma^2_l n_{l-1}\sim 1\) for \(l=1,2\), and \(g^2_3 \sigma^2_3 n_{2}=O(1)\).

\[\Vert h_l(\mathbf{x})\Vert^2 = g_l^2 \Vert W_lh_{l-1}(\mathbf{x})\Vert^2\]

By the assumption of bounded representations, we have \(\Vert h_l(\mathbf{x})\Vert^2\sim n_l\), and thus

\[n_l \sim g_l^2 \Vert W_lh_{l-1}(\mathbf{x})\Vert^2 \\\]

And therefore, we can justify the Claim 1.

Claim 2: Changes of representation in both hidden layers are of the same scale, i.e., \(\Vert\Delta h_1\Vert\sim \Vert\Delta h_2\Vert\)

We call the scale as richness, \(\Vert\Delta h\Vert\), i.e., \(\Vert\Delta h\Vert\sim \Vert\Delta h_1\Vert\sim \Vert\Delta h_2\Vert\)

Claim 3: All \(\sigma_l\) must be chosen with respect to the richness \(\Vert\Delta h\Vert\); the scale of \(\sigma_l\) is reciprocal of \(\Vert\Delta h\Vert\), i.e., \(\sigma_l \sim \frac{1}{\Vert\Delta h\Vert}\)

With those three claims, we can summarize the scales of all relevant quantities in the table below:

 \(g_l\)\(\sigma_l\)\(\Vert h_l\Vert\)\(g_l\Delta W_l h_{l-1}\)\(g_l W_l \Delta h_{l-1}\)
\(l=1\)\(\frac{\Vert \Delta h\Vert}{n_0}\)\(\frac{1}{\Vert \Delta h\Vert}\)\(\sqrt{n}\)\(\Vert \Delta h\Vert\)0
\(l=2\)\(\frac{\Vert \Delta h\Vert}{n}\)\(\frac{1}{\Vert \Delta h\Vert}\)\(\sqrt{n}\)\(\Vert \Delta h\Vert\)\(\Vert \Delta h\Vert\)
\(l=3\)\(\frac{1}{n}\)\(\frac{1}{\Vert \Delta h\Vert}\)\(\frac{\sqrt{n_3}}{\Vert \Delta h \Vert}\)11

Significance of Richness Scale

U