This paper introduces "CheckFree" and "CheckFree+", two novel methods for recovering LLM training from stage failures in a distributed pipeline parallelism setup, particularly relevant for decentralized and "wimpy" computation nodes like spot instances. These methods aim to reduce the significant communication, storage, and computation overhead associated with traditional recovery techniques like checkpointing and redundant computation.
The core problem addressed is the loss of a training stage (a partition of the model's layers) due to node failures. Traditional checkpointing requires periodically saving the entire model, incurring high overhead, especially for large models. Redundant computation involves each node performing extra work for a subsequent stage, increasing computational load.
CheckFree: Memoryless Recovery for Intermediate Stages
CheckFree proposes reinitializing a failed intermediate stage by taking a weighted average of the weights of its two neighboring stages.
The motivation stems from observations that:
- LLMs are resilient to layer omission, suggesting redundancy between layers.
- Layer stacking, where new layers are initialized from neighbors, can improve training.
- Implementation:
- When a stage Si fails, its new node receives the weights Ws,i−1 and Ws,i+1 from the neighboring stages Si−1 and Si+1, respectively.
- It also receives the squared L2 norm of the last gradients from these neighbors: ωi−1=∣∣∇Ws,i−1∣∣2 and ωi+1=∣∣∇Ws,i+1∣∣2.
The weights of the failed stage Ws,i are initialized as:
Ws,i←ωi−1+ωi+1ωi−1Ws,i−1+ωi+1Ws,i+1
This gives more importance to stages that are less converged (higher gradient norm), effectively offloading some of their learning to the new stage.
After reinitialization, the learning rate is scaled up by a factor of 1.1 to help the new stage adapt.
The storage and communication overhead for the gradient norms is negligible (a single scalar per stage).
The recovery process for a failed stage i is outlined in Algorithm 1:
1
2
3
4
5
6
7
|
Algorithm 1: Recovery algorithm for stage i
1: REQUIRE: new node assigned to stage i, λ learning rate
2: Receive W_{s,i-1} and ω_{i-1} of stage i-1 where ω_{i-1} = ||∇W_{s,i-1}||²
3: Receive W_{s,i+1} and ω_{i+1} of stage i+1 where ω_{i+1} = ||∇W_{s,i+1}||²
4: Initialize the weights of the failed stage W_{s,i} ← (ω_{i-1}*W_{s,i-1} + ω_{i+1}*W_{s,i+1}) / (ω_{i-1} + ω_{i+1})
5: Update learning rate λ ← 1.1λ
6: Continue training from the current batch |
CheckFree, however, cannot recover the first and last stages of the pipeline as they only have one neighbor, and simply copying from that single neighbor leads to a significant performance drop.
CheckFree+: First and Last Stage Recovery
CheckFree+ extends CheckFree to handle failures of the first and last stages.
- Implementation for First/Last Transformer Stages:
- It utilizes out-of-order pipeline execution. For half the microbatches, the standard pipeline order (S0,S1,S2,…,SL−1,SL,S0) is used.
- For the other half, the order of the first two and last two transformer stages is swapped (e.g., S0,S2,S1,…,SL,SL−1,S0). S0 typically holds embedding/de-embedding layers.
- This makes stage S2 learn the behavior of S1, and SL−1 learn the behavior of SL, without additional computation, as they effectively take their places in the pipeline for half the time.
- If S1 (or SL) fails, it can be recovered by simply copying the weights from S2 (or SL−1).
- Implementation for Embedding/De-embedding Layers:
- The (de)embedding layers (typically in S0) are critical. CheckFree+ handles their recovery by copying these layers to the neighboring stages (S1 and SL).
- This requires a small storage overhead (O(∣E∣) where E is the embedding layer size), which is significantly smaller than the full model size. If S0 fails, its weights can be recovered exactly from these copies.
Comparison of Recovery Strategies:
Feature |
Checkpointing |
Redundant Comp. |
CheckFree |
CheckFree+ |
Additional Memory |
O(∣F∣) |
O(∣F∣) |
0 |
O(∣E∣) |
Additional Comm. |
O(∣F∣) |
O(∣F∣) |
0 |
O(∣E∣) |
Additional Comp. |
0 |
Forward pass |
0 |
0 |
Non-faulty storage |
Yes |
No |
No |
No |
Recovery of stages |
All stages |
Non-consecutive |
Non-consecutive intermediate stages |
Non-consecutive stages |
Evaluation:
The methods were evaluated on LLaMa models (124M, 500M, 1.5B parameters) with varying stage failure rates (5%, 10%, 16% per hour).
- Baselines: Checkpointing (checkpointing every ~3 hours) and Redundant Computation (each stage also computes the next stage's forward pass).
- Performance Metrics: Iteration time and total training time to reach a target validation loss.
- Results:
- CheckFree and CheckFree+ significantly outperformed checkpointing in terms of wall-clock training time due to avoiding rollback delays.
- At low to medium failure rates (5-10%), CheckFree and CheckFree+ were over 12% faster in wall-clock time to convergence than redundant computation. This is because while redundant computation might converge in fewer iterations, its iteration time is much higher.
- CheckFree+ demonstrated robustness across different failure frequencies, with only slight degradation in validation loss even when failure rates tripled.
- The out-of-order swapping in CheckFree+ incurs a convergence slowdown in a no-failure setting, but is beneficial when failures are present.
- Large models trained with CheckFree+ (at 16% failure rate) achieved perplexity comparable to models trained without faults, despite different resulting weights.
Practical Implications and Limitations:
- CheckFree and CheckFree+ offer a lightweight recovery mechanism without needing external storage or redundant computations, making them suitable for cost-sensitive decentralized training.
- The recovery time for a failed stage is around 30 seconds.
- A key limitation is that neither method can recover from consecutive stage failures.
- The out-of-order execution in CheckFree+ can slightly slow down convergence in scenarios with very few or no failures.
Conclusion:
CheckFree and CheckFree+ present efficient alternatives for fault tolerance in distributed LLM training. By leveraging LLM properties like layer redundancy and resilience to omission, they achieve faster training times under failure conditions compared to existing methods, with minimal overhead. The code is available at \url{https://github.com/gensyn-ai/CheckFree}. Future work includes addressing consecutive failures and improving convergence in non-faulty cases for CheckFree+.