Criss-Cross Attention (CCA)
- Criss-Cross Attention (CCA) is an attention mechanism that restricts interactions to row and column axes in structured data, thereby reducing computational and memory overhead.
- Its one-pass design projects inputs into query, key, and value spaces to compute attention along criss-cross paths, with recurrent or dense stacking enabling effective global context aggregation.
- Empirical results in semantic segmentation and document-level relation extraction demonstrate state-of-the-art performance with significantly lower FLOPs and memory usage compared to full self-attention.
Criss-Cross Attention (CCA) is an attention mechanism designed to efficiently harvest contextual information along criss-cross paths within structured data, such as 2D spatial grids in vision or entity-pair matrices in document-level relation extraction. Distinct from dense self-attention, CCA reduces computational and memory costs by restricting interactions to spatially or semantically structured axes, and—when stacked or applied recurrently—enables effective context aggregation with sub-quadratic complexity.
1. Core Principles of Criss-Cross Attention
CCA operates by allowing each position in a structured data tensor to attend exclusively to elements that share either its row or column. In pixel-wise semantic segmentation, the input feature map is projected into query, key, and value subspaces via 1x1 convolutions (, , ). For each position , CCA computes affinities only among positions lying on the same criss-cross path—namely, those with fixed or fixed—yielding a set of connections, as opposed to all possible pairings as in a non-local block (Huang et al., 2018).
In entity-pair-centric tasks, such as document-level relation extraction, the attention is analogously structured over a matrix , where indexes subject and object entities. For each , attention is computed only along (i) the entire subject row and (ii) the entire object column (Zhang et al., 2022).
2. One-Pass Criss-Cross Attention Module
For a single forward pass, CCA proceeds as follows:
- Project inputs into feature spaces (commonly for queries and keys).
- For each position or entity pair , gather key/value vectors along its row and column (vision: spatial axes; NLP: subject/object axes).
- Compute raw attentions as dot-products between the query at and corresponding keys, add possible position-specific biases (in NLP).
- Normalize these affinities via softmax, separately for row and column directions.
- Aggregate weighted values from attended positions, summed with the original input as a residual connection.
Formally, in image grids, the output at pixel is:
where is the normalized attention map and are the criss-cross path values (Huang et al., 2018).
In entity-pair grids, similar aggregation follows, with additional attention biases to encourage focus on likely related pairs (Zhang et al., 2022).
3. Recurrent or Densely Connected CCA for Global Context
A single CCA module grants each position access to all others in its row and column but not the entire input (except via consecutive axes). To enable full global context, CCA is either:
- Recurrent (RCCA): For vision, sequential CCA modules are stacked with weight sharing. After passes, any two positions are connected via two criss-cross hops; i.e., , covering all grid pairs at two-hop distance. This extension suffices to propagate full-image dependencies with minimal added cost (Huang et al., 2018).
- Densely Connected Stacking: For entity-pair matrices, multiple CCA layers are stacked à la DenseNet, each layer's input formed by concatenating all previous outputs along the channel dimension, thus capturing both single-hop and multi-hop logical chains. Dense connectivity allows direct access to lower-level features, improving multi-hop reasoning and mitigating feature redundancy (Zhang et al., 2022).
A table recapping recursion strategies:
| Task Domain | Stacking Method | Context Connectivity | Empirical Layer Count |
|---|---|---|---|
| Vision | Recurrent CCA | Full image in | or $3$ |
| NLP (RE) | Dense stacking | Up to -hop |
4. Category/Clustering Regularization Mechanisms
In vision, over-smoothing from context aggregation is counteracted by an auxiliary category consistent loss. This loss, inspired by margin-based discriminative instance embedding, comprises:
- Intra-class variance reduction: Penalizes feature norms exceeding margin from the per-class mean.
- Inter-class separation: Encourages class means to be separated by .
- Regularization: Penalizes large mean feature norms.
The cumulative loss is
with typical weights , . This regularization yields an observed further increase of +0.7% mIoU on Cityscapes (Huang et al., 2018).
In entity-pair reasoning, a clustering loss, leveraging cosine similarity and supervised “is_related”/“not_related” labels, further structures representations in embedding space (Zhang et al., 2022).
5. Computational Complexity and Efficiency
CCA’s main advantage over non-local or full self-attention mechanisms is sub-quadratic complexity.
- Vision: Standard non-local blocks require storage and computation, where . A single CCA layer operates on edges.
- Entity-Pair Grids: Full self-attention on possible relations would entail computations. CCA admits only attention links per cell, yielding overall.
Concrete measurements in Cityscapes segmentation (Huang et al., 2018):
| Method | Extra FLOPs | Extra GPU Memory | mIoU (val) |
|---|---|---|---|
| Non-local | ~108 GFLOPs | ~1,411 MB | 78.7% |
| RCCA () | ~16.5 GFLOPs | ~127 MB | 80.5% |
CCA thus achieves ~85% fewer FLOPs and ~11× less memory, with improved accuracy.
6. Application Domains and Empirical Performance
Semantic Segmentation and Video Segmentation: CCNet, employing RCCA and category consistent loss, achieves state-of-the-art mean Intersection-over-Union (mIoU) across major benchmarks:
- Cityscapes: 80.5% (val), 81.9% (test)
- ADE20K: 45.76% (val)
- LIP: 55.47% (val)
- CamVid (3D-RCCA, ): 79.1% (test)
- COCO Instance Segmentation: Improves Mask AP by +1.3 points over baseline Mask-RCNN, outperforming comparable non-local alternatives (Huang et al., 2018).
Document-level Relation Extraction: Dense-CCNet, with densely stacked CCA, attains state-of-the-art performance on DocRED, CDR, and GDA by enabling direct reasoning over entity-pair matrices—an inference regime not accessible to mention- or entity-level graphs (Zhang et al., 2022).
CCA’s efficiency and modularity facilitate its deployment in both plug-and-play vision backbones and as core logical reasoners in NLP pipelines, scaling context aggregation without incurring the computational bottleneck of full self-attention.
7. Design Choices and Implementation Considerations
Key implementation details include:
- Projection Dimensions: Queries/keys/values are often of reduced dimension (e.g., ), with output channels restored via pointwise convolutions.
- Normalization: Layer normalization follows each residual addition; batch normalization is not used within CCA blocks.
- Directional Extensions: In document-level RE, four L-shaped aggregation modes (covering row–row, column–column, and mirror paths) ensure comprehensive multi-hop information flow.
- Attention Biases: In Dense-CCNet, attention logits are augmented with learned pairwise biases, trained with a binary cross-entropy auxiliary loss.
- Multi-hop Layer Count: Empirically, two passes (vision) or three dense layers (NLP) suffice for near-optimal aggregation; deeper stacking does not yield commensurate gains and may risk overfitting.
- Activation: Softmax only in attention scoring; optional GeLU/ReLU in the transition modules.
CCA modules are universally compatible with standard backbone architectures in vision (e.g., ResNet) and transformer-style text encoders (e.g., BERT), facilitating wide adoption in dense prediction and structural reasoning tasks (Huang et al., 2018, Zhang et al., 2022).