- The paper introduces a VAE-based NRI model that infers latent interaction graphs and predicts future states from observational data.
- It employs an encoder-decoder framework with MLP, CNN, and GNN components to capture complex inter-agent dependencies.
- Experimental results show near-perfect recovery and improved prediction accuracy in physics simulations, motion capture, and sports tracking data.
Neural Relational Inference for Interacting Systems
Interacting systems are pervasive across multiple domains, including physical systems, biological systems, and even complex societal dynamics. Neural Relational Inference (NRI) presents a robust framework for understanding such systems, aiming to infer the implicit interactions within these systems from observational data alone. This paper introduces the NRI model, an unsupervised approach that leverages a variational auto-encoder (VAE) architecture with graph neural networks (GNNs) to simultaneously infer latent interactions and dynamics.
Model Architecture
The NRI model consists of two main components: an encoder and a decoder. The encoder infers the latent interaction graph based on observed trajectories, while the decoder predicts future states by leveraging the inferred graph. The encoder uses a multi-layer perceptron (MLP) or convolutional neural network (CNN) to transform node features into edge features, effectively capturing potential interactions between nodes by passing messages along the graph.
Encoder
The encoder views the entire system as a fully connected graph, utilizing MLP or CNN transformations followed by attentive pooling to compute representations:
1
2
3
4
5
6
7
8
9
10
|
x = self.mlp1(x) # 2-layer ELU net per node
x = self.node2edge(x)
x = self.mlp2(x)
x_skip = x
x = self.edge2node(x)
x = self.mlp3(x)
x = self.node2edge(x)
x = torch.cat((x, x_skip), dim=2)
x = self.mlp4(x)
return self.fully_connected_out(x) |
This encoding process ensures that the model captures non-trivial dependencies in the data, essential for understanding complex interactions.
Decoder
The decoder can be either a Markovian or Recurrent structure. The Markovian decoder performs single-step predictions, while the Recurrent variant uses a gated recurrent unit (GRU) to predict multi-step future states, conditioning on past states and inferred interaction types:
1
2
3
4
5
6
7
8
9
10
11
12
|
pre_msg = self.node2edge(inputs)
for i in range(start_idx, num_edges):
msg = F.relu(self.msg_fc1[i](pre_msg))
msg = F.relu(self.msg_fc2[i](msg))
msg = msg * edge_type[:, :, :, i:i + 1]
all_msgs += msg
agg_msgs = self.edge2node(all_msgs)
hidden = torch.cat([inputs, agg_msgs], dim=-1)
pred = F.relu(self.out_fc1(hidden)
pred = F.relu(self.out_fc2(pred)
pred = self.out_fc3(pred)
return inputs + pred |
By separating interactions into different edge types, the decoder enables nuanced modeling of varying interaction effects.
Experimental Results
The NRI model demonstrates efficacy across several simulated and real-world datasets:
- Physics Simulations: Including particles connected by springs, charged particles, and phase-coupled oscillators, the NRI model achieves near-perfect unsupervised recovery of interaction graphs and robust future state predictions.
- Motion Capture Data: When applied to human motion data, the model significantly outperforms predictive baselines, capturing dynamic dependencies among joints effectively. Moreover, dynamically re-estimating the latent graph improves prediction accuracy.
- Sports Tracking Data: For NBA player tracking data, NRI distinguishes itself by learning interpretable edge types relevant to basketball plays, such as distinguishing between ball handler interactions and other players.
Implications and Future Directions
The practical implications for NRI are profound, particularly in domains requiring dynamic system modeling without explicit interaction labels. The versatility of both MLP and GNN-based encoders, coupled with flexible decoder architectures, provides a framework adaptable to varying complexities in latent interactions.
Theoretically, NRI contributes to understanding how neural architectures can infer complex, multi-agent interactions. Future research can explore extensions to dynamic interaction graphs, improving long-term predictive accuracy and modeling even more intricate systems, such as those found in real-time traffic management or large-scale biological processes.
Conclusion
The NRI model marks a significant advancement in unsupervised learning of interacting systems. Its dual capability of inferring latent graphs and predicting dynamics positions it as a valuable tool for scientists and engineers working with complex multi-agent systems. With further refinement, particularly in handling dynamically evolving interactions, NRI holds promise for an even broader range of applications.