The paper "Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes" (Zheng et al., 2020 ) addresses the challenge of reducing the lengthy training time required for large deep neural networks like BERT, which hinders rapid experimentation and development. While using large mini-batches in synchronous stochastic gradient methods can accelerate training by allowing larger learning rates and fewer iterations, this approach often faces limitations where increasing the batch size further leads to accuracy degradation or divergence. The paper introduces LANS (Large Batch Nesterov-style AdamW with Normalized Gradient) and a novel learning rate scheduler to overcome these limitations and achieve significantly faster BERT pretraining.
The existing state-of-the-art large-batch optimizer for BERT pretraining at the time was LAMB [you2020reducing]. LAMB combined the AdamW optimizer with a layer-wise adaptive scaling factor, normalizing the update for each parameter block by its norm and scaling it by the norm of the parameter block itself. This allowed LAMB to scale BERT pretraining up to a mini-batch size of 64K, reducing training time from 3 days to 76 minutes on 1024 TPUs. However, LAMB struggled to scale to even larger batch sizes without losing accuracy or diverging.
The authors propose LANS and an accompanying learning rate strategy to enable stable and efficient training with larger batch sizes.
The LANS optimizer introduces two main modifications:
- Per-block Gradient Normalization: The gradient for each parameter block is normalized by its norm before being used to update the first and second-order moments ( and ). This means the direction of the gradient for each block is used, but its magnitude is ignored. This technique, previously explored in other contexts, is shown to make the optimization more robust to vanishing and exploding gradients and eliminates the need for explicit gradient clipping. The normalization step is:
1
%%%%6%%%%
- Incorporating Nesterov's Momentum: LANS modifies LAMB's update rule to incorporate a concept similar to Nesterov's accelerated gradient (NAG). While classic momentum updates parameters based on the current momentum, NAG effectively uses a "lookahead" gradient. The LANS update rule is formulated as a convex combination of a LAMB-style update using the first-order momentum () and a LAMB-style update using the instantaneous normalized gradient. Specifically, the update direction for block at time step is:
where is the adaptive momentum ratio and is the adaptive instantaneous gradient ratio (using the normalized gradient ), is the layer-wise scaling function (typically identity), is the weight decay parameter, and is the Adam momentum parameter. The terms and are bias-corrected Adam-style moments calculated using the normalized gradient. This structure allows LANS to benefit from Nesterov-style acceleration while maintaining the adaptive and normalized properties of LAMB.1
%%%%11%%%%
Furthermore, the paper identifies that the standard linear warmup followed by linear decay learning rate schedule [goyal2017accurate] used by LAMB is insufficient for very large batch sizes. The maximum theoretically usable learning rate is bounded by the inverse of the Lipschitz constant, and increasing the batch size doesn't necessarily increase this bound indefinitely. If the required large learning rate exceeds this bound, training diverges. If a smaller learning rate is used to avoid divergence, the total "area under the curve" of the learning rate schedule might be too small for sufficient training progress within a reduced number of iterations. To address this, they propose a new learning rate scheduler with a constant phase after the linear warmup:
1 2 |
%%%%20%%%% |
For practical implementation in distributed training, the paper emphasizes the importance of data sharding with random sampling without replacement within each shard. This reduces the variance of the mini-batch gradient compared to sampling with replacement across the entire dataset, which is crucial for the stability and efficiency of large-batch training.
The empirical evaluation was conducted on pretraining BERT-Large on Wikipedia and BooksCorpus using 192 AWS EC2 P3dn.24xlarge instances (1536 NVIDIA V100 GPUs) with EFA for high-throughput networking. The training was split into two stages, similar to previous works:
- Stage 1: 3519 iterations, sequence length 128, mini-batch size 96K.
- Stage 2: 782 iterations, sequence length 512, mini-batch size 33K. Total iterations: 4301.
LANS with the proposed linear-warmup-constant-decay schedule and these batch sizes achieved a SQuAD v1.1 dev F1 score of 90.60 in 53.6 minutes. This significantly outperforms LAMB, which either took longer (76.2 minutes with 64K/32K batch size and 8599 steps) or diverged when attempted with larger batch sizes (96K/33K).
The implementation of LANS is provided online, enabling practitioners to integrate these techniques into their large-scale training workflows. The findings demonstrate that LANS, combined with the specific learning rate scheduler and data sharding strategy, effectively addresses the challenges of scaling BERT pretraining to extremely large batch sizes, achieving a new record for cloud-based BERT training speed while maintaining target accuracy.
Practical considerations for implementing LANS include:
- Implementing the per-block gradient normalization logic within the optimizer. This typically involves iterating over parameter groups/blocks, calculating the L2 norm of the gradient for each block, and dividing the gradient by its norm.
- Modifying the momentum update and parameter step logic to follow the LANS update formula, which combines the normalized momentum and normalized instantaneous gradient.
- Implementing the linear-warmup-constant-decay learning rate schedule, which requires tracking the current iteration number to determine the appropriate learning rate.
- Setting up distributed data loading with data sharding and sampling without replacement within shards across workers.
The computational requirements are high due to the large model size and the distributed nature of the training. Utilizing high-performance interconnect like EFA is crucial for efficient gradient synchronization across a large number of GPUs. The per-block normalization adds a slight computational overhead per iteration compared to standard optimizers, but this is negligible compared to the overall gradient computation and communication costs in large-scale distributed training. The provided reference implementation can serve as a starting point for integrating LANS into deep learning frameworks.