- The paper introduces new scaling laws that systematically predict the optimal data mixture using model size, tokens, and domain weights.
- It extends traditional formulations with both additive and joint models, accurately forecasting losses in large language, vision, and multimodal models.
- The approach validates its effectiveness through low error rates and practical optimization, reducing reliance on trial-and-error in mixture selection.
Scaling Laws for Data Mixture Optimization
This paper introduces a novel approach to determine the optimal data mixture for training large foundation models across different modalities. Recognizing that the standard trial-and-error method for selecting data mixtures becomes impractical at scale, the authors propose a systematic method based on scaling laws. This method accurately predicts the loss of a model as a function of its size (N), the number of training tokens (D), and the domain weight vector (h). The key innovation is extending traditional scaling laws to explicitly model the impact of domain weights on model performance.
The authors validate the universality of their scaling laws in three distinct large-scale settings: LLMs, native multimodal models (NMMs), and large vision models (LVMs). They demonstrate that these scaling laws can extrapolate to new data mixtures and across scales, estimating parameters using a few small-scale training runs and predicting performance at larger scales and unseen domain weights. This approach offers a principled alternative to costly trial-and-error methods, enabling the derivation of optimal domain weights for any target domain under a given training budget (N, D).
The paper formulates the problem of training models with data from k domains, aiming to predict the loss on a target domain DT​ after training a model of size N with D tokens using domain weights h. Two scaling law formulations are proposed:
- Additive Scaling Law: This law models only the bias term (Eh) as a function of h, while other parameters (Ah,αh,Bh,βh) are constants. The formula is:
L=E+∑i=1k​Ci​hi​γi​1​+NαA​+DβB​
This law has $5 + 2k$ parameters, and the optimal domain weights are independent of model size N and the number of tokens D.
- Joint Scaling Law: This law models the terms Ah and Bh as functions of h, capturing the interaction between scale and mixture. The formula is:
L=E+∑i=1k​Ci​hiγi​​1​+NαAh​+DβBh​ with Ah=(i=1∑k​CiA​hi​)γA and Bh=(i=1∑k​CiB​hi​)γB
This scaling law has $5 + 4k$ parameters and predicts that the contribution of N and D to the loss depends on the domain weights, making the optimal domain weights compute-dependent.
The authors use the Huber loss to fit the scaling laws, employing a random search and the Basin-hopping algorithm for optimization. The Mean Relative Error (MRE) is used to evaluate the scaling laws by comparing predicted losses against actual losses on a new set of runs with different (N,D,h) values.
The experimental setup involves pretraining LLMs, NMMs, and LVMs on diverse datasets. For LLMs, the authors use the k=7 domains from SlimPajama. For NMMs, they train on a mixture of text-only data, interleaved multimodal documents, and paired image-caption datasets (k=3). For LVMs, they use a mixture of paired image-caption datasets drawn from four domains (k=4).
The paper presents strong numerical results demonstrating the effectiveness of the proposed scaling laws. The key findings include:
- Accurate Extrapolation: The scaling laws accurately capture training data and generalize effectively to larger scales with significantly increased values of N and D. \Cref{fig:scaling_laws_observed_vs_predicted_multimodal} shows a close alignment between predicted and observed losses for both joint and additive laws, with good extrapolation to larger model sizes. The MRE\% in \cref{tab:mre_results} is consistently low for both laws, with the joint law showing improvement over the additive law.
- Optimal Domain Weights Estimation: The fitted scaling laws enable accurate estimation of optimal domain weights by solving an optimization problem on the simplex. Models trained with these optimized mixtures consistently outperform alternatives, including uniform mixtures and those used in prior works.
- Practical Mixture Estimation: The authors demonstrate that the scaling laws can be accurately fitted with small-scale runs, and then used to solve for optimal domain weights, providing a principled approach to mixture estimation compared to ad-hoc methods.
The paper also includes an analysis of the scaling laws, exploring aspects such as the number of runs needed for accurate fitting, the behavior of optimal domain weights when scaling FLOPs, and the validity of the laws with cosine learning rate schedules. They find that only 10-20 runs are needed to fit the scaling laws accurately and that the optimal mixture evolves as a function of the compute budget.
The presented work has significant implications for the field of AI, particularly in the training of large foundation models. By providing a systematic and scalable method for determining optimal data mixtures, this research addresses a critical bottleneck in model development. This approach has practical benefits, such as reduced computational costs and improved model performance, and theoretical implications, such as a deeper understanding of the relationship between data composition and model behavior.
Future developments in this area could explore:
- Extending the scaling laws to continual pretraining and finetuning scenarios.
- Predicting downstream task performance directly, rather than relying on generic target loss.
- Accounting for data repetition and dynamic evolution of domain weights during training.
- Incorporating additional factors, such as data quality and diversity, into the scaling laws.