ReLU Networks as Random Functions: Their Distribution in Probability Space (2503.22082v1)
Abstract: This paper presents a novel framework for understanding trained ReLU networks as random, affine functions, where the randomness is induced by the distribution over the inputs. By characterizing the probability distribution of the network's activation patterns, we derive the discrete probability distribution over the affine functions realizable by the network. We extend this analysis to describe the probability distribution of the network's outputs. Our approach provides explicit, numerically tractable expressions for these distributions in terms of Gaussian orthant probabilities. Additionally, we develop approximation techniques to identify the support of affine functions a trained ReLU network can realize for a given distribution of inputs. Our work provides a framework for understanding the behavior and performance of ReLU networks corresponding to stochastic inputs, paving the way for more interpretable and reliable models.
Summary
- The paper models trained ReLU networks as random functions by considering input uncertainty, specifically assuming inputs follow a Gaussian Mixture Model (GMM).
- It derives exact, computationally tractable expressions for the distribution over network activation patterns and outputs by reducing the problem to computing Gaussian orthant probabilities.
- A sample-free algorithm is introduced to efficiently identify and approximate the support of high-probability activation patterns for large networks, addressing computational complexity.
The paper "ReLU Networks as Random Functions: Their Distribution in Probability Space" (2503.22082) introduces a framework for analyzing trained feed-forward neural networks with Rectified Linear Unit (ReLU) or other piecewise linear activations by modeling their behavior under input uncertainty. The core idea is to treat the network not as a fixed deterministic function, but as a random function where the randomness is induced by the probability distribution governing the input data x
. This approach leverages the piecewise affine nature of ReLU networks to derive exact, computationally tractable expressions for the probability distributions over the network's internal states (activation patterns) and its final output.
Modeling Trained Networks as Random Functions
A key property of networks composed of ReLU activations (or similar piecewise linear functions) is that they partition the input space Rn0 into a finite set of disjoint convex polytopes {K(ζ′)}. Within each polytope K(ζ′), the network function f:Rn0→RnL behaves precisely as a single affine transformation:
f(x)=Cζ′x+dζ′for x∈K(ζ′)
Here, ζ′ represents a specific activation pattern, which is a binary vector indicating the activation status (output > 0 or output ≤ 0) of every neuron in the hidden layers. Formally, ζ′=[z1T,...,zL−1T]T, where zl is the binary vector for hidden layer l. Each unique activation pattern ζ′ corresponds to a unique polytope K(ζ′) and a unique affine function (Cζ′,dζ′).
When the input x is drawn from a probability distribution px(x), the specific polytope K(ζ′) containing x, and thus the activation pattern ζ′ induced by x, becomes a random variable. Consequently, the trained network itself can be viewed as selecting a random affine function from the finite set {(Cζ′,dζ′)} it can realize, based on the input distribution. The paper focuses on characterizing the probability distribution governing this selection process and the resulting distribution of the network's output y=f(x). The analysis primarily assumes the input x follows a Gaussian Mixture Model (GMM), px(x)=m=1∑MwmN(x;μm,Σm).
Distribution Over Activation Patterns
The central goal is to determine the Probability Mass Function (PMF) of the activation pattern random variable ζ, denoted P(ζ=ζ′). This PMF quantifies the probability that, for a randomly drawn input x∼px(x), the network operates according to the specific affine function (Cζ′,dζ′).
Initially, this probability can be expressed as the integral of the input density over the corresponding polytope:
P(ζ=ζ′)=∫K(ζ′)px(x)dx
Directly computing this integral is generally intractable due to the complex geometry of the polytope K(ζ′), which is defined by the intersection of half-spaces related to neuron pre-activations across all layers.
The paper introduces a crucial transformation (Proposition 1) that recasts this problem in terms of Gaussian orthant probabilities. The condition x∈K(ζ′) is equivalent to a set of sign constraints on the pre-activation values hl=Wlal−1+bl for all hidden layers l=1,...,L−1. Let h=[h1T,...,hL−1T]T be the concatenated vector of all hidden layer pre-activations. The activation pattern ζ′ corresponds to h lying within a specific orthant O(ζ′) in the space of h.
Under the assumption that x follows a GMM, the vector h also follows a GMM, as it is obtained through a sequence of affine transformations and piecewise linear activations (which maintain the mixture structure, albeit modifying the components). Proposition 1 shows that the PMF can be computed as:
P(ζ=ζ′)=m=1∑MwmPh∣k=m(h∈O(ζ′))
where Ph∣k=m(h∈O(ζ′)) is the probability that the vector h, conditioned on the input x being drawn from the m-th Gaussian component, falls into the orthant O(ζ′). This is precisely a Gaussian orthant probability—the integral of a multivariate Gaussian distribution over a region defined by sign constraints (R≥0 or R<0 for each dimension). These probabilities are more amenable to numerical computation using established algorithms (e.g., methods developed by Genz and Bretz). This formulation replaces the complex polytope integration with a sum of lower-dimensional (if total hidden units < input dimension) orthant probability calculations.
Distribution Over Network Outputs
The framework is extended to characterize the probability density function (PDF) py(y) of the network's output vector y=f(x). This is achieved by marginalizing the joint density p(x,y) over the input x. The marginalization is decomposed based on the activation patterns:
py(y)=ζ′∈supp(ζ)∑∫K(ζ′)p(x,y∣ζ=ζ′)p(ζ=ζ′)dx
where the sum is over all activation patterns ζ′ that have non-zero probability (P(ζ=ζ′)>0).
A significant theoretical result (Proposition 2) states that if the input x follows a GMM, the output y is distributed as a mixture of truncated multivariate Gaussian distributions. This arises because for a fixed pattern ζ′, the output y=Cζ′x+dζ′ is an affine transformation of x. When x is restricted to the polytope K(ζ′), the underlying Gaussian components of the input GMM become truncated. The affine transformation of these truncated Gaussians results in truncated Gaussians in the output space. Summing these conditional densities (weighted by P(ζ=ζ′)) over all possible ζ′ yields the final mixture distribution for y. This finding contrasts with simpler assumptions sometimes made in related literature, highlighting that intermediate and output layers of ReLU networks do not necessarily maintain simple Gaussian distributions even with Gaussian inputs, due to the input domain partitioning.
Computationally, Proposition 3 provides an expression for py(y) that again relies on Gaussian orthant probabilities, avoiding direct integration over the complex polytopes or handling truncated Gaussians explicitly. It involves considering the joint distribution of the hidden pre-activations h and the output y. Let h~=[h1T,...,hL−1T,yT]T. This combined vector is also a GMM under GMM inputs. The output PDF can be expressed by integrating a related Gaussian density over the orthant O(ζ′) corresponding to the hidden activations, effectively linking the output value y to the probability of the activation patterns consistent with producing that output. The formula involves summing contributions associated with each ζ′ and each input GMM component, where each term involves a Gaussian orthant probability calculation related to the joint distribution.
Computation via Gaussian Orthant Probabilities
Gaussian orthant probabilities are fundamental to the computational tractability of the proposed framework. Their emergence is natural: ReLU activations (max(0,z)) act as selectors based on the sign of their input z. An activation pattern ζ′ is determined by the signs of all hidden pre-activations hl,j. The probability P(ζ=ζ′) is thus the probability that the random vector h falls into the specific orthant O(ζ′) defined by the required signs. When x (and consequently h) follows a GMM, this probability calculation becomes a sum of standard Gaussian orthant probability problems.
The computation of P(X∈O) where X∼N(μ,Σ) and O is an orthant is a non-trivial task, especially in high dimensions. However, specialized numerical algorithms exist, such as those based on Monte Carlo methods (e.g., Genz's algorithm) or analytical approximations, which can estimate these probabilities with controlled accuracy. The paper leverages these existing methods. The dimensionality of the orthant probability calculation corresponds to the total number of hidden neurons, which can be large but potentially smaller than the input dimension n0.
Approximating the Support of Activation Patterns
A major practical challenge is that the number of possible activation patterns ∣{ζ′}∣ can grow exponentially with the number of neurons, potentially reaching 2N where N is the total number of hidden neurons. Calculating P(ζ=ζ′) for every possible ζ′ to compute the full PMF or output PDF is often computationally infeasible. However, empirical observations suggest that for a given input distribution, only a small subset of all possible activation patterns typically have significant probability mass.
To address this, the paper proposes a sample-free algorithm (Algorithm 1) to efficiently identify a reduced set of high-probability activation patterns, effectively approximating the support supp(ζ). This algorithm operates layer-by-layer:
- Initialization: Start with the input distribution (GMM).
- Layer Propagation: For layer l, given the distribution of the activations al−1 from the previous layer (represented as a mixture derived from high-probability patterns of preceding layers), compute the distribution of the pre-activations hl=Wlal−1+bl. This will also be a mixture model.
- Neuron Pruning: For each neuron j in layer l, calculate the marginal probability P(hl,j>0). This is computed using the mixture distribution of hl.
- Branching Decision: If P(hl,j>0) is very close to 0 or 1 (i.e., the neuron's activation state has low entropy), fix the corresponding bit zl,j in the pattern ζ′. If the probability is far from 0 or 1 (high entropy), the algorithm branches, exploring both possibilities (zl,j=0 and zl,j=1).
- Pattern Generation: Recursively build partial activation patterns layer by layer, pruning branches where the accumulated probability (product of conditional probabilities along the path) falls below a threshold or where neurons are deemed deterministic.
- Output: The algorithm returns a list (
curList
in the paper's pseudocode) of activation patterns ζ′ estimated to have high probability, along with their approximate probabilities.
This procedure acts like a guided search over the tree of possible activation patterns, focusing computational effort on the regions of the pattern space that are most likely given the input distribution. The resulting list of high-probability patterns can then be used to approximate the full PMF P(ζ) and the output PDF py(y) by summing only over this reduced set, significantly reducing computational cost while capturing most of the probability mass.
Implementation Considerations and Applications
Implementing this framework requires several components:
- GMM Representation: The ability to represent and manipulate GMMs through affine transformations is needed.
- Orthant Probability Solver: An efficient and accurate implementation for computing multivariate Gaussian orthant probabilities is essential. The computational cost of these solvers typically scales poorly with the dimension (number of hidden neurons).
- Support Approximation Algorithm: Implementation of the layer-wise pruning algorithm (Algorithm 1) is needed for practical application to large networks. Its effectiveness depends on the distribution of neuron activation probabilities – it works best when many neurons are strongly biased towards being active or inactive for the given input distribution.
Limitations and Trade-offs:
- GMM Input Assumption: The exact analytical results rely on the input distribution being a GMM. While GMMs can approximate many distributions, this assumption might not hold in all practical scenarios. Extensions to other input distributions might require different analytical tools or rely on approximations.
- Computational Cost: Calculating orthant probabilities can be computationally intensive, especially for networks with many hidden neurons. The support approximation algorithm mitigates this but introduces an approximation error. There is a trade-off between computational cost and the accuracy of the approximated distributions.
- Scalability: Both the exact computation and the approximation algorithm's complexity scale with the network size (number of neurons and layers). Applying this framework to very deep or wide networks might still be challenging.
Potential Applications:
- Robustness Analysis: Understanding the distribution P(ζ=ζ′) can reveal how sensitive the network's functional behavior is to input perturbations characterized by px(x). A highly peaked distribution suggests functional stability, while a flatter distribution indicates sensitivity.
- Interpretability: Identifying the most probable affine functions (Cζ′,dζ′) the network uses for a class of inputs provides insight into the dominant modes of operation.
- Output Uncertainty Quantification: The derived output PDF py(y) provides a principled way to quantify the uncertainty in the network's predictions arising from input uncertainty, potentially more accurately than Monte Carlo sampling, especially for rare events.
- Model Comparison: The framework could potentially be used to compare different trained models based on their probabilistic behavior under input uncertainty.
Conclusion
The paper provides a rigorous mathematical framework for analyzing trained ReLU networks as random functions induced by stochastic inputs, specifically assuming a Gaussian Mixture Model for the input distribution. By leveraging the piecewise affine structure and transforming the problem into the computation of Gaussian orthant probabilities, it derives exact expressions for the distribution over the network's realized affine functions (via activation patterns) and the distribution of its outputs. Recognizing the computational barrier of evaluating all possible functions, it also presents a practical, sample-free algorithm to approximate the set of most probable functions. This work offers valuable tools for deeper understanding, analysis, and potentially enhancing the reliability of ReLU networks in applications involving uncertain inputs.