Task Graph Transformer (TGT)
- The paper introduces Task Graph Transformer (TGT), a unified framework combining GNN-based feature extraction with transformer pipelines for multi-task learning on WSIs.
- It employs task-aware knowledge injection and domain-driven pooling, achieving consistent accuracy boosts (up to 3.78%) over baseline models.
- The architecture processes WSIs by tiling images into patches, applying graph convolutions, and using specialized transformer modules for diagnostic predictions.
The Task Graph Transformer (TGT), introduced as the core of the MulGT framework, is a unified architecture for multi-task learning over graph-structured representations of Whole Slide Images (WSIs) in computational pathology. TGT integrates a shared Graph Neural Network (GNN) backbone and task-specific transformer pipelines with modules for task-aware knowledge injection and domain knowledge-driven graph pooling. It is designed to address clinically realistic scenarios where joint prediction of multiple diagnostic endpoints, such as tumor typing and staging, is required. Empirical results on The Cancer Genome Atlas (TCGA) datasets demonstrate that TGT yields consistent improvements over prior single-task and single-pooling Graph Transformer baselines in both accuracy and robustness (Zhao et al., 2023).
1. Architecture Overview
The TGT framework processes a WSI by tiling it into fixed-size, non-overlapping image patches, each of which is embedded via a frozen feature extractor (e.g., KimiaNet) into a -dimensional feature vector. The resulting set of patch embeddings forms the vertex set of a spatial graph, where each node is connected to its eight spatially neighboring tiles. The adjacency matrix is augmented with self-loops to obtain and normalized as , where is the diagonal degree matrix.
A stack of shared graph convolutional layers encodes low-level, task-agnostic local features: with as the initial patch embeddings and trainable.
The network then branches into pipelines (e.g., for typing and staging), each consisting of:
- Task-aware Knowledge Injection: Projects the shared into a task-specific space, producing .
- Domain Knowledge-driven Pooling: Pools to graph tokens using pooling strategies aligned with their respective clinical tasks.
- Transformer Module: Processes pooled tokens, prepending a learnable [CLS] token, using a stack of transformer layers without positional encoding.
- Prediction Head: A task-specific multilayer perceptron for slide-level classification.
2. Graph-Transformer Composition
At the core, TGT first applies GCN-based locality-aware feature extraction, as described above. After pooling, each task pipeline receives a matrix and a prepended [CLS] token, forming an input to a standard transformer stack:
with as linear projections.
The output [CLS] embedding is fed to a task-specific MLP for the final prediction. This modular design allows the architecture to efficiently share computational resources while exploiting task-specific patterns.
3. Task-aware Knowledge Injection
TGT incorporates Task-aware Knowledge Injection (TKI) to transfer the shared, task-agnostic graph features into task-specific feature spaces. For each task , a set of trainable latent tokens is learned. The injection is formulated as a cross-attention block: The output undergoes residual connection, layer normalization, and position-wise feed-forward transformation:
This method allows for disentangling latent task-knowledge and minimizes feature interference, as each task has independent TKI parameters. Ablation studies demonstrate that cross-attention-based injection provides up to ~1% accuracy improvement over linear projection or sharing (Zhao et al., 2023).
4. Domain Knowledge-driven Graph Pooling
Pooling is explicitly designed to reflect the diagnostic bias of each task:
- Tumor Typing: A node drop pooling ("DropPool") is used, randomly dropping nodes during training to retain exactly nodes, aligning with the expectation that observing only a few distinctive patches suffices.
- Tumor Staging: A soft spectral clustering mechanism ("GCMinCut") is employed, where an assignment matrix soft-clusters the nodes:
This method preserves region-level context and is regularized by an unsupervised MinCut loss to sharpen locality: Ablation results indicate that matching pooling strategies to diagnostic tasks improves both accuracy and robustness by 1–3% (Zhao et al., 2023).
5. Optimization Objective
TGT is optimized jointly on all tasks using the following objective: where each task’s cross-entropy loss is defined as
Weights are tuned so that losses have similar magnitudes. The multi-task setting yields small but consistent additional gains, particularly on challenging staging labels.
6. Empirical Evaluation
TGT was evaluated on two WSI datasets from TCGA: KICA (Kidney carcinoma) and ESCA (Esophageal carcinoma). WSIs were tiled into tiles, filtered by tissue content, and embedded into vectors with KimiaNet. Graph pooling size was set to , latent tokens to , and Adam optimizer used for 40 epochs with a batch size of 8. Metrics (AUC, Accuracy, F1) were reported as the result of $5$-fold cross-validation averaged over three runs.
Key findings include:
| Task & Dataset | AUC | Accuracy | F1 | Comparative Gain (ACC) |
|---|---|---|---|---|
| Typing (KICA) | 98.44% | 93.89% | 93.89% | +1.58% over baseline |
| Staging (KICA) | 80.22% | 74.98% | 72.55% | +3.78% |
| Typing (ESCA) | 97.49% | 92.81% | 92.74% | +2.94% |
| Staging (ESCA) | 71.48% | 66.63% | 65.73% | +1.43% |
Ablations confirm the efficacy of both the task-aware injection and domain-driven pooling modules. The architecture achieves consistent gains in both typing and staging, substantiating the benefit of TGT for clinical-grade multi-task WSI analysis (Zhao et al., 2023).