This paper introduces MedM-PLM, a Medical Multimodal Pre-trained LLM designed to process both structured and unstructured data from Electronic Health Records (EHRs). The core idea is to move beyond treating these data sources independently or simply concatenating them, and instead explicitly model the intricate interactions and complementary information present between clinical codes (structured) and clinical narratives (unstructured). The authors argue that these interactions are crucial for understanding patient status comprehensively and improving performance on various clinical prediction tasks.
MedM-PLM operates in two main phases: pre-training and fine-tuning. The architecture consists of two primary modules:
- Unimodal Module: This module is responsible for learning representations specific to each modality.
- Structured Data Component: Inspired by G-BERT [25], this component processes sequences of clinical codes (like diagnosis or medication codes). It uses an ontology-aware embedding layer that incorporates hierarchical information from medical code ontologies (like ICD-9 and ATC-3) using Graph Attention Networks (GATs) [61]. This allows the model to capture relationships between codes based on their position in the hierarchy. The embedded code sequences are then processed by Transformer encoder layers.
- Unstructured Data Component: Similar to BERT [60] and ClinicalBERT [22], this component handles clinical narratives. It uses standard token, segment, and position embeddings for the text sequence, which are then fed into Transformer encoder layers. The parameters for this component can be initialized with pre-trained weights from models like ClinicalBERT to leverage general clinical language understanding.
- Cross-modal Module: This module integrates the representations learned by the unimodal components to model interactions between the two modalities. It uses a cross-attention mechanism where representations from one modality (e.g., the visit-level [CLS] token from structured data) are used as queries to attend over the token-level representations of the other modality (unstructured data), and vice-versa. The outputs of this cross-attention are combined with the original unimodal representations using residual connections to produce augmented visit-level representations, denoted as and . These augmented representations are designed to capture information that is an overview of both modalities.
The model is pre-trained on a large dataset of paired structured and unstructured EHR records (derived from MIMIC-III (Liu et al., 2022 )). The pre-training employs two masked prediction tasks specifically designed to foster cross-modal understanding:
- Text-to-Code Prediction: Given the unstructured text and a masked structured code sequence, the model is trained to predict the masked codes using the augmented text representation (). This encourages the model to identify textual evidence that supports the presence of specific codes.
- Code-to-Code Prediction: Given the structured code sequence (with masked codes) and the unstructured text, the model is trained to predict the masked codes using the augmented structured code representation (). This helps the model understand the relationships and dependencies between different codes within a visit, potentially informed by the context from the text.
The overall pre-training objective minimizes the combined cross-entropy loss for these two tasks.
After pre-training, MedM-PLM can be fine-tuned for various downstream clinical prediction tasks. This typically involves taking the learned augmented visit-level representations ( and ) or combinations thereof and feeding them into a task-specific classification layer, usually a Multi-Layer Perceptron (MLP). The fine-tuning process adjusts the model parameters (including the pre-trained ones, potentially with different learning rates) to optimize performance on the target task.
The paper demonstrates MedM-PLM's effectiveness on three common clinical tasks using the MIMIC-III dataset:
- Medication Recommendation: Predicting the medications prescribed during a visit. This is framed as a multi-label classification task. For fine-tuning, the paper mentions using a combination of mean historical representations and current visit representations, suggesting the model can potentially leverage sequential patient history if available.
- 30-day Readmission Prediction: Predicting whether a patient will be readmitted within 30 days of discharge, a binary classification task. Fine-tuning uses the concatenated and from the current visit.
- ICD Coding: Assigning relevant diagnosis codes to a patient record, a multi-label classification task. Fine-tuning uses concatenated and a representation derived from the structured data (specifically, the medication codes, denoted as ).
The experimental results show that MedM-PLM consistently outperforms baseline methods, including unimodal PLMs (G-BERT, ClinicalBERT, Med-BERT) and straightforward concatenation approaches (G-BERT+ClinicalBERT, Med-BERT+ClinicalBERT). This highlights the benefit of explicitly modeling cross-modal interactions. The model also shows robust performance even with limited training data for fine-tuning, suggesting its pre-trained knowledge is valuable for few-shot learning scenarios.
Practical Implementation Details and Considerations:
- Data Preparation: The success relies heavily on having paired structured and unstructured data for the same patient visits. Preprocessing involves converting clinical text to lowercase, removing noise (line breaks, carriage returns), tokenizing text, and mapping clinical codes to a standardized vocabulary. For structured data, representing codes using hierarchical ontologies is a key step. Sequence lengths need to be managed (e.g., truncating text to 512 tokens, limiting codes to 61 per visit as done in the paper).
- Ontology Embedding: Implementing the GAT-based ontology embedding for structured codes requires defining the hierarchy (adjacency matrices) and initializing node embeddings. This component adds complexity compared to simple token embeddings.
- Model Architecture: A deep learning framework like PyTorch or TensorFlow is necessary to build the Transformer-based unimodal encoders and the cross-attention module. Initializing the unstructured component with pre-trained weights from ClinicalBERT (or a similar model) requires loading and potentially adapting the checkpoint. Freezing early layers during pre-training or fine-tuning can help stabilize training and transfer knowledge.
- Pre-training: This is computationally intensive, requiring significant GPU resources (the paper used two RTX 3090 GPUs). The masking strategy and objectives need careful implementation. The prediction heads for masked codes likely involve projecting the augmented visit representations () to the size of the code vocabulary.
- Fine-tuning: Adapting the pre-trained model involves adding simple linear layers (MLPs) for the specific downstream tasks. Different tasks may require different input combinations to the final MLP (e.g., historical averages + current visit for medication recommendation, current visit for readmission prediction). Learning rates should typically be lower for fine-tuning compared to pre-training to avoid disrupting the learned representations.
- Computational Resources: Both pre-training and fine-tuning large Transformer models require substantial memory and computation. Deployment for real-time inference would require hardware capable of running these models efficiently.
- Limitations: The paper notes several limitations relevant to real-world application:
- Using only diagnosis and medication codes; incorporating other data types (labs, vitals, procedures) could enhance the model but requires more complex data integration.
- Focusing primarily on single-visit records during pre-training limits the model's ability to capture temporal patterns across multiple visits for the same patient.
- Fixed sequence length truncation might lose information from very long clinical notes.
- The model's performance is evaluated on MIMIC-III (intensive care unit data), and generalizability to data from different hospital settings, specialties, or countries (with different coding systems or clinical practices) might require further adaptation or pre-training on more diverse datasets.
Conceptual Pseudocode Snippet:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
class MedM_PLM(nn.Module): def __init__(self, unimodal_structured_encoder, unimodal_unstructured_encoder, cross_modal_module, code_vocab_size): super().__init__() self.structured_encoder = unimodal_structured_encoder # e.g., G-BERT component self.unstructured_encoder = unimodal_unstructured_encoder # e.g., ClinicalBERT component self.cross_modal = cross_modal_module # Cross-attention + Residual self.code_prediction_head_text = nn.Linear(hidden_dim, code_vocab_size) # For Text-to-Code self.code_prediction_head_code = nn.Linear(hidden_dim, code_vocab_size) # For Code-to-Code def forward(self, structured_input, unstructured_input): # Unimodal Encoding # Zc: token-level, Zc_cls: visit-level Zc, Zc_cls = self.structured_encoder(structured_input) # Zw: token-level, Zw_cls: visit-level Zw, Zw_cls = self.unstructured_encoder(unstructured_input) # Cross-modal Integration # RCode, RText are augmented visit-level representations RCode, RText = self.cross_modal(Zc, Zw, Zc_cls, Zw_cls) # Return outputs needed for pre-training or fine-tuning return Zc, Zw, RCode, RText # May return different things based on phase def compute_pretraining_loss(model, structured_batch, unstructured_batch): # Assume structured_batch contains masked codes, and we have ground truth masked_code_labels # unstructured_batch is unmasked text for this part Zc, Zw, RCode, RText = model(structured_batch, unstructured_batch) # Get predictions for the masked tokens using augmented visit representations # Need logic to align predictions with actual masked positions and labels masked_code_logits_from_text = self.code_prediction_head_text(RText) masked_code_logits_from_code = self.code_prediction_head_code(RCode) # Compute loss only for masked tokens # (Actual implementation involves extracting logits/labels for masked positions) lt2c = F.cross_entropy(masked_code_logits_from_text, masked_code_labels_aligned_with_RText) lc2c = F.cross_entropy(masked_code_logits_from_code, masked_code_labels_aligned_with_RCode) total_loss = lt2c + lc2c return total_loss class ReadmissionClassifier(nn.Module): def __init__(self, pre_trained_medm_plm, hidden_dim): super().__init__() self.medm_plm = pre_trained_medm_plm # Freeze pre-trained weights if desired # for param in self.medm_plm.parameters(): # param.requires_grad = False self.classifier = nn.Sequential( nn.Linear(2 * hidden_dim, hidden_dim), # Concatenating RCode and RText nn.ReLU(), nn.Linear(hidden_dim, 1), nn.Sigmoid() ) def forward(self, structured_input, unstructured_input): # Get augmented visit representations from pre-trained model (running in eval mode if frozen) # Need to ensure the pre-trained model returns RCode, RText for fine-tuning _, _, RCode, RText = self.medm_plm(structured_input, unstructured_input) combined_rep = torch.cat([RText, RCode], dim=-1) prediction = self.classifier(combined_rep) return prediction |
In summary, MedM-PLM offers a principled approach to leveraging multimodal EHR data by explicitly modeling interactions. Its pre-training strategy and cross-modal architecture demonstrate strong performance across diverse clinical prediction tasks, making it a valuable tool for developing data-driven healthcare applications, particularly in scenarios with limited task-specific labeled data. However, deploying such a model requires careful handling of complex EHR data, significant computational resources for training, and consideration of its generalizability across different healthcare contexts.