Manifold Token Mixup for Text Classification
- Manifold Token Mixup is an interpolation-based augmentation technique that swaps token features to enhance deep text classifiers, especially in low-resource settings.
- It operates by inserting a binary mask at a randomly selected layer to partially replace hidden representations between two input samples.
- Empirical evaluations show that MSMix variants, particularly MSMix-A and MSMix-B, improve classification accuracy by around 1.6 percentage points over baseline models.
Manifold Token Mixup (MSMix) is an interpolation-based data augmentation technique for deep neural network text classifiers, designed to address performance degradation under limited data regimes. MSMix operates by inducing partial swaps of token-level hidden feature representations between two samples at a randomly selected layer within a neural architecture. It augments the sample space not through linear blending but via feature-wise replacement, offering distinct theoretical and empirical benefits for robust text classification tasks.
1. Formal Mathematical Framework
Let be an -layer text classifier, typified by models such as a 12-layer BERT. The hidden-state tensor at layer for input is denoted , where is the token length and is the hidden dimension. Given two randomly sampled pairs and , draw a mixing coefficient . Traditional manifold mixup would produce:
and
MSMix instead constructs a binary mask with exactly zeros. MSMix's swap-mixup operation is:
where ‘‘’’ denotes element-wise multiplication. For each zero position in , the corresponding feature value in is replaced by that of , preserving the remainder.
2. Dimension-Selection Strategies
MSMix implements three dimension-selection strategies to govern which features are swapped:
- MSMix-Base (Random Selection): positions within the tensor are selected uniformly at random.
- MSMix-A (Correlated-Magnitude Selection): The element-wise product magnitude
is computed and flattened. The top- highest values are chosen for replacement.
- MSMix-B (Low-Importance to High-Importance Selection):
- Identify index set of the smallest values in , with .
- Among these, select the top- largest values; call this set .
- Construct with zeros at indices in , ones elsewhere.
These methods introduce structure into the mixup process, enhancing diversity of hidden representations relative to uniformly random swaps.
3. Step-by-Step MSMix Training Procedure
The MSMix algorithm proceeds as follows for each training iteration:
- Sample two examples from dataset .
- Draw mixing coefficient .
- Select a mixing layer randomly (random is empirically superior to fixed).
- Execute the forward pass for and up to layer , yielding .
- Generate mask with zeros according to the desired strategy (Base/A/B).
- Form mixed hidden tensor as above.
- Continue the forward pass from layer to with to obtain model output .
- Compose mixed label .
- Calculate cross-entropy loss and backpropagate through all layers.
4. Theoretical Analysis and Regularization Perspective
Each hidden-dimension acts as a virtual feature, and the swapping mechanism introduces correlated noise, distinct from additive Gaussian noise or stochastic dropout techniques. This approach can be interpreted as an instantiation of Dropout/DropConnect at the level of semantic token features rather than indiscriminate neuron or weight deactivation. MSMix leverages manifold mixup theory (cf. Verma et al. 2019), whereby linear interpolations within hidden space flatten decision boundaries and combat overfitting. By swapping finite, localized subsets of dimensions, MSMix preserves the discrete character of textual features, crucial for NLP tasks.
This methodology is particularly effective in low-resource regimes: the creation of 'in-between' hidden states impedes memorization of anomalous samples, providing strong regularization and increased generalization.
5. Empirical Evaluation and Key Findings
MSMix was evaluated across three Chinese intent classification datasets: YiwiseIC (12 classes), SMP2017-ECDT (31 classes), and CrossWOZ-IC (task-oriented extraction), including reduced-size splits:
| Model | YiwiseIC | SMP2017 | CrossWOZ |
|---|---|---|---|
| simBERT | 91.79 | 95.05 | 95.14 |
| EDA | 90.28 | 93.40 | 95.12 |
| Mixup-Transformer | 92.50 | 94.45 | 95.38 |
| TMix | 92.61 | 95.35 | 95.63 |
| MSMix-Base | 94.09 | 95.36 | 95.64 |
| MSMix-A | 93.53 | 95.80 | 95.87 |
| MSMix-B | 94.30 | 95.35 | 95.75 |
Under small-sample scenarios, MSMix-A and MSMix-B deliver improvements of approximately 1.6 percentage points over the simBERT baseline:
| Model | YiwiseIC_FS | SMP2017_FS | CrossWOZ_FS |
|---|---|---|---|
| simBERT | 81.17 | 89.81 | 91.12 |
| EDA | 79.73 | 90.85 | 91.77 |
| Mixup-Transformer | 80.04 | 90.85 | 90.65 |
| TMix | 81.40 | 90.10 | 92.00 |
| MSMix-Base | 81.98 | 90.40 | 92.00 |
| MSMix-A | 82.80 | 90.85 | 92.62 |
| MSMix-B | 82.54 | 91.45 | 92.55 |
MSMix-B yields the highest test accuracy on YiwiseIC (94.30%), and MSMix-A on SMP2017 (95.80%). All MSMix variants surpass mixup-at-output (Mixup-Transformer), TMix, and EDA methods.
6. Architectural and Implementation Details
Experiments employed simbert-base-chinese (12-layer BERT), with hidden dimension and maximum token length . MSMix was compared against EDA (, eight augmentations), Mixup-Transformer, and TMix. Standard Adam optimizer with learning rate and batch size 32 was used, with three to five epochs of fine-tuning. The mixing coefficient utilizes the conventional Beta distribution ().
Random selection of mixing layer heterogeneous across batches was empirically favored over fixed-layer mixing.
7. Significance and Implications
MSMix provides an efficient, implementation-light method for augmenting text classification datasets at the hidden-state level, focusing on partial feature swaps rather than full-sample interpolation. This enhances local feature integrity, improves both in full and low-resource regimes, and establishes the viability of swap-based mixup methods for deep NLP models (Ye et al., 2023). A plausible implication is that MSMix strategies may generalize to other sequence encoding architectures and tasks where discrete token feature integrity is pivotal.