- The paper introduces state tuning, a method that optimizes RWKV-7's internal state matrix without modifying its pre-trained weights.
- It details four approaches—including standard tuning, dynamic scaling with kernel methods, DBP-enhanced tuning, and test-time scaling—to adapt the model for specific tasks.
- Experiments show that DBP-enhanced tuning delivers the highest gains, significantly outpacing the baseline on benchmarks like MMLU and GSM8K.
This paper introduces "State Tuning," a set of techniques designed to enhance the performance of the RWKV-7 "Goose" LLM without modifying its pre-trained weights. The core idea is to optimize the model's internal state matrix St to adapt it to specific tasks or improve its general capabilities during inference. This addresses the common challenge of needing large computational resources for training or fine-tuning LLMs, offering more efficient ways to boost smaller models like the 7B parameter RWKV-7.
The paper proposes four distinct methods:
- Standard State Tuning: This baseline approach involves initializing a new state matrix S0 (e.g., with zeros) and optimizing it directly using backpropagation on a target task dataset. All other model parameters remain frozen. The state St∈RN×N is updated at each time step using the standard RWKV recurrence relation, but only St itself is treated as trainable parameters. This allows the model to adapt its internal memory dynamics to the specific task while retaining its pre-trained knowledge.
- Dynamic Scaling with Kernel Method: To increase the model's expressive capacity beyond the original state dimension N, this method uses a kernel trick to effectively upscale the state matrix to a higher dimension M×M (where M>N).
- Support vectors {u1,…,uM} are chosen.
- A kernel function (e.g., Gaussian kernel K(u,v)=exp(−γ∥u−v∥2)) is used to map the RWKV internal vectors (wt,kt,at,vt,rt∈RN) into higher-dimensional feature vectors ϕ(⋅)∈RM.
- The state update and output computation happen in this higher M-dimensional space using the transformed vectors ϕ(⋅) and an upscaled state matrix St∈RM×M.
- The final output is projected back to the original dimension N using a fixed projection matrix Q.
- Only the upscaled state matrix St∈RM×M is tuned. This adds non-linearity and increases state capacity without changing weights.
- DBP-Enhanced Dynamic State Tuning: This method builds upon the dynamic scaling approach by incorporating Decorrelated Backpropagation (DBP). DBP aims to improve training efficiency and model expressivity by decorrelating the inputs to layers. Here, it's adapted to decorrelate the kernel-transformed vectors ϕ(⋅) before they are used in the state update.
- A decorrelation matrix R∈RM×M is introduced and applied to the kernel features (e.g., ϕ(kt)decor=Rϕ(kt)).
- The state update uses these decorrelated vectors.
- A decorrelation loss Ldecor is added to the task loss, penalizing correlations between components of the transformed vectors and encouraging unit variance.
- Both the upscaled state matrix St and the decorrelation matrix R are jointly optimized during training. DBP is expected to accelerate convergence and lead to more expressive state representations.
- Test-Time Scaling with Larger Model Guidance: This technique performs state tuning during inference for each input sequence, guided by a larger, more capable LLM.
- For a given input sequence, the larger LLM generates a step-by-step Chain of Thought (COT) reasoning sequence.
- At each generation step t of the RWKV-7 model, its current state St is optimized using Reinforcement Learning (RL).
- A reward R(St,xt+1) is defined based on how well the RWKV-7 model's next predicted token xt+1 aligns with the corresponding step in the larger LLM's COT sequence (using log-probabilities from the larger LLM).
- The gradient of this reward with respect to St is computed, and St is updated via gradient ascent for a few iterations.
- The tuned state St is then used to generate the next token xt+1.
- This allows RWKV-7 to dynamically adapt its internal state at test time to perform more complex reasoning, guided by the larger model, without requiring prior training on similar reasoning tasks.
Experiments and Results:
The methods were evaluated on the RWKV-7 "Goose" 7B model using benchmarks like MMLU (general knowledge), GSM8K (math reasoning), WinoGrande (commonsense), and ARC-Challenge (scientific reasoning).
- All proposed methods significantly outperformed the vanilla RWKV-7 baseline.
- Standard State Tuning provided a solid improvement (approx. 7-8 points absolute gain on benchmarks like MMLU and GSM8K).
- Dynamic Scaling offered further improvements over standard tuning.
- DBP-Enhanced Dynamic State Tuning achieved the best results across all benchmarks (e.g., 79.0% MMLU, 89.0% GSM8K), demonstrating the benefits of decorrelating state inputs for enhanced expressivity and potentially faster convergence during tuning.
- Test-Time Scaling performed nearly as well as the DBP-enhanced method (e.g., 78.6% MMLU, 88.5% GSM8K), showcasing the effectiveness of dynamic, inference-time adaptation guided by a larger model.
Conclusion:
The paper successfully demonstrates that state tuning, in various forms, is an effective and computationally efficient strategy for enhancing the performance of the RWKV-7 model without altering its pre-trained weights. The DBP-enhanced method provides the highest performance gains among the training-based approaches, while the test-time scaling method offers a flexible way to leverage larger models for guidance during inference. These techniques present practical ways to improve smaller models for complex tasks, especially in resource-constrained settings.