BREP ReFT: Bias-Restrained Prefix FineTuning
- The paper introduces BREP ReFT, which optimizes per-layer scaling and bias vectors to enhance mathematical reasoning accuracy with minimal parameter overhead.
- BREP ReFT employs prefix-truncated training and PID-controlled bias constraints to prevent degradation of numerical encoding and reasoning prefixes.
- Empirical results demonstrate that BREP ReFT outperforms standard PEFT methods on benchmarks like GSM8K, achieving superior accuracy with orders-of-magnitude fewer parameters.
Bias-Restrained Prefix Representation FineTuning (BREP ReFT) is a representation finetuning technique for LLMs designed to address limitations of standard ReFT on mathematical reasoning tasks. BREP ReFT achieves high parameter efficiency by freezing all pretrained model weights and optimizing only per-layer scaling and bias vectors that act directly on hidden states. The method introduces prefix-truncated training, early-stage intervention, and a PID-constrained bias magnitude objective to prevent degradation of mathematical inference accuracy commonly observed in conventional PEFT (Parameter-Efficient Finetuning) and unrestricted ReFT approaches. Empirical evidence demonstrates that BREP ReFT attains superior accuracy and generalization on chain-of-thought mathematical reasoning benchmarks, matching or exceeding the performance of weight-based PEFT methods with orders-of-magnitude fewer learned parameters (Liang et al., 13 Nov 2025).
1. Background and Motivation
Parameter-Efficient Finetuning (PEFT) schemes, such as LoRA and Prefix-tuning, update only a fraction of the model weights, enabling task adaptation without full weight updates. Representation Finetuning (ReFT) extends PEFT by solely learning lightweight transformations—elementwise scaling and bias—on intermediate hidden states while freezing pretrained weights. ReFT is highly parameter-efficient and succeeds on commonsense and instruction-following tasks. However, on mathematical benchmarks such as GSM8K, ReFT shows a notable accuracy drop (∼11.5% lower than PEFT).
Two principal failure modes are identified:
- Misleading Reasoning Prefixes: ReFT-finetuned models tend to generate poor initial chain-of-thought (CoT) tokens (the "reasoning prefix"), leading to erroneous subsequent inference.
- Numerical Encoding Degradation: The learned per-layer bias vectors introduce deviations in the model’s internal linear encoding of numbers, with the effect compounding across autoregressive token generation steps. Empirical projection of these biases onto the original number-encoding direction indicates frequent excursions beyond a critical threshold, correlating with elevated addition error rates.
These findings motivate BREP ReFT, which targets the initialization phase of mathematical reasoning and incorporates explicit constraints to control representational drift.
2. Mathematical Formulation
BREP ReFT formalism integrates a standard transformer architecture with modified ReFT and a prefix-focused, bias-constrained optimization schema.
- Transformer with ReFT:
For input , let be the initial embedding. At layer , the transformer computes:
Standard ReFT then applies:
where indicates elementwise multiplication.
- Prefix-Focused Objective:
The training objective focuses on the initial tokens of each target response sequence . At token position , the per-token prefix reward is:
where is the full sequence length. The cumulative prefix reward is maximized:
Equivalently, the loss minimized is the mean negative log-likelihood over the truncated prefix:
- Bias-Restraint via PID Control:
The average per-layer bias magnitude is
and the instantaneous PID error is . A PID controller computes:
The total loss is then:
3. Algorithmic Procedures
BREP ReFT integrates three main algorithmic components: prefix-truncated training, two-stage inference, and magnitude constraint via PID control.
- Prefix-Truncated Training: Each sample’s target response is truncated to its first tokens. The optimization occurs only over these tokens, directly shaping the model's initial reasoning behavior and sharpening prefix accuracy. The prefix-truncation training loop proceeds as follows:
1 2 3 4 5 |
for each batch {x, full_response y[1..l_f]}: y_prefix = y[1..l_p] compute p_t = model(x, y_prefix[1:t-1]) for t=1..l_p L_ce = -(1/l_p) * sum_{t=1}^{l_p} log p_t(y_t) update bias W, b via ∇ L_total = w(t) ∇L_ce |
- Two-Stage Inference: During decoding, ReFT transforms (scaling and bias) are applied only to the first generated tokens. For subsequent positions, the base (unfinetuned) model representation is used, preventing error propagation through the CoT.
1 2 3 4 5 6 7 8 |
for t in 1..T:
if t <= n:
apply ReFT transforms
else:
leave hidden states unchanged
sample y_t ∼ p(· | x, y_{1:t-1})
append y_t to output
return output sequence |
- Bias Magnitude Constraint: PID hyperparameters are initialized as , and the target bias norm is set per model family.
4. Experimental Setup and Evaluation
Models:
- Llama3-8B-Instruct, Llama3.1-8B-Instruct
- Qwen2.5-Math-7B-Instruct, Qwen3-8B, Qwen3-14B
Training and inference:
- Typical prefix lengths: Llama (, ); Qwen2.5-7B (66,10); Qwen3-8B (67,11); Qwen3-14B (68,12).
- AdamW optimizer; learning rates (Llama) and (Qwen).
- Computation: Single NVIDIA A100 80 GB for 1 hour of training on 5K samples.
Datasets:
- Simple reasoning: MATH10K subsampled for GSM8K, SVAMP, MathQA
- Complex reasoning: PRM800K (5K) for MATH500, AMC23
Benchmarks: GSM8K, SVAMP, MathQA, MATH500, AMC23 Baselines: Base (frozen), LoRA, RED (RepEdit), LoReFT Metrics: Answer correctness (chain-of-thought verification)
Results summary:
| Model | Method | GSM8K | SVAMP | MathQA | MATH500 | AMC23 |
|---|---|---|---|---|---|---|
| Llama3-8B | Base | 80.0 | 88.9 | 55.0 | 40.4 | 57.5 |
| LoRA | 81.1 | 90.0 | 54.0 | 39.3 | 53.8 | |
| RED | 73.8 | 88.9 | 51.3 | 41.5 | 56.4 | |
| LoReFT | 78.8 | 80.7 | 44.7 | 37.0 | 35.0 | |
| BREP | 82.8 | 89.5 | 54.3 | 42.8 | 52.5 | |
| Qwen3-8B | Base | 95.1 | 96.7 | 86.5 | 82.0 | 85.0 |
| LoRA | 95.1 | 96.8 | 86.2 | 81.8 | 87.5 | |
| RED | 87.9 | 91.8 | 77.3 | 54.2 | 35.0 | |
| LoReFT | 87.1 | 96.3 | 72.8 | 72.4 | 80.0 | |
| BREP | 95.3 | 97.4 | 86.3 | 82.6 | 87.5 |
BREP improves GSM8K accuracy by up to +2.8 points over base Llama3-8B, and shows consistent gains or parity across all benchmarks and model families.
Efficiency:
BREP introduces only $2d$ parameters per layer (scaling + bias), representing of total model parameters. Training is rapid (1 hour for 5K examples), and inference time increases negligibly due to intervention being limited to the prefix.
5. Analysis and Ablation
Ablation studies confirm the contributions of BREP’s distinct components:
| Model | Variant | GSM8K | MATH500 |
|---|---|---|---|
| L3-8B | Full BREP | 82.8 | 42.8 |
| w/o Prefix Truncation | 81.0 | 40.2 | |
| w/o Bias Constraint | 80.0 | 39.4 | |
| w/o Early Intervention | 80.4 | 37.6 | |
| Q3-8B | Full BREP | 95.3 | 82.0 |
| w/o Prefix Truncation | 95.5 | 79.4 | |
| w/o Bias Constraint | 94.9 | 79.8 | |
| w/o Early Intervention | 95.1 | 81.6 |
Each component—prefix truncation, bias constraint, and early-stage intervention—substantially affects mathematical reasoning accuracy, especially on longer CoT benchmarks.
Probing internal representations demonstrates preservation or improvement of linear number encoding with BREP, in contrast to the degradation observed in unconstrained ReFT.
6. Implementation Guidelines
To adopt BREP ReFT:
- Freeze the pretrained LLM and insert per-layer learnable scaling () and bias () vectors.
- Prepare dataset and truncate each example's target response to the first tokens suitable for the target model.
- Implement PID control to maintain near via updating loss weight .
- Train for $5$K steps on a single high-memory GPU, optimize total loss as above.
- During inference, apply the ReFT transformations only for the first tokens, then revert to base model.
- Decoding may use greedy or preferred CoT policies.
Reference code and scripts are publicly available at https://github.com/LiangThree/BREP.
7. Context, Significance, and Limitations
BREP ReFT establishes new accuracy and generalization standards for representation-efficient finetuning on mathematical reasoning tasks, mitigating the prefix misalignment and numerical encoding drift present in prior ReFT methods. By concentrating adaptation on the initial reasoning prefix and tightly regulating representational drift, BREP enables robust mathematical inference with minimal parameter overhead. For out-of-domain commonsense tasks (BoolQ, PIQA, GPQA), BREP maintains or improves generalization relative to baselines.
A plausible implication is that BREP’s separation of early-stage reasoning intervention from subsequent unperturbed token generation is applicable to other tasks with critical prefix dependencies. However, potential scaling to multi-turn dialogue or highly compositional mathematical contexts may warrant further investigation, as may optimal selection of prefix and intervention lengths for arbitrary architectures.
BREP ReFT provides a reproducible procedure and open-source codebase for investigation and deployment in high-stakes mathematical language modeling, offering a systematic framework to balance adaptation capacity and representational stability for rigorous downstream reasoning applications.