Papers
Topics
Authors
Recent
Search
2000 character limit reached

FlashSampling: Fast and Memory-Efficient Exact Sampling

Published 16 Mar 2026 in cs.LG, cs.AI, and cs.CL | (2603.15854v1)

Abstract: Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because $\argmax$ decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to $19%$ on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.

Summary

  • The paper introduces a fused matmul-epilogue technique that integrates exact Gumbel-max sampling directly into LLM inference, eliminating full-logits materialization.
  • It employs a hierarchical group-Gumbel-max decomposition to enable exact sampling in tensor-parallel and streaming regimes while ensuring strict distributional correctness.
  • Empirical results on NVIDIA GPUs demonstrate up to 2.52× speedup and significant reductions in Time Per Output Token, especially for low-batch decoding.

FlashSampling: Fused Matmul-Epilogue Exact Sampling for Large-Vocabulary LLM Decoding

Motivation and Problem Formulation

Autoregressive LLM decoding repeatedly samples categorical distributions over large vocabularies, incurring significant memory and kernel overhead. Conventional pipelines materialize the [B,V][B,V] logits tensor in HBM, followed by normalization and sampling steps that detract from hardware efficiency and increase latency. While the mathematical operation—drawing a sample from Cat(softmax())\mathrm{Cat}(softmax(\ell))—is trivial, the systems bottleneck arises from unnecessary memory round-trips and kernel launches at low batch sizes, especially in the decode regime. Large-vocabulary models (V105V \gg 10^5) further exacerbate HBM and synchronization costs. Figure 1

Figure 1: Conventional multinomial sampling writes full logits to HBM and launches a separate sampling kernel, while FlashSampling fuses sampling into the matmul epilogue, avoiding logits materialization and reducing HBM traffic.

FlashSampling Algorithmic Design

FlashSampling deploys two core innovations:

  1. Fused Matmul-Epilogue Sampling: Sampling is fused directly into the LM-head matmul kernel. Logits are computed tile-by-tile in on-chip memory, perturbed via i.i.d. Gumbel noise, and only the tile-local maximizer per row is retained. The epilogue avoids writing the [B,V][B,V] logits tensor, instead performing a lightweight reduction over vocabulary tiles to select the global maximizer.
  2. Hierarchical Group-Gumbel-Max Decomposition: Exact sampling remains feasible in tensor-parallel and streaming regimes through log-mass factorization. Each shard/group computes a local sample and log-mass, and a secondary reduction samples over these summarized states.

Algorithmically, FlashSampling bypasses explicit probability formation and normalization. The fused kernel’s exactness is pathwise via arg max\argmax decomposition; distributed and online variants achieve exact distributional correctness via hierarchical composition of categorical distributions and binary merge rules. Random number generation is deterministic via counter-based indexing for reproducibility.

Theoretical Guarantees

FlashSampling delivers strict exactness with respect to the target Cat(softmax())\mathrm{Cat}(softmax(\ell)) distribution. The fusion’s correctness arises from max-stability: maxi[V]xi=maxtmaxiTtxi\max_{i\in[V]} x_i = \max_t \max_{i\in\mathcal{T}_t} x_i for partitioned vocabulary tiles. For online and tensor-parallel variants, the hierarchical log-mass factorization ensures that structured sampling and binary merge regimes remain exact. No approximation is introduced: the method always returns the arg max\argmax of perturbed logits according to the classical Gumbel-Max trick.

Practical Integration: Systems and Kernel Fusion

FlashSampling leverages existing trends in IO-aware kernel fusion, as seen in FlashAttention and similar efforts applied to cross-entropy. By retaining all expensive computation within the matmul epilogue, the method eliminates additional memory-bound passes for normalization and sampling. HBM savings due to avoiding logits writes/reads are minimal in isolation, but the holistic removal of kernel launches and synchronizations yields substantial real-world speedups. Triton implementation allows platform-agnostic deployment, albeit with some tradeoffs versus cuBLAS GEMM efficiency at high batch sizes.

Empirical Evaluation and Results

FlashSampling was benchmarked across NVIDIA H100, H200, B200, and B300 GPUs. Microbenchmarks and end-to-end vLLM integration were performed on models such as Qwen3-1.7B, Qwen3-8B, Qwen3-32B, and gpt-oss-120b. Relative performance is summarized below. Figure 2

Figure 2

Figure 2: FlashSampling outperforms Multinomial Sampling and FlashInfer baselines across all batch sizes on B300, achieving fastest decode in the low-batch regime.

FlashSampling exhibits up to 1.84×\times speedup vs. Multinomial Sampling and up to 2.52×\times speedup vs. FlashInfer FI1 at B64B \le 64; speedups narrow as batch grows due to baseline GEMM efficiency improvements. Sampling runtime grows steeply with batch size for baselines, while FlashSampling absorbs sampling into the matmul with negligible overhead. Figure 3

Figure 3

Figure 3: Sampling and matmul runtimes in μ\mus vs. batch size show FlashSampling’s minimal epilogue overhead and baseline kernel launch costs.

Roofline analysis demonstrates that FlashSampling achieves higher HBM bandwidth utilization, tracking the memory-bound slope and outperforming baselines in decode-dominated workloads. Figure 4

Figure 4

Figure 4: FlashSampling sits above baselines on the roofline plot and achieves superior bandwidth utilization on H100 for small batch sizes.

End-to-end vLLM evaluation reveals reductions in Time Per Output Token (TPOT) of up to 19\% on Qwen3-1.7B. Speedups scale proportionally with the fraction of decode time spent on the LM head, diminishing for larger models where attention and FFN dominate. Figure 5

Figure 5

Figure 5

Figure 5

Figure 5: TPOT reductions vs. concurrency for four models show FlashSampling’s strongest gains on small architectures; larger models are dominated by attention and FFN.

Correctness verification at the kernel and inference levels establishes strict statistical fidelity to reference sampling implementations; GSM8K accuracy is statistically indistinguishable between FlashSampling and conventional methods.

Extension to Top-kk, Nucleus, and Masking

The tile-based structure allows efficient integration of top-kk sampling and localized soft/truncate operations; Gumbel-Top-kk extensions are immediate. Nucleus sampling (top-pp) is less amenable to tile-local reduction but can be combined sequentially after top-kk. Masking is supported by setting logits to -\infty pre-perturbation.

FlashSampling builds on statistical and systems foundations in the Gumbel-Max trick, IO-aware kernel fusion, and large-vocabulary LLM inference. Prior works such as FlashInfer, Qrita, and SIMPLE address sampling via pre-materialized logits, introducing various levels of approximation or hardware offloading. FlashSampling’s contribution is to achieve strict exactness with zero approximation and maximal memory locality via kernel fusion.

Implications and Future Directions

FlashSampling’s stricter separation of model computation from memory-bound post-processing yields demonstrable efficiency improvements in practical LLM inference, with particular benefit to low-batch, memory-bound workloads. The method’s hierarchical factorization allows for scalable, distributed, and online variants suitable for both CPU-GPU heterogeneous inference and large-cluster serving. As Transformer vocabulary sizes and serving workloads continue to scale, systems-level kernel fusion for sampling will become a critical optimization. Future work may incorporate approximate softmax/normalizers as optional features, extend to more complex sampling strategies (beam search, stochastic sequence sampling), or explore fully integrated whole-model epilogues.

Conclusion

FlashSampling provides a fused, tile-wise matmul epilogue for exact categorical sampling, eliminating unnecessary memory-bound post-processing and kernel overhead. The method’s simplicity and exactness are theoretically validated via max-decomposition and log-mass factorization. Empirical results across multiple hardware platforms and LLM architectures demonstrate consistent kernel and inference time reductions in the decode regime. FlashSampling sets a new standard for efficient large-vocabulary sampling in neural LLM inference, with substantial implications for future AI deployment at scale.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Explain it Like I'm 14

What is this paper about?

This paper introduces FlashSampling, a new way to pick the next word from a large list of choices when a LLM (like ChatGPT) is generating text. It makes this “picking” step both faster and more memory‑friendly by folding it directly into the model’s last math step, instead of doing it afterward as a separate job. Importantly, it still picks words exactly as the usual method would—no shortcuts or approximations.

What questions are the researchers trying to answer?

  • Can we sample (pick) the next token from a huge vocabulary without first writing a huge table of scores (“logits”) to slow, off‑chip memory?
  • Can we fuse sampling into the model’s final matrix multiplication (the LM head) so there’s no extra pass over those scores?
  • Can we do this in a way that stays exact (i.e., gives the same results as traditional methods), works on multiple GPUs, and speeds up real systems?

How does it work? (Simple explanation with analogies)

Imagine you’re choosing a candy from a giant shelf with thousands of candies. Each candy has a score (how much the model “likes” it). Traditionally, you would:

  1. Write down all the scores on a big sheet (slow),
  2. Turn scores into probabilities,
  3. Roll a random number to pick a candy.

That big sheet is called the logits tensor, and writing/reading it is slow because it lives in large but slower memory (HBM on a GPU).

FlashSampling does something smarter:

  • Instead of writing the whole sheet, it scans the shelf in small chunks (called tiles), right where the math is happening on the chip (fast, on‑chip memory).
  • For each candy’s score, it adds a tiny dose of randomness (called Gumbel noise). Here’s the neat trick: if you add this specially chosen random noise to each score and then pick the biggest one, you get exactly the same result as sampling from the correct probability distribution. This is known as the “Gumbel‑Max trick.”
  • In each chunk (tile), it keeps only the best candy (the local winner) and throws away the rest.
  • At the very end, it looks at the few tile winners and picks the overall best. Because “the biggest of all numbers is also the biggest among the chunk winners,” this gives the same winner you’d have gotten by scanning everything at once.

Why this is faster:

  • On‑chip memory is like a small, super‑fast notepad on your desk. Off‑chip HBM is like a big binder across the room. FlashSampling keeps almost everything on the notepad and only writes a tiny summary (one winner per chunk) to the binder.
  • Fusing the sampling into the main math step avoids launching extra “sampler” programs and avoids multiple trips to the binder.

What about multiple GPUs?

  • If the vocabulary is split across different GPUs, each GPU picks its local winner and a small “group score.” Then a tiny final step picks which GPU’s winner is the global winner. This avoids sending huge score lists between GPUs.

What did they find, and why is it important?

  • It’s exact: FlashSampling produces the same distribution of choices as the standard method—no approximations, just a different (and clever) way of doing the same thing.
  • It’s faster in practice:
    • On NVIDIA H100, H200, B200, and B300 GPUs, FlashSampling sped up the decoding kernel (the low‑level “pick the next token” work) compared to strong baselines.
    • Integrated into vLLM (a popular LLM serving stack), it reduced the time per output token by up to 19% on some models, especially when generating text one or a few tokens at a time (the common case in interactive use).
  • Why it’s faster: The biggest gains don’t just come from writing fewer bytes. They mainly come from removing extra steps (kernels) and memory round‑trips that add overhead in the most time‑sensitive part of generation.

In short: FlashSampling turns sampling from a large, memory‑heavy afterthought into a light add‑on to the main computation, saving time without changing the results.

What could this change in the future?

  • Faster responses from LLMs: When you chat with an AI, it picks one token at a time. Making that picking step faster speeds up the whole conversation.
  • Better scaling: For very large vocabularies and multi‑GPU setups, FlashSampling cuts communication and memory costs by only sending tiny summaries, not massive score tables.
  • A template for other speedups: Like FlashAttention did for attention layers, FlashSampling shows how “don’t write giant intermediate results to slow memory” can be a powerful idea. This could inspire similar fusions for other parts of AI models.

Overall, FlashSampling is a simple, elegant change—add random noise and keep only the best per chunk—that preserves exact behavior and delivers real speedups in modern LLM systems.

Knowledge Gaps

Knowledge gaps, limitations, and open questions

Below is a single, concrete list of what remains uncertain, missing, or unexplored, to guide future research and engineering work:

  • Top-k/Top-p integration is not implemented or benchmarked end-to-end:
    • While theoretical compatibility with Gumbel-Top-k is stated and a top-k-then-top-p strategy is suggested, there is no fused implementation, accuracy analysis vs. standard top-p, or performance measurement across realistic decode settings.
  • Distributed/tensor-parallel variant lacks empirical validation:
    • The paper gives a theoretical “grouped”/sharded design that avoids O(V) all-gathers, but reports no end-to-end multi-GPU/multi-node benchmarks, NCCL communication traces, or determinism tests under tensor parallelism.
  • Determinism across different partitions/shardings is not fully specified:
    • Pathwise reproducibility with counter-based RNG is described for fused tiles, but it is unclear how to guarantee identical outcomes across different tilings/shardings (grouped variants are exact in distribution, not necessarily pathwise). A concrete, implementable scheme for shard-agnostic determinism is missing.
  • CuBLAS/cuBLASLt/CUTLASS epilogue integration is absent:
    • The fused Triton GEMM shows lower matmul efficiency at larger batch sizes; there is no prototype using vendor libraries (e.g., cuBLASLt epilogues or EVT CUTLASS fusion) to preserve cuBLAS matmul performance while keeping sampling fused.
  • Quantization and low-precision support is untested:
    • No experiments with FP8/INT8 (or mixed-precision) LM-heads, nor analysis of numerical stability and exactness under quantization, structured sparsity, or Blackwell FP8 pathways.
  • RNG generation cost and quality are not characterized:
    • The overhead of on-chip Gumbel generation, vectorization strategies, and its impact on occupancy/throughput are not profiled; there is no cross-platform (e.g., different GPU vendors) assessment of RNG quality and reproducibility.
  • Extreme-logit and masking edge cases are not stress-tested:
    • Numerical stability under very large/small logits, heavy masking (including rows with nearly all −∞), or dynamic per-step logit processors (e.g., presence/frequency penalties, temperature schedules) is not quantified.
  • Candidate buffer sizing and memory pressure are not analyzed:
    • The second-stage buffer stores one (score, index) per row per tile; its size, register/shared-memory trade-offs, occupancy impacts, and tuning of tile sizes across B, V, and D are not systematically explored.
  • Portability beyond NVIDIA GPUs is untested:
    • Although Triton is platform-agnostic, there are no results on AMD (ROCm) or Intel GPUs, nor discussion of required kernel adaptations or expected performance deltas due to differing memory hierarchies.
  • End-to-end multi-GPU systems results are missing:
    • No evaluation with tensor- and pipeline-parallel deployments (single-node multi-GPU or multi-node), continuous batching under heterogeneous sequence lengths, or interaction with NCCL scheduling and overlap.
  • Latency distribution and tail behavior are not reported:
    • Results focus on medians; P95/P99 latency, variance under bursty loads, and sensitivity to dynamic batching policies (e.g., vLLM’s continuous batching) are not provided.
  • Interaction with speculative decoding is unaddressed:
    • Speculative methods often need logits for accept/reject or reranking; it is unclear how to reuse/avoid rematerialization in verification steps or integrate fused sampling into speculative pipelines.
  • Multi-sample and beam search are not supported or evaluated:
    • Generating n>1 candidates per step (for reranking/diverse sampling/beam search) is not implemented; the fused kernel changes and performance/communication impacts remain open.
  • Throughput-centric metrics under high concurrency are limited:
    • The paper emphasizes time per output token (TPOT); aggregate tokens/s throughput, GPU utilization under high-QPS serving, and interactions with KV-cache/attention kernels are not comprehensively studied.
  • Incomplete ablations of where speedups come from:
    • The cost model argues kernel fusion (not raw HBM savings) drives gains, but there is no controlled ablation isolating contributions from (i) avoiding logits write/read, (ii) removing kernel launches, (iii) better cache locality, and (iv) candidate-buffer design.
  • Integration with other fused kernels and schedulers is open:
    • How FlashSampling composes with other IO-aware fusions (e.g., fused MLP/activation or persistent kernels), Triton/TVM scheduling, and model-level megakernel strategies is not explored.
  • Energy efficiency and power draw are not measured:
    • No reporting on energy/token or performance-per-watt, which is relevant for deployment trade-offs.
  • Handling of failure modes and fallbacks is unspecified:
    • Behavior when all tokens are masked (no finite logits), or when numerical issues produce NaNs/Infs during Gumbel transforms, lacks a documented fallback policy.
  • Auto-tuning and adaptivity are not provided:
    • There is no mechanism to adapt tile sizes, precision, or to fall back to materialized sampling when batch sizes grow and cuBLAS is preferable; a policy for switching paths dynamically is missing.
  • Vocabulary scaling beyond ~150k is not shown:
    • While the approach should scale, there are no measurements at 500k–1M vocab sizes, nor analysis of tile-reduction overheads and kernel occupancy at extreme V.
  • Grammar/structured constraints support is only sketched:
    • The paper notes masking via −∞, but does not benchmark complex grammar constraints (e.g., incremental automata-based restrictions) or their per-step overhead inside the epilogue.
  • Distributional fidelity tests are limited:
    • Correctness evaluation includes a chi-squared test on 5,000 draws and one downstream task; more rigorous, high-sample statistical tests across varied V, temperature, and masking regimes are not reported.
  • KV-cache and LM-head weight streaming interplay is not detailed:
    • The effect of LM-head weight layout (streaming, prefetching, compression) and cache behaviors (L2/L1/register pressure) on FlashSampling performance isn’t quantified.
  • Software packaging and integration paths are undeveloped:
    • There is no guidance or artifacts for integrating into cuBLASLt epilogues, TensorRT-LLM, vLLM at scale (multi-node), or autotuning frameworks, nor API contracts for RNG seeds, masks, and custom logits processors.
  • Security/robustness considerations for on-chip RNG are unexamined:
    • While likely low risk, potential side effects (e.g., cross-stream RNG state contamination, reproducibility under preemption, or MIG partitioning) are not discussed.

Practical Applications

Immediate Applications

Below is a set of actionable use cases that can be deployed with today’s tooling and hardware, leveraging the paper’s fused, exact Gumbel-Max sampling in the LM-head epilogue, online/grouped variants, and reduced inter-GPU communication.

  • Faster LLM serving in production APIs and chatbots
    • Sector: software/cloud inference, developer tools, customer support, search, creative assistants
    • What: Replace materialized-logits sampling with FlashSampling in serving stacks (e.g., vLLM, TensorRT-LLM, custom Triton/CUTLASS epilogues) to reduce time per output token (TPOT) and end-to-end latency (reported up to ~19% on some models).
    • Tools/workflows:
    • Integrate Triton FlashSampling kernels as LM-head epilogues.
    • Add a “FlashSampling” path to vLLM, OpenAI-style inference servers, and in-house inference engines.
    • Use counter-based RNG (e.g., Philox) for deterministic/reproducible sampling in CI/regression tests.
    • Assumptions/dependencies:
    • GPU inference on Hopper/Blackwell-class hardware or similar with sufficient on-chip memory and support for custom kernels.
    • Models with standard linear LM-heads; greatest gains for small batch sizes (decode regime).
    • For top-p, adopt top-k-first-then-top-p logic on reduced candidates (as widely used in practice).
  • Reduced cost and higher throughput for cloud providers
    • Sector: cloud/infra, MLOps, FinOps
    • What: Increase requests-per-GPU and reduce serving costs by eliminating post-GEMM sampling kernels and logits round-trips; improve bandwidth utilization in memory-bound decode paths.
    • Tools/workflows:
    • Benchmark with roofline and TPOT metrics before/after FlashSampling integration.
    • Roll out canary deployments with continuous batching to validate p50/p95 latency improvements.
    • Assumptions/dependencies:
    • Benefits scale best when the LM head is a detectable bottleneck; attention/FFN-dominated models will see smaller gains.
    • Requires ops readiness for upgrading inference containers and kernel dependencies (Triton/CUDA versions).
  • Lower interconnect traffic in multi-GPU tensor-parallel setups
    • Sector: cloud/infra, HPC, large-model serving
    • What: Use the grouped/distributed variant to avoid all-gather of full logits; communicate only small per-shard summaries (local sample + log-mass), reducing communication to O(#shards) vs O(vocab).
    • Tools/workflows:
    • Integrate shard-local FlashSampling into TP pipelines; finalize selection with a small cross-rank reduction.
    • Pair with NCCL collectives over small payloads; instrument network counters to confirm reductions.
    • Assumptions/dependencies:
    • Tensor-parallel LM-head sharding in place; frameworks (vLLM, DeepSpeed, Megatron-like) must expose sharded logits hooks.
    • Correct masking/bias/temperature handling per shard (zero-mass shards are skipped exactly).
  • On-device and edge inference with tighter memory budgets
    • Sector: mobile/edge AI, embedded devices, robotics
    • What: Reduce HBM traffic by fusing sampling into the matmul epilogue and keeping logits on-chip; improve token latency for local assistants, real-time agents, and embedded copilots.
    • Tools/workflows:
    • Deploy Triton or vendor-provided fused epilogues in TensorRT or ONNX Runtime EPs where available.
    • Align RNG indexing with device constraints for reproducibility.
    • Assumptions/dependencies:
    • Edge accelerators must support custom epilogues and have sufficient SRAM/register capacity for tiled processing.
    • Model sizes and vocabularies compatible with device memory and latency targets.
  • Real-time decision systems with large categorical outputs
    • Sector: recommender systems, ads ranking, routing, operations research, some RL inference
    • What: Embed FlashSampling into the final linear layer to sample from very large catalogs/action spaces (exact Gumbel-Max) without materializing logits, reducing latency in live serving.
    • Tools/workflows:
    • Replace softmax-prefix pipelines with online Gumbel-Max reductions in the epilogue of the last projection.
    • Use the online (streaming) variant when memory is limited, maintaining running log-masses and a single candidate.
    • Assumptions/dependencies:
    • The action or item distribution is derived from a (possibly masked/bias/temperature-transformed) linear projection.
    • If policies require differentiable sampling during training, Gumbel-Softmax (relaxation) is separate; FlashSampling targets inference-time exact sampling.
  • Safety- and compliance-sensitive deployments requiring exactness
    • Sector: healthcare, finance, legal, public sector
    • What: Retain exact sampling semantics (no approximations) to ensure statistical fidelity of outputs while reducing latency and memory traffic.
    • Tools/workflows:
    • Use deterministic RNG seeding for auditability and reproducible runs.
    • Validate with chi-squared tests or bootstrap accuracy checks as in the paper’s correctness section.
    • Assumptions/dependencies:
    • Sampling policy must remain unchanged (e.g., same temperature/masking); nucleus sampling should be applied after a top-k reduction when needed.
  • Library and framework enhancements
    • Sector: AI software ecosystem
    • What: Upstream a fused “sample epilogue” to CUTLASS/cuBLASLt, Triton templates, PyTorch custom ops, and FlashInfer plugins so users adopt via version upgrades.
    • Tools/workflows:
    • Provide drop-in APIs mirroring existing sampler signatures; expose toggles for grouped/distributed/online modes.
    • Add kernel tests for exactness and RNG determinism; include perf CI on target GPUs.
    • Assumptions/dependencies:
    • Vendor libraries must accept new epilogue hooks; maintenance and ABI stability across CUDA versions.
  • Energy and cost reporting for “Green AI” KPIs
    • Sector: policy/compliance within organizations, sustainability reporting
    • What: Use measured reductions in kernel time and HBM traffic to report lower energy per token and cost per request.
    • Tools/workflows:
    • Integrate energy telemetry (e.g., NVML, data center meters) with inference runs pre-/post-FlashSampling.
    • Assumptions/dependencies:
    • Benefits depend on workload mix (decode-bound vs compute-bound); energy savings must be measured at system level, not inferred solely from kernel timing.

Long-Term Applications

These opportunities require further research, ecosystem changes, or hardware/software co-design to reach production scale.

  • Vendor-supported GEMM epilogues for exact sampling
    • Sector: semiconductor/hardware vendors, systems software
    • What: Add first-class Gumbel-Max epilogues to cuBLASLt/CUTLASS, TensorRT-LLM, and compiler stacks (e.g., TVM/Mosaic) for portable performance with minimal code changes.
    • Potential tools/products: “GEMM+Sample” epilogue APIs; auto-tuning for tile sizes; RNG hardware acceleration.
    • Assumptions/dependencies:
    • Vendor prioritization and ABI support; extensive cross-architecture validation.
    • Persistent kernels and scheduling integration with continuous batching.
  • Hardware co-design: RNG and Gumbel units in tensor cores
    • Sector: hardware/accelerator design
    • What: Introduce microarchitectural support for counter-based RNG and Gumbel transforms in GEMM epilogues to minimize latency and register pressure.
    • Potential products: Next-gen GPU blocks offering “sample-ready” tensor cores; firmware-configurable epilogues.
    • Assumptions/dependencies:
    • ISA and microcode changes; clear security model for RNG; silicon lead times.
  • Fully fused top-p (nucleus) and advanced decoding within matmul
    • Sector: software ecosystems, LLM research
    • What: Move beyond top-k-then-top-p to a single-pass nucleus selection or other adaptive truncation inside or immediately after the epilogue without materializing logits.
    • Potential workflows: Epilogue-level approximate CDF estimation + exact finalization over a small candidate set; layered selection algorithms co-designed with FlashSampling.
    • Assumptions/dependencies:
    • New algorithms that avoid global sorts/prefix sums over V while retaining exactness or provable bounds.
    • Extensive correctness testing to match existing decoding semantics.
  • Generalization to beam search, contrastive decoding, and constrained grammars
    • Sector: LLM decoding research, applied AI
    • What: Extend tile-wise reduction and grouped-factorization to support beam candidates, contrastive objectives, and complex grammar constraints with minimal HBM writes.
    • Potential tools: Beam-aware epilogues maintaining per-beam tile maxima; grammar-aware masking fused pre-perturbation.
    • Assumptions/dependencies:
    • Maintaining exactness for multi-candidate objectives; managing state growth in on-chip memory.
  • Split- and federated-inference communication minimization
    • Sector: edge/cloud, telecom
    • What: Use grouped/distributed FlashSampling to reduce cross-device communication when vocabularies are sharded across heterogeneous nodes (e.g., edge-cloud split inference).
    • Potential workflows: Shard-local sampling with small-metadata uplinks; policy selection at aggregator; privacy-preserving sampling without exposing logits.
    • Assumptions/dependencies:
    • Consistent RNG/indexing across devices; robust handling of zero-mass shards and intermittent links.
  • Large-categorical sampling beyond LLMs
    • Sector: vision/audio tokenizers, generative compression, discrete diffusion, structured prediction
    • What: Apply fused Gumbel-Max to other models with very large discrete output spaces (e.g., codebook tokenizers, discrete diffusion steps), avoiding full logits materialization.
    • Potential tools: Library primitives for “final-layer fused sampling” in JAX/PyTorch; plugin heads for VQ-VAE/VQGAN-like models.
    • Assumptions/dependencies:
    • Existence of linear final projections or amenable transforms; domain-specific masking/truncation needs.
  • Training-time synergies: fast sampling for RL/IL and negative sampling
    • Sector: RL systems, large-scale training
    • What: While FlashSampling targets inference, similar epilogues could accelerate frequent sampling in rollout or negative-sampling steps (exact at inference, relaxed for backprop where needed).
    • Potential workflows: Action selection via exact Gumbel-Max in the forward pass; Gumbel-Softmax relaxation for gradient flow; hybrid training/inference kernels.
    • Assumptions/dependencies:
    • Clear separation of exact vs differentiable paths; compiler support for mixing epilogues in training graphs.
  • Speculative decoding + vocabulary pruning co-design
    • Sector: LLM optimization research, production inference
    • What: Combine speculative decoding, vocab trimming, and FlashSampling to shrink candidate sets earlier and fuse exact sampling for remaining tokens, maximizing end-to-end gains.
    • Potential workflows: Draft model proposes restricted vocab; FlashSampling runs exact selection on remaining tiles; dynamic group sizes per token.
    • Assumptions/dependencies:
    • Robust coordination of draft/target models; correctness under dynamic support changes; scheduling overheads kept low.
  • Security and privacy hardening via reduced memory footprint
    • Sector: security, compliance
    • What: Explore whether avoiding HBM materialization of logits mitigates certain memory-snooping or snapshot risks in multi-tenant settings; formalize threat models.
    • Potential tools: Memory forensics and side-channel analyses comparing materialized vs fused paths.
    • Assumptions/dependencies:
    • Real security benefits must be empirically demonstrated; GPU isolation and virtualization layers vary by cloud.
  • Standardization and benchmarking for “exact fused sampling”
    • Sector: benchmarking consortia, open-source
    • What: Establish benchmarks, conformance tests (exactness under masks/temperature), and API standards for fused sampling operators across frameworks.
    • Potential tools: MLPerf-like microbenchmarks for decode TPOT; conformance harness with chi-squared tests and paired bootstrap comparisons.
    • Assumptions/dependencies:
    • Community alignment and maintenance; coverage across diverse hardware/backends.

Glossary

  • all-gather: A collective communication operation that concatenates data from all processes/ranks and distributes the result to each. "an all-gather concatenates the full VV logits before sampling,"
  • arithmetic intensity: The ratio of computational operations to data movement (FLOPs per byte), used to analyze performance bounds. "An optimistic lower bound on arithmetic intensity is therefore"
  • argmax: The index of the maximum value in a set, often used to select the most likely outcome. "the fused tiled kernel is exact because arg max\argmax decomposes over a partition;"
  • autoregressive decoding: Sequential token generation where each step conditions on previously generated tokens. "Modern LLM serving stacks invoke sampling repeatedly during autoregressive decoding,"
  • BF16: Brain floating point (bfloat16), a 16-bit floating-point format used for efficient ML computation. "For a BF16 baseline that materializes logits,"
  • categorical distribution: A discrete probability distribution over a finite set of outcomes. "Sampling from a categorical distribution is mathematically simple,"
  • chi-squared goodness-of-fit test: A statistical test comparing observed frequencies to expected frequencies under a target distribution. "we compare samples from FlashSampling to the reference PyTorch implementation using a chi-squared goodness-of-fit test on 5,000 samples,"
  • counter-based RNG: A random number generator where outputs are a deterministic function of a counter and key, enabling reproducibility and parallelism. "using a counter-based RNG (e.g.\ Philox)"
  • cuBLAS: NVIDIA’s CUDA-optimized BLAS library for high-performance linear algebra on GPUs. "Logits are also materialized using cuBLAS."
  • decode regime: The inference setting emphasizing token-by-token generation, typically with small batch sizes. "Decode regime."
  • epilogue (matmul epilogue): The post-matrix-multiplication phase where fused operations (e.g., sampling) are applied before writing results to memory. "fuses sampling into the matmul epilogue,"
  • GEMM: General Matrix-Matrix Multiplication, a core linear algebra operation. "GEMM: compute logits and write to HBM"
  • Group-Gumbel-Max: A hierarchical sampling method that factors the categorical distribution into groups, sampling a group then within-group exactly. "Group-Gumbel-Max: Hierarchical Exact Sampling"
  • GSM8K: A benchmark dataset of grade-school math problems used to evaluate reasoning. "We run FlashSampling on 1,319 questions from the GSM8K dataset"
  • Gumbel-Max trick: An exact sampling method that selects the argmax of logits perturbed by i.i.d. Gumbel noise. "The classical Gumbel-Max trick states that exact categorical sampling can be performed by adding i.i.d.\ Gumbel noise and taking an arg max\argmax:"
  • Gumbel-Top-k: An extension enabling sampling without replacement by taking the top-k of Gumbel-perturbed logits. "The trick extends to sampling without replacement via the Gumbel-Top-kk method"
  • Gumbel(0,1): The standard Gumbel distribution with location 0 and scale 1, used for perturbation in Gumbel-Max. "i.i.d.\ Gumbel(0,1)Gumbel(0,1)."
  • HBM (High Bandwidth Memory): A high-throughput off-chip memory used in GPUs; access is slower than on-chip memory. "never materializes the logits tensor in HBM."
  • hierarchical factorization: Decomposing a distribution into nested components (e.g., group selection then within-group sampling) for exact sampling. "hierarchical factorization of the categorical distribution."
  • i.i.d.: Independent and identically distributed, a key assumption for random variables in many algorithms. "We denote i.i.d.\ standard Gumbel variables by giGumbel(0,1)g_i\sim Gumbel(0,1)."
  • LM-head: The final linear projection layer mapping hidden states to vocabulary logits in LLMs. "At decode time, the LM-head projection already streams a large [V,D][V,D] weight matrix from HBM."
  • log-mass: The log of the total probability mass within a group, i.e., log-sum-exp over group logits. "its group log-mass LkL_k."
  • logits: Unnormalized scores (pre-softmax) output by a model. "Materializing the resulting [B,V][B,V] logits tensor,"
  • matmul: Matrix multiplication; in this context, the LM-head matrix-vector or matrix-matrix computation. "can be integrated into the matmul itself,"
  • max-stability: A property of Gumbel distributions where maxima of Gumbel-perturbed variables preserve distributional form. "Gumbel max-stability under grouping"
  • memory-bandwidth bound: A performance regime where data movement, not computation, limits throughput. "this projection is typically memory-bandwidth bound."
  • Mixture-of-Experts (MoE): An architecture that routes inputs to a subset of specialized expert networks. "Qwen3-235B-A22B MoE."
  • multinomial sampling: Drawing an outcome from a categorical distribution based on its probability mass function. "Conventional multinomial sampling (left) materializes the full [B,V][B,V] logits tensor in HBM"
  • nucleus sampling (top-p): A truncation strategy that samples from the smallest set of tokens with cumulative probability at least p. "Top-pp (nucleus):"
  • on-chip memory: Fast, small-capacity memory on the GPU (registers and SRAM) used to avoid costly HBM traffic. "On-chip memory (registers, SRAM) is orders of magnitude faster than HBM"
  • ops:byte ratio: A hardware metric (peak operations per second divided by memory bandwidth) indicating the compute vs. bandwidth balance. "The ops:byte ratio (peak compute / bandwidth) contextualizes the crossover between bandwidth- and compute-limited regimes,"
  • paired bootstrap test: A statistical method for comparing two systems by resampling paired outcomes to estimate significance. "This difference is not statistically significant (p=0.776), according to a paired bootstrap test."
  • Philox: A counter-based GPU-friendly pseudo-random number generator ensuring reproducible parallel streams. "using a counter-based RNG (e.g.\ Philox)"
  • prefix sum: The cumulative sum of a sequence, used in sampling-by-search over cumulative probabilities. "No softmax, no normalization constant, and no prefix sum are required"
  • prefix-sum sampling: Sampling by computing cumulative probabilities (CDF) and searching for a threshold. "softmax followed by prefix-sum sampling:"
  • roofline: A performance model/plot relating arithmetic intensity to achieved performance relative to hardware ceilings. "Roofline Analysis and Bandwidth Utilization"
  • sharded (vocabulary sharding): Partitioning the vocabulary across multiple devices/ranks for tensor parallelism. "the vocabulary is sharded across ranks,"
  • softmax: The exponential normalization transforming logits into probabilities. "exact sampling does not require an explicit softmax."
  • SRAM: Static random-access memory on chip, faster but smaller than HBM. "On-chip memory (registers, SRAM) is orders of magnitude faster than HBM"
  • tensor-parallel: A model-parallelism strategy that splits tensors (e.g., vocabulary) across devices to distribute computation. "tensor-parallel settings"
  • tile (vocabulary tile): A contiguous chunk of the vocabulary processed locally before a final reduction. "vocabulary tile"
  • top-k: Truncation to the k highest-scoring tokens prior to sampling. "Top-kk:"
  • TPOT (time per output token): A latency metric measuring average time to generate each token during inference. "We benchmark TPOT"
  • Triton: A GPU programming system for writing high-performance kernels in Python-like code. "All benchmarks use the open-source FlashSampling Triton implementation"
  • vLLM: A high-performance LLM inference engine with continuous batching and optimized kernels. "end-to-end vLLM experiments,"

Open Problems

We found no open problems mentioned in this paper.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 2 tweets with 68 likes about this paper.