Layer-Aware Task Arithmetic (LATA)
- Layer-Aware Task Arithmetic (LATA) is a methodology for merging transformer models by leveraging the linearity of internal submodules to isolate task-specific adaptations.
- It employs layer-specific weighting using cosine similarity between task and instruction vectors to enhance multi-task performance and facilitate task forgetting.
- Experimental results on models like Llama and Gemma show that LATA reduces perplexity and improves accuracy by focusing on critical task-specific layers.
Layer-Aware Task Arithmetic (LATA) is a methodology for merging multiple fine-tuned transformer-based models, designed to enhance multi-task learning and model editing by leveraging the layer-wise structure and internal linearity properties of LLMs. LATA introduces explicit, principled mechanisms for either amplifying or attenuating the influence of particular submodules or layers, enabling the disentanglement of task-specific knowledge from general-purpose or instruction-following behavior during model merging. The approach relies on both empirical and theoretical findings that demonstrate substantially greater linearity in internal submodules than in the model as a whole, and operationalizes these insights by assigning task-specific, layer-aware weights during the arithmetic combination of different model parameters (Dai et al., 15 Apr 2025, Chen et al., 27 Feb 2025).
1. Motivations and Limitations of Traditional Task Arithmetic
The canonical Task Arithmetic (TA) framework considers each fine-tuning as a shift in parameter space—a "task vector"—with additive (multi-task) and subtractive (task forgetting) operations, typically computed as . Standard TA merges models by global, uniform coefficient addition or subtraction across all layers:
While TA allows simple model merging, it cannot disentangle task-specific adaptations from generic instruction-following updates, as both are entangled within . Direct addition frequently causes interference, cumulatively reinforcing instruction-following behavior across tasks and resulting in degraded utility (higher perplexity) and alignment issues (Chen et al., 27 Feb 2025). Fine-grained analysis reveals that only a subset of layers encode genuine task-specific deltas, while other layers primarily encode global instruction-following patterns.
2. Submodule Linearity and Empirical Foundations
Recent work highlights that, when models are fine-tuned, the linearity property—in which the model's output interpolates linearly between pre-trained and fine-tuned states—is satisfied to a far greater degree within internal submodules (layers, self-attention blocks, MLPs) than in the overall, end-to-end model. This is formalized as follows: for a submodule , a pre-trained weight , a fine-tuned weight , and , the submodule is linear if
for all and . Empirical assessments using the non-linearity score (NL), defined as the squared deviation between model interpolations and linear expectations, show full-model NL 0 2–3, while submodules exhibit NL 1 0.3 (often 2 0.1) (Dai et al., 15 Apr 2025). This pronounced linearity enables precise, closed-form merging at the module level.
3. Layer-Aware Weighting and Disentanglement
LATA operationalizes the disentanglement of task-specific and instruction-following knowledge by computing layerwise alignment scores between fine-tuned "complex" vectors and pure instruction deltas. For each layer 3 and task 4:
- Instruction vector: 5
- Complex vector: 6
- Task vector: 7
Layerwise alignment is estimated via cosine similarity:
8
LATA then defines a layerwise weight 9 via one of several schemes (e.g., Linear Drop-by-Rank, Log Drop-by-Rank, or Drop-with-Threshold), which up-weights task-specific layers (low similarity) and down-weights instruction-like layers (high similarity). The final, "pure" per-layer update is
0
and layerwise concatenation forms the update applied in merging or forgetting.
4. Closed-Form Merging via Least Squares
Building on the empirical linearity of submodules, LATA merges 1 fine-tuned models 2 using per-module, learned weights:
3
The optimal weights 4 are found by minimizing the feature-space squared error between the output of the merged module and each fine-tuned module over a held-out sample set. Under submodule linearity, the problem reduces to a quadratic minimization, yielding the closed-form solution:
5
where the entries of 6 and 7 depend on the empirical feature deltas 8 computed for each module and task over held-out examples (930 points per task suffice for stability) (Dai et al., 15 Apr 2025). This approach does not require retraining and is computationally efficient.
5. Algorithmic Workflow
The LATA procedure can be instantiated as follows:
- For each task 0, run held-out data 1 through the base model to collect intermediate features at each module 2, denoting 3.
- For each module 4 and task 5, compute the feature delta: 6.
- Construct the 7 and 8 matrices by summing inner products of feature deltas across tasks and examples.
- Solve for the optimal merging weights 9 for each module.
- Form the merged model via layerwise addition of weighted parameter differences.
In the layer alignment variant, the alignment scores and layerwise weights modulate the construction of task vectors before merging. The process is summarized in the pseudocode provided in (Chen et al., 27 Feb 2025).
6. Experimental Validation and Core Findings
LATA has been evaluated across multiple model architectures (e.g., Llama-2-7B, Llama-2-13B, Gemma-2-9B, Llama-3-8B) and task domains (math: GSM8K, coding: HumanEval/CodeAlpaca, translation: zh–en, utility: WikiText-2, alignment: harmful-content removal, multilingual QA/NLI). Key findings include:
- Absolute and relative improvements in multi-task evaluation metrics compared to global TA and other baselines. For Llama-2-13B, merging three models yields a +2 points absolute gain in accuracy (Task Arithmetic: ≈48.24, LATA layer-level: 50.80, attn/MLP-level: 51.05) (Dai et al., 15 Apr 2025).
- Quantitative improvements are seen in perplexity, downstream accuracy, and pass@1 rates across diverse tasks (e.g., Gemma-2-9B: TA perplexity 11.88 vs LATA 10.43; GSM8K acc: 0.824 → 0.843; HumanEval: 0.616 → 0.628) (Chen et al., 27 Feb 2025).
- Ablations show that only a small fraction (~10%) of layers, identified as task-specific, are critical for high task accuracy; discarding the rest preserves or even improves performance.
- Subtractive LATA for task forgetting yields dramatic reduction in harm scores (German Llama-3: GPT-4 harm from 3.60 → 2.57) with negligible loss in QA/NLI utility for other languages or tasks.
7. Limitations and Prospects
LATA assumes models share identical architecture and consistent pre-training/fine-tuning regimes. Model-level LATA (no decomposition) fails due to full-model nonlinearity; overly fine-grained merging at the attention head level is unstable, as functional specialization prohibits linear combination. The method adds computational cost from per-layer analysis and depends on the correct calibration of layer selection, weight functions, and task coefficients. Its application to encoder models, heterogeneous architectures, and less standard pretraining remains untested. Research directions include automatic tuning of layerwise thresholds, continual task addition/removal, and further theoretical grounding in the linear subspace structure of task vs. instruction updates (Chen et al., 27 Feb 2025, Dai et al., 15 Apr 2025).
8. Significance and Future Directions
LATA demonstrates that the effective merging and editing of LLMs depends on understanding and leveraging the distribution of task information across layers and submodules. By isolating and selectively combining only the most informative, task-specific updates, LATA enables robust multi-task model construction, efficient task forgetting, and fine-grained behavior editing without incurring the utility degradation or spurious interference endemic to traditional global merging. Future extensions may include dynamic or automatic discovery of task-relevant modules, cross-family or cross-architecture task transfer, and integration with interpretability techniques to guide module-level editing. These advances point toward a principled, modular approach to scalable model reuse and continual learning in next-generation LLMs (Chen et al., 27 Feb 2025, Dai et al., 15 Apr 2025).