What do all these layers do in LLM?
We can describe the architecture, training procedure, and inference flow for LLM, But what about the hidden/internal state? Do we understand what LLM generates inside before showing the results to the user?
We can write a simple prompt and ask LLM to repeat words - Repeat words: dog -> dog, pet -> pet, house -> ?. We can expect that LLM will return - house. But what if we patch the LLM inner state of LLM during inference and replace word house with something else? For example, with a token from a layer we are interested in? In other words - there is a layer in a source LLM, and we are interested in it. We take the state of the source layer and replace a layer in the target LLM with a task to repeat words. We expect the target LLM will repeat the word and decode the inner state from the source LLM, so that we can read decoded word. This is a Patchscopes framework.
Applying Patchscopes, we can see how LLM generates text layer by layer. This framework also improves multi-hop reasoning from 19.57% to 50%.
Attention is the core part of Transformer architecture. Attention heads compute the importance of words in a prompt and help LLM generate high-quality answers. Recent research shows that Attention heads play different roles. For instance, there is the Induction Head, which is responsible for understanding patterns and predicting when the same instance will occur again in the sequence.
There is one more type of attention head—the Retrieval Head. If the token in the output sequence is the same as in the input sequence, then we say that an attention head performs a copy-paste operation. Based on this information, we can calculate the Retrieval Score for each head. Only 3% to 6% of heads retrieve more than 10% from the input sequence.
Retrieval Heads are responsible for in-context learning. They find important information in the prompt and promote it to the other layers. Disabling non-retrieval heads in the model won't affect its performance on Needle-in-a-Haystack tasks, but disabling retrieval heads will cause big drops in quality. Moreover, Retrieval Heads are good at detecting hallucinations - when these heads are not activated during inference, it's a sign of hallucination.
Interestingly, not all layers are needed to generate a response. The technique is called model pruning and can be applied to any neural network. In the ShortGPT paper, a new metric, Block Inference, was defined. Block Inference measures the contribution of each layer in the final result. Removing low-contribution layers shouldn't impact the overall quality too much. Using these techniques, removing ten layers out of 40 in Llama-2 13B with a performance drop of less than 5% on popular benchmarks is possible.
References:
https://arxiv.org/abs/2404.15574 - Retrieval Head Mechanistically Explains Long-Context Factuality
https://arxiv.org/abs/2401.06102 - Patchscopes: A Unifying Framework for Inspecting Hidden Representations of Language Models
https://arxiv.org/abs/2403.03853 - ShortGPT: Layers in Large Language Models are More Redundant Than You Expect
https://arxiv.org/abs/2209.11895 - In-context Learning and Induction Heads