Parallel Scaling Law for Language Models (2505.10475v1)
Abstract: It is commonly believed that scaling LLMs should commit a significant space or time cost, by increasing the parameters (parameter scaling) or output tokens (inference-time scaling). We introduce the third and more inference-efficient scaling paradigm: increasing the model's parallel computation during both training and inference time. We apply $P$ diverse and learnable transformations to the input, execute forward passes of the model in parallel, and dynamically aggregate the $P$ outputs. This method, namely parallel scaling (ParScale), scales parallel computation by reusing existing parameters and can be applied to any model structure, optimization procedure, data, or task. We theoretically propose a new scaling law and validate it through large-scale pre-training, which shows that a model with $P$ parallel streams is similar to scaling the parameters by $O(\log P)$ while showing superior inference efficiency. For example, ParScale can use up to 22$\times$ less memory increase and 6$\times$ less latency increase compared to parameter scaling that achieves the same performance improvement. It can also recycle an off-the-shelf pre-trained model into a parallelly scaled one by post-training on a small amount of tokens, further reducing the training budget. The new scaling law we discovered potentially facilitates the deployment of more powerful models in low-resource scenarios, and provides an alternative perspective for the role of computation in machine learning.
Summary
- The paper introduces ParScale, which scales language models by employing parallel computational streams as an alternative to traditional parameter scaling.
- The methodology leverages learnable input transformations and a dynamic aggregation mechanism to process multiple streams concurrently and efficiently.
- Empirical results indicate that ParScale achieves similar performance gains with up to 22× less memory increase and 6× lower latency than conventional scaling.
Parallel Scaling Paradigm for LLMs
The paper "Parallel Scaling Law for LLMs" (2505.10475) introduces ParScale, a novel scaling paradigm for LLMs distinct from conventional parameter scaling or inference-time output token scaling. This approach focuses on increasing the computational parallelism during both training and inference phases without necessarily increasing model parameters or extending sequence generation length. The core mechanism involves generating P parallel computational streams derived from the initial input, processing these streams concurrently through the base model's architecture, and subsequently aggregating the resulting P outputs to produce the final result.
ParScale Mechanism and Implementation
The implementation of ParScale involves several key components. Given an input token sequence X, the process begins with P distinct, learnable transformations applied to X. These transformations, denoted as T1,T2,…,TP, map the input representation into P potentially diverse representations: Xi=Ti(X) for i=1,…,P. These transformed representations Xi are then fed into P separate, identical copies of the base LLM M. This effectively creates P parallel forward passes, yielding P output representations Oi=M(Xi).
The critical aspect of ParScale lies in the dynamic aggregation of these P parallel outputs. A learnable aggregation function, denoted as A, combines the Oi representations into a single final output Y^=A(O1,O2,…,OP). The nature of the transformations Ti and the aggregation function A are learnable parameters optimized end-to-end during training. The design of Ti aims to encourage diversity among the parallel streams, enabling the model to explore different computational paths or extract distinct features concurrently. The aggregation mechanism A is responsible for integrating the information from these diverse streams effectively.
From an implementation perspective, this necessitates a parallel execution framework capable of managing P simultaneous forward passes of the base model. This can be achieved through standard data parallelism techniques, where the input batch is conceptually expanded to P copies per original instance, or by leveraging model parallelism within a single inference instance if the infrastructure supports it. The learnable transformations Ti could range from simple linear layers to more complex context-dependent modules. The aggregation A could involve mechanisms like weighted summation, attention-based pooling, or gating networks, dynamically adjusting the contribution of each parallel stream based on the input context.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch import torch.nn as nn class ParScaleModel(nn.Module): def __init__(self, base_model, num_parallel_streams): super().__init__() self.base_model = base_model self.num_parallel_streams = num_parallel_streams # Learnable transformations self.transformations = nn.ModuleList([ nn.Linear(base_model.config.hidden_size, base_model.config.hidden_size) # Example: Simple linear layer for _ in range(num_parallel_streams) ]) # Learnable aggregation mechanism (Example: Simple weighted average) # More complex aggregation like attention or gating can be used self.aggregation_weights = nn.Parameter(torch.ones(num_parallel_streams)) def forward(self, inputs, **kwargs): # Assuming inputs is the output of the base model's initial embedding/input layer batch_size, seq_len, hidden_size = inputs.shape # Apply transformations in parallel transformed_inputs = [self.transformations[i](inputs) for i in range(self.num_parallel_streams)] # Execute base model forward passes in parallel # This pseudocode assumes a simple loop, actual implementation requires parallel execution backend parallel_outputs = [] for i in range(self.num_parallel_streams): # In a real implementation, this loop would be parallelized across devices/cores # and the base_model would process the transformed_inputs[i] # For simplicity, we'll simulate processing here. # A more accurate representation would involve replicating the base_model or using data/model parallelism. output_i = self.base_model(transformed_inputs[i], **kwargs) # Assuming base_model handles input format parallel_outputs.append(output_i) # Stack outputs for aggregation stacked_outputs = torch.stack(parallel_outputs, dim=0) # Shape: (num_parallel_streams, batch_size, seq_len, hidden_size) # Apply aggregation (Example: Weighted average) # Normalize weights for weighted average normalized_weights = torch.softmax(self.aggregation_weights, dim=0) # Expand weights to match output shape expanded_weights = normalized_weights.view(self.num_parallel_streams, 1, 1, 1) # Compute weighted sum aggregated_output = torch.sum(stacked_outputs * expanded_weights, dim=0) # Shape: (batch_size, seq_len, hidden_size) return aggregated_output # # # |
Scaling Law and Performance Characteristics
The paper posits and empirically validates a new scaling law for LLMs under the ParScale paradigm. While traditional scaling laws relate performance to parameters or data scale, the ParScale law relates performance to the number of parallel streams, P. The theoretical analysis suggests that scaling computation via P parallel streams is roughly analogous to scaling parameters by O(logP) in terms of performance improvement. This implies that relatively modest increases in parallel computation can yield performance gains comparable to significant increases in model parameter count.
A key finding highlighted is the efficiency profile of ParScale compared to parameter scaling. Achieving a similar level of performance improvement through ParScale requires substantially less additional memory and incurs significantly lower latency increases during inference than scaling the model size. Specifically, the paper reports that ParScale can achieve comparable performance improvements to parameter scaling with up to 22× less memory increase and 6× less latency increase. This is a crucial practical advantage, as memory capacity and inference latency are often bottlenecks in deploying LLMs, particularly in resource-constrained environments.
The empirical validation through large-scale pre-training supports this scaling law and the reported efficiency gains. The ability of ParScale to leverage existing model architectures and parallel computation resources effectively translates computational throughput into model capability gains with a favorable trade-off in terms of memory footprint and response time.
Post-Training and Deployment Implications
A notable practical advantage of ParScale is its potential for application to off-the-shelf pre-trained models. The paper demonstrates that an existing pre-trained model can be effectively scaled parallelly by post-training it with the ParScale mechanism on a relatively small amount of tokens. This capability significantly reduces the training budget required to upgrade or adapt models for improved performance, as it bypasses the need for extensive pre-training from scratch for each desired performance level achieved via parallel scaling.
The inference efficiency benefits make ParScale particularly relevant for deploying powerful LLMs in low-resource scenarios or applications requiring low latency. By decoupling performance gains from massive parameter counts, ParScale enables the use of more capable models on hardware with limited memory or computational power per core, provided parallel execution resources are available. This could broaden the applicability of LLMs in edge devices, mobile applications, or environments where distributing computation across multiple processing units is feasible, but increasing the size of individual model instances is prohibitive.
Conclusion
The Parallel Scaling Law for LLMs introduces ParScale as a viable and efficient method for enhancing model performance by leveraging parallel computation. This paradigm, through learnable input transformations and dynamic output aggregation across P parallel streams, offers a distinct path for scaling that is complementary to parameter scaling. The favorable trade-off between performance improvement and resource increase, particularly regarding memory and latency, alongside the ability to apply it via post-training, positions ParScale as a promising technique for developing and deploying more capable LLMs in diverse computational settings, especially those with stringent resource constraints.
Follow-up Questions
We haven't generated follow-up questions for this paper yet.
Related Papers
- PaSS: Parallel Speculative Sampling (2023)
- Scaling Data-Constrained Language Models (2023)
- Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism (2019)
- Scaling Laws for Neural Language Models (2020)
- Unified Scaling Laws for Routed Language Models (2022)
- Observational Scaling Laws and the Predictability of Language Model Performance (2024)
- Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters (2024)
- Scaling Laws for Precision (2024)
- Leveraging the true depth of LLMs (2025)
- Communication-Efficient Language Model Training Scales Reliably and Robustly: Scaling Laws for DiLoCo (2025)
Tweets
YouTube
HackerNews
- Parallel Scaling Law for Language Models (2 points, 1 comment)
- Qwen: Parallel Scaling Law for Language Models (61 points, 6 comments)
- [Qwen] Parallel Scaling Law for Language Models (18 points, 4 comments)