LLaMPPL: SMC Steering for LLMs
- LLaMPPL is a probabilistic programming framework that models constrained text generation as a Feynman-Kac process using SMC methods.
- It integrates a custom SMC inference engine and shared caching with LLaMA-family transformers for efficient, diversity-preserving decoding.
- Empirical results show that LLaMPPL achieves sampling-correct, constraint-respecting completions with runtime close to traditional beam search.
LLaMPPL is a probabilistic programming framework for controlling LLMs at inference time using sequential Monte Carlo (SMC) methods, with an explicit focus on enforcing syntactic and semantic constraints during text generation. Developed to address the limitations of prompt-based steering, LLaMPPL enables users to represent complex constrained or structured generation tasks as probabilistic programs and then sample strings from the exact posterior induced by those programs rather than resorting to heuristic decoding mechanisms. The framework provides a lightweight Python API for rapid prototyping of such inference-time constraints, with particular emphasis on integration with LLaMA-family Transformers. The core methodology and implementation are described in detail in "Sequential Monte Carlo Steering of LLMs using Probabilistic Programs" (Lew et al., 2023).
1. Design Goals and Architecture
LLaMPPL is designed to let users "write down" a broad class of constrained text-generation tasks as concise probabilistic programs, which are then executed using an SMC-based algorithm that guarantees sampling-correct completions under the model's posterior. Unlike ad-hoc constraint satisfaction or beam-search masking strategies, LLaMPPL encodes generation requirements into a Feynman-Kac model composed of a Markov sequence model (typically autoregressive token selection) with non-negative potentials enforcing constraints.
The architecture consists of:
- A minimal Python probabilistic programming library: All generative tasks are encoded by subclassing a Model object and implementing a step() method. This method interacts with five primitives:
sample(dist [, proposal])condition(boolean)observe(dist, value)transformer(context)finish()
- Underlying Feynman-Kac transformer model: Each probabilistic program defines an initial state, a sequence of Markov kernels , and potentials .
- Custom SMC inference engine: Operates across a particle population, sharing a single CachedTransformer to minimize expensive repeated model invocations, and utilizing a "without-replacement" stratified resampling strategy to maintain particle diversity.
- Efficient transformer integration: Special support for the LLaMA transformer family, using prefix tries and shared key/value caches across particles to ensure competitive runtime relative to standard beam search.
2. Probabilistic Program Representation for LLM Steering
Every LLaMPPL program defines:
- State: The current partial output string , typically initialized in
__init__(). - Actions: The step() method may at each token:
sample(D): Draw a token from distribution (e.g., transformer output).sample(D, proposal=Q): Draw from proposal for importance sampling, with weighted by .observe(D, v): Condition on the next token being , factoring in .condition(bool): Zero out particle weight if boolean is false.transformer(ctx): Return the next-token categorical LM distribution for context .finish(): Append EOS and terminate.
- Markov process: The chain 0 generates token sequences, with transitions governed by 1 and weighted by 2.
- Posterior sampling objective: The posterior over completed strings is proportional to 3, enabling explicit definition of desired constraints or structure.
3. Mathematical Framework: Feynman-Kac Models and SMC Steering
3.1 Feynman-Kac Sequence Models
Let 4 be the LM vocabulary, 5 the set of all finite sequences, EOS 6, and 7 the set of EOS-terminated strings. The generation process is a stochastic chain with:
- Initial state 8 (typically the prompt or empty string).
- At each time step 9, a Markov kernel 0 selects new tokens.
- Non-negative potential functions 1 encode arbitrary constraints.
The unnormalized posterior is: 2 with normalization
3
3.2 Sequential Monte Carlo (SMC) Algorithm
LLaMPPL's SMC implementation proceeds as follows:
- Expansion: Non-EOS particles are cloned 4 times, each extended using the proposal kernel 5.
- Weight Update: Each clone's weight is updated via
6
with 7 the current active particle/cloning count.
- Normalization: All particle weights normalized to sum to 1.
- Without-replacement Resampling: Compute minimal 8 to ensure 9; select particles deterministically or stochastically, with unbiased estimator for 0.
- Termination: Repeat until all particles have produced EOS-terminated strings.
This "without-replacement" mechanism preserves sample diversity better than multinomial SMC and yields more faithful posterior inference.
4. LLaMPPL Python API and Example Use Cases
The LLaMPPL API centers on subclassing Model and defining the step() method. Selected task examples:
A. Word Length Constraint
1
B. Infilling between Fragments
2
C. Prompt Intersection (Product-of-Experts)
3
These high-level abstractions make explicit a range of constraints through direct probabilistic modeling.
5. Integration with LLaMA-family Transformers
LLaMPPL supports efficient inference-time control of Meta's LLaMA family by introducing a CachedTransformer object:
- Prefix Trie: Tracks all prefixes seen by any particle, storing for each:
- Token sequence
- Cached logits from last expansion
- Full key/value (KV) caches from all transformer layers
- Cache Traversal: Logits for any context 1 can be retrieved in 2 time if present, or else by running the transformer only on the difference from the longest shared prefix (“new suffix mode”), after which logits and KV activations are cached.
- Efficiency: All particles share this cache, thus wall-clock time matches that of beam search with width 3 (number of particles), not 4 independent LM runs. Empirically, SMC decoding with 5 particles and 6 clones is only 7 as costly as beam search with beam size 8, but yields a diverse set of solutions faithful to imposed constraints (Lew et al., 2023).
6. Experimental Results and Performance Evaluation
Empirical findings using LLaMPPL with LLaMA-7B demonstrate:
- Prompt Intersection: Comparing two Feynman-Kac model formulations for generating sentences likely under two prompts, Model B (using product-of-logits proposals) achieves higher average unbiased log-normalizers 9 and qualitatively better prompt intersections than Model A, across a range of particle counts 0.
- Sample Diversity and Cost: SMC steering with LLaMPPL produces diverse samples fulfilling hard constraints, avoiding the mode collapse associated with greedy or beam search. Runtime remains within a small constant factor of standard beam search, due to shared data structures and KV cache optimization.
- Constraint Satisfaction: With properly specified probabilistic programs, SMC steering yields "sampling-correct" constrained completions, validating the approach's rigor and extensibility (Lew et al., 2023).
7. Implications, Limitations, and Extensions
LLaMPPL illustrates that the probabilistic programming paradigm, paired with optimized SMC inference, can deliver principled, practical solutions for imposing arbitrary constraints on LLM outputs without retraining or functionally intrusive search hacks.
- Scalability and Generality: The approach generalizes to any constraint expressible as a potential in a Feynman-Kac model, ranging from syntactic templates to semantic requirements and multi-prompt intersections.
- Limitations: The method's performance and memory costs scale with particle number and per-step expansion, constraining its use for extremely long or high-dimensional generation tasks. All constraints must be tractably encoded as Markov kernels and local potentials.
- Extension Potential: The shared cache and inference machinery could enable extension to other transformer models, or hybrid approaches incorporating human feedback or symbolic reasoning.
In conclusion, LLaMPPL formalizes constrained language generation as probabilistic inference and makes SMC-based steering practically viable for LLaMA-family LLMs, producing faithful, diverse, and constraint-respecting outputs at competitive inference cost (Lew et al., 2023).