Papers
Topics
Authors
Recent
Search
2000 character limit reached

AutoMixAlign (AMA): Multi-Task LLM Alignment

Updated 22 May 2026
  • The paper presents AMA, which minimizes worst-case per-task excess loss by adaptively mixing data to align LLMs on objectives like helpfulness and coding ability.
  • The methodology employs two variants—adaptive reweighting (AMA-R) and adaptive resampling (AMA-S)—using online learning techniques to achieve O(1/√T) convergence.
  • Implications include improved performance over standard methods, reducing underperformance in critical tasks and offering robust, scalable solutions for multi-objective LLM alignment.

AutoMixAlign (AMA) is a theoretically-grounded algorithmic framework for multi-task preference optimization in aligning LLMs by adaptive data mixing. AMA introduces task-adaptive training workflows for optimizing LLM alignment across multiple objectives, such as helpfulness, harmlessness, and coding ability, by systematically minimizing worst-case excess loss relative to task-specific specialist models. The framework consists of two main variants: adaptive reweighting (AMA-R) and adaptive resampling (AMA-S), both capable of minimizing the maximum per-task clipped excess loss using online learning techniques with provable convergence rates (Corrado et al., 31 May 2025).

1. Multi-Task Preference Alignment Problem Formulation

Given kk preference-optimization tasks with corresponding datasets D1,...,DkD_1, ..., D_k where each Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}, the goal is to train a generalist LLM πθ\pi_\theta using Direct Preference Optimization (DPO) over a reference model πref\pi_\mathrm{ref}. The DPO per-example loss is defined as: (θ;z)=logσ(β(πθ(y+x)πref(y+x)πθ(yx)πref(yx)))\ell(\theta; z) = -\log \sigma\left(\beta \cdot \left(\frac{\pi_\theta(y_+|x)}{\pi_\mathrm{ref}(y_+|x)} - \frac{\pi_\theta(y_-|x)}{\pi_\mathrm{ref}(y_-|x)}\right)\right) Specialist models θi\theta^*_i are obtained by training independently on each DiD_i. For a generalist model θ\theta, the average loss on DiD_i is D1,...,DkD_1, ..., D_k0, and the reference specialist loss is D1,...,DkD_1, ..., D_k1. The clipped excess loss is: D1,...,DkD_1, ..., D_k2 AMA seeks D1,...,DkD_1, ..., D_k3 that minimizes the worst-case excess loss across all tasks: D1,...,DkD_1, ..., D_k4 Introducing D1,...,DkD_1, ..., D_k5 (probability simplex), this is equivalently

D1,...,DkD_1, ..., D_k6

2. AMA-R: Adaptive Reweighting via Minimax Optimization

AMA-R casts the objective as a two-player minimax game. The D1,...,DkD_1, ..., D_k7-player (task weights) performs exponentiated-gradient (EG) ascent to emphasize tasks where current generalist losses most exceed the specialist baseline, while the D1,...,DkD_1, ..., D_k8-player updates parameters to minimize the weighted excess loss: D1,...,DkD_1, ..., D_k9 Algorithmic steps per iteration:

  • Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}0-player: EG update Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}1
  • Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}2-player: Stochastic gradient descent on Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}3

In practice, Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}4 is smoothed: Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}5, where Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}6 is the EG internal weight and Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}7 is a smoothing parameter. Under convexity, converges at rate Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}8 as implied by Sagawa et al. (2019) (Corrado et al., 31 May 2025).

3. AMA-S: Adaptive Resampling via Bandit Algorithms

AMA-S adaptively adjusts the sampling distribution over tasks using the bandit algorithm EXP3, rather than reweighting objective components. At each iteration:

  • A minibatch is formed by first sampling task counts Di={z=(x,y+,y)}D_i = \{z = (x, y_+, y_-)\}9 Multinomialπθ\pi_\theta0 where πθ\pi_\theta1 is the smoothed distribution.
  • The loss gradient is computed on the batch using the clipped excess loss.
  • The internal distribution πθ\pi_\theta2 is updated via πθ\pi_\theta3 where πθ\pi_\theta4 is the empirical average excess loss for task πθ\pi_\theta5 in the minibatch.

This process is a bandit-style solver for the minimax: πθ\pi_\theta6. With convex loss and boundedness assumptions, πθ\pi_\theta7 convergence of the worst-case per-task excess loss is guaranteed: πθ\pi_\theta8 where πθ\pi_\theta9 bounds πref\pi_\mathrm{ref}0-player regret.

4. Algorithmic Procedures

Summary of the two core AMA variants:

Variant Adaptation Mechanism Update Rule for Task Distribution
AMA-R Objective reweighting EG: πref\pi_\mathrm{ref}1
AMA-S Data resampling EXP3: πref\pi_\mathrm{ref}2

Both algorithms return the average model parameters πref\pi_\mathrm{ref}3. Smoothing and hyperparameters, such as πref\pi_\mathrm{ref}4, learning rates πref\pi_\mathrm{ref}5 (for AMA-R) and πref\pi_\mathrm{ref}6 for πref\pi_\mathrm{ref}7-updates, are empirically recommended.

Practical guidelines include:

  • Precomputing specialist losses πref\pi_\mathrm{ref}8 for all πref\pi_\mathrm{ref}9
  • Using clipped excess loss to prevent overfitting to easier tasks
  • Regular checkpointing, model selection using confidence interval overlap in multitask accuracy

5. Experimental Setup and Empirical Results

Experiments use Zephyr-7B SFT Full as the base LLM, with DPO and AMA for generalist tuning over 1–3 epochs (batch size 256, gradient accumulation 4, per-device batch 8, AdamW, (θ;z)=logσ(β(πθ(y+x)πref(y+x)πθ(yx)πref(yx)))\ell(\theta; z) = -\log \sigma\left(\beta \cdot \left(\frac{\pi_\theta(y_+|x)}{\pi_\mathrm{ref}(y_+|x)} - \frac{\pi_\theta(y_-|x)}{\pi_\mathrm{ref}(y_-|x)}\right)\right)0, (θ;z)=logσ(β(πθ(y+x)πref(y+x)πθ(yx)πref(yx)))\ell(\theta; z) = -\log \sigma\left(\beta \cdot \left(\frac{\pi_\theta(y_+|x)}{\pi_\mathrm{ref}(y_+|x)} - \frac{\pi_\theta(y_-|x)}{\pi_\mathrm{ref}(y_-|x)}\right)\right)1, warmup 10%). Specialist models are trained for each task with identical hyperparameters.

Task domains:

  • Helpfulness: Chatbot Arena 2024 (LLM-as-judge), UltraFeedback
  • Coding: CodeUltraFeedback, MBPP, HumanEval
  • Harmlessness: SafeRLHF, Toxigen

Baselines:

  • Standard: Uniform sampling over all data, minimize total loss
  • Standard-Uniform: Uniform task-level sampling
  • Model Averaging: Uniform parameter-wise average of all specialists

Key empirical findings:

Setup Standard Standard-Uniform Model Averaging AMA-R AMA-S
Helpfulness (Arena) + Coding (CodeUF) 35.14% 35.47% 39.43% 41.75% 40.19%
Helpfulness (UltraFB) + Harmlessness 49.37% 53.21% 50.51% 53.50% 53.81%
Help + Code + Harmlessness 50.95% 44.96% 51.50% 54.38% 53.18%

AMA improves the average metric by up to 9.42% over standard training. Notably, Figure X from the source demonstrates that AMA-S rapidly rebalances sampling to the task with greater excess loss, then stabilizes task allocation as losses converge. This suggests that AMA-S effectively tracks and mitigates task-level underperformance during multitask alignment (Corrado et al., 31 May 2025).

6. Theoretical Properties and Extensions

AMA variants offer (θ;z)=logσ(β(πθ(y+x)πref(y+x)πθ(yx)πref(yx)))\ell(\theta; z) = -\log \sigma\left(\beta \cdot \left(\frac{\pi_\theta(y_+|x)}{\pi_\mathrm{ref}(y_+|x)} - \frac{\pi_\theta(y_-|x)}{\pi_\mathrm{ref}(y_-|x)}\right)\right)2 convergence under convexity and bounded loss, and are robust to the number of tasks (θ;z)=logσ(β(πθ(y+x)πref(y+x)πθ(yx)πref(yx)))\ell(\theta; z) = -\log \sigma\left(\beta \cdot \left(\frac{\pi_\theta(y_+|x)}{\pi_\mathrm{ref}(y_+|x)} - \frac{\pi_\theta(y_-|x)}{\pi_\mathrm{ref}(y_-|x)}\right)\right)3. Specialist training is parallelizable, and total compute is approximately double the standard DPO regime, independent of (θ;z)=logσ(β(πθ(y+x)πref(y+x)πθ(yx)πref(yx)))\ell(\theta; z) = -\log \sigma\left(\beta \cdot \left(\frac{\pi_\theta(y_+|x)}{\pi_\mathrm{ref}(y_+|x)} - \frac{\pi_\theta(y_-|x)}{\pi_\mathrm{ref}(y_-|x)}\right)\right)4.

AMA methodologies extend directly to any DPO-style training, and are potentially applicable to RLHF, PPO, or supervised fine-tuning by substituting appropriate loss functions. Use of unclipped excess loss is discouraged as it leads to overfitting trivial tasks. Regular checkpointing, smoothing, and careful model selection using confidence intervals are recommended for practical deployment.

7. Broader Implications and Future Work

AMA addresses the critical challenge of data mixture selection in LLM alignment by algorithmically optimizing mixture weights or samplings to directly target minimax excess loss, sidestepping reliance on large-scale ablation or subjective heuristics. A plausible implication is that this approach could generalize to related domains in safe or robust multi-objective optimization. Future work may further explore extensions to non-convex LLM landscapes, broader RLHF settings, or adaptive mixture optimization in large-scale distributed fine-tuning (Corrado et al., 31 May 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to AutoMixAlign (AMA).