Multi-head Gaussian Decoder
- The paper introduces a novel multi-head Gaussian decoder that replaces traditional cross-attention with a Gaussian prior for explicit alignment in transformer models.
- Its methodology predicts alignment centers and computes Gaussian priors that combine multiplicatively with soft attention scores, ensuring monotonic and streaming decoding.
- Implications include improved translation quality and latency management through a unified, differentiable framework without needing external alignment supervision.
A multi-head Gaussian decoder is a specialized architectural component for neural sequence-to-sequence models, prominently utilized in Simultaneous Machine Translation (SiMT). It replaces the standard multi-head cross-attention mechanism in a Transformer's decoder with a variant—Gaussian Multi-head Attention (GMA)—that integrates explicit alignment prediction via a parameterized Gaussian prior centered on predicted source positions. Each decoder layer predicts alignment increments that define Gaussian priors over source positions, which are then combined multiplicatively with traditional soft attention scores to yield final context vectors. This design enables a unified, deterministic policy for deciding when to emit each target token in streaming translation settings, balancing translation quality and latency without additional loss terms or external alignment supervision (Zhang et al., 2022).
1. Architectural Overview
In the multi-head Gaussian decoder framework, the standard Transformer decoder architecture is retained except for the cross-attention sublayer, which is redefined to incorporate a differentiable alignment model and Gaussian prior. Each decoder layer at decoding step :
- Predicts a scalar alignment center (shared by all attention heads within the layer) via an MLP over the previous target-side hidden state.
- Determines the number of source tokens to attend (corresponding to the streaming input's current availability).
- Computes the cross-attention output by combining dot-product attention with a Gaussian prior centered at .
The encoder and the remaining parts of the decoder (including self-attention, feed-forward networks, and normalization) remain unaltered. The multi-head structure is thus preserved, but with the constraint that all heads within a layer share the same alignment prediction, yielding alignment predictions per time step for an -layer decoder (Zhang et al., 2022).
2. Alignment Center Prediction and Incremental Policy
Rather than predicting the absolute aligned source position for each target token, the model outputs a positive, incremental step with:
where is a query projection of the previous decoder state, and are learned parameters. The alignment center is recursively computed as:
This mechanism ensures monotonic progression suitable for streaming input: the decoder cannot "jump backward" over the input sequence, thus supporting online translation policies.
3. Gaussian Alignment Prior and Posterior Attention Computation
A discrete Gaussian prior is defined over source positions :
with ("two-sigma rule"). is renormalized so that .
The model computes soft attention scores over source encodings as usual:
The unnormalized posterior for attention is then computed by a pointwise product:
Final attention weights are normalized:
The attended context vector is:
This mechanism tightly integrates learned alignment prediction with translation via the attention mechanism, guiding the model's focus to the "most informative" source positions for each target token.
4. Multi-head Extension and Layer Interdependency
For attention heads in each decoder layer, the alignment center and derived variables are shared, not head-specific. Across layers, predictions of are independent. The global read position for emitting the next target token is set to the maximum required across all layers:
where is a user-tunable relaxation offset, accommodating minor misalignment or anticipation in practical settings. The decoder proceeds only when the stream has delivered at least source tokens, enforcing monotonicity and ensuring that all decoder states are computed over available input (Zhang et al., 2022).
5. Simultaneous Translation Policy
This architecture directly operationalizes an alignment-guided, monotonic simultaneous translation policy. The procedure for each target token is:
- Predict and update per layer.
- Calculate .
- Wait until the streaming input has provided source tokens.
- Compute the Gaussian prior, combine with attention scores, aggregate, produce the context vector, and output .
- Repeat until the end-of-sequence symbol.
This deterministic policy abrogates the need for auxiliary agent-style control, integrating translation and input consumption within a unified, differentiable mechanism.
6. Training Objective and Differentiability
Training is conducted end-to-end using standard cross-entropy loss:
No explicit additional loss terms for alignment or latency are used. Because , Gaussian priors, and final attentions are differentiable functions of the predicted increments and alignment centers, all parameters (including those for alignment prediction) are trained by backpropagation focused exclusively on translation accuracy (Zhang et al., 2022). This design introduces a soft inductive bias towards meaningful alignments, without requiring explicit supervised alignments or reinforcement-style learning of emission timing.
7. Context, Applications, and Further Implications
Gaussian multi-head attention was introduced to address limitations in SiMT, providing unified and explicit control over alignment and translation latency. Previous methods lacked continuous, differentiable modeling of alignment, often treating emission policies as external or relying on rigid synchrony. The GMA decoder integrates alignment directly into cross-attention, supporting deterministic, monotonic, and alignment-aware decoding essential for low-latency streaming translation.
A plausible implication is broader applicability in other contexts where explicit control of source-target alignment, streaming policies, or monotonic attention is desired—extending beyond translation to speech recognition, summarization, or real-time interactive systems.
The model's decomposition—explicit yet differentiable alignment, Gaussian soft priors, and per-layer shared predictions—supports both architectural interpretability and operational efficiency, as demonstrated empirically on English–Vietnamese and German–English translation benchmarks, where the approach outperforms strong baselines in balancing translation quality and latency (Zhang et al., 2022).