Stochastic Optimization of Sorting Networks
- The paper introduces NeuralSort, a novel differentiable surrogate that enables gradient-based training by continuously relaxing permutation matrices.
- It leverages softmax-based relaxations and Gumbel reparameterization to approximate traditional sorting operations while maintaining differentiability within neural networks.
- Empirical results in semantic sorting, quantile regression, and differentiable k-nearest neighbors demonstrate significant performance improvements and efficient computation.
Stochastic optimization of sorting networks addresses the fundamental challenge of making sorting operations differentiable and amenable to gradient-based learning. Classical sorting is non-differentiable, which traditionally prohibits its integration within end-to-end trainable neural architectures. "NeuralSort: Stochastic Optimization of Sorting Networks via Continuous Relaxations" introduces NeuralSort, a differentiable surrogate for sorting based on continuous relaxations of permutation matrices. This framework enables stochastic optimization over permutations by combining NeuralSort with a reparameterized estimator for the Plackett–Luce distribution using Gumbel perturbations, thereby making sorting networks directly tractable within deep learning pipelines (grover et al., 2019).
1. Non-Differentiability of Sorting Operations
The standard sort function takes and produces a permutation that orders the elements, typically represented by a permutation matrix , with exactly one "1" per row and column. The sorting operator is piecewise constant: infinitesimal changes in rarely alter the ranking unless a tie occurs, resulting in a Jacobian that is zero almost everywhere and undefined at ties. This non-differentiability means that embedding a sort operation inside a computational graph will zero out or remove all gradient information with respect to inputs, presenting a fundamental barrier to direct gradient-based optimization for objectives dependent on the output ordering (grover et al., 2019).
2. NeuralSort: Continuous Relaxation of Permutations
2.1 Unimodal Row-Stochastic Matrix Relaxation
Permutation matrices are relaxed to unimodal row-stochastic matrices with the following properties:
- for all
- for all 0
- Each row 1 has a unique arg max at column 2, and 3 forms a permutation of 4.
The pairwise-difference matrix 5 enables the exact construction of 6:
7
otherwise, 8 (grover et al., 2019).
2.2 Softmax-Based Relaxation
NeuralSort replaces row-wise hard arg max with a softmax, yielding the continuous relaxation:
9
with temperature parameter 0. An alternative softmax-based formulation is:
1
These relaxations yield differentiable matrices for any 2, recovering hard permutation matrices in the 3 limit (in absence of ties).
3. Stochastic Gradient Estimation via Plackett–Luce and Gumbel Tricks
3.1 Plackett–Luce Permutation Distribution
The Plackett–Luce (PL) distribution models random permutations 4 parameterized by positive scores 5:
6
This reflects a sequential draw without replacement, with probabilities proportional to exponentiated scores.
3.2 Gumbel Reparameterization for Sampling and Gradients
Sampling from PL7 is enabled using the Gumbel-max trick with i.i.d. Gumbel8 noise 9: \begin{align*} \tilde{s}_i &= \log s_i + g_i \ \pi &= \operatorname{sort_indices}(\tilde s) \end{align*} This renders permutation sampling as a deterministic (but non-differentiable) function of 0 and 1. The expectation of interest is:
2
Approximating the discrete 3 with NeuralSort yields:
4
Gradients w.r.t. 5 are then given by:
6
which can be efficiently approximated using Monte Carlo sampling.
4. Stochastic Optimization Workflow
The NeuralSort stochastic optimization loop can be implemented as follows:
5 In the limit 7, row-wise arg max can be applied to recover hard permutations, supporting straight-through optimization (grover et al., 2019).
5. Complexity Analysis and Computational Characteristics
The construction of the pairwise-difference matrix 8 requires 9 operations, fully parallelizable on GPUs. Each softmax operation per row costs 0, yielding 1 complexity per forward pass with no iterative normalization. Memory requirements are 2 per relaxed permutation per sample.
Comparatively:
| Approach | Forward Pass Complexity | GPU Parallelism | Differentiability |
|---|---|---|---|
| Standard sorting | 3 | Limited | No |
| Sinkhorn-based relaxations | 4 per iteration | Good | Yes (doubly-stochastic) |
| NeuralSort | 5 one-shot | High | Yes (unimodal stochastic) |
NeuralSort’s single-pass 6 computation is competitive and often faster than iterative Sinkhorn methods for practical 7 (grover et al., 2019).
6. Empirical Performance and Impact
NeuralSort and its stochastic extension (via PL reparameterization) were evaluated on several tasks:
- Semantic Sorting (large-MNIST, 8): Deterministic NeuralSort achieves approximately 84% exact permutation accuracy, outperforming Sinkhorn baselines (3–9%) and a naïve row-stochastic predictor (9%). Individual-rank accuracy (correctly placed elements) improves from 60% (Sinkhorn) to 92% (NeuralSort).
- Quantile Regression (median estimation): Mean squared error decreases from 9 (Sinkhorn) to 0 (NeuralSort), with 1 improving from 0.25 to 0.94 for 2.
- Differentiable k-Nearest Neighbors (3, top 4 selection):
- MNIST: 99.5% (NeuralSort) vs. 97.2% (standard kNN), 99.4% (CNN)
- Fashion-MNIST: 93.5% (NeuralSort) vs. 85.8% (kNN), 93.4% (CNN)
- CIFAR-10: 90.7% (NeuralSort) vs. 35.4% (kNN), 95.1% (CNN)
Across all tasks, stochastic NeuralSort offers comparable accuracy to its deterministic version while enabling principled uncertainty estimation for permutations. This framework supports a one-shot, differentiable surrogate for sorting, hard permutation projection for metrics, and a reparameterized estimator for optimizing over permutation distributions in deep learning pipelines (grover et al., 2019).