Bayesian Multi-Task Prompt Tuning
- BMTPT is a Bayesian framework for prompt tuning that models full posterior distributions over source prompts, capturing both positive and negative inter-task correlations.
- It employs Stein Variational Gradient Descent to approximate the posterior with an ensemble of prompt vectors, guiding target adaptation through a MAP objective.
- Empirical results on T5-base show BMTPT’s superior transfer efficacy and parameter efficiency compared to traditional prompt-tuning methods in few-shot settings.
Bayesian Multi-Task Prompt Tuning (BMTPT) is a Bayesian framework for prompt-based adaptation of large pre-trained LLMs (PLMs), designed to enable parameter-efficient transfer to new tasks while systematically accounting for both positive and negative correlations among multiple source tasks. It differs fundamentally from prior approaches by modeling the full posterior distribution over source prompts and leveraging this distribution to regularize and initialize target prompt adaptation, thereby improving transfer efficacy and stability while maintaining strict parameter budgets (Lee et al., 2024).
1. Background and Motivations
Prompt tuning (PT) replaces standard fine-tuning of PLMs with the optimization of a compact “soft prompt” — a matrix , where is prompt length and the embedding dimension — which is prepended to each model input, while the underlying model remains frozen. PT matches or nearly matches the performance of full fine-tuning when using very large models, with only $0.01$– of parameters updated (Lee et al., 2024).
Multi-task variants of PT have been introduced to further enhance generalization and transfer. Notable examples include:
- SPoT: Selects the source prompt closest in task embedding space to initialize a target prompt.
- ATTEMPT: Learns an attention-weighted mixture of all source task prompts.
- MPT: Factorizes source task prompts into full-rank shared and low-rank task-specific components.
However, these prior approaches independently train source prompts and aggregate them using heuristics that fail to account for inter-task correlations—ignoring that some source tasks may positively reinforce each other, while others interfere destructively. BMTPT addresses this limitation by modeling the entire space of posterior distributions over source prompts, enabling explicit handling of both positive and negative transfer.
2. Bayesian Formulation for Multi-Task Transfer
Formally, let denote the datasets for source tasks and the flattened soft prompt. Under an (uninformative) prior , and assuming conditional independence among tasks, the posterior over is:
0
where each likelihood factor is
1
with 2 the PLM output probability for 3 given the prompt and input 4. During transfer to a new target task with dataset 5, this posterior serves as the prior in a MAP objective:
6
This formalization allows the transfer procedure to explicitly integrate prior information from correlated source tasks, rather than treating them as independent or aggregating by simple mean or attention mechanisms.
3. Stein Variational Gradient Descent and Posterior Approximation
Since directly sampling from the posterior 7 is intractable, BMTPT employs Stein Variational Gradient Descent (SVGD) to transport an ensemble of 8 “particles” (candidate prompt vectors) toward the posterior, capturing diverse modes and inter-task dependencies. At step 9, the 0th prompt particle is updated:
1
with the SVGD functional
2
using a positive-definite kernel 3 (commonly RBF, bandwidth selected by the median heuristic). The first term pulls particles toward high posterior density; the second induces repulsion for sample diversity and combats mode collapse. In practice, gradients reduce to those of summed cross-entropy losses. BMTPT also adopts “damped SVGD”, scaling down self-gradients to further discourage collapse.
SVGD produces a population 4 of prompt samples approximating the source prompt posterior.
4. Prompt Aggregation and Target Task Adaptation
Upon SVGD convergence, the 5 source prompt particles 6 are treated as approximate samples from 7. For adaptation to a target task, initialization and regularization are constructed as follows:
- The aggregate mean 8 is computed:
9
- Target prompt $0.01$0 is initialized to $0.01$1.
- The MAP objective for $0.01$2 uses a quadratic penalty, assuming a Gaussian mixture prior:
$0.01$3
This ensures the target prompt remains close to the learned source prompt distribution, enforcing an inductive bias towards previously transferable knowledge while adapting efficiently to the target.
5. Detailed Algorithmic Workflow
The BMTPT procedure is as follows (notation as above):
- Inputs: Source datasets $0.01$4; target dataset $0.01$5; particles $0.01$6.
- Particle Initialization: $0.01$7 for $0.01$8.
- SVGD Posterior Learning: For $0.01$9:
- Sample minibatches from each 0.
- Compute per-particle gradients 1.
- Compute SVGD transformation 2 and update particles.
- Aggregate Posterior Mean: 3
- Target Prompt Adaptation:
- Initialize 4.
- Minimize the quadratic-regularized MAP objective by gradient descent.
- Output: Fine-tuned prompt for target 5.
6. Empirical Evaluation and Quantitative Results
Experiments use T5-base as the backbone (220M parameters), with prompt length 6, embedding dimension 7, 8 particles, and 100K SVGD steps. Source tasks encompass MNLI, QQP, QNLI, SST-2, SQuAD, and ReCoRD, while target tasks include a broad range of GLUE, SuperGLUE, MRQA, and other benchmarks (Lee et al., 2024).
Key quantitative outcomes for T5-base:
| Task Suite | BMTPT | MPT | ATTEMPT | Vanilla PT | Full FT |
|---|---|---|---|---|---|
| GLUE (avg score) | 88.7 | 85.6 | 83.4 | 72.2 | 84.9 |
| SuperGLUE | 74.6 | 74.1 | — | 57.8 | 73.9 |
- BMTPT achieves the highest reported scores using only 9 of the backbone’s parameters for the prompt, outperforming adapter- and hypernetwork-based multi-task transfer baselines.
- In few-shot settings (4/16/32 examples), BMTPT consistently outperforms MPT and vanilla prompt tuning, e.g., averaging +10 points over MPT on 16-shot GLUE.
- Ablations show that removing the Bayesian regularizer reduces GLUE scores by ≈1.0 point, that 0 particles yield no benefit over 1, and that performance is robust to subsampling source tasks during SVGD.
7. Parameter Efficiency, Implementation Aspects, and Discussion
BMTPT’s design obviates any need for auxiliary models—unlike ATTEMPT (which uses an attention network) or MPT (which requires a teacher model for distillation). Only 2 prompt vectors 3 are retained during the source stage; for 4, 5, 6, this corresponds to 7K parameters vs. 8M for T5-base. During target adaptation, requirements are a full-rank prompt matrix 9 and a low-rank factorization 0, totaling 1K parameters—still just 2 of the model.
Hyperparameters include 3 for the SVGD ensemble, learning rates of order 4 for SVGD and 5 for target fine-tuning, and RBF kernel bandwidth set by the median heuristic. These settings balance diversity and convergence of prompt representations.
By explicitly encoding the distribution of transferable knowledge across source tasks, BMTPT enables parameter-minimal, state-of-the-art adaptation for varied NLP targets, realizes both positive and negative transfer, and provides a unified Bayesian perspective on prompt transfer in large-scale multi-task scenarios (Lee et al., 2024).