Transformer’s contextual learning capabilities

《Uncovering mesa-optimization algorithms in Transformers》

Paper link: https://arxiv.org/abs/2309.05858

Why does transformer perform so well? Where does the In-Context Learning capability it brings to many large language models come from? In the field of artificial intelligence, transformer has become the dominant model in deep learning, but the theoretical basis for its excellent performance has been insufficiently studied.
Recently, new research from researchers at Google AI, ETH Zurich, and Google DeepMind has attempted to reveal the answer to the mystery. In new research, they reverse-engineered the transformer and found some optimization methods.

The authors show that minimizing the general autoregressive loss results in an auxiliary gradient-based optimization algorithm operating in the forward pass of the Transformer. This phenomenon has recently been called "mesa-optimization." Furthermore, the researchers found that the resulting mesa optimization algorithm exhibited contextual small-shot learning capabilities, independent of model size. The new results therefore complement the principles of small-shot learning that have emerged previously in large language models.

The researchers believe that the success of Transformers is based on its architectural bias in implementing the mesa optimization algorithm in the forward pass: (i) defining internal learning goals, and (ii) optimizing them.

Figure 1: Illustration of the new hypothesis: optimizing the weights θ of the autoregressive Transformer fθ results in a mesa optimization algorithm implemented in the forward propagation of the model. As input sequence s_1, . . . , s_t is processed to time step t, Transformer (i) creates an internal training set consisting of input-target association pairs, (ii) defines an internal objective function through the result dataset, which is used to measure the performance of the internal model using weights W, (iii) Optimize this objective and use the learned model to generate future predictions.  

Contributions of this study include:

  • summarizes the theory of von Oswald et al. and shows how Transformers can theoretically predict the next element of a sequence from regression by optimizing an internally constructed objective using gradient-based methods.

  • Experimentally reverse-engineered Transformers trained on a simple sequence modeling task and found strong evidence that their forward pass implements a two-step algorithm: (i) Early self-attention layers are constructed by grouping and copying markers Internal training data set, so the internal training data set is built implicitly. Define internal objective functions and (ii) optimize these objectives at a deeper level to generate predictions.

  • Similar to LLM, experiments show that simple autoregressive training models can also become contextual learners, and on-the-fly adjustments are crucial to improve contextual learning of LLM and can also improve performance in specific environments.

  • Inspired by the discovery that attention layers try to implicitly optimize the internal objective function, the authors introduce the mesa layer, a new type of attention layer that can effectively solve the least squares optimization problem, rather than just taking a single gradient step to achieve the most optimal excellent. Experiments demonstrate that a single mesa layer outperforms deep linear and softmax self-attention Transformers on simple sequential tasks while providing more interpretability.

  • After preliminary language modeling experiments, it was found that replacing the standard self-attention layer with the mesa layer achieved promising results, demonstrating the powerful contextual learning capabilities of this layer.

Building on recent work, it has been shown that transformers explicitly trained to solve small-shot tasks in context can implement gradient descent (GD) algorithms. Here, the authors show that these results generalize to autoregressive sequence modeling—a typical approach to training LLMs.

We first analyze a transformer trained on simple linear dynamics, where each sequence is generated by a different W* - to prevent cross-sequence memorization. In this simple setup, the authors demonstrate a transformer that creates a mesa dataset and then uses preprocessed GD to optimize the mesa target.

This study trains deep transformers on token structures that aggregate adjacent sequence elements. Interestingly, this simple preprocessing results in an extremely sparse weight matrix (less than 1% of the weights are non-zero), resulting in a reverse-engineered algorithm.

For single-layer linear self-attention, the weight corresponds to one GD step. For deep transformers, interpretability becomes difficult. This study relies on linear probing and examines whether latent activations predict autoregressive targets or preprocessed inputs.

Interestingly, the predictability of both detection methods gradually improves with increasing network depth. This finding suggests that preprocessed GD is hidden in the model.

Figure 2: Reverse engineering of a trained linear self-attention layer.

The study found that the training layer can be perfectly fitted when all degrees of freedom are used in the construction, including not only the learned learning rate η, but also a learned set of initial weights W_0. Importantly, as shown in Figure 2, the learned one-step algorithm still performs far better than a single mesa layer.

We can note that with simple weight settings, it is easy to find through basic optimization that this layer optimally solves the task studied here. This result demonstrates the advantage of hard-coded inductive biases in favor of mesa optimization.

With theoretical insights into the multi-layer case, deep linear and softmax are first analyzed paying attention only to the Transformer. The author formats the input according to the 4-channel structure,

picture

, which corresponds to the choice W_0 = 0.

As with the single-layer model, the authors saw clear structure in the weights of the trained model. As a first reverse engineering analysis, this study exploits this structure and builds an algorithm (RevAlg-d, where d represents the number of layers) containing 16 parameters per layer header (instead of 3200). The authors found that this compressed but complex expression can describe the trained model. In particular, it allows interpolation between actual Transformer and RevAlg-d weights in an almost lossless manner.

While the RevAlg-d expression explains a trained multi-layer Transformer with a small number of free parameters, it is difficult to interpret it as a mesa optimization algorithm. Therefore, the authors employed linear regression probing analysis (Alain & Bengio, 2017; Akyürek et al., 2023) to find the characteristics of the hypothesized mesa optimization algorithm.

On the deep linear self-attention Transformer shown in Figure 3, we can see that both probes can be linearly decoded, and the decoding performance increases with sequence length and network depth. Therefore, base optimization discovers a hybrid algorithm that descends layer by layer on the original mesa-objective Lt (W) while improving the condition number of the mesa optimization problem. This results in a rapid decline in mesa-objective Lt (W). It can also be seen that performance improves significantly with increasing depth.

It can therefore be considered that the rapid decline of the autoregressive mesa-objective Lt (W) is achieved by stepwise (cross-layer) mesa optimization on better preprocessed data.

Figure 3: Multi-layer Transformer training with reverse engineering of constructed token inputs. 

This shows that if the transformer is trained on the constructed tokens, it will predict with mesa optimization. Interestingly, when sequence elements are given directly, the transformer will construct the token by itself by grouping the elements, which the research team calls "creating the mesa dataset".

in conclusion

This study shows that the Transformer model is capable of developing gradient-based inference algorithms when trained on a sequence prediction task under a standard autoregressive objective. Therefore, state-of-the-art results obtained in multi-task, meta-learning settings can also be transferred to traditional self-supervised LLM training settings.

Additionally, the study found that learned autoregressive inference algorithms can be repurposed to solve supervised contextual learning tasks without the need for retraining, explaining results within a single unified framework.

So, how does this relate to in-context learning? The study believes that after training the transformer on the autoregressive sequence task, it achieves appropriate mesa optimization and therefore can perform few-shot context learning without any fine-tuning.

This study assumes that mesa optimization also exists for LLM, thereby improving its contextual learning capabilities. Interestingly, the study also observed that effectively adapting prompts for LLM can also lead to substantial improvements in contextual learning capabilities.

 

 

 

Guess you like

Origin blog.csdn.net/Angelina_Jolie/article/details/133136574