Causal Sinkhorn Balancing in Transformers
- Causal Sinkhorn Balancing is a technique that adapts the Sinkhorn normalization to enforce lower-triangular, causal constraints for autoregressive models.
- It employs masked row and column normalization to restrict computations to present and past blocks, enabling efficient quasi-global attention.
- Its integration within Sparse Sinkhorn Attention improves memory efficiency and model performance, as demonstrated by reduced perplexity in language modeling benchmarks.
Causal Sinkhorn Balancing is a modification of the Sinkhorn–Knopp normalization procedure that enables memory-efficient and autoregressive sorting of sequence blocks within the Sparse Sinkhorn Attention framework. By enforcing causality constraints during normalization, it ensures that permutations only leverage information from present and past blocks—thus preventing any “peeking into the future”—and maintains strict lower-triangular structure in expectation over the soft permutation matrix. This causal doubly stochastic matrix is central to enabling quasi-global attention with local computational patterns in sequence models, especially within autoregressive transformer decoders (Tay et al., 2020).
1. Motivation and Context
Sparse Sinkhorn Attention uses a small, block-wise, doubly-stochastic matrix to permute or “softly sort” blocks (of size ) of a full sequence before the application of local attention mechanisms. This strategy improves attention memory efficiency over vanilla self-attention, allowing each token indirect access to a broader context while retaining feasible computational complexity.
In autoregressive transformer architectures, enforcing causality in the sorting stage is critical: only past and present inputs can impact predictions for the current timestep. Standard Sinkhorn normalization, which alternates row and column normalization on (where encodes sorting scores), inherently mixes information from the entire sequence—allowing future (unavailable) blocks to contribute to the normalization constants.
Causal Sinkhorn Balancing remedies this by restricting the computation at each step of the iterative normalization process to only currently available (past and present) blocks. Concretely, this enforces that is lower-triangular (in expectation), guaranteeing that block can only be mapped to positions . This causal masking is essential for correct autoregressive decoding in block-reordered transformers (Tay et al., 2020).
2. Mathematical Formulation and Notation
Let denote the token sequence length, the number of blocks, and the token embeddings. The block-pooling function produces a pooled representation for each block, which is then processed by a feedforward scoring network to form of sorting scores, where .
The iterative normalization proceeds as follows:
- , with temperature parameter and i.i.d. Gumbel noise.
- At each iteration :
- Standard row normalization:
- Standard column normalization:
- For numerical stability, normalization is often performed in the log domain.
For Causal Sinkhorn Balancing, a mask is used, where if , and $0$ otherwise. The normalization updates become:
- Masked row normalization:
- Masked column normalization:
where denotes element-wise product and is the length- all-ones vector.
Row normalization only sums over columns , and column normalization only sums over rows , ensuring that the resulting has lower-triangular support. The Sinkhorn iterations project into the intersection of the Birkhoff polytope and the lower-triangular cone defined by .
3. Algorithmic Details
The causal Sinkhorn algorithm proceeds as follows:
- Initialization: In the log domain, set GumbelNoise().
- Iterative Normalization (for in ):
- Row normalization, masked:
- For each :
- For in :
- If , set
- Else, set (zero out future mass)
- Column normalization, masked:
- For each :
- For in :
- If , set
- Else, set
- Row normalization, masked:
- Finalization:
Convergence is achieved when all row and column sums—computed under —are within a small tolerance of $1$.
4. Complexity and Theoretical Properties
The computation cost per iteration is , arising from evaluation of masked row and column sums and updates. Therefore, total runtime is over iterations. Memory demand is for , and for the sequence itself.
Convergence follows from the Sinkhorn–Knopp theorem, provided the initial matrix has support on the masked diagonal (). Causal masking restricts support to the lower-triangular region, but convergence to the unique matrix in the causally-masked Birkhoff polytope still holds under mild positivity assumptions.
5. Integration within Sparse Sinkhorn Attention
After causal Sinkhorn iterations, the resulting softly permutes the blocks of . The sorted representation is , where reassembles tokens into blockwise sequences. Local attention is then applied within each block, operating on tokens that are now quasi-globally reordered.
During decoding, is recomputed at each timestep using cumulative-sum pooling to ensure depends only on blocks , preserving causality. For encoding, blockwise sum-pooling is used. In all configurations, , , and are block-permuted by prior to attention.
6. Empirical Performance and Observed Effects
Empirical ablation shows that omitting Sinkhorn normalization entirely () reduces performance significantly, raising perplexity by $10$–$11$ on LM1B (see Table 9 in (Tay et al., 2020)). The computational and memory overhead of causal Sinkhorn is low: each layer incurs only additional cost on top of standard local attention (which has complexity). With and –$10$, this overhead is minor.
Compared to vanilla Transformer and Sparse Transformer baselines, applying causal Sinkhorn in the decoder enables matching or superior accuracy on language modeling, sorting, and generation benchmarks, while operating with a much reduced attention-memory footprint (total vs. ). Guidance for optimal operation includes setting –$0.75$ and –$10$; excessively low temperature () or high iteration counts () can moderately degrade perplexity (cf. Figures 4–5 in (Tay et al., 2020)).