- The paper introduces Discrete Spatial Diffusion (DSD), a framework that preserves total intensity while modeling discrete spatial data using a continuous-time jump process.
- It employs a neural network to predict per-pixel reverse transition rates, leveraging an adaptive noise schedule and loss functions for effective reverse dynamics.
- Empirical results demonstrate DSD's capability in image synthesis, inpainting, and generating realistic microstructures for materials science applications.
This paper introduces Discrete Spatial Diffusion (DSD), a novel generative modeling framework designed to address limitations of traditional diffusion models when applied to data that is inherently discrete and governed by strict conservation laws, such as mass preservation. Unlike standard diffusion models that often operate in continuous intensity spaces and apply noise independently per pixel (like Gaussian diffusion or even discrete-state models that diffuse intensity independently), DSD operates directly on discrete units of intensity (treated as particles) on a spatial lattice.
The core idea of DSD is based on a continuous-time, discrete-state jump stochastic process. The forward diffusion process involves these discrete intensity units performing independent random walks on a 2D spatial lattice (though the framework generalizes to higher dimensions). Each particle can jump to one of its nearest neighbors at a defined rate. Crucially, this jump process exactly conserves the total number of particles (total intensity) within the system for each color channel independently, both in the forward (corruption) and reverse (generation) processes. Boundary conditions (no-flux or periodic) determine how particles behave at the edges of the domain.
The forward process naturally introduces spatial correlation into the noise, contrasting with methods that add uncorrelated Gaussian noise. The paper provides the mathematical formulation for the Master Equation governing this process and shows how the transition probabilities can be computed, efficiently using Fast Fourier Transform for periodic boundary conditions.
To train a neural network (NN) to reverse this process, the paper leverages the theoretical framework for reverse-time dynamics of continuous-time Markov systems. The reverse transition rates for a particle depend on its current location and initial condition, as well as the forward solution pt​(⋅∣⋅). Since the initial conditions of individual particles are unknown during inference, the NN is trained to model the per-pixel reverse transition rates. For a pixel (x,y,c) with n particles, the total reverse rate for jumping to a neighbor (x+νˉx​,y+νˉy​,c) is the sum of the individual reverse rates of the n particles.
The NN is trained using samples corrupted by the forward process at discrete time steps. The time steps are chosen using a noise schedule designed to produce an even degradation of image quality (measured by SSIM) throughout the diffusion process. Two loss functions are explored: an L1 rate-matching loss between the NN predicted rates and the true rates, and a process likelihood loss derived from the negative log-likelihood of the NN-induced process matching the true reverse process. The authors empirically find no significant difference between these loss functions.
For sampling (generating data from noise), the paper adapts Ï„-leaping, a common numerical integration method for discrete-state stochastic processes. To improve efficiency and avoid the issue of negative particle counts that can occur with standard Poisson Ï„-leaping, an adaptive step-size selection strategy is proposed. This strategy combines ideas from binomial Ï„-leaping and the Courant-Friedrichs-Lewy (CFL) condition to conservatively determine the time step Ï„, ensuring that the probability of a particle moving out of a pixel within that step remains bounded.
The DSD framework is implemented using a modified Noise Conditional Score Network (NCSN++) architecture, where the output layer predicts the rates for four directions (up, down, left, right) for each pixel and channel, and a SoftPlus activation ensures non-negativity of the rates.
Computational experiments demonstrate DSD's capabilities:
- Image Synthesis: On standard datasets like MNIST and CelebA, DSD achieves reasonable generative performance, showing it can capture complex spatial structures (Figure 3a, 3b, Appendix Figure 2).
- Image Inpainting: DSD can perform inpainting while preserving mass within the active region. By varying the total mass in the masked area, different digits can be generated from the same initial image (Figure 3c, Appendix Figure 3).
- Class Conditioning with Intensity Control: DSD can be conditioned on class labels and simultaneously allows for precise control over the total intensity (particle count) of the generated samples. This enables generating digits that are "bolder" or "lighter" while maintaining the correct class identity, a capability not precisely possible with conventional continuous diffusion models (Figure 3d, Appendix Figure 4).
- Scientific Applications (Materials Microstructure):
- Subsurface Rocks: DSD was trained on microtomography data of different rock types (Berea Sandstone, Savonnières Carbonate, Massangis Limestone). The ability to exactly condition on total intensity (porosity/volume fraction) is highly valuable in geosciences for generating synthetic microstructures that match field measurements. Generated samples successfully replicate key statistical properties like spatial correlation and pore size distribution (Figure 5, Appendix Figure 6).
- Lithium-ion Electrodes: Trained on tomography data of NMC cathodes, DSD can generate samples with precisely tuned phase volume fractions (active material, carbon binder, pore space). This enables systematic computational studies of how microstructure impacts electrode performance by controlling phase composition exactly (Figure 7, Appendix Figures 11 & 12). Metrics like interface length, triple-phase boundary, and relative diffusivity are used to compare generated samples to real data.
A limitation of DSD is that the computational cost of forward sampling for training and reverse sampling for inference scales linearly with the total intensity of the image. This makes it efficient for low-bit-depth or binary data but potentially less so than alternative methods for high-resolution, high-intensity images. Additionally, implementing DSD requires custom code for the forward process and sampling due to its discrete, spatial nature, presenting a steeper learning curve compared to standard continuous diffusion models.
In conclusion, DSD provides a principled approach to generative modeling for discrete spatial data, ensuring exact conservation laws are upheld. Its application to scientific domains like materials science demonstrates its potential to generate complex microstructures conditioned on critical global properties, bridging the gap between advanced generative models and the requirements of scientific data analysis and simulation.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
I_tk = initialize_image(shape(I_0), dtype=int) # Create an empty image for corrupted particles
for x in range(W):
for y in range(H):
for c in range(C):
num_particles = I_0[x, y, c]
if num_particles > 0:
# For each particle at (x,y,c) in I_0, sample its position at time t_k
# This is done by sampling from p_t_k(x',y',c'| x,y,c) num_particles times
sampled_positions = sample_from_transition_prob(x, y, c, t_k, num_particles, p_t_k)
# Add the sampled particles to the corrupted image I_tk
for pos_x, pos_y, pos_c in sampled_positions:
I_tk[pos_x, pos_y, pos_c] += 1
original_initial_positions = get_initial_positions(I_0, I_tk, t_k, p_t_k) # Conceptual step
ground_truth_reverse_rates = compute_ground_truth_rates(I_tk, t_k, r, original_initial_positions, p_t_k)
predicted_reverse_rates = NN(I_tk, t_k)
loss = mean(abs(predicted_reverse_rates - ground_truth_reverse_rates))
optimizer.zero_grad()
loss.backward()
optimizer.step() |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
|
I_0_initial = initialize_image_with_intensity(shape, target_total_intensities)
I_t = initialize_image(shape, dtype=int)
for x in range(W):
for y in range(H):
for c in range(C):
num_particles = I_0_initial[x, y, c]
if num_particles > 0:
# Sample particle positions at t=1 starting from (x,y,c) at t=0
sampled_positions = sample_from_transition_prob(x, y, c, 1.0, num_particles, p_1)
for pos_x, pos_y, pos_c in sampled_positions:
I_t[pos_x, pos_y, pos_c] += 1
current_time = 1.0
while current_time > 0:
# Get NN predicted reverse transition rates at the current image and time
# Output shape: 4 (directions) x H x W x C
predicted_rates = NN(I_t, current_time)
# Ensure rates are non-negative (SoftPlus in NN output layer handles this)
predicted_rates = max(predicted_rates, 0)
# Determine adaptive time step (tau) based on CFL condition
# CFL condition: max_rate * tau < epsilon
# Max rate across all pixels, channels, and directions
max_rate = max(predicted_rates.flatten())
# Avoid division by zero if max_rate is effectively zero (e.g., near pure state)
if max_rate < 1e-9:
tau = current_time # Step directly to 0 if process has stopped
else:
tau_cfl = epsilon / max_rate
tau = min(current_time, tau_cfl) # Ensure we don't step past time 0
# Perform jumps for the duration tau using a Binomial/Poisson approximation (tau-leaping)
# Iterate over each pixel and channel
I_next_t = initialize_image(shape(I_t), dtype=int) # Image for the next state
for x in range(W):
for y in range(H):
for c in range(C):
current_particles = I_t[x, y, c]
if current_particles > 0:
# Total rate out of this pixel for this channel
total_rate_out = sum(predicted_rates[:, x, y, c])
# Approximate number of particles leaving this pixel in time tau
# Using Binomial approximation based on total rate out
# Num trials = current_particles
# Probability of a single particle jumping in time tau = approx total_rate_out * tau / current_particles (rough idea, more complex with multiple directions)
# A more accurate approach from paper: Sample total jumps out using Binomial
# n_sigma = Binom(current_particles, 1 - exp(-total_rate_out * tau)) # Exact for single exit, approx for multiple directions
# Paper suggests n_sigma ~ Binom(current_particles, sum(rate*tau)) - this seems less rigorous for small populations?
# A common tau-leaping approach: sample number of jumps *per reaction channel* using Poisson(rate * tau)
# But need to avoid negative counts. Paper suggests binomial approach.
# Simpler (potentially less accurate) approach: Sample total particles leaving using Binomial, then directions using Multinomial
prob_jump_any_dir = 1.0 - exp(-total_rate_out * tau) # Approx probability of *at least one* jump in time tau for a particle
num_total_jumps = np.random.binomial(current_particles, prob_jump_any_dir) # Sample total particles attempting to jump
if num_total_jumps > 0:
# Sample destinations for the jumped particles using Multinomial
# Probabilities for each direction are proportional to predicted rates for that direction
direction_probs = predicted_rates[:, x, y, c] / total_rate_out
num_jumps_per_dir = np.random.multinomial(num_total_jumps, direction_probs)
# Move particles
I_t[x, y, c] -= num_total_jumps # Decrease count at origin pixel
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Order matching predicted_rates
for i, (dx, dy) in enumerate(directions):
num_moved = num_jumps_per_dir[i]
# Calculate destination pixel with boundary conditions
dest_x, dest_y = apply_boundary_conditions(x + dx, y + dy, shape, boundary_condition_type) # e.g., periodic wraps around
I_next_t[dest_x, dest_y, c] += num_moved
else:
# No particles at this pixel
pass # Or initialize I_next_t from I_t where no jumps occur from
# Update the image state by adding particles that moved TO each pixel
I_t += I_next_t
# Update time
current_time -= tau
|