Papers
Topics
Authors
Recent
2000 character limit reached

Exploring a Graph-based Approach to Offline Reinforcement Learning for Sepsis Treatment

Published 3 Sep 2025 in cs.LG | (2509.03393v1)

Abstract: Sepsis is a serious, life-threatening condition. When treating sepsis, it is challenging to determine the correct amount of intravenous fluids and vasopressors for a given patient. While automated reinforcement learning (RL)-based methods have been used to support these decisions with promising results, previous studies have relied on relational data. Given the complexity of modern healthcare data, representing data as a graph may provide a more natural and effective approach. This study models patient data from the well-known MIMIC-III dataset as a heterogeneous graph that evolves over time. Subsequently, we explore two Graph Neural Network architectures - GraphSAGE and GATv2 - for learning patient state representations, adopting the approach of decoupling representation learning from policy learning. The encoders are trained to produce latent state representations, jointly with decoders that predict the next patient state. These representations are then used for policy learning with the dBCQ algorithm. The results of our experimental evaluation confirm the potential of a graph-based approach, while highlighting the complexity of representation learning in this domain.

Summary

  • The paper introduces a graph-based offline RL framework that leverages dynamic heterogeneous patient graphs to improve sepsis treatment decisions.
  • It compares GNN-based encoders, notably GraphSAGE and GATv2, with a baseline autoencoder, demonstrating superior policy performance with GraphSAGE reaching a WIS score of 0.75.
  • The study highlights trade-offs between increased model complexity and computational efficiency, emphasizing the need for optimized graph modeling in clinical RL applications.

Graph-Based Offline Reinforcement Learning for Sepsis Treatment: Technical Analysis

Introduction

This paper presents a graph-based methodology for offline reinforcement learning (RL) in the context of sepsis treatment, leveraging dynamic heterogeneous graphs and graph neural networks (GNNs) to encode patient trajectories from the MIMIC-III dataset. The study decouples representation learning from policy learning, enabling a controlled comparison between GNN-based encoders (GraphSAGE and GATv2) and a baseline autoencoder (AE) approach. The RL agent is trained using the discrete Batch-Constrained Q-Learning (dBCQ) algorithm, with policy evaluation performed via Weighted Importance Sampling (WIS).

Graph Modeling of Patient Trajectories

The core innovation is the transformation of tabular patient data into dynamic heterogeneous graphs. Each patient trajectory is represented as a graph with a central Patient node (encoding time-invariant features) and a sequence of Timestep nodes (encoding time-variant features). Edges between Timestep nodes capture treatment actions, and a Terminal node encodes the final reward (survival outcome). Figure 1

Figure 1: Type graph for patient trajectory graph, illustrating node and edge types for encoding patient features and treatment history.

Figure 2

Figure 2: Patient trajectory graph with 7 time steps, showing the evolution of patient state and administered actions.

This graph structure enables the explicit modeling of temporal dependencies and treatment history, which are critical for safe and effective RL-based clinical decision support.

GNN-Based Representation Learning

Two GNN architectures are evaluated for encoding graph snapshots into latent state representations:

  • GraphSAGE (SAGEConv): Aggregates node features via mean pooling per edge type, followed by summation across node types and a final linear projection.
  • GATv2Conv: Employs attention mechanisms to weight neighbor contributions, leveraging edge weights and learnable attention scores.

Both encoders are trained in an autoencoder framework to predict the next patient state, with the decoder architecture held constant across experiments. Figure 3

Figure 3: Component diagram of the proposed approach, highlighting the graph modeling and GNN encoder as the main contributions.

Training and Hyperparameter Selection

Extensive hyperparameter tuning is performed for both GNN encoders. For SAGEConv, the optimal configuration is two convolutional layers with 64 output features. For GATv2Conv, a single convolutional layer with 64 output features yields the best results, likely due to oversmoothing in deeper architectures given the small graph size.

Training is performed on high-performance hardware (AMD EPYC 9554, NVIDIA L40 GPUs), with GNN autoencoder training requiring significantly more computational resources than the AE baseline.

Policy Learning and Evaluation

Latent representations from the trained encoders are used as observations for dBCQ policy learning. The action space is discretized into 25 bins (combinations of vasopressor and IV fluid dosages), and the reward is terminal (survival outcome). Policy evaluation is performed using WIS, with behavioral cloning used to estimate clinician policies for importance weighting. Figure 4

Figure 4: WIS on 1e6 training iteration on GNN-SAGEConv and GNN-GATv2Conv, comparing policy learning efficiency and final scores.

Figure 5

Figure 5: WIS on 500K training iteration on AE, GNN-SAGEConv and GNN-GATv2Conv, illustrating early learning dynamics.

Key findings:

  • AE representations yield rapid initial policy improvement but plateau at a WIS score of 0.68.
  • GNN-SAGEConv representations require more iterations to converge but ultimately reach a higher WIS score (0.75).
  • GNN-GATv2Conv representations underperform both AE and SAGEConv, with slow learning and lower final scores.

Representation Learning Performance

Autoencoder training loss and validation loss are tracked for all encoder types. GNN-GATv2Conv achieves the lowest validation loss in the representation learning phase, but this does not translate to superior policy learning performance. Figure 6

Figure 6

Figure 6: Training loss for GNN-SAGEConv and GNN-GATv2Conv autoencoders, showing convergence behavior.

Figure 7

Figure 7

Figure 7: Validation loss for GNN-SAGEConv and GNN-GATv2Conv, highlighting generalization performance.

Architectural Trade-Offs and Limitations

The decoupling of representation and policy learning isolates the effect of encoder architecture. The results indicate that access to historical data via graph structure is necessary but not sufficient for improved policy learning; architectural choices (e.g., attention mechanisms) can introduce complexity that may hinder learning efficiency.

GNN-GATv2Conv's additional attention weights appear to mislead the policy learner, consistent with prior findings that increased model complexity does not guarantee better RL performance. Training efficiency is a significant concern: GNN autoencoder training is an order of magnitude slower than AE, and GPU acceleration does not yield expected speedups due to small graph sizes.

Threats to Validity

The study addresses construct, internal, external, and reliability validity. Notably, the use of a single dataset (MIMIC-III) limits generalizability, and the reliance on WIS as a quantitative metric may not fully capture clinical relevance. The difference in input information between AE and GNN encoders complicates direct architectural comparisons.

Implications and Future Directions

The results demonstrate that graph-based representations, specifically those learned via GraphSAGE, can match or exceed the accuracy of traditional relational encodings for offline RL in sepsis treatment, albeit with greater computational cost and slower convergence. The findings suggest that further improvements may be achieved by hybridizing GNNs with recurrent architectures to better capture temporal dependencies. Figure 8

Figure 8

Figure 8

Figure 8

Figure 8: Snapshot at the first time step, illustrating the initial graph structure for a patient trajectory.

The approach is extensible to other clinical decision-making tasks where temporal and relational dependencies are critical. Future work should explore alternative graph modeling strategies (e.g., distributing features across multiple nodes), qualitative policy evaluation, and dynamic GNN architectures incorporating RNNs.

Conclusion

This paper provides a rigorous evaluation of graph-based representation learning for offline RL in sepsis treatment. The graph modeling approach enables richer encoding of patient trajectories, and GraphSAGE-based encoders yield competitive policy learning performance. However, increased model complexity does not guarantee improved results, and computational efficiency remains a challenge. The study lays the groundwork for further exploration of GNNs in clinical RL, with implications for personalized medicine and safe, interpretable AI-driven treatment recommendations.

Whiteboard

Paper to Video (Beta)

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 32 tweets with 253 likes about this paper.