Differentiable Recursive Transformers (R2D2)
- The paper presents a differentiable recursive Transformer that integrates CKY-style tree induction with standard Transformer composition to build interpretable binary trees.
- It introduces efficient pruned tree induction and Fast-R2D2 to scale tree construction from O(n³) to O(n), achieving competitive performance in language modeling, parsing, and classification.
- Empirical results highlight strong linguistic alignment, improved scalability, and enhanced interpretability compared to vanilla Transformers and traditional recursive approaches.
Differentiable Recursive Transformers (R2D2) are architectures designed to fuse explicit hierarchical structure induction with the flexible expressivity of Transformer networks. Unlike conventional deep models that employ stacked layers without modeling explicit hierarchical composition, R2D2 leverages a differentiable chart-based recursive algorithm—generalizing CKY parsing—to recursively construct and compose latent binary trees over input sequences. This approach encodes interpretable and adaptive tree structures while maintaining full end-to-end differentiability and large-scale pretraining capability (Hu et al., 2021, Hu et al., 2022, Chowdhury et al., 2024).
1. Model Architecture: Differentiable Recursive Composition
At the core of R2D2 is a triangular CKY-style chart constructed over a token sequence . Each cell encodes:
- : span representation for
- : probability of a local binary merge
- : marginal subtree probability covering
Leaf cells are initialized with token embeddings (). Non-terminal cells (for 0) consider all split points 1, applying a Transformer-based merge function
2
with 3 and 4 the merge score.
Subtree probabilities propagate as
5
A Straight-Through Gumbel-Softmax over 6 produces a sparse mixing vector 7, yielding soft mixtures: 8
The merge operation concatenates [SUM], [CLS], left, and right span embeddings with role augmentations, passes these through stacked Transformer layers, and computes merge scores and compositional weights via 9, 0, and softmaxed gating (Hu et al., 2021, Hu et al., 2022).
2. Efficient Pruned Tree Induction
Plain CKY-style filling requires 1 computation. R2D2 introduces pruning to reduce this to 2:
- A fixed pruning window 3 is set. All spans of length 4 are filled.
- For longer spans, the highest-confidence binary merge ("lock-in") is found via an ambiguity scoring criterion, fixed, and overlapping cells are pruned (Tetris-like).
- The chart is re-indexed and the process repeats until the entire tree is filled.
This scheme yields linear scaling in 5, enabling pretraining on long sequences (Hu et al., 2021, Hu et al., 2022). Fast-R2D2 further accelerates induction by learning a top-down split-point scoring parser (BiLSTM+MLP) that predicts a global merge order in 6 time, permitting parallel encoding over independent tree levels and achieving a 30–507 speedup over the heuristic R2D2 pruner (Hu et al., 2022).
3. Training Objectives and Optimization
The core pretraining objective in R2D2 is bidirectional language modeling:
8
Here, for each token 9, the left and right context representations (0, 1) are recursively computed as root nodes of the chart's respective subtrees. The prediction head takes a [MASK] token and the two abstractions as input to the merge function 2, using the hidden state 3 to classify 4.
In Fast-R2D2, an additional KL divergence term aligns the parser's induced tree distribution 5 and the R2D2 chart's tree distribution 6: 7 Stochastic estimation of the gradient is via REINFORCE using samples from the R2D2 chart (Hu et al., 2022).
The loss for downstream tasks consists of the bidirectional LM loss, the parser–encoder KL term, and cross-entropy for classification if fine-tuned: 8
4. Design Space and Relationship to Other Architectures
R2D2 occupies an intermediate regime in the spectrum between Recursive Neural Networks (RvNNs) and vanilla Transformers. Recent work (Chowdhury et al., 2024) formalizes this landscape through:
- Continuous Recursive Neural Networks (CRvNN): These models use soft "existence" masks and continuous gating, relaxing discrete tree induction for fully differentiable recursion with dynamic halting. Each recursion step softly selects neighbor merges, updating states and existence scores until a halting threshold is met.
- Neural Data Routers (NDR): Constrain Transformers with geometric, nearest-neighbor self-attention, parameter sharing (as in Universal Transformers), and strong local composition inductive bias.
The R2D2 proposal in this context fuses multi-head geometric attention with existential halting from CRvNNs. At each recursive step 9:
- Retrieve: Multi-head geometric attention with existence mask,
- Gating: 0,
- Compose: 1,
- Existence update: 2; terminate when mean existence drops below 3.
This setup interpolates between strict tree-structured recursion (projective, local merges) and standard Transformer-style global context mixing, but with adaptive depth and halting.
5. Empirical Results and Scaling
R2D2 and Fast-R2D2 demonstrate strong empirical performance:
- Language Modeling: On WikiText-2 (3 layers, 10–60 epochs), R2D2 achieves pseudo-perplexity (PPPL) scores of 83.10 (m=4) and 57.40 (m=8), outperforming XLNet (PPPL=301.87) and BERT (PPPL=441.42) baselines at comparable scale. After extensive training (60 epochs), R2D2 matches larger BERT/XLNet models (PPPL ≈ 55) (Hu et al., 2021).
- Unsupervised Parsing: On the Penn Treebank WSJ (word-piece input), R2D2 achieves F4 = 52.28, exceeding DIORA and matching C-PCFG. On Chinese Treebank (CTB8), R2D2 achieves F5 = 63.94. Fast-R2D2 with model-based pruning attains F6 = 57.2 (WSJ, word input) and 67.7 (CTB, word-piece), surpassing the baseline (Hu et al., 2022).
- Downstream Classification: Fast-R2D2 with ~62M parameters achieves SST-2=90.7, CoLA=40.1, QQP F7=84.3, and MNLI=69.6/69.6, matching BERT-12L (116M params). R2D2 outperforms 4-layer BERTs on classification, indicating the efficacy of recursive, compositional inductive bias.
Pruned R2D2 reduces per-batch encoding times from thousands of hours (naive CKY) to ≈7 h/epoch on a single V100 GPU, with Fast-R2D2* achieving 30–50× faster inference and efficient parallelization (Hu et al., 2021, Hu et al., 2022).
6. Interpretability, Linguistic Analysis, and Structural Properties
R2D2 yields explicit, recoverable binary parse trees whose structure closely aligns with linguistic constituents. Empirical analysis shows:
- Nearly perfect recovery of word chunks (99.24% recall WSJ) and proper noun spans (86.76% recall), indicating strong alignment with morphological and named-entity boundaries.
- For higher-level spans (NP, VP, SBAR), constituent recall matches or exceeds unsupervised baselines (e.g., C-PCFG, DIORA).
- R2D2-induced spans have higher compatibility with dependencies (i.e., induced subtrees form connected subgraphs in gold dependency parses), particularly in longer sentences. This suggests robust semantic coherence in induced units.
- Qualitative evaluation demonstrates accurate grouping of complex subphrases, enhancing transparency and interpretability relative to standard Transformer encoders (Hu et al., 2021).
7. Limitations, Open Problems, and Future Directions
Several methodological and expressivity constraints are identified:
- The fully projective bias in R2D2/CRvNN restricts its ability to model non-projective or non-binary structures, which may be limiting for certain languages or tasks (Chowdhury et al., 2024).
- The sequential nature of original R2D2 induction limits parallel depth and scalability, partially addressed by Fast-R2D2’s parser-guided pruning and forced encoding.
- Dynamic halting is reliant on robust gating; brittle gate predictions can impede optimal adaptive depth (Chowdhury et al., 2024).
- R2D2’s TreeInduction and pruning are not part of the learned model; only the Transformer composition and heads are end-to-end trainable.
A plausible implication is that future designs combining flexible, non-projective geometric attention, adaptive gating, and efficient learned induction could further advance recursive Transformer architectures. Continued investigation into CRvNN–NDR–R2D2 bridges promises new directions in length-generalization, grammar induction, and interpretable neural composition (Chowdhury et al., 2024).
References:
- [R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling, (Hu et al., 2021)]
- [Fast-R2D2: A Pretrained Recursive Neural Network based on Pruned CKY for Grammar Induction and Text Representation, (Hu et al., 2022)]
- [On the Design Space Between Transformers and Recursive Neural Nets, (Chowdhury et al., 2024)]