Papers
Topics
Authors
Recent
Search
2000 character limit reached

CrossWKV Architecture: Hybrid State & Sparse-Attention

Updated 2 April 2026
  • CrossWKV is a hybrid architecture that fuses state-based recurrence with sparse cross-attention for scalable long-context and multimodal modeling.
  • It employs a generalized delta rule and Top‑k Chunk Sparse Attention to compress historical context efficiently while ensuring strong expressivity.
  • Empirical results demonstrate state-of-the-art performance in ultra-long-context tasks and text-to-image generation with constant per-token decoding cost.

CrossWKV (“Cross Weighted Key–Value”) describes a family of architectures—most prominently, the hybrid LLM RWKV-X and multimodal cross-attention models built around RWKV-7's Weighted Key–Value recurrence—which combine or fuse state-based sequence modeling with cross-modal or sparse-attention mechanisms. These architectures maintain linear or constant-complexity scaling for both memory usage and computation, while enabling strong expressivity for both intra- and cross-modal tasks. Notable applications include ultra-long-context language modeling and efficient text-to-image generation (Hou et al., 30 Apr 2025, Xiao et al., 19 Apr 2025).

1. Foundational Mechanisms and Model Structure

The CrossWKV approach builds on the RWKV-7 state-based paradigm, where each block maintains a compact channel-wise state matrix StRN×NS_t \in \mathbb{R}^{N \times N} that synthesizes historical context without full attention over past tokens. The core state update, called the generalized delta rule, is:

St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T

where, at each timestep tt:

  • kt,vtRNk_t, v_t \in \mathbb{R}^N: key and value vectors,
  • wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N: data-dependent decay, learning-rate, and readout gates (computed from the input via Linear+LoRA projections),
  • rtr_t gates the final output,
  • (atkt)ktT(a_t \otimes k_t) k_t^T provides a non-diagonal, input-adaptive “forgetting” term.

The output at tt is

yt=rt(Stkt)+(rt(pkt))Tvty_t = r_t \odot (S_t k_t) + (r_t \odot (p \otimes k_t))^T v_t

where pp is a trainable scalar. This recurrence allows compression of all past key-value history into St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T0 in St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T1 memory, preserving linearity in sequence length.

For extended context and cross-modal modeling, RWKV-X (also referred to as CrossWKV in the language context) employs an alternating stack of standard RWKV-7 blocks and Top-St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T2 Chunk Sparse Attention blocks, allowing the architecture to scale efficiently to St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T3-token sequences or fuse heterogeneous modalities (Hou et al., 30 Apr 2025).

2. CrossWKV Cross-Attention and Multimodal Fusion

The CrossWKV cross-attention module generalizes the above recurrence to fuse features from disparate modalities, such as combining CLIP-style text embeddings St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T4 with image features St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T5. The fusion is achieved in one unidirectional pass by:

  1. Computing temporal differences on image features (St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T6).
  2. Applying fused projections and LoRA-injected linear layers to St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T7 and St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T8 to obtain gates (St=St1(diag(wt)(atkt)ktT)+ktvtTS_t = S_{t-1} \left( \mathrm{diag}(w_t) - (a_t \otimes k_t) k_t^T \right) + k_t v_t^T9, tt0), keys, values, and readout vectors.
  3. Normalizing and adjusting keys, splitting into multiple heads (e.g., tt1, tt2).
  4. Running the RWKV-7 WKV recurrence over the resulting keys, values, and gates as a sequence.
  5. Applying GroupNorm and linear projections, with optional value blending for initial layers.

This mechanism preserves the efficient state-based computation while enabling tight cross-modal coordination. Empirical evaluation within the DIR-7 framework on datasets like ImageNet confirms that CrossWKV achieves state-of-the-art Frechet Inception Distance (FID) and CLIP scores for text-to-image generation (e.g., FID=2.88, CLIP=0.33 for DIR-7-H on ImageNet 256×256) (Xiao et al., 19 Apr 2025).

3. Top-tt3 Chunk Sparse Attention for Long-Context Modeling

To mitigate the quadratic scaling bottleneck of traditional Transformers in long-context settings, RWKV-X adopts the Top-tt4 Chunk Sparse Attention mechanism. For a sequence of length tt5:

  • The sequence is partitioned into tt6 fixed-size chunks (chunk size tt7).
  • For each query tt8, chunk relevance scores

tt9

are computed, and the top-kt,vtRNk_t, v_t \in \mathbb{R}^N0 most relevant chunks are selected.

  • Attention is restricted to keys/values within these top-kt,vtRNk_t, v_t \in \mathbb{R}^N1 chunks, reducing computation to kt,vtRNk_t, v_t \in \mathbb{R}^N2.
  • For autoregressive decoding, a recency-aware cache management policy is used: the cache is split into a sliding observation window (kt,vtRNk_t, v_t \in \mathbb{R}^N3) and an older, dynamically-compressed region (size kt,vtRNk_t, v_t \in \mathbb{R}^N4); cumulative importance scores determine which past keys/values are retained.

This design facilitates linear kt,vtRNk_t, v_t \in \mathbb{R}^N5 training cost and constant kt,vtRNk_t, v_t \in \mathbb{R}^N6 per-token decoding memory/profile, even for million-token sequences (Hou et al., 30 Apr 2025).

4. Model Integration, Training Paradigm, and Inference

RWKV-X integrates these modules in an interleaved stack: typically, every kt,vtRNk_t, v_t \in \mathbb{R}^N73 standard RWKV blocks are followed by one sparse attention block, composing a highly efficient backbone. Integration employs residual connections, layer normalization, and feed-forward adapters at each boundary to ensure stable depth-wise communication.

Training proceeds in two logical phases:

  1. Alignment Stage: Only new sparse attention layers are trained on 4K-token context, with all pre-existing RWKV weights frozen.
  2. Long-Context Continual Pretraining: All parameters are jointly trained on up to 64K-token contexts, favoring long-range dependencies via dynamic token weighting.

Inference is fully autoregressive: each token update proceeds by recurrently updating kt,vtRNk_t, v_t \in \mathbb{R}^N8, applying sparse cross-block attention, and managing the KV cache under a fixed budget constraint. This results in decoding throughput and memory footprint that are constant with respect to sequence length, verified up to 1 million tokens (Hou et al., 30 Apr 2025).

5. Expressivity, Complexity, and Empirical Performance

CrossWKV's non-diagonal, input-dependent transition matrix kt,vtRNk_t, v_t \in \mathbb{R}^N9 grants greater expressivity than classical diagonal SSMs (such as those corresponding to wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N0 circuits). Specifically, CrossWKV can represent arbitrary regular languages and model complex finite-state transitions, demonstrated by successful wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N1 permutation tracking (Xiao et al., 19 Apr 2025).

Complexity profiles:

Component Memory Usage Computational Cost
RWKV-7 recurrence wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N2 wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N3
Top-wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N4 Chunk Sparse Attn wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N5 in cache wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N6 (with const wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N7)
Transformer (baseline) wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N8 wt,at,rtRNw_t, a_t, r_t \in \mathbb{R}^N9

Empirical results demonstrate:

  • Long-context recall: 100% accuracy on the 64K passkey retrieval benchmark (S-NIAH-1) with RWKV-X-3.6B; prior RWKV-7 degrades beyond 28K.
  • Decoding efficiency: At 128K context, RWKV-X is 1.37rtr_t0 faster per token than FlashAttention-based full Transformers; for 1M tokens, per-token latency remains constant.
  • Generalization: CrossWKV achieves parity with state-of-the-art generative models on text-to-image benchmarks, with competitive robustness on out-of-distribution prompts (Hou et al., 30 Apr 2025, Xiao et al., 19 Apr 2025).

6. Key Hyperparameters and Implementation Details

Critical hyperparameters defining CrossWKV and RWKV-X behavior include:

Hyperparameter Typical Value Description
Chunk size (rtr_t1) 256–512 tokens Size for sparse attention chunk partitioning
Top-rtr_t2 selected (rtr_t3) 4–8 Number of chunks attended per query
Observation window (rtr_t4) rtr_t51024 tokens Size of recency buffer in cache
Cache budget (rtr_t6) 64K tokens Long-term memory retention limit
Sparse layers (%) rtr_t725% Fraction of layers replaced with sparse attention
Context lengths 4K (alignment), 64K (pretrain), 1M (inference) Sequence lengths for training and evaluation
LoRA ranks 16–128 (various gates) Low-rank adaptation for gate/value projections
Heads, head-dims rtr_t8, rtr_t9 Parallel multi-head structure

Implementation leverages chunked or fused-recurrent kernels, group normalization, and optional LoRA modules for efficient parameterization. Example codebases are available at https://github.com/howard-hou/RWKV-X and https://github.com/TorchRWKV/flash-linear-attention (Hou et al., 30 Apr 2025, Xiao et al., 19 Apr 2025).

7. Applications, Limitations, and Outlook

CrossWKV serves as an efficient backbone for ultra-long-context LLMs, large-context sequence learning, and cross-modal generative tasks, especially where Transformer architectures become unsustainable due to quadratic resource requirements.

Limitations include the (atkt)ktT(a_t \otimes k_t) k_t^T0 per-token cost in pure RWKV-7 recurrence and the expressivity bound by the state matrix dimensionality. Nevertheless, the hybridization with sparse attention and cross-modal fusion extends the model's reach. Empirical scaling studies confirm superior or comparable performance to equivalent-parameter Transformer models (e.g., a 786M-param RWKV-X outperforms GPT-2 774M by 0.16 perplexity after 10B-token pretraining) (Hou et al., 30 Apr 2025).

A plausible implication is that CrossWKV, by fusing powerful state-based recurrence with scalable cross-attention, presents a general pattern for the future of foundation models required to manage immense context windows and multimodal integration without incurring the prohibitive costs of full self-attention.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (2)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to CrossWKV Architecture.