This paper introduces "Mimic In-Context Learning" (MimIC), a novel method to enhance the In-Context Learning (ICL) capabilities of Large Multimodal Models (LMMs) by learning stable and generalizable "shift effects" from In-Context Demonstrations (ICDs). The authors observe that ICL performance in LMMs is highly sensitive to ICD configurations due to the synergistic effects of multimodal data. Traditional ICL relies on ICDs during inference, which can be computationally expensive and sensitive to the choice and order of these demonstrations. Previous "shift vector" based methods, which aim to learn a general mapping function from ICDs, have limitations in their approximation of how ICDs influence model behavior.
The core idea of MimIC is to more rigorously approximate the mathematical effect of ICDs, which is to add "shift vectors" to the hidden representations of query tokens within Transformer-based models. MimIC introduces lightweight learnable modules into LMMs with four key enhancements:
- Inserting shift vectors after attention layers: Unlike previous methods that place them after Feed-Forward Network (FFN) layers, MimIC aligns with the mathematical derivation showing that the shift effect should occur post-attention.
- Assigning a shift vector to each attention head: This allows each head to learn a unique shift, capturing distinct representation shifts for different aspects of the input.
- Making shift magnitude query-dependent: The scaling factor of the shift vector is dynamically adjusted based on the current query, allowing for more nuanced adaptations.
- Employing a layer-wise alignment loss: This loss function ensures that the hidden states of the MimIC-enhanced LMM (processing only the query) closely match the hidden states of the original LMM when performing standard ICL (processing query and ICDs).
The training process involves two parallel LMMs:
- The original LMM processes a query along with ICDs to generate target hidden states at each layer.
- The MimIC LMM processes only the query, using its learnable MimIC attention heads to produce shifted hidden states.
The total loss function combines this layer-wise alignment loss () with a standard LLMing loss () to maintain task performance: . By training with randomly selected ICDs, MimIC learns a general shift pattern. During inference, MimIC no longer requires ICDs, leading to significant speed improvements.
Implementation Details:
- The shift effect is approximated by a learnable vector in each attention head.
- The query-dependent magnitude is determined by a trainable linear layer that approximates , where represents the sum of exponentiated attention scores over ICDs.
- The MimIC attention head output is: .
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
function mimic_attention(query_q, keys_K, values_V, learnable_v, linear_f): # Standard attention (independent of ICDs) standard_attn_output = self_attention(query_q, keys_K, values_V) # Approximate Z1 (sum of attention scores over ICDs for this query) log_Z1_approx = linear_f(query_q) Z1_approx = exp(log_Z1_approx) # Z2 (sum of attention scores over query's own context) Z2 = sum(exp(dot_product(query_q, k_i)) for k_i in keys_K) # Simplified # Approximate query-dependent magnitude mu_approx = Z1_approx / (Z1_approx + Z2) # Shift vector application shifted_output = standard_attn_output + mu_approx * learnable_v return shifted_output |
Experiments and Results:
MimIC was evaluated on Idefics-9b and Idefics2-8b-base models across VQAv2, OK-VQA, and COCO Captioning tasks.
- Performance: MimIC consistently outperformed standard ICL (e.g., 32-shot ICL) and previous shift vector-based methods (like Function Vector, Task Vector, and LIVE) as well as LoRA. For instance, on Idefics-9b, MimIC achieved a 3.46% accuracy improvement on VQAv2, 3.57% on OK-VQA, and 9.00 CIDEr on COCO Captioning compared to 32-shot ICL.
- Data Efficiency: MimIC matched 32-shot ICL performance with guidance from only 1-shot ICL during its training. It also required significantly fewer training samples than methods like LIVE to achieve strong performance (e.g., surpassing LIVE's best with 1/8th the data).
- Ablation Studies: Confirmed the importance of each of MimIC's four enhancements (multi-head shift vectors, query-dependent magnitude, placement after attention, and layer-wise alignment loss). Using a multi-head, query-dependent magnitude was shown to be crucial.
- Alignment: MimIC demonstrated a closer alignment (smaller L2 distance) in latent space to traditional ICL compared to other methods, including a variant of MimIC using KL divergence (MimIC) and LoRA.
- Hallucinations: MimIC generated fewer hallucinations in image captioning tasks compared to other non-zero-shot methods while maintaining high recall. While hallucinations increased slightly with more simulated "shots" during training, they remained lower than standard ICL.
Key Contributions:
- Provides a more rigorous mathematical approximation of ICL's shift effects in LMMs, highlighting flaws in previous methods.
- Proposes a feasible method (MimIC) to achieve this approximation with few learnable parameters integrated into attention heads.
- Demonstrates consistent improvements over ICL, prior shift vector methods, and LoRA across multiple tasks and LMM architectures, with better data efficiency and reduced inference latency.
The paper concludes that MimIC effectively learns the ICL shift effect, offering competitive few-shot performance with reduced latency, fewer training samples than comparable methods, and fewer parameters than LoRA while often achieving better results and reducing hallucinations. The code is available at https://github.com/Kamichanw/MimIC
.