MedUnifier: Unifying Vision-and-Language Pre-training on Medical Data with Vision Generation Task using Discrete Visual Representations
(2503.01019v3)
Published 2 Mar 2025 in cs.CV and cs.AI
Abstract: Despite significant progress in Vision-Language Pre-training (VLP), current approaches predominantly emphasize feature extraction and cross-modal comprehension, with limited attention to generating or transforming visual content. This gap hinders the model's ability to synthesize coherent and novel visual representations from textual prompts, thereby reducing the effectiveness of multi-modal learning. In this work, we propose MedUnifier, a unified VLP framework tailored for medical data. MedUnifier seamlessly integrates text-grounded image generation capabilities with multi-modal learning strategies, including image-text contrastive alignment, image-text matching and image-grounded text generation. Unlike traditional methods that reply on continuous visual representations, our approach employs visual vector quantization, which not only facilitates a more cohesive learning strategy for cross-modal understanding but also enhances multi-modal generation quality by effectively leveraging discrete representations. Our framework's effectiveness is evidenced by the experiments on established benchmarks, including uni-modal tasks (supervised fine-tuning), cross-modal tasks (image-text retrieval and zero-shot image classification), and multi-modal tasks (medical report generation, image synthesis), where it achieves state-of-the-art performance across various tasks. MedUnifier also offers a highly adaptable tool for a wide range of language and vision tasks in healthcare, marking advancement toward the development of a generalizable AI model for medical applications.
Summary
The paper introduces MedUnifier, a unified vision-language pre-training framework for medical data that integrates text-grounded image generation using discrete visual representations.
MedUnifier employs discrete visual representations and a text-grounded image generation task to enable high-quality medical image synthesis and enhance multi-modal understanding.
The model achieves state-of-the-art performance across uni-modal, cross-modal, and multi-modal medical tasks, demonstrating adaptability for report generation and dataset augmentation.
The paper introduces MedUnifier, a unified vision-language pre-training (VLP) framework tailored for medical data, integrating text-grounded image generation capabilities with multi-modal learning strategies. The framework employs visual vector quantization to leverage discrete representations, enhancing both cross-modal understanding and multi-modal generation quality. The model achieves SOTA performance across uni-modal, cross-modal, and multi-modal tasks.
The MedUnifier framework incorporates learnable embeddings within a Transformer model, drawing inspiration from BLIP-2, and introduces a text-grounded image generation (TIG) loss, leveraging vector quantization for discrete visual representation learning. A novel latent adapter connects the base model with the image generation module, enabling co-training with image-text contrastive (ITC), image-text matching (ITM), and image-grounded text generation (ITG) losses.
The main contributions include:
The MedUnifier framework which unifies the VLP paradigm with language-guided visual generation
Discrete visual representation learning with a bridging design to guide visual outputs and enhance data interpretation
Performance enhancements on Chest X-rays across uni-modality, cross-modality, and multi-modality tasks
Adaptability in generating realistic medical images and reports, augmenting out-of-distribution datasets
A TIG module to capture fine-grained details by recovering pixel-level information from hierarchical multi-modal representations
Related Work
The paper discusses related work in two primary areas: VLP and text-to-image (T2I) generation. Existing VLP models are categorized into those using uni-modal encoders and those using fusion encoder-based structures. The paper posits that current VLP approaches often lack consideration for generating visual information and exploration of detailed vision content. Recent T2I approaches using GANs and auto-regressive transformers, variational auto-encoders (VAEs), vector quantized VAEs (VQ-VAEs), and diffusion models are also discussed. The paper chooses VQ-VAEs to learn robust representations, enhancing the quality and efficiency of medical image generation.
Method
The MedUnifier framework aggregates four key learning objects on Med-VLP. The entire pre-training objective function is defined as
Ltotal=m=1∑MλmLm(Hm(F(XI,XT)))
where
F represents backbone that takes the paired [xi,xt] as input.
Hm stands for task-specific modules for further encoding visual and textual features.
Lm and λm are different loss functions and their weights for the overall loss calculation with the total number of loss functions being M.
The model consists of an image-text encoder, a text generator, and an image generator with cross-attention layer, masking strategies and vector discretization.
Image-text encoder
A BERT-styled Transformer is used as the image-text encoder network. The input contains learnable embeddings and clinical reports tokenized by words. Input images are processed into a set of patch embeddings using a pre-trained, frozen Vision Transformer (ViT). The initial visual embeddings engage with the image-text encoder network through cross-attention layers.
Text generator
The text encoder of the image-text encoder is duplicated as a language-generative network with shared weights. A decoding head is added to map each word token embedding to the vocabulary dictionary.
Image generator
A vector-quantized variational auto-encoder (VQ-VAE) is integrated within a cross-modal interactive fusion framework to generate high-quality synthetic visual content.
Given an image xi∈RC×H×W, the entire image is divided into Lv patches with spatial size (h,w), and learnable positional encodings are added:
Xi=[p[CLS],p1,p2,…,pLv]+Eposv
where
p∈Rdv stands for input patch embedding
Eposv∈R(1+Lv)×dv is learnable positional encodings
Lv=hH∗wW
These patch embeddings get passed through a standard pre-trained ViT-g, denoted as EI, to attain preliminary visual embeddings $\boldsymbol{f}^{v}\in \mathbb{R}^{(L_{v}+1)\times d_{v}$:
f[CLS]v is global visual feature
outputs of patch embeddings flocalv∈RLv×dv represent local visual features
For the corresponding textual input, the input text is tokenized to word token embeddings, adding learnable positional encodings:
Xt=[w[SPE],w1,w2,…,wLt]+Epost
where
w∈Rdt represents word token embeddings
Epost∈R(1+Lt)×dt is learnable positional encodings
To enable interaction between word token embeddings and preliminary visual embeddings, a set of learnable embeddings, denoted as $\boldsymbol{Q} = [\boldsymbol{q}_1, \boldsymbol{q}_2, \dots, \boldsymbol{q}_{L_q}], \boldsymbol{Q}\in \mathbb{R}^{L_q\times d_{q}$ are constructed. Word token embeddings and learnable embeddings of the same feature dimension are unified, e.g. dt=dq. Then, Q and Xt are concatenated to form the input of the image-text encoder, denoted as EQ, encoding it to get output embeddings:
EQ([Q,Xt])=[fq,ft]=[fq,f[SPE]t,flocalt]
where
fq∈RLq×dq is learned embeddings
f[SPE]t∈Rdt
flocalt∈RLt×dt represent special text representation and all word token embeddings, respectively
Image-text contrastive learning (ITC)
This task aligns visual and textual representations by maximizing their mutual information through a contrastive approach. The pairwise similarity between each visual and textual representation gq and gt is computed. The highest one is chosen as the image-text similarity to calculate bi-directional contrastive loss:
τ∈R is a scaling temperature parameter initialized to 0.07
N is mini-batch size and a^¨⋅,⋅a^c◯ represents the cosine similarity
The overall ITC loss is defined as:
Litc=21(Litc(q∣t)+Litc(t∣q))
Image-text matching (ITM)
This task learns a precise alignment between visual and textual representations by training a model to classify image-text pairs as either positive or negative in a binary classification framework. The Image-Text Matching (ITM) loss is computed as:
Litm=N1k=1∑N−log(p(Yk∣Yk^))
where Y^ is defined as:
Y^=Lq1i=1∑LqHitm(fiq)
Y represents ground truth labels within mini-batch by hard negative samples mining.
Image-grounded text generation (ITG)
This task trains the model to generate text conditioned on paired images using causal LLMing (CLM). The learning objective is formalized as:
Litg=NLt1k=1∑Ni=1∑Lt−log(pi)
where pi is defined as:
pi=Softmax(Hitg(flocalt))=p(wi∣Q,…,wi−1)
Text-grounded image generation (TIG)
The TIG module is designed for the text-grounded image generation task, integrated into the image-text encoder and text generator. At the top level, a latent adapter, denoted as Ztop, transforms ztop into spatial feature map zetop:
zetop=Ztop(ztop)
where Ztop consisted of a nonlinear transformation, spatial positional encoding summer and a residual block. Followed by a vector quantization layer with a latent embedding space etop, a discrete feature map zqtop is derived:
zqtop=quantizertop(zetop)
At the bottom level, a latent adapter and vector quantizer with a latent embedding space ebottom are deployed to gain a discrete feature map:
zebottom=Zbottom(zbottom)
zqbottom=quantizerbottom(zebottom,zqtop)
A hierarchical decoder D with deconvolutional layers is built to recover raw visual input from discrete multi-modal representations:
x^i=D(zqtop,zqbottom)
The text-grounded image generation (TIG) loss is formulated as:
the negative logarithmic term can be written as mean square error (MSE) xki−x^ki2sg[⋅] is gradient stop operation
Hyper-parameters β1,β2 are both set to be 0.5
Total learning objectives
The ultimate objective function is:
Ltotal=λ1Litc+λ2Litm+λ3Litg+λ4Ltig
The four loss weights λ were set to 1 in experiments.
Experiments
The pre-training is performed on the MIMIC-CXR v2.0.0 dataset, and the model is evaluated on various downstream tasks.
Implementation details
A BERT model is used as the primary network for the image-text encoder, and ViT-g is used as the pre-trained ViT. The input image resolution is set to 224×224, with a maximum text length of 95 tokens and 32 learnable embeddings. For optimization, the AdamW optimizer is applied with specific parameters.
Medical Vision-and-Language Benchmark
The effectiveness of the proposed method is assessed across uni-modal, cross-modal, and multi-modal tasks. The experiments are conducted on MedUnifier with/without TIG loss and compared with previous studies.
Results and Analyses
The model outperforms prior studies on uni-modal tasks across various downstream datasets. For cross-modal retrieval, the MedUnifier model achieves the highest performance. The MedUnifier demonstrates better performance on zero-shot classification tasks for both the MIMIC 5x200 and RSNA datasets. The models, both with and without the TIG module, surpass previous methods for image-grounded medical report generation. The Med-VLP framework also gains substantial advantages from incorporating causal LLMing. The reconstructed visual samples are nearly indistinguishable from authentic radiographs, and synthetic samples generated from multi-modal priors demonstrate high diversity.
Ablation paper
An ablation paper across various learning objectives indicates that using only the ITC loss yields the lowest performance. The model with TIG surpasses the one with ITG. The integration of all objective types enables the model to achieve optimal performance.
Conclusion
The paper introduces MedUnifier, a unified Med-VLP model, which optimizes four distinct learning objectives simultaneously. The framework circumvents the need to learn visual embeddings from scratch and reconstructs pixel-level visual details from both image and report. The proposed method effectively complements existing Med-VLP frameworks and achieves SOTA performance on various downstream tasks.