Dice Question Streamline Icon: https://streamlinehq.com

Efficient attention kernel for CAT training

Develop an efficient self-attention kernel for training the Compress and Attend Transformer (cat) that supports the custom decoder attention mask where each token in chunk c_i attends only to previous tokens in c_i and to past compressed chunk representations f_θ(c_{i−1}), …, f_θ(c_1). The objective is to enable scalable training with reduced attention compute and improved wall-clock throughput relative to standard dense-transformer kernels.

Information Square Streamline Icon: https://streamlinehq.com

Background

The paper implements CAT’s custom attention mask using PyTorch FlexAttention to fuse a kernel, but observes limited training speedups compared to FlashAttention for dense transformers.

Because an efficient training kernel is unavailable, theoretical reductions in attention FLOPs (O(N2/C)) do not translate into practical wall-clock gains during pretraining. The authors note training currently takes up to ~2.35× longer under their setup and suggest custom kernels could mitigate this.

References

Developing an efficient attention kernel for training cat is left as future work.

Attention and Compression is all you need for Controllably Efficient Language Models (2511.05313 - Prakash et al., 7 Nov 2025) in Appendix, Section "Training throughput analysis" (app:cat_training_throughput)