- The paper introduces a graph-based offline RL framework that leverages dynamic heterogeneous patient graphs to improve sepsis treatment decisions.
- It compares GNN-based encoders, notably GraphSAGE and GATv2, with a baseline autoencoder, demonstrating superior policy performance with GraphSAGE reaching a WIS score of 0.75.
- The study highlights trade-offs between increased model complexity and computational efficiency, emphasizing the need for optimized graph modeling in clinical RL applications.
Graph-Based Offline Reinforcement Learning for Sepsis Treatment: Technical Analysis
Introduction
This paper presents a graph-based methodology for offline reinforcement learning (RL) in the context of sepsis treatment, leveraging dynamic heterogeneous graphs and graph neural networks (GNNs) to encode patient trajectories from the MIMIC-III dataset. The study decouples representation learning from policy learning, enabling a controlled comparison between GNN-based encoders (GraphSAGE and GATv2) and a baseline autoencoder (AE) approach. The RL agent is trained using the discrete Batch-Constrained Q-Learning (dBCQ) algorithm, with policy evaluation performed via Weighted Importance Sampling (WIS).
Graph Modeling of Patient Trajectories
The core innovation is the transformation of tabular patient data into dynamic heterogeneous graphs. Each patient trajectory is represented as a graph with a central Patient node (encoding time-invariant features) and a sequence of Timestep nodes (encoding time-variant features). Edges between Timestep nodes capture treatment actions, and a Terminal node encodes the final reward (survival outcome).
Figure 1: Type graph for patient trajectory graph, illustrating node and edge types for encoding patient features and treatment history.
Figure 2: Patient trajectory graph with 7 time steps, showing the evolution of patient state and administered actions.
This graph structure enables the explicit modeling of temporal dependencies and treatment history, which are critical for safe and effective RL-based clinical decision support.
GNN-Based Representation Learning
Two GNN architectures are evaluated for encoding graph snapshots into latent state representations:
- GraphSAGE (SAGEConv): Aggregates node features via mean pooling per edge type, followed by summation across node types and a final linear projection.
- GATv2Conv: Employs attention mechanisms to weight neighbor contributions, leveraging edge weights and learnable attention scores.
Both encoders are trained in an autoencoder framework to predict the next patient state, with the decoder architecture held constant across experiments.
Figure 3: Component diagram of the proposed approach, highlighting the graph modeling and GNN encoder as the main contributions.
Training and Hyperparameter Selection
Extensive hyperparameter tuning is performed for both GNN encoders. For SAGEConv, the optimal configuration is two convolutional layers with 64 output features. For GATv2Conv, a single convolutional layer with 64 output features yields the best results, likely due to oversmoothing in deeper architectures given the small graph size.
Training is performed on high-performance hardware (AMD EPYC 9554, NVIDIA L40 GPUs), with GNN autoencoder training requiring significantly more computational resources than the AE baseline.
Policy Learning and Evaluation
Latent representations from the trained encoders are used as observations for dBCQ policy learning. The action space is discretized into 25 bins (combinations of vasopressor and IV fluid dosages), and the reward is terminal (survival outcome). Policy evaluation is performed using WIS, with behavioral cloning used to estimate clinician policies for importance weighting.
Figure 4: WIS on 1e6 training iteration on GNN-SAGEConv and GNN-GATv2Conv, comparing policy learning efficiency and final scores.
Figure 5: WIS on 500K training iteration on AE, GNN-SAGEConv and GNN-GATv2Conv, illustrating early learning dynamics.
Key findings:
- AE representations yield rapid initial policy improvement but plateau at a WIS score of 0.68.
- GNN-SAGEConv representations require more iterations to converge but ultimately reach a higher WIS score (0.75).
- GNN-GATv2Conv representations underperform both AE and SAGEConv, with slow learning and lower final scores.
Autoencoder training loss and validation loss are tracked for all encoder types. GNN-GATv2Conv achieves the lowest validation loss in the representation learning phase, but this does not translate to superior policy learning performance.

Figure 6: Training loss for GNN-SAGEConv and GNN-GATv2Conv autoencoders, showing convergence behavior.
Figure 7: Validation loss for GNN-SAGEConv and GNN-GATv2Conv, highlighting generalization performance.
Architectural Trade-Offs and Limitations
The decoupling of representation and policy learning isolates the effect of encoder architecture. The results indicate that access to historical data via graph structure is necessary but not sufficient for improved policy learning; architectural choices (e.g., attention mechanisms) can introduce complexity that may hinder learning efficiency.
GNN-GATv2Conv's additional attention weights appear to mislead the policy learner, consistent with prior findings that increased model complexity does not guarantee better RL performance. Training efficiency is a significant concern: GNN autoencoder training is an order of magnitude slower than AE, and GPU acceleration does not yield expected speedups due to small graph sizes.
Threats to Validity
The study addresses construct, internal, external, and reliability validity. Notably, the use of a single dataset (MIMIC-III) limits generalizability, and the reliance on WIS as a quantitative metric may not fully capture clinical relevance. The difference in input information between AE and GNN encoders complicates direct architectural comparisons.
Implications and Future Directions
The results demonstrate that graph-based representations, specifically those learned via GraphSAGE, can match or exceed the accuracy of traditional relational encodings for offline RL in sepsis treatment, albeit with greater computational cost and slower convergence. The findings suggest that further improvements may be achieved by hybridizing GNNs with recurrent architectures to better capture temporal dependencies.



Figure 8: Snapshot at the first time step, illustrating the initial graph structure for a patient trajectory.
The approach is extensible to other clinical decision-making tasks where temporal and relational dependencies are critical. Future work should explore alternative graph modeling strategies (e.g., distributing features across multiple nodes), qualitative policy evaluation, and dynamic GNN architectures incorporating RNNs.
Conclusion
This paper provides a rigorous evaluation of graph-based representation learning for offline RL in sepsis treatment. The graph modeling approach enables richer encoding of patient trajectories, and GraphSAGE-based encoders yield competitive policy learning performance. However, increased model complexity does not guarantee improved results, and computational efficiency remains a challenge. The study lays the groundwork for further exploration of GNNs in clinical RL, with implications for personalized medicine and safe, interpretable AI-driven treatment recommendations.