Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes
(2006.13484v2)
Published 24 Jun 2020 in cs.LG, cs.CL, cs.DC, and stat.ML
Abstract: BERT has recently attracted a lot of attention in natural language understanding (NLU) and achieved state-of-the-art results in various NLU tasks. However, its success requires large deep neural networks and huge amount of data, which result in long training time and impede development progress. Using stochastic gradient methods with large mini-batch has been advocated as an efficient tool to reduce the training time. Along this line of research, LAMB is a prominent example that reduces the training time of BERT from 3 days to 76 minutes on a TPUv3 Pod. In this paper, we propose an accelerated gradient method called LANS to improve the efficiency of using large mini-batches for training. As the learning rate is theoretically upper bounded by the inverse of the Lipschitz constant of the function, one cannot always reduce the number of optimization iterations by selecting a larger learning rate. In order to use larger mini-batch size without accuracy loss, we develop a new learning rate scheduler that overcomes the difficulty of using large learning rate. Using the proposed LANS method and the learning rate scheme, we scaled up the mini-batch sizes to 96K and 33K in phases 1 and 2 of BERT pretraining, respectively. It takes 54 minutes on 192 AWS EC2 P3dn.24xlarge instances to achieve a target F1 score of 90.5 or higher on SQuAD v1.1, achieving the fastest BERT training time in the cloud.
The paper "Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes" (Zheng et al., 2020) addresses the challenge of reducing the lengthy training time required for large deep neural networks like BERT, which hinders rapid experimentation and development. While using large mini-batches in synchronous stochastic gradient methods can accelerate training by allowing larger learning rates and fewer iterations, this approach often faces limitations where increasing the batch size further leads to accuracy degradation or divergence. The paper introduces LANS (Large Batch Nesterov-style AdamW with Normalized Gradient) and a novel learning rate scheduler to overcome these limitations and achieve significantly faster BERT pretraining.
The existing state-of-the-art large-batch optimizer for BERT pretraining at the time was LAMB [you2020reducing]. LAMB combined the AdamW optimizer with a layer-wise adaptive scaling factor, normalizing the update for each parameter block by its ℓ2 norm and scaling it by the ℓ2 norm of the parameter block itself. This allowed LAMB to scale BERT pretraining up to a mini-batch size of 64K, reducing training time from 3 days to 76 minutes on 1024 TPUs. However, LAMB struggled to scale to even larger batch sizes without losing accuracy or diverging.
The authors propose LANS and an accompanying learning rate strategy to enable stable and efficient training with larger batch sizes.
The LANS optimizer introduces two main modifications:
Per-block Gradient Normalization: The gradient for each parameter block gt,Gb is normalized by its ℓ2 norm before being used to update the first and second-order moments (mt and vt). This means the direction of the gradient for each block is used, but its magnitude is ignored. This technique, previously explored in other contexts, is shown to make the optimization more robust to vanishing and exploding gradients and eliminates the need for explicit gradient clipping. The normalization step is:
Incorporating Nesterov's Momentum: LANS modifies LAMB's update rule to incorporate a concept similar to Nesterov's accelerated gradient (NAG). While classic momentum updates parameters based on the current momentum, NAG effectively uses a "lookahead" gradient. The LANS update rule is formulated as a convex combination of a LAMB-style update using the first-order momentum (mt) and a LAMB-style update using the instantaneous normalized gradient. Specifically, the update direction dt,Gb for block Gb at time step t is:
where rt,Gb=vt,Gb+ϵmt,Gb is the adaptive momentum ratio and ct,Gb=vt,Gb+ϵgt,Gb is the adaptive instantaneous gradient ratio (using the normalized gradient gt,Gb), ϕ is the layer-wise scaling function (typically identity), λ is the weight decay parameter, and β1 is the Adam momentum parameter. The terms mt,Gb and vt,Gb are bias-corrected Adam-style moments calculated using the normalized gradient. This structure allows LANS to benefit from Nesterov-style acceleration while maintaining the adaptive and normalized properties of LAMB.
Furthermore, the paper identifies that the standard linear warmup followed by linear decay learning rate schedule [goyal2017accurate] used by LAMB is insufficient for very large batch sizes. The maximum theoretically usable learning rate is bounded by the inverse of the Lipschitz constant, and increasing the batch size doesn't necessarily increase this bound indefinitely. If the required large learning rate exceeds this bound, training diverges. If a smaller learning rate is used to avoid divergence, the total "area under the curve" of the learning rate schedule might be too small for sufficient training progress within a reduced number of iterations. To address this, they propose a new learning rate scheduler with a constant phase after the linear warmup:
This allows the optimizer to stay at the maximum feasible learning rate η for a longer duration (Tconst iterations), ensuring sufficient total progress even if η cannot be scaled proportionally to the ideal large-batch scaling heuristics.
For practical implementation in distributed training, the paper emphasizes the importance of data sharding with random sampling without replacement within each shard. This reduces the variance of the mini-batch gradient compared to sampling with replacement across the entire dataset, which is crucial for the stability and efficiency of large-batch training.
The empirical evaluation was conducted on pretraining BERT-Large on Wikipedia and BooksCorpus using 192 AWS EC2 P3dn.24xlarge instances (1536 NVIDIA V100 GPUs) with EFA for high-throughput networking. The training was split into two stages, similar to previous works:
LANS with the proposed linear-warmup-constant-decay schedule and these batch sizes achieved a SQuAD v1.1 dev F1 score of 90.60 in 53.6 minutes. This significantly outperforms LAMB, which either took longer (76.2 minutes with 64K/32K batch size and 8599 steps) or diverged when attempted with larger batch sizes (96K/33K).
The implementation of LANS is provided online, enabling practitioners to integrate these techniques into their large-scale training workflows. The findings demonstrate that LANS, combined with the specific learning rate scheduler and data sharding strategy, effectively addresses the challenges of scaling BERT pretraining to extremely large batch sizes, achieving a new record for cloud-based BERT training speed while maintaining target accuracy.
Practical considerations for implementing LANS include:
Implementing the per-block gradient normalization logic within the optimizer. This typically involves iterating over parameter groups/blocks, calculating the L2 norm of the gradient for each block, and dividing the gradient by its norm.
Modifying the momentum update and parameter step logic to follow the LANS update formula, which combines the normalized momentum and normalized instantaneous gradient.
Implementing the linear-warmup-constant-decay learning rate schedule, which requires tracking the current iteration number to determine the appropriate learning rate.
Setting up distributed data loading with data sharding and sampling without replacement within shards across workers.
The computational requirements are high due to the large model size and the distributed nature of the training. Utilizing high-performance interconnect like EFA is crucial for efficient gradient synchronization across a large number of GPUs. The per-block normalization adds a slight computational overhead per iteration compared to standard optimizers, but this is negligible compared to the overall gradient computation and communication costs in large-scale distributed training. The provided reference implementation can serve as a starting point for integrating LANS into deep learning frameworks.