SLOT: Sample-specific Language Model Optimization at Test-time (2505.12392v2)
Abstract: We propose SLOT (Sample-specific LLM Optimization at Test-time), a novel and parameter-efficient test-time inference approach that enhances a LLM's ability to more accurately respond to individual prompts. Existing LLMs often struggle with complex instructions, leading to poor performances on those not well represented among general samples. To address this, SLOT conducts few optimization steps at test-time to update a light-weight sample-specific parameter vector. It is added to the final hidden layer before the output head, and enables efficient adaptation by caching the last layer features during per-sample optimization. By minimizing the cross-entropy loss on the input prompt only, SLOT helps the model better aligned with and follow each given instruction. In experiments, we demonstrate that our method outperforms the compared models across multiple benchmarks and LLMs. For example, Qwen2.5-7B with SLOT achieves an accuracy gain of 8.6% on GSM8K from 57.54% to 66.19%, while DeepSeek-R1-Distill-Llama-70B with SLOT achieves a SOTA accuracy of 68.69% on GPQA among 70B-level models. Our code is available at https://github.com/maple-research-lab/SLOT.
Summary
- The paper introduces SLOT, a method that optimizes a sample-specific parameter vector at test-time to reduce cross-entropy loss on input prompts.
- It employs a few iterations with an AdamW optimizer, efficiently reusing cached hidden features to adapt LLM responses for complex instructions.
- SLOT boosts reasoning performance on benchmarks like GSM8K and AIME24, achieving improvements up to 10% with minimal computational overhead.
This paper introduces SLOT (Sample-specific LLM Optimization at Test-time), a method to enhance LLM performance on individual prompts, particularly those with complex or unfamiliar instructions. SLOT operates by performing a few optimization steps at test-time to update a lightweight, sample-specific parameter vector, denoted as δ. This vector is added to the hidden features of the LLM's final layer, just before the output classification head. The optimization minimizes the cross-entropy loss on the input prompt itself, effectively making the model better "0" the specific instruction before generating a response.
How SLOT Works
SLOT operates in two main stages during inference for each prompt:
- Prompt Stage (Optimization):
- A sample-specific parameter vector δ∈R1×d (where d is the hidden dimension of the LLM) is initialized (typically with zeros).
- The LLM processes the input prompt x=(x1,…,xn) to obtain the final hidden features H∈Rn×d. These features H are cached.
- For a small number of iterations T (e.g., T=3):
- The modified hidden features H′ are computed: H′=H+δ. δ is broadcast across the sequence length.
- Logits are calculated: logits=WLMH′, where WLM is the LM head's weight matrix.
- The cross-entropy loss L(δ)=−i=1∑n−1logp(xi+1∣x1:i,δ) is computed on the input prompt (i.e., predicting the next token in the prompt given the preceding tokens).
- The gradient ∇δL(δ) is computed.
- δ is updated using an optimizer (e.g., AdamW): δ(t+1)=OptimizerStep(δ(t),∇δL(δ(t))).
- The final optimized vector is denoted δopt.
- Efficiency Note: Because δ only modifies the final hidden features, the computationally expensive forward pass through the main body of the LLM to get H is done only once. Each optimization step for δ only involves operations on H (which is cached) and the LM head, making it very fast.
- Generation Stage (Response Generation):
- The optimized δopt is reused.
- The LLM generates the response token by token, autoregressively.
- For each new token to be generated:
- The LLM computes the hidden features for the current last token, Hlast.
- The hidden features are modified: Hlast′=Hlast+δopt.
- The logits for the next token are computed: Lnext=WLMHlast′.
- The next token is sampled (e.g., greedy decoding or softmax sampling).
- This process continues until an end-of-sequence token is generated or the maximum length is reached.
The core idea is that by optimizing δ to make the input prompt more "1" under the adapted model, SLOT helps the LLM align better with the specific nuances and requirements of that individual prompt.
Implementation Details and Considerations
Key Parameters and Hyperparameters:
- δ Vector: A small vector of size 1×d. Its dimensionality matches the LLM's hidden size.
- Initialization: δ is initialized to zeros (δ(0)=0). This ensures the base LLM performance is the starting point.
- Optimization Steps (T): Typically a small number, e.g., T=3 to T=5. The paper shows benefits even with few steps.
- Optimizer: AdamW is used.
- Learning Rate (η): e.g., $0.01$ (experiments also explore $0.05, 0.1, 0.2$).
- Weight Decay: 1×10−8.
- Epsilon: 1×10−5.
- Gradient Clipping: Mentioned as potentially applicable but often unnecessary for few-step optimization.
Computational Overhead:
- Prompt Stage: The main overhead comes from the T optimization steps. However, since the features H from the LLM body are cached, each step only involves:
- Adding δ to H.
- A forward pass through the LM head (WLM).
- Loss computation.
- A backward pass through the LM head to get gradients for δ.
- Optimizer update for δ. This is significantly cheaper than full backpropagation through the entire LLM.
Generation Stage: The overhead is minimal, involving only an element-wise addition of δopt to the last hidden state for each token generated (O(d) operations).
- Overall Inference Time: The paper reports a modest increase in inference time. For example, with Qwen-2.5-7B on GSM8K, 5 SLOT iterations increased inference time by only 7.9% compared to baseline. Ablation studies (Table 4) show prompt processing throughput (SI) decreases by about 12-25% with T=1 depending on LR, but doesn't significantly worsen with more iterations due to caching. Generation throughput (SO) is only slightly reduced.
Pseudocode (Algorithm 1 in the paper):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
Algorithm: Sample-specific LLM Optimization at Test-time (SLOT) Input: Pre-trained LLM M, input tokens x, optimization steps T, learning rate eta, optimizer params lambda Output: Generated sequence y // Phase 1: Optimize delta on the input prompt delta = initialize_zeros(1, d) optimizer_state = initialize_optimizer_state() H_cached = M_pre_LM(x) // Get features before output head, cache them for t = 0 to T-1: H_prime = H_cached + delta // Broadcast delta logits = W_LM * H_prime loss = CrossEntropyLoss(logits[..., :-1, :], x[2:n]) // Predict next token in prompt gradients_delta = compute_gradients(loss, delta) delta = optimizer_step(delta, gradients_delta, eta, lambda) // e.g., AdamW delta_opt = delta // Phase 2: Generate using the optimized delta y = () current_sequence = x while not (eos_generated or max_length_reached): H_last = M_pre_LM(current_sequence)[-1, :] // Get last hidden state H_prime_last = H_last + delta_opt // Reuse optimized delta logits_next = W_LM * H_prime_last next_token = sample_from_softmax(logits_next) append next_token to y append next_token to current_sequence Return y |
Code Snippet Explanation (from Appendix):
The provided code snippet illustrates the core logic within the model's forward pass:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
prompt_only = os.environ.get("prompt_only", "False") == "True" if prompt_only: # --- Prompt Stage: Optimization --- with torch.enable_grad(): # Ensure gradients are computed for delta # Initialize delta (could be a class member or passed around) # If self.delta doesn't exist or needs re-init per sample: # self.delta = nn.Parameter(torch.zeros_like(hidden_states[:, :1, :])) # Assuming batch_size=1 # Or, as in snippet, a new Parameter for each optimization phase delta_param = nn.Parameter(0.0 * torch.randn([1, 1, hidden_states.shape[-1]]).to(hidden_states)) optimizer = torch.optim.AdamW([delta_param], lr=0.01, weight_decay=1e-8, eps=1e-5) # Optimization loop (T=3 steps in snippet) for _ in range(3): # T = 3 optimizer.zero_grad() # H' = H + delta transformed_hidden = hidden_states + delta_param # hidden_states is H_cached logits = self.lm_head(transformed_hidden) # Calculate LM loss on the prompt loss_fct = nn.CrossEntropyLoss() # Align logits and labels for next token prediction shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[:, 1:].contiguous() # input_ids is x loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) loss.backward() # Computes gradients w.r.t. delta_param optimizer.step() # Updates delta_param # Store the optimized delta for reuse in generation stage # This assumes self.delta is a persistent attribute of the model instance # or managed appropriately in the inference script self.delta = delta_param.detach() # Detach to prevent further gradient tracking if not needed # Apply optimized delta immediately for the first token generation (if generation starts right after) # Or this hidden_states modification might happen at the start of the actual generation phase hidden_states = hidden_states + self.delta # Set flag to indicate optimization is done for this sample os.environ["prompt_only"] = "False" # So next forward pass (for generation) uses else branch else: # --- Generation Stage: Reuse delta --- # Ensure self.delta was set from the prompt_only phase if hasattr(self.delta, 'dtype'): # Check if self.delta is initialized hidden_states = hidden_states + self.delta # Else: handle cases where delta might not be available (e.g. first call is not prompt_only) |
Key takeaways from the code:
- An environment variable
prompt_only
is used to switch between the optimization phase and the generation phase. This is a practical way to control behavior within theforward
method during a multi-step inference process. hidden_states
(representing H) are assumed to be the output of the LLM body before the LM head. These are cached implicitly by being passed into the optimization loop.delta
is initialized as ann.Parameter
to allowtorch.optim
to update it.- The loss is standard cross-entropy for LLMing, calculated on the input prompt.
- After optimization, the optimized
delta
is stored (e.g., asself.delta
) and then added tohidden_states
for subsequent token generation.
Logit Modulation Vector (LMV)
SLOT's effect can be interpreted as a direct modulation of the output logits. The paper defines the Logit Modulation Vector (LMV) as:
LMV≜WLMδopt∈R∣V∣
This vector represents an additive shift to the logits for every token in the vocabulary. Observations on GSM8K (Figure 3) show:
- Increased probability: Tokens related to reasoning (e.g., "2", "reasoning") are enhanced.
- Decreased probability: Numerical tokens (e.g., "0", "1", "2") and common function words (e.g., "should", "will") are suppressed. Interestingly, the end-of-text token
eos
is also suppressed, potentially encouraging longer, more detailed reasoning chains.
Experimental Results and Practical Applications
SLOT demonstrates significant improvements across various models and benchmarks:
- Qwen-7B:
- C-Eval: +8.55 on 'Hard' subset, +4.19 on 'STEM', +1.05 overall average.
- GSM8K: +3.0 (51.2% to 54.2%).
- HumanEval: +1.8 (29.9% to 31.7%).
- Qwen2.5-7B:
- GSM8K: +8.6% (57.54% to 66.19%) (from abstract, Figure 1 shows improvement from ~57% to ~66% with 5 iterations).
- DeepSeek-R1-Distill-Llama-70B:
- GPQA: Achieved 68.69% (a +3.03% improvement), SOTA for 70B open-source models at the time.
- AIME24: +10.00% (63.33% to 73.33%).
- General Trend: SLOT consistently boosts performance on reasoning tasks (AIME24, Math500, GPQA, GSM8K) for both base models and specialized reasoning models. Improvements are often substantial (e.g., +10% on AIME24 for Qwen2.5-32B and DeepSeek-R1-Distill-Llama-70B).
Potential Applications:
- Improving Reasoning in LLMs: For tasks requiring complex multi-step reasoning (math problems, science Q&A), SLOT can make models more robust and accurate by helping them better adhere to the problem's specific constraints and reasoning structure.
- Enhancing Instruction Following: When prompts have strict formatting requirements or complex instructions not well-represented in general training data, SLOT can adapt the model at test-time to follow these instructions more faithfully. Figure 1 shows improved format accuracy alongside answer accuracy.
- Resource-Constrained Deployment: Since SLOT adds minimal computational overhead, it can be a practical way to boost the performance of smaller or existing LLMs without needing to retrain or use a much larger model.
- Specialized Domains: For applications in specialized domains where prompts might contain unique jargon or conventions, SLOT could help the model adapt quickly to these sample-specific elements.
Ablation Study Insights:
- SLOT is relatively insensitive to hyperparameters like the number of optimization iterations (T) and learning rate (η), with most configurations outperforming the baseline.
- Optimal performance in one test (DeepSeek-R1-Distill-Qwen-1.5B on AIME-24) was found with (T=4,η=0.05) and (T=5,η=0.05).
- The number of optimization steps (T) does not significantly increase prompt processing time beyond T=1 due to feature caching.
Conclusion
SLOT is a parameter-efficient, test-time adaptation technique that optimizes a small, sample-specific vector δ by minimizing the cross-entropy loss on the input prompt itself. This allows LLMs to better align with and follow individual instructions, leading to significant performance gains on complex reasoning benchmarks with minimal computational overhead. Its efficiency stems from caching last-layer features during the per-sample optimization of δ. The method's ability to enhance reasoning-related token probabilities (via LMV) suggests it encourages deeper processing. SLOT offers a practical approach for improving LLM responses on-the-fly, particularly for challenging or unfamiliar prompts.
Related Papers
GitHub
- GitHub - maple-research-lab/SLOT (3 stars)