Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
153 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

SLOT: Sample-specific Language Model Optimization at Test-time (2505.12392v2)

Published 18 May 2025 in cs.CL, cs.AI, and cs.LG

Abstract: We propose SLOT (Sample-specific LLM Optimization at Test-time), a novel and parameter-efficient test-time inference approach that enhances a LLM's ability to more accurately respond to individual prompts. Existing LLMs often struggle with complex instructions, leading to poor performances on those not well represented among general samples. To address this, SLOT conducts few optimization steps at test-time to update a light-weight sample-specific parameter vector. It is added to the final hidden layer before the output head, and enables efficient adaptation by caching the last layer features during per-sample optimization. By minimizing the cross-entropy loss on the input prompt only, SLOT helps the model better aligned with and follow each given instruction. In experiments, we demonstrate that our method outperforms the compared models across multiple benchmarks and LLMs. For example, Qwen2.5-7B with SLOT achieves an accuracy gain of 8.6% on GSM8K from 57.54% to 66.19%, while DeepSeek-R1-Distill-Llama-70B with SLOT achieves a SOTA accuracy of 68.69% on GPQA among 70B-level models. Our code is available at https://github.com/maple-research-lab/SLOT.

Summary

  • The paper introduces SLOT, a method that optimizes a sample-specific parameter vector at test-time to reduce cross-entropy loss on input prompts.
  • It employs a few iterations with an AdamW optimizer, efficiently reusing cached hidden features to adapt LLM responses for complex instructions.
  • SLOT boosts reasoning performance on benchmarks like GSM8K and AIME24, achieving improvements up to 10% with minimal computational overhead.

This paper introduces SLOT (Sample-specific LLM Optimization at Test-time), a method to enhance LLM performance on individual prompts, particularly those with complex or unfamiliar instructions. SLOT operates by performing a few optimization steps at test-time to update a lightweight, sample-specific parameter vector, denoted as δ\delta. This vector is added to the hidden features of the LLM's final layer, just before the output classification head. The optimization minimizes the cross-entropy loss on the input prompt itself, effectively making the model better "0" the specific instruction before generating a response.

How SLOT Works

SLOT operates in two main stages during inference for each prompt:

  1. Prompt Stage (Optimization):
    • A sample-specific parameter vector δR1×d\delta \in \mathbb{R}^{1 \times d} (where dd is the hidden dimension of the LLM) is initialized (typically with zeros).
    • The LLM processes the input prompt x=(x1,,xn)x = (x_1, \dots, x_n) to obtain the final hidden features HRn×dH \in \mathbb{R}^{n \times d}. These features HH are cached.
    • For a small number of iterations TT (e.g., T=3T=3):
      • The modified hidden features HH' are computed: H=H+δH' = H + \delta. δ\delta is broadcast across the sequence length.
      • Logits are calculated: logits=WLMH\text{logits} = W_{\text{LM}} H', where WLMW_{\text{LM}} is the LM head's weight matrix.
      • The cross-entropy loss L(δ)=i=1n1logp(xi+1x1:i,δ)\mathcal{L}(\delta) = -\sum_{i=1}^{n-1} \log p(x_{i+1}|x_{1:i}, \delta) is computed on the input prompt (i.e., predicting the next token in the prompt given the preceding tokens).
      • The gradient δL(δ)\nabla_{\delta} \mathcal{L}(\delta) is computed.
      • δ\delta is updated using an optimizer (e.g., AdamW): δ(t+1)=OptimizerStep(δ(t),δL(δ(t)))\delta^{(t+1)} = \text{OptimizerStep}(\delta^{(t)}, \nabla_{\delta} \mathcal{L}(\delta^{(t)})).
    • The final optimized vector is denoted δopt\delta_{\text{opt}}.
    • Efficiency Note: Because δ\delta only modifies the final hidden features, the computationally expensive forward pass through the main body of the LLM to get HH is done only once. Each optimization step for δ\delta only involves operations on HH (which is cached) and the LM head, making it very fast.
  2. Generation Stage (Response Generation):
    • The optimized δopt\delta_{\text{opt}} is reused.
    • The LLM generates the response token by token, autoregressively.
    • For each new token to be generated:
      • The LLM computes the hidden features for the current last token, HlastH_{\text{last}}.
      • The hidden features are modified: Hlast=Hlast+δoptH'_{\text{last}} = H_{\text{last}} + \delta_{\text{opt}}.
      • The logits for the next token are computed: Lnext=WLMHlastL_{\text{next}} = W_{\text{LM}} H'_{\text{last}}.
      • The next token is sampled (e.g., greedy decoding or softmax sampling).
    • This process continues until an end-of-sequence token is generated or the maximum length is reached.

The core idea is that by optimizing δ\delta to make the input prompt more "1" under the adapted model, SLOT helps the LLM align better with the specific nuances and requirements of that individual prompt.

Implementation Details and Considerations

Key Parameters and Hyperparameters:

  • δ\delta Vector: A small vector of size 1×d1 \times d. Its dimensionality matches the LLM's hidden size.
  • Initialization: δ\delta is initialized to zeros (δ(0)=0\delta^{(0)} = \mathbf{0}). This ensures the base LLM performance is the starting point.
  • Optimization Steps (TT): Typically a small number, e.g., T=3T=3 to T=5T=5. The paper shows benefits even with few steps.
  • Optimizer: AdamW is used.
    • Learning Rate (η\eta): e.g., $0.01$ (experiments also explore $0.05, 0.1, 0.2$).
    • Weight Decay: 1×1081 \times 10^{-8}.
    • Epsilon: 1×1051 \times 10^{-5}.
  • Gradient Clipping: Mentioned as potentially applicable but often unnecessary for few-step optimization.

Computational Overhead:

  • Prompt Stage: The main overhead comes from the TT optimization steps. However, since the features HH from the LLM body are cached, each step only involves:

    1. Adding δ\delta to HH.
    2. A forward pass through the LM head (WLMW_{\text{LM}}).
    3. Loss computation.
    4. A backward pass through the LM head to get gradients for δ\delta.
    5. Optimizer update for δ\delta. This is significantly cheaper than full backpropagation through the entire LLM.
  • Generation Stage: The overhead is minimal, involving only an element-wise addition of δopt\delta_{\text{opt}} to the last hidden state for each token generated (O(d)O(d) operations).

  • Overall Inference Time: The paper reports a modest increase in inference time. For example, with Qwen-2.5-7B on GSM8K, 5 SLOT iterations increased inference time by only 7.9% compared to baseline. Ablation studies (Table 4) show prompt processing throughput (SI) decreases by about 12-25% with T=1T=1 depending on LR, but doesn't significantly worsen with more iterations due to caching. Generation throughput (SO) is only slightly reduced.

Pseudocode (Algorithm 1 in the paper):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
Algorithm: Sample-specific LLM Optimization at Test-time (SLOT)

Input: Pre-trained LLM M, input tokens x, optimization steps T, learning rate eta, optimizer params lambda
Output: Generated sequence y

// Phase 1: Optimize delta on the input prompt
delta = initialize_zeros(1, d)
optimizer_state = initialize_optimizer_state()
H_cached = M_pre_LM(x)  // Get features before output head, cache them

for t = 0 to T-1:
    H_prime = H_cached + delta  // Broadcast delta
    logits = W_LM * H_prime
    loss = CrossEntropyLoss(logits[..., :-1, :], x[2:n]) // Predict next token in prompt
    gradients_delta = compute_gradients(loss, delta)
    delta = optimizer_step(delta, gradients_delta, eta, lambda) // e.g., AdamW

delta_opt = delta

// Phase 2: Generate using the optimized delta
y = ()
current_sequence = x

while not (eos_generated or max_length_reached):
    H_last = M_pre_LM(current_sequence)[-1, :] // Get last hidden state
    H_prime_last = H_last + delta_opt           // Reuse optimized delta
    logits_next = W_LM * H_prime_last
    next_token = sample_from_softmax(logits_next)
    append next_token to y
    append next_token to current_sequence

Return y

Code Snippet Explanation (from Appendix):

The provided code snippet illustrates the core logic within the model's forward pass:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
prompt_only = os.environ.get("prompt_only", "False") == "True"

if prompt_only:
    # --- Prompt Stage: Optimization ---
    with torch.enable_grad(): # Ensure gradients are computed for delta
        # Initialize delta (could be a class member or passed around)
        # If self.delta doesn't exist or needs re-init per sample:
        # self.delta = nn.Parameter(torch.zeros_like(hidden_states[:, :1, :])) # Assuming batch_size=1
        # Or, as in snippet, a new Parameter for each optimization phase
        delta_param = nn.Parameter(0.0 * torch.randn([1, 1, hidden_states.shape[-1]]).to(hidden_states))
        optimizer = torch.optim.AdamW([delta_param], lr=0.01, weight_decay=1e-8, eps=1e-5)

        # Optimization loop (T=3 steps in snippet)
        for _ in range(3): # T = 3
            optimizer.zero_grad()
            # H' = H + delta
            transformed_hidden = hidden_states + delta_param # hidden_states is H_cached
            logits = self.lm_head(transformed_hidden)

            # Calculate LM loss on the prompt
            loss_fct = nn.CrossEntropyLoss()
            # Align logits and labels for next token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous() # input_ids is x
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            loss.backward() # Computes gradients w.r.t. delta_param
            optimizer.step() # Updates delta_param

        # Store the optimized delta for reuse in generation stage
        # This assumes self.delta is a persistent attribute of the model instance
        # or managed appropriately in the inference script
        self.delta = delta_param.detach() # Detach to prevent further gradient tracking if not needed

    # Apply optimized delta immediately for the first token generation (if generation starts right after)
    # Or this hidden_states modification might happen at the start of the actual generation phase
    hidden_states = hidden_states + self.delta
    # Set flag to indicate optimization is done for this sample
    os.environ["prompt_only"] = "False" # So next forward pass (for generation) uses else branch
else:
    # --- Generation Stage: Reuse delta ---
    # Ensure self.delta was set from the prompt_only phase
    if hasattr(self.delta, 'dtype'): # Check if self.delta is initialized
        hidden_states = hidden_states + self.delta
    # Else: handle cases where delta might not be available (e.g. first call is not prompt_only)

Key takeaways from the code:

  • An environment variable prompt_only is used to switch between the optimization phase and the generation phase. This is a practical way to control behavior within the forward method during a multi-step inference process.
  • hidden_states (representing HH) are assumed to be the output of the LLM body before the LM head. These are cached implicitly by being passed into the optimization loop.
  • delta is initialized as a nn.Parameter to allow torch.optim to update it.
  • The loss is standard cross-entropy for LLMing, calculated on the input prompt.
  • After optimization, the optimized delta is stored (e.g., as self.delta) and then added to hidden_states for subsequent token generation.

Logit Modulation Vector (LMV)

SLOT's effect can be interpreted as a direct modulation of the output logits. The paper defines the Logit Modulation Vector (LMV) as:

LMVWLMδoptRVLMV \triangleq W_{\text{LM}} \delta_{\text{opt}} \in \mathbb{R}^{|V|}

This vector represents an additive shift to the logits for every token in the vocabulary. Observations on GSM8K (Figure 3) show:

  • Increased probability: Tokens related to reasoning (e.g., "2", "reasoning") are enhanced.
  • Decreased probability: Numerical tokens (e.g., "0", "1", "2") and common function words (e.g., "should", "will") are suppressed. Interestingly, the end-of-text token eos is also suppressed, potentially encouraging longer, more detailed reasoning chains.

Experimental Results and Practical Applications

SLOT demonstrates significant improvements across various models and benchmarks:

  • Qwen-7B:
    • C-Eval: +8.55 on 'Hard' subset, +4.19 on 'STEM', +1.05 overall average.
    • GSM8K: +3.0 (51.2% to 54.2%).
    • HumanEval: +1.8 (29.9% to 31.7%).
  • Qwen2.5-7B:
    • GSM8K: +8.6% (57.54% to 66.19%) (from abstract, Figure 1 shows improvement from ~57% to ~66% with 5 iterations).
  • DeepSeek-R1-Distill-Llama-70B:
    • GPQA: Achieved 68.69% (a +3.03% improvement), SOTA for 70B open-source models at the time.
    • AIME24: +10.00% (63.33% to 73.33%).
  • General Trend: SLOT consistently boosts performance on reasoning tasks (AIME24, Math500, GPQA, GSM8K) for both base models and specialized reasoning models. Improvements are often substantial (e.g., +10% on AIME24 for Qwen2.5-32B and DeepSeek-R1-Distill-Llama-70B).

Potential Applications:

  1. Improving Reasoning in LLMs: For tasks requiring complex multi-step reasoning (math problems, science Q&A), SLOT can make models more robust and accurate by helping them better adhere to the problem's specific constraints and reasoning structure.
  2. Enhancing Instruction Following: When prompts have strict formatting requirements or complex instructions not well-represented in general training data, SLOT can adapt the model at test-time to follow these instructions more faithfully. Figure 1 shows improved format accuracy alongside answer accuracy.
  3. Resource-Constrained Deployment: Since SLOT adds minimal computational overhead, it can be a practical way to boost the performance of smaller or existing LLMs without needing to retrain or use a much larger model.
  4. Specialized Domains: For applications in specialized domains where prompts might contain unique jargon or conventions, SLOT could help the model adapt quickly to these sample-specific elements.

Ablation Study Insights:

  • SLOT is relatively insensitive to hyperparameters like the number of optimization iterations (TT) and learning rate (η\eta), with most configurations outperforming the baseline.
  • Optimal performance in one test (DeepSeek-R1-Distill-Qwen-1.5B on AIME-24) was found with (T=4,η=0.05)(T=4, \eta=0.05) and (T=5,η=0.05)(T=5, \eta=0.05).
  • The number of optimization steps (TT) does not significantly increase prompt processing time beyond T=1T=1 due to feature caching.

Conclusion

SLOT is a parameter-efficient, test-time adaptation technique that optimizes a small, sample-specific vector δ\delta by minimizing the cross-entropy loss on the input prompt itself. This allows LLMs to better align with and follow individual instructions, leading to significant performance gains on complex reasoning benchmarks with minimal computational overhead. Its efficiency stems from caching last-layer features during the per-sample optimization of δ\delta. The method's ability to enhance reasoning-related token probabilities (via LMV) suggests it encourages deeper processing. SLOT offers a practical approach for improving LLM responses on-the-fly, particularly for challenging or unfamiliar prompts.

Github Logo Streamline Icon: https://streamlinehq.com
X Twitter Logo Streamline Icon: https://streamlinehq.com