AutoMixAlign (AMA): Multi-Task LLM Alignment
- The paper presents AMA, which minimizes worst-case per-task excess loss by adaptively mixing data to align LLMs on objectives like helpfulness and coding ability.
- The methodology employs two variants—adaptive reweighting (AMA-R) and adaptive resampling (AMA-S)—using online learning techniques to achieve O(1/√T) convergence.
- Implications include improved performance over standard methods, reducing underperformance in critical tasks and offering robust, scalable solutions for multi-objective LLM alignment.
AutoMixAlign (AMA) is a theoretically-grounded algorithmic framework for multi-task preference optimization in aligning LLMs by adaptive data mixing. AMA introduces task-adaptive training workflows for optimizing LLM alignment across multiple objectives, such as helpfulness, harmlessness, and coding ability, by systematically minimizing worst-case excess loss relative to task-specific specialist models. The framework consists of two main variants: adaptive reweighting (AMA-R) and adaptive resampling (AMA-S), both capable of minimizing the maximum per-task clipped excess loss using online learning techniques with provable convergence rates (Corrado et al., 31 May 2025).
1. Multi-Task Preference Alignment Problem Formulation
Given preference-optimization tasks with corresponding datasets where each , the goal is to train a generalist LLM using Direct Preference Optimization (DPO) over a reference model . The DPO per-example loss is defined as: Specialist models are obtained by training independently on each . For a generalist model , the average loss on is 0, and the reference specialist loss is 1. The clipped excess loss is: 2 AMA seeks 3 that minimizes the worst-case excess loss across all tasks: 4 Introducing 5 (probability simplex), this is equivalently
6
2. AMA-R: Adaptive Reweighting via Minimax Optimization
AMA-R casts the objective as a two-player minimax game. The 7-player (task weights) performs exponentiated-gradient (EG) ascent to emphasize tasks where current generalist losses most exceed the specialist baseline, while the 8-player updates parameters to minimize the weighted excess loss: 9 Algorithmic steps per iteration:
- 0-player: EG update 1
- 2-player: Stochastic gradient descent on 3
In practice, 4 is smoothed: 5, where 6 is the EG internal weight and 7 is a smoothing parameter. Under convexity, converges at rate 8 as implied by Sagawa et al. (2019) (Corrado et al., 31 May 2025).
3. AMA-S: Adaptive Resampling via Bandit Algorithms
AMA-S adaptively adjusts the sampling distribution over tasks using the bandit algorithm EXP3, rather than reweighting objective components. At each iteration:
- A minibatch is formed by first sampling task counts 9 Multinomial0 where 1 is the smoothed distribution.
- The loss gradient is computed on the batch using the clipped excess loss.
- The internal distribution 2 is updated via 3 where 4 is the empirical average excess loss for task 5 in the minibatch.
This process is a bandit-style solver for the minimax: 6. With convex loss and boundedness assumptions, 7 convergence of the worst-case per-task excess loss is guaranteed: 8 where 9 bounds 0-player regret.
4. Algorithmic Procedures
Summary of the two core AMA variants:
| Variant | Adaptation Mechanism | Update Rule for Task Distribution |
|---|---|---|
| AMA-R | Objective reweighting | EG: 1 |
| AMA-S | Data resampling | EXP3: 2 |
Both algorithms return the average model parameters 3. Smoothing and hyperparameters, such as 4, learning rates 5 (for AMA-R) and 6 for 7-updates, are empirically recommended.
Practical guidelines include:
- Precomputing specialist losses 8 for all 9
- Using clipped excess loss to prevent overfitting to easier tasks
- Regular checkpointing, model selection using confidence interval overlap in multitask accuracy
5. Experimental Setup and Empirical Results
Experiments use Zephyr-7B SFT Full as the base LLM, with DPO and AMA for generalist tuning over 1–3 epochs (batch size 256, gradient accumulation 4, per-device batch 8, AdamW, 0, 1, warmup 10%). Specialist models are trained for each task with identical hyperparameters.
Task domains:
- Helpfulness: Chatbot Arena 2024 (LLM-as-judge), UltraFeedback
- Coding: CodeUltraFeedback, MBPP, HumanEval
- Harmlessness: SafeRLHF, Toxigen
Baselines:
- Standard: Uniform sampling over all data, minimize total loss
- Standard-Uniform: Uniform task-level sampling
- Model Averaging: Uniform parameter-wise average of all specialists
Key empirical findings:
| Setup | Standard | Standard-Uniform | Model Averaging | AMA-R | AMA-S |
|---|---|---|---|---|---|
| Helpfulness (Arena) + Coding (CodeUF) | 35.14% | 35.47% | 39.43% | 41.75% | 40.19% |
| Helpfulness (UltraFB) + Harmlessness | 49.37% | 53.21% | 50.51% | 53.50% | 53.81% |
| Help + Code + Harmlessness | 50.95% | 44.96% | 51.50% | 54.38% | 53.18% |
AMA improves the average metric by up to 9.42% over standard training. Notably, Figure X from the source demonstrates that AMA-S rapidly rebalances sampling to the task with greater excess loss, then stabilizes task allocation as losses converge. This suggests that AMA-S effectively tracks and mitigates task-level underperformance during multitask alignment (Corrado et al., 31 May 2025).
6. Theoretical Properties and Extensions
AMA variants offer 2 convergence under convexity and bounded loss, and are robust to the number of tasks 3. Specialist training is parallelizable, and total compute is approximately double the standard DPO regime, independent of 4.
AMA methodologies extend directly to any DPO-style training, and are potentially applicable to RLHF, PPO, or supervised fine-tuning by substituting appropriate loss functions. Use of unclipped excess loss is discouraged as it leads to overfitting trivial tasks. Regular checkpointing, smoothing, and careful model selection using confidence intervals are recommended for practical deployment.
7. Broader Implications and Future Work
AMA addresses the critical challenge of data mixture selection in LLM alignment by algorithmically optimizing mixture weights or samplings to directly target minimax excess loss, sidestepping reliance on large-scale ablation or subjective heuristics. A plausible implication is that this approach could generalize to related domains in safe or robust multi-objective optimization. Future work may further explore extensions to non-convex LLM landscapes, broader RLHF settings, or adaptive mixture optimization in large-scale distributed fine-tuning (Corrado et al., 31 May 2025).