Reinforced Self-Attention Network (ReSAN)
- ReSAN is an attention-only network that synergizes hard attention via reinforcement learning with soft self-attention to capture sparse dependencies effectively.
- The model employs reinforced sequence sampling (RSS) to trim input sequences, significantly reducing computational cost while maintaining high accuracy on benchmarks like SNLI and SICK.
- Empirical results show that ReSAN outperforms traditional RNN and CNN approaches by achieving state-of-the-art accuracy with fewer parameters and improved scalability.
The Reinforced Self-Attention Network (ReSAN) is an attention-only sentence encoding architecture designed to efficiently capture sparse dependencies common in natural language, integrating both hard and soft attention in a hybridized, interactive manner. ReSAN eliminates the need for recurrent or convolutional layers, employing a sequence-trimming hard attention policy learned via reinforcement learning alongside a context-fusing soft self-attention module. This architecture delivers state-of-the-art empirical results on benchmarks such as SNLI and SICK, with improved efficiency, scalability, and parameter economy compared to traditional RNN- or CNN-based approaches (Shen et al., 2018).
1. Hybrid Attention Mechanisms
ReSAN’s architecture unites two complementary mechanisms: a soft, masked self-attention and a hard, reinforced sequence sampling (RSS) module. The soft self-attention component computes token-level contextual dependencies using a learned compatibility function, while the RSS module stochastically selects a task-relevant token subset in parallel, guiding the focus of soft attention. The hard attention module restricts computation to a curated set of “head” and “dependent” tokens, dramatically reducing the computational cost and improving the efficiency of modeling long or information-sparse sequences.
Soft Self-Attention
Given a sequence of -dimensional embeddings, soft attention employs a masked compatibility function: where is a scaling constant, and encodes context constraints, such as directionality. For each token ,
where denotes element-wise multiplication, allowing for multi-dimensional attention.
Hard Attention: Reinforced Sequence Sampling (RSS)
RSS produces a binary vector , , representing parallel token selection: Each is computed using a non-recurrent context aggregation: followed by a learned sigmoid gate: where is a nonlinearity.
2. Reinforced Self-Attention (ReSA) Module
The ReSA module employs two parallel RSS block instances: one selects "head" tokens () and one selects "dependent" tokens (), both sampled in parallel. This induces a sparse mask
Self-attention is then confined to this mask: A fusion gate integrates context with the original token representation:
3. Training Paradigm
Training ReSAN involves two sets of parameters: those of the soft attention, embeddings, and final classifier (), and those of RSS (hard attention, ). is optimized using standard backpropagation on cross-entropy loss, while is updated via policy gradient (REINFORCE), handling the non-differentiable sampling process: The reward penalizes excessive selection, promoting sparse yet informative hard attention. Training is staged: initially, hard attention is turned off (i.e., selects all tokens), then gradually activated once the supervised components are stable.
4. Integration Dynamics and Efficiency
The architectural cooperation is bidirectional: RSS trims the input space for soft attention, streamlining computation and focusing learning on sparse dependencies. The soft attention module, in turn, stabilizes RSS learning by affording denser reward signals, as the hard attention's influence is reflected in the ultimate task objective. This interaction overcomes the inefficiencies of soft attention on long sequences and the training difficulties of hard attention alone. Each forward pass avoids recurrence, remaining fully parallelizable, and the model’s computational burden scales with the density of the token selection, not with the input sequence length.
5. Comparative Performance and Empirical Results
Empirical evaluation on sentence-level inference and semantic relatedness tasks underscores ReSAN’s effectiveness.
| Model | Parameters | SNLI Test Accuracy |
|---|---|---|
| 600D Gumbel TreeLSTM | 10m | 86.0% |
| 600D Residual stacked enc. | 29m | 86.0% |
| Bi-LSTM + intra-attention | 2.8m | 84.2% |
| Multi-head self-attention | 2.0m | 84.2% |
| DiSAN (directional SA) | 2.4m | 85.6% |
| ReSAN | 3.1m | 86.3% |
On the SNLI test set, ReSAN attains 86.3% accuracy, surpassing prior sentence encoding models of both RNN/CNN and tree-based architectures, while requiring fewer parameters than deeper LSTM or convolutional networks. Inference speed is comparable to other attention models and substantially greater than RNN-based approaches. On SICK for semantic relatedness, ReSAN reports state-of-the-art metrics: Pearson's , Spearman's , and MSE=0.2623, outperforming competitive attention and recursive models.
Ablation reveals that both hard and soft attention contribute to final performance, with the hybrid yielding the best results. Hard attention alone improves efficiency and accuracy to a degree, but the synergistic interplay in ReSAN is strictly superior.
6. Mathematical Characteristics and Operational Scalability
ReSAN’s primary operations are linear algebraic and highly parallelizable. The policy gradient objective for RSS is: Soft attention and hard attention employ distinct parameter sets, enabling independent scaling and modular adaptation. The hard attention’s sparseness directly controls the resource footprint: a plausible implication is that ReSAN can process substantially longer sequences than dense-attention architectures at constant or sublinear computational cost.
7. Significance and Implications
By uniting parallel stochastic hard attention and masked self-attention, ReSAN introduces a modular paradigm for sentence representation that sidesteps the sequential limitations of RNNs/CNNs while directly addressing sparse dependency structures with efficient, scalable computation. Its learning paradigm leverages reinforcement learning rewards propagated through standard supervised objectives, facilitating robust and stable training of its discrete selection mechanism. ReSAN’s open-source implementation enables further exploration in both research and practical settings (Shen et al., 2018).