LLaMA-X Decoder for Visual Recognition
- LLaMA-X Decoder is a decoder-only Transformer adapted from LLaMA for visual recognition, employing a post-sequence [CLS] token to aggregate global image features under causal masking.
- It integrates a soft-mask warmup schedule that smoothly transitions from bidirectional to causal masking, stabilizing training and mitigating early optimization issues.
- Empirical results show competitive accuracy on ImageNet benchmarks along with improved calibration, higher attention-map rank, and efficient computation compared to encoder-only models.
LLaMA-X Decoder refers to a family of decoder-only Transformer architectures adapted from LLaMA, initially designed for LLMs, and repurposed for visual recognition tasks via architectural and training strategies that overcome the inherent limitations of causal masking when applied to images. The canonical implementation, denoted as image LLaMA (iLLaMA), integrates a post-sequence class token (PS [CLS]) mechanism and a soft-mask warmup schedule, enabling causal self-attention to function effectively in vision settings, while delivering computational advantages and strong empirical performance across multiple benchmarks.
1. LLaMA Decoder-Only Transformer Architecture
The LLaMA decoder-only architecture is constructed exclusively from decoder blocks and features the following core components for input sequence :
- Causal Self-Attention Layer:
- Query, Key, Value projections: , ,
- Scaled dot-product attention:
- Causal masking: output , where , with for , for
- Feed-Forward Network (FFN):
- Gated structure using SwiGLU activation:
- Followed by a linear projection
- Root-Mean-Square Layer Norm (RMSNorm): Applied on each sublayer in place of standard LayerNorm.
- Rotary Positional Embeddings (RoPE): Incorporated within the attention mechanism.
While excelling at autoregressive text generation, this architecture exhibits catastrophic training collapse if naively applied to images with patches as tokens and a class ([CLS]) token at the first position, as causal masking prevents the class token from attending to any subsequent (patch) tokens, rendering gradient flow ineffective and resulting in training failure.
2. Post-Sequence Class Token Strategy
To circumvent attention collapse under causal masking, the post-sequence class token (PS [CLS]) strategy positions the [CLS] token at the end of the patch sequence, not at the beginning:
- Input sequence: ; with .
- Under causal masking, the [CLS] token at index has attention access to all preceding patches (), permitting it to aggregate global image representations.
- No custom or hybrid masks are required—standard lower-triangular masks suffice.
Token Preparation Example:
1 2 3 4 |
def prepare_tokens(image): patches = split_image_into_patches(image) sequence = concat(patches, [CLS]) return embed(sequence) |
Positioning the class token at the sequence head under a causal mask disables its receptive field, while a PS [CLS] resolves this without architectural or masking exceptions.
3. Causal Self-Attention and Soft-Mask Warmup
Causal attention employs the mask:
To stabilize early-stage training, a soft-mask interpolates between bidirectional (: zeros) and strict causal () masks: with the warmup scalar decreasing from $1$ (bidirectional) to $0$ (strictly causal) over a schedule.
Alternatively, in the attention-weight domain: with attention output .
This mechanism smooths optimization, mitigating underfitting and preventing collapse during the convergence process of decoder-only architectures on visual data.
4. Training Protocols and Architectural Variants
The supervised training procedure with soft-mask warmup follows:
- Algorithm Steps:
- Total epochs , cutoff epoch , base learning rate , warmup epochs , schedule type (linear/constant).
- At each epoch :
- Compute (linear), or for , else $0$ (constant).
- Set .
- Forward: causal self-attention with and PS [CLS].
- Backward: update via AdamW , weight decay $0.05$.
- Learning rate: cosine schedule from to 0, with linear warmup for epochs.
- Training Hyperparameters for ImageNet-1K:
- , , , (for tiny/S/B), (for large pretraining)
- Data augmentations: RandAugment, Mixup (0.1–0.95), CutMix (0.1–1.0), label smoothing 0.1
- Architectural Variants: Four isotropic iLLaMA models are instantiated, mirroring ViT scaling.
| Model | Depth | Embedding Dim | Heads | #Params | MACs |
|---|---|---|---|---|---|
| Tiny (T) | 12 | 192 | 3 | 5.7 M | 1.3 G |
| Small (S) | 12 | 384 | 6 | 21.9 M | 4.6 G |
| Base (B) | 12 | 768 | 12 | 86.3 M | 17.6 G |
| Large (L) | 24 | 1024 | 16 | 310.2 M | 62.8 G |
Notable modifications relative to ViT include: SwiGLU FFN, RMSNorm layers, causal self-attention with PS [CLS] and rotary embeddings, and retention of learnable 2D positional embeddings.
5. Empirical Performance across Tasks
iLLaMA exhibits competitive results compared to encoder-only ViTs:
- ImageNet-1K (224×224, supervised):
- iLLaMA-T: 75.0% top-1 (vs. DeiT-Ti: 72.2%)
- iLLaMA-S: 79.9% (vs. DeiT-S: 79.8%)
- iLLaMA-B: 81.6% (vs. DeiT-B: 81.8%)
- Fine-tuned at 384×384: iLLaMA-B → 83.0%
- ImageNet-21K Pretraining + 1K Finetuning:
- iLLaMA-B: 83.6% @224, 85.0% @384
- iLLaMA-L: 84.8% @224, 86.0% @384
- Model Calibration (Expected Calibration Error):
- ConvNeXt-B: 0.0281
- DeiT3-B: 0.0415
- iLLaMA-B: 0.0335
- Shape–Texture Bias (shape-preference, higher is better):
- ConvNeXt-B: 33.3%
- DeiT3-B: 39.9%
- iLLaMA-B: 41.5%
- Attention Map Rank (Layer 1, Head 1):
- ViT-T: rank ≈ 81; iLLaMA-T: rank ≈ 129
- The increased uniform singular-value distribution in iLLaMA attention maps suggests elevated representational capacity.
- Task Transfer:
- ADE20K semantic segmentation (UperNet): iLLaMA-T: 37.7 mIoU (vs. ViT-T: 39.8); iLLaMA-B: 45.1 (vs. ViT-B: 47.3)
- CIFAR10/100: iLLaMA-T: 97.9%/84.8%; +soft mask: 97.9%/85.5%
- Quantization Robustness: 8-bit weights/activations yield iLLaMA-T at 72.4% top-1, matching DeiT-Ti (32-bit).
6. Computational Efficiency and Representational Properties
Computational analysis for attention (sequence length , dimension ):
- Bidirectional Attention: FLOPs
- Causal Attention: FLOPs, saving approximately over the bidirectional case
Elevated attention-map rank in iLLaMA indicates richer cross-token relationships. The soft-mask warmup effectively smooths the optimization landscape, abating underfitting during supervised training of the decoder-only design on visual domains.
7. Implications and Outlook
The iLLaMA architecture verifies that a decoder-only Transformer, originated for textual modalities, can function effectively as a vision backbone with minimal adaptation. The post-sequence class token addresses the naïve causal masking collapse. The soft-mask warmup improves training dynamics, yielding models that rival encoder-only ViTs in classification, calibration, and transfer, while securing computational advantages and higher attention-map rank without bespoke architectural exceptions. These findings indicate a viable pathway toward unified multimodal decoders in which both images and text are processed within a common LLaMA-style architecture (Wang et al., 10 Apr 2024).
Sponsored by Paperpile, the PDF & BibTeX manager trusted by top AI labs.
Get 30 days free