Papers
Topics
Authors
Recent
Search
2000 character limit reached

Differentiable Recursive Transformers (R2D2)

Updated 23 April 2026
  • The paper presents a differentiable recursive Transformer that integrates CKY-style tree induction with standard Transformer composition to build interpretable binary trees.
  • It introduces efficient pruned tree induction and Fast-R2D2 to scale tree construction from O(n³) to O(n), achieving competitive performance in language modeling, parsing, and classification.
  • Empirical results highlight strong linguistic alignment, improved scalability, and enhanced interpretability compared to vanilla Transformers and traditional recursive approaches.

Differentiable Recursive Transformers (R2D2) are architectures designed to fuse explicit hierarchical structure induction with the flexible expressivity of Transformer networks. Unlike conventional deep models that employ stacked layers without modeling explicit hierarchical composition, R2D2 leverages a differentiable chart-based recursive algorithm—generalizing CKY parsing—to recursively construct and compose latent binary trees over input sequences. This approach encodes interpretable and adaptive tree structures while maintaining full end-to-end differentiability and large-scale pretraining capability (Hu et al., 2021, Hu et al., 2022, Chowdhury et al., 2024).

1. Model Architecture: Differentiable Recursive Composition

At the core of R2D2 is a triangular CKY-style chart T\mathcal{T} constructed over a token sequence S=(s1,,sn)S = (s_1, \ldots, s_n). Each cell Ti,j\mathcal{T}_{i,j} encodes:

  • ei,jRde_{i,j} \in \mathbb{R}^d: span representation for [i,j][i,j]
  • pi,jp_{i,j}: probability of a local binary merge
  • p~i,j\tilde{p}_{i,j}: marginal subtree probability covering [i,j][i,j]

Leaf cells are initialized with token embeddings (pi,i=p~i,i=1p_{i,i} = \tilde{p}_{i,i} = 1). Non-terminal cells Ti,j\mathcal{T}_{i,j} (for S=(s1,,sn)S = (s_1, \ldots, s_n)0) consider all split points S=(s1,,sn)S = (s_1, \ldots, s_n)1, applying a Transformer-based merge function

S=(s1,,sn)S = (s_1, \ldots, s_n)2

with S=(s1,,sn)S = (s_1, \ldots, s_n)3 and S=(s1,,sn)S = (s_1, \ldots, s_n)4 the merge score.

Subtree probabilities propagate as

S=(s1,,sn)S = (s_1, \ldots, s_n)5

A Straight-Through Gumbel-Softmax over S=(s1,,sn)S = (s_1, \ldots, s_n)6 produces a sparse mixing vector S=(s1,,sn)S = (s_1, \ldots, s_n)7, yielding soft mixtures: S=(s1,,sn)S = (s_1, \ldots, s_n)8

The merge operation concatenates [SUM], [CLS], left, and right span embeddings with role augmentations, passes these through stacked Transformer layers, and computes merge scores and compositional weights via S=(s1,,sn)S = (s_1, \ldots, s_n)9, Ti,j\mathcal{T}_{i,j}0, and softmaxed gating (Hu et al., 2021, Hu et al., 2022).

2. Efficient Pruned Tree Induction

Plain CKY-style filling requires Ti,j\mathcal{T}_{i,j}1 computation. R2D2 introduces pruning to reduce this to Ti,j\mathcal{T}_{i,j}2:

  • A fixed pruning window Ti,j\mathcal{T}_{i,j}3 is set. All spans of length Ti,j\mathcal{T}_{i,j}4 are filled.
  • For longer spans, the highest-confidence binary merge ("lock-in") is found via an ambiguity scoring criterion, fixed, and overlapping cells are pruned (Tetris-like).
  • The chart is re-indexed and the process repeats until the entire tree is filled.

This scheme yields linear scaling in Ti,j\mathcal{T}_{i,j}5, enabling pretraining on long sequences (Hu et al., 2021, Hu et al., 2022). Fast-R2D2 further accelerates induction by learning a top-down split-point scoring parser (BiLSTM+MLP) that predicts a global merge order in Ti,j\mathcal{T}_{i,j}6 time, permitting parallel encoding over independent tree levels and achieving a 30–50Ti,j\mathcal{T}_{i,j}7 speedup over the heuristic R2D2 pruner (Hu et al., 2022).

3. Training Objectives and Optimization

The core pretraining objective in R2D2 is bidirectional language modeling:

Ti,j\mathcal{T}_{i,j}8

Here, for each token Ti,j\mathcal{T}_{i,j}9, the left and right context representations (ei,jRde_{i,j} \in \mathbb{R}^d0, ei,jRde_{i,j} \in \mathbb{R}^d1) are recursively computed as root nodes of the chart's respective subtrees. The prediction head takes a [MASK] token and the two abstractions as input to the merge function ei,jRde_{i,j} \in \mathbb{R}^d2, using the hidden state ei,jRde_{i,j} \in \mathbb{R}^d3 to classify ei,jRde_{i,j} \in \mathbb{R}^d4.

In Fast-R2D2, an additional KL divergence term aligns the parser's induced tree distribution ei,jRde_{i,j} \in \mathbb{R}^d5 and the R2D2 chart's tree distribution ei,jRde_{i,j} \in \mathbb{R}^d6: ei,jRde_{i,j} \in \mathbb{R}^d7 Stochastic estimation of the gradient is via REINFORCE using samples from the R2D2 chart (Hu et al., 2022).

The loss for downstream tasks consists of the bidirectional LM loss, the parser–encoder KL term, and cross-entropy for classification if fine-tuned: ei,jRde_{i,j} \in \mathbb{R}^d8

4. Design Space and Relationship to Other Architectures

R2D2 occupies an intermediate regime in the spectrum between Recursive Neural Networks (RvNNs) and vanilla Transformers. Recent work (Chowdhury et al., 2024) formalizes this landscape through:

  • Continuous Recursive Neural Networks (CRvNN): These models use soft "existence" masks and continuous gating, relaxing discrete tree induction for fully differentiable recursion with dynamic halting. Each recursion step softly selects neighbor merges, updating states and existence scores until a halting threshold is met.
  • Neural Data Routers (NDR): Constrain Transformers with geometric, nearest-neighbor self-attention, parameter sharing (as in Universal Transformers), and strong local composition inductive bias.

The R2D2 proposal in this context fuses multi-head geometric attention with existential halting from CRvNNs. At each recursive step ei,jRde_{i,j} \in \mathbb{R}^d9:

  1. Retrieve: Multi-head geometric attention with existence mask,
  2. Gating: [i,j][i,j]0,
  3. Compose: [i,j][i,j]1,
  4. Existence update: [i,j][i,j]2; terminate when mean existence drops below [i,j][i,j]3.

This setup interpolates between strict tree-structured recursion (projective, local merges) and standard Transformer-style global context mixing, but with adaptive depth and halting.

5. Empirical Results and Scaling

R2D2 and Fast-R2D2 demonstrate strong empirical performance:

  • Language Modeling: On WikiText-2 (3 layers, 10–60 epochs), R2D2 achieves pseudo-perplexity (PPPL) scores of 83.10 (m=4) and 57.40 (m=8), outperforming XLNet (PPPL=301.87) and BERT (PPPL=441.42) baselines at comparable scale. After extensive training (60 epochs), R2D2 matches larger BERT/XLNet models (PPPL ≈ 55) (Hu et al., 2021).
  • Unsupervised Parsing: On the Penn Treebank WSJ (word-piece input), R2D2 achieves F[i,j][i,j]4 = 52.28, exceeding DIORA and matching C-PCFG. On Chinese Treebank (CTB8), R2D2 achieves F[i,j][i,j]5 = 63.94. Fast-R2D2 with model-based pruning attains F[i,j][i,j]6 = 57.2 (WSJ, word input) and 67.7 (CTB, word-piece), surpassing the baseline (Hu et al., 2022).
  • Downstream Classification: Fast-R2D2 with ~62M parameters achieves SST-2=90.7, CoLA=40.1, QQP F[i,j][i,j]7=84.3, and MNLI=69.6/69.6, matching BERT-12L (116M params). R2D2 outperforms 4-layer BERTs on classification, indicating the efficacy of recursive, compositional inductive bias.

Pruned R2D2 reduces per-batch encoding times from thousands of hours (naive CKY) to ≈7 h/epoch on a single V100 GPU, with Fast-R2D2* achieving 30–50× faster inference and efficient parallelization (Hu et al., 2021, Hu et al., 2022).

6. Interpretability, Linguistic Analysis, and Structural Properties

R2D2 yields explicit, recoverable binary parse trees whose structure closely aligns with linguistic constituents. Empirical analysis shows:

  • Nearly perfect recovery of word chunks (99.24% recall WSJ) and proper noun spans (86.76% recall), indicating strong alignment with morphological and named-entity boundaries.
  • For higher-level spans (NP, VP, SBAR), constituent recall matches or exceeds unsupervised baselines (e.g., C-PCFG, DIORA).
  • R2D2-induced spans have higher compatibility with dependencies (i.e., induced subtrees form connected subgraphs in gold dependency parses), particularly in longer sentences. This suggests robust semantic coherence in induced units.
  • Qualitative evaluation demonstrates accurate grouping of complex subphrases, enhancing transparency and interpretability relative to standard Transformer encoders (Hu et al., 2021).

7. Limitations, Open Problems, and Future Directions

Several methodological and expressivity constraints are identified:

  • The fully projective bias in R2D2/CRvNN restricts its ability to model non-projective or non-binary structures, which may be limiting for certain languages or tasks (Chowdhury et al., 2024).
  • The sequential nature of original R2D2 induction limits parallel depth and scalability, partially addressed by Fast-R2D2’s parser-guided pruning and forced encoding.
  • Dynamic halting is reliant on robust gating; brittle gate predictions can impede optimal adaptive depth (Chowdhury et al., 2024).
  • R2D2’s TreeInduction and pruning are not part of the learned model; only the Transformer composition and heads are end-to-end trainable.

A plausible implication is that future designs combining flexible, non-projective geometric attention, adaptive gating, and efficient learned induction could further advance recursive Transformer architectures. Continued investigation into CRvNN–NDR–R2D2 bridges promises new directions in length-generalization, grammar induction, and interpretable neural composition (Chowdhury et al., 2024).


References:

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Differentiable Recursive Transformers (R2D2).