Heterogeneous Mini-Batch Sampling (HGSampling)
- HGSampling is a family of mini-batch construction techniques that prioritize diversity by targeting under-represented or hard data points to reduce variance in stochastic gradients.
- It employs methods like stratified clustering, kernel-based repulsive processes, and Poisson Disk Sampling to ensure balanced representation across heterogeneous classes and modalities.
- Practically, HGSampling leads to faster convergence and improved model performance in both CNN and GNN applications, while reducing computational costs and communication overhead.
Heterogeneous Mini-batch Sampling (HGSampling) refers to a family of mini-batch construction schemes for stochastic optimization that, instead of uniform random sampling, employ mechanisms to favor diversity or explicitly target under-represented, rare, or otherwise "hard" structures in the data. The main technical aim is to reduce the variance of the stochastic gradient estimator and/or to ensure balanced representation across heterogeneous data components (classes, clusters, modalities, types). HGSampling generalizes both classical stratified sampling and recent repulsive point process approaches—such as Determinantal Point Processes (DPP), Poisson Disk Sampling (PDS), and budgeted importance sampling for heterogeneous graphs—providing a rigorously justified and practically efficient path to scalable, accelerated training in settings where data homogeneity cannot be assumed.
1. Core HGSampling Formulations: From Stratified to Repulsive Processes
Classical HGSampling arose in the context of stochastic gradient descent (SGD) as stratified or cluster-based mini-batch selection. Given a dataset partitioned into clusters (e.g., by label or unsupervised clustering), clusters are sampled proportionally to their within-cluster gradient variance proxy, yielding optimal per-cluster mini-batch allocations , where quantifies within-cluster variability. The resulting stratified mini-batch estimator,
is unbiased and guaranteed to have strictly smaller variance than uniform sampling when clusters are well-formed (Zhao et al., 2014).
Repulsive point processes generalize this logic. Determinantal Point Processes (DPPs), for example, define the probability of sampling a subset as , where is a positive semidefinite kernel encoding data similarity. Subsets with redundant points are exponentially suppressed, favoring diversity. In k-DPP sampling—fixed-size, —the induced negative correlations further reduce estimator variance by actively discouraging the co-occurrence of similar points (Zhang et al., 2017).
The PDS (Poisson Disk Sampling) algorithm samples mini-batches by enforcing minimum pairwise distances in feature (or other metric) space, preventing redundant/close-together samples within each batch. More generally, any repulsive point process whose pair correlation density for nearby is guaranteed to reduce the variance of the SGD gradient estimator, provided that gradients are positively correlated for nearby data (Zhang et al., 2018).
2. Construction Principles and Algorithms
The implementation of HGSampling depends on the chosen diversification mechanism:
- Stratified and Clustered Sampling: Partition data into clusters (e.g., via -means, conditioned on labels or feature similarity), then sample mini-batch elements per-cluster in proportion to estimated within-cluster variance. This requires a fixed partition prior to training, after which batch allocation and sampling are straightforward (Zhao et al., 2014).
- Kernel-based Repulsive Process Sampling: Construct a similarity kernel from feature representations (raw features, learned embeddings, labels, concatenations), select batch size , and at each iteration sample . The DPP kernel may be linear, RBF, or adapted to data structure, potentially with low-rank approximations for large . For example:
- (linear)
- (RBF)
- for mixed features/labels (Zhang et al., 2017).
The standard two-phase k-DPP sampler requires time, but approximate algorithms (e.g., Nyström) reduce compute to with .
- Poisson Disk Sampling (PDS): For mini-batch of size , iteratively select random candidates, adding each only if its minimal distance from the current batch exceeds a pre-specified radius . Complexity is —substantially faster than exact k-DPP and independent of . Extensions introduce active-bias radii or adaptivity to local data “difficulty” (e.g., using a mingling index in classification to upsample boundary points) (Zhang et al., 2018).
3. Theoretical Analysis: Variance Reduction and Generalization
All HGSampling schemes aim to reduce the variance of the stochastic gradient estimator :
- In stratified settings, the variance is upper-bounded by a sum weighted by per-cluster within-group variability, which, under appropriate batch allocation, can dramatically outperform the uniform baseline (Zhao et al., 2014).
- For DPP and general repulsive processes, the covariance term in the variance formula,
provides direct analytical evidence that negative sampling correlations (i.e., for ) strictly reduce variance compared to independent or positively correlated sampling (Zhang et al., 2017, Zhang et al., 2018). When the sufficient condition holds for all , variance reduction is guaranteed.
For orthogonal polynomial DPPs, it is shown that the variance of the DPP-based mini-batch estimator decays asymptotically as versus the rate for uniform Poisson sampling, leading to provably improved finite-time SGD convergence on convex objectives (Bardenet et al., 2021).
In distributed asynchronous (Hogwild) settings with heterogeneous local datasets, let the mini-batch size in “round” increase linearly: , allocating work proportionally across heterogeneous clients. The resultant scheme achieves the same test accuracy as constant-batch schemes, but with just communication rounds for strongly convex losses, where is the total number of gradient computations—a significant reduction in communication under data heterogeneity (Dijk et al., 2020).
4. HGSampling in Graph-Structured and Large-Scale Data
Heterogeneous Graph Transformer (HGT) applications require node- and edge-type-aware mini-batch subgraph sampling for web-scale heterogeneous graphs. The HGSampling algorithm here maintains type-specific candidate budgets and uses importance sampling with probabilities proportional to the accumulated normalized degree squared:
enabling controlled frontier growth and scalable layer-wise subgraph construction. This procedure is both memory and computation efficient, remaining constant regardless of total graph size and provably lowering variance of GNN message aggregation under assumptions inherited from the LADIES sampler (Hu et al., 2020).
5. Empirical Results and Observed Advantages
HGSampling schemes have been empirically validated on a variety of supervised and unsupervised learning tasks:
- Topic modeling (unsupervised): DM-SVI (DPP-based) recovers rare topics ignored by standard SVI and achieves higher per-class accuracies on imbalanced datasets (e.g., Reuters-R8, yielding per-class increase ~5pp) (Zhang et al., 2017).
- Supervised tasks: On Oxford-102 Flowers (imbalanced), DM-SGD with appropriately weighted kernels improves test accuracy by ≈2pp and reduces convergence epochs. On MNIST, even with balanced classes, DM-SGD shows reduction in gradient variance and faster convergence, most pronounced for small batch sizes.
- PDS and DPPs in deep models: PDS and k-DPP sampling significantly outperform uniform sampling on fine-grained and boundary-sensitive tasks (e.g., boundary classification in synthetic 2D data, MNIST, and speech commands), achieving lower test error (∼0.7–0.8% vs ∼1.0%) and faster convergence. PDS matches k-DPP in accuracy but is orders of magnitude faster per batch (Zhang et al., 2018).
- Large-Scale GNNs: In HGT experiments on the Open Academic Graph (179 million nodes, 2 billion edges), HGSampling remains the only method to achieve feasible memory footprint and high accuracy (batch times ∼1–2 seconds) while outperforming baselines by 9–21% in classification tasks (Hu et al., 2020).
- Distributed asynchronous SGD: Linearly increasing mini-batch HGSampling achieves competitive accuracy with dramatically reduced communication, robust to strong data heterogeneity (Dijk et al., 2020).
- Variance decay acceleration: For orthogonal polynomial DPPs, experimentally observed variance decay (fit in log-log plot), confirming theoretical bounds and surpassing uniform strategies for both synthetic and real data (Bardenet et al., 2021).
- Classical stratified schemes: Stratified mini-batch SGD shows order-of-magnitude reductions in gradient estimator variance and improved convergence on convex objectives across standard UCI and vision datasets (Zhao et al., 2014).
6. Canonical Algorithms and Implementation Trade-offs
| Method | Main Mechanism | Complexity per Batch | Empirical Strength |
|---|---|---|---|
| Stratified | Pre-clustered strata | Cluster select + O(b) | Large variance reduction |
| k-DPP | Gram kernel | O(Nk³), fast: O(Nm²) | Maximum diversity, general |
| PDS | Distance exclusion | O(k²) | Fastest sampler, flexible |
| Graph HGSampling | Layer-wise budgeted | O((m+LTn)d̄) | Web-scale graph feasible |
Stratified schemes require clustering upfront and are limited by the granularity of clusters and possibly labels. DPPs provide a soft, metric-based generalization, recover stratified and uniform sampling as special cases, and can be extended to combinations of modalities. Approximate algorithms, such as Nyström DPP or fast projective samplers, enable scaling to large datasets. PDS provides a highly efficient alternative for spatial or metric-based data; in computer vision, PDS is widely used due to its simplicity and scalability. HGSampling for graphs solves the multi-type and memory control issues that arise with naive expansion in GNNs.
7. Limitations, Extensions, and Open Problems
Several practical and theoretical considerations arise:
- Clustering cost and granularity: For stratified/clustered HGSampling, the cost and optimality of clustering may be prohibitive for high-dimensional or unlabeled data. The static cluster approach also neglects evolution of model and data distributions (Zhao et al., 2014).
- Kernel adaptivity: The choice of kernel (feature, label, or learned representation) heavily influences DPP performance but may be nontrivial to tune.
- Computational scaling: While DPPs are powerful, exact sampling remains costly for very large ; approximations (e.g., core-set methods) mitigate this at a possible tradeoff in diversity optimality (Zhang et al., 2017).
- Extending repulsion to representation/gradient space: Current repulsion mechanisms typically act in input space; direct enforcement in gradient or deep representation space could further improve variance but at higher computational complexity (Zhang et al., 2018).
- Applicability to non-convex and deep architectures: While performance gains persist in CNNs and deep tasks, most theoretical guarantees concern convex objectives.
- Graph applications: HGSampling is well justified for multi-type GNNs but requires careful management of budget, per-type fan-out, and cross-relation sampling (Hu et al., 2020).
- Hybrid and higher-order processes: Multi-scale or higher-order repulsive processes, or hybrid approaches (combining PDS and DPP), offer prospects for further variance reductions.
In sum, HGSampling via stratified clustering, DPPs, PDS, and graph-specific subgraph construction offers a unified and rigorously justified toolkit for variance-reduced, balanced, and scalable mini-batch selection in heterogeneous data environments—generalizing and improving upon classical SGD sampling schemes (Zhang et al., 2017, Zhang et al., 2018, Dijk et al., 2020, Hu et al., 2020, Bardenet et al., 2021, Zhao et al., 2014).