Papers
Topics
Authors
Recent
2000 character limit reached

ZAYA1 Model Architecture

Updated 25 November 2025
  • ZAYA1 model architecture is a mixture-of-experts transformer that integrates MI300X-aware tuning, custom convolutional attention, and expert routing to optimize large-scale training.
  • The design incorporates per-layer residual scaling, rotary embeddings, and specialized AMD-specific kernels to maximize throughput and minimize latency.
  • It achieves a competitive balance between dense and MoE components, yielding strong evaluation results across tasks like reasoning, mathematics, and coding.

The ZAYA1 model architecture is a mixture-of-experts (MoE) transformer designed for large-scale training on AMD MI300X GPUs with Pollara interconnect. ZAYA1-base incorporates a suite of systems and modeling innovations tailored to the AMD hardware stack, including MI300X-aware dimensioning, custom convolutional attention mechanisms, per-layer residual scaling, and expert routing. The architecture achieves a competitive balance of training throughput and inference latency with strong evaluation results across tasks, establishing the maturity of AMD’s distributed compute environment for state-of-the-art pretraining (Anthony et al., 21 Nov 2025).

1. Overall Model Structure

ZAYA1-base is built with L=40L=40 transformer layers and an embedding dimension h=2048h=2048. The vocabulary size is v=262,272v=262{,}272, chosen to be divisible by 64 for optimized device throughput. Each transformer layer contains an MoE block comprising E=16E=16 experts, with a top-k=1k=1 expert selected per token at each routing step. This yields $8.3$ billion total parameters (considering all experts) but an “active” parameter count of $760$ million (the dense backbone plus one expert per token path).

The forward path through each transformer layer \ell follows this sequence:

  1. Residual-scaled RMSNorm \rightarrow Compressed Convolutional Attention (CCA) \rightarrow residual add
  2. Residual-scaled RMSNorm \rightarrow ZAYA1 Router gating \rightarrow expert MLP (MoE) \rightarrow residual add
  3. Final RMSNorm

Residual scaling is implemented on every residual path via per-channel learnable gates.

2. Transformer Layer Components

Attention and Token Path

CCA attention receives input xRB×S×hx_\ell \in \mathbb{R}^{B \times S \times h} and projects it to queries, keys, and values with the following details:

  • a=16a=16 total attention heads, each with head dimension dh=h/a=128d_h = h / a = 128
  • Query heads: aq=8a_q = 8 (cq=1/2c_q = 1/2)
  • Key/value heads: g=2g = 2 (ckv=1/8c_{kv} = 1/8)

Projections:

  • WQRh×(aqdh)W_Q \in \mathbb{R}^{h \times (a_q \cdot d_h)}
  • WK,WV1,WV2W_K, W_{V1}, W_{V2} with analogous dimensions (with V1V1 and V2V2 handling half of key/value each)

CCA then applies a convolutional stage:

  • Depthwise conv1d (k0=2k_0=2) plus grouped conv1d (groups aq+ga_q + g, k1=2k_1=2) along the sequence
  • FlashAttention operates in a compressed latent space of size (aqdh)(a_q \cdot d_h)
  • Rotary position embeddings (RoPE) are applied to half the channels of each head, supporting 4k–1M context extension

Outputs are projected back via WOR(aqdh)×hW_O \in \mathbb{R}^{(a_q \cdot d_h) \times h}, followed by RMSNorm (with ϵ=105\epsilon = 10^{-5}) and per-head key temperature.

MoE and Routing

MoE routing in ZAYA1 involves the following operations for each token:

  • Down-projection: WdownRh×DW_{\text{down}} \in \mathbb{R}^{h \times D}, where D=256D=256
  • Exponential Depth Averaging (EDA): r=Wdownx+γr1r_\ell = W_{\text{down}} x_\ell + \gamma \cdot r_{\ell-1} (with learned scalar γ\gamma)
  • Outputs go to a 3-layer MLP (GeLU activations), yielding logits Wgate,RD×EW_{\text{gate},\ell} \in \mathbb{R}^{D \times E}
  • Post-softmax, each token's expert is selected as eidx=argmaxj(s+b)e_{\text{idx}} = \arg\max_j(s_\ell + b_\ell), with bias vector bREb_\ell \in \mathbb{R}^E

The chosen expert’s MLP has weights:

  • First FC: We1Rh×fW_{e1} \in \mathbb{R}^{h \times f}, with f=4096f = 4096 (hidden expansion factor α=2\alpha=2)
  • Activation: SwiGLU across pre-activation width ff
  • Second FC: We2Rfo×hW_{e2} \in \mathbb{R}^{f_o \times h} where fo=f/2=2048f_o = f / 2 = 2048
  • Followed by residual addition and RMSNorm

3. MI300X-Aware Sizing Principles

The architecture’s sizing rules and GEMM shapes are directly informed by MI300X hardware characteristics:

  • All core dimensions (h,dh,D,f,foh, d_h, D, f, f_o) are set as multiples of 64, maximizing rocBLAS/hipBLASLt performance
  • Microbatch product bshb \cdot s \cdot h is divisible by 64, and (ba)/t(b \cdot a) / t is integer to avoid padding overhead
  • MLP expansion factor is fixed (f=2h,fo=hf = 2h, f_o = h)
  • MoE per-layer parameter count: hf+fohh \cdot f + f_o \cdot h
  • Convolutional and attention kernel sizes, e.g., 2048×10242048 \times 1024, are chosen based on MI300X TFLOPs heatmaps to maximize utilization

These practices are derived from explicit MI300X benchmarking, targeting “hot” performance regions for compute and memory transfers.

4. AMD-Specific Kernels and Communication

The model stack incorporates several AMD-specific optimizations:

Component Optimization/Detail
CCA conv kernels Tuned for MI300X HBM2 bandwidth and warp size
Custom HIP kernels Multi-tensor Muon optimizer kernels; fused residual-add + RMSNorm kernels (two-stage)
Communication Gradient-fusion buffer sizes saturate Pollara 400 Gbps at break-even; ZeRO-1/context-parallel worlds aligned to xGMI hardware node boundaries

The optimization of collective communication primitives (all-reduce, reduce-scatter, all-gather, broadcast) as well as kernel fusion is critical for training throughput on MI300X + Pollara platforms.

5. Parameter and Compute Profile

Per-layer parameter and FLOPs breakdown, with t=1t=1 (no tensor/data parallelism) (Anthony et al., 21 Nov 2025):

Component Parameter Count (per layer) FLOPs per token (approx.)
Attention Q,K,V,O 5.2\approx 5.2 M 9.5\approx 9.5 kM
CCA convs + RoPE 0.02\approx 0.02 GFLOPs
Router down-proj 0.52\approx 0.52 M 1.05\approx 1.05 kM
Router MLP (2) 0.13\approx 0.13 M 0.13\approx 0.13 kM
Router logits $0.004$ M $4.1$k
Expert FC1 8.39\approx 8.39 M 16.8\approx 16.8 kM
Expert FC2 4.19\approx 4.19 M 8.4\approx 8.4 kM
Residual scaling negligible (\sim0.004 M) \sim0.1k

Total per-layer parameters: 18.4\approx 18.4 M Total per-layer FLOPs per token: 36\approx 36 k A forward pass over S=1024S=1024 tokens, b=1b=1, totals 37\approx 37 M FLOPs per layer; all 40 layers give 1.5\approx 1.5 G FLOPs per sample. Inference latency is dominated by expert MLPs (60%), attention kernels (30%), and routing/norms (10%).

6. Special Architectural Components

  • Embeddings: Token embeddings EtokRv×hE_{\text{tok}} \in \mathbb{R}^{v \times h}, tied with the LM head.
  • Normalization: All RMSNorm (no learnable bias); router MLP uses standard LayerNorm before GeLU.
  • Activation Functions: GeLU in router blocks; SwiGLU within expert MLPs.
  • Rotary Embeddings: RoPE is applied to half of each head’s channels only, supporting long-context extrapolation.
  • Residual Scaling: Per-layer parameterized by α,βRh\alpha, \beta \in \mathbb{R}^h and bias bb_\ell:

ResScaleα(x)=αx+b\text{ResScale}_\alpha(x) = \alpha \odot x + b_\ell

  • CCA Compression: Query compression 2×2\times, key/value compression 8×8\times, denoted "CCGQA" in model documentation.

7. Comparative Performance and Context

ZAYA1-base achieves performance at or above leading models of similar and larger active scale (Qwen3-4B, Gemma3-12B) and outperforms Llama-3-8B and OLMoE on benchmarks targeting reasoning, mathematics, and coding. The empirical findings suggest that the combination of tailored architecture and hardware-aware engineering enables the AMD stack to match or exceed the competitiveness of established foundation model pretraining environments (Anthony et al., 21 Nov 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Follow Topic

Get notified by email when new papers are published related to ZAYA1 Model Architecture.