Papers
Topics
Authors
Recent
Search
2000 character limit reached

Task Graph Transformer (TGT)

Updated 9 March 2026
  • The paper introduces Task Graph Transformer (TGT), a unified framework combining GNN-based feature extraction with transformer pipelines for multi-task learning on WSIs.
  • It employs task-aware knowledge injection and domain-driven pooling, achieving consistent accuracy boosts (up to 3.78%) over baseline models.
  • The architecture processes WSIs by tiling images into patches, applying graph convolutions, and using specialized transformer modules for diagnostic predictions.

The Task Graph Transformer (TGT), introduced as the core of the MulGT framework, is a unified architecture for multi-task learning over graph-structured representations of Whole Slide Images (WSIs) in computational pathology. TGT integrates a shared Graph Neural Network (GNN) backbone and task-specific transformer pipelines with modules for task-aware knowledge injection and domain knowledge-driven graph pooling. It is designed to address clinically realistic scenarios where joint prediction of multiple diagnostic endpoints, such as tumor typing and staging, is required. Empirical results on The Cancer Genome Atlas (TCGA) datasets demonstrate that TGT yields consistent improvements over prior single-task and single-pooling Graph Transformer baselines in both accuracy and robustness (Zhao et al., 2023).

1. Architecture Overview

The TGT framework processes a WSI by tiling it into fixed-size, non-overlapping image patches, each of which is embedded via a frozen feature extractor (e.g., KimiaNet) into a dd-dimensional feature vector. The resulting set of patch embeddings forms the vertex set V\mathcal V of a spatial graph, where each node is connected to its eight spatially neighboring tiles. The adjacency matrix AA is augmented with self-loops to obtain A~=A+I\tilde A = A + I and normalized as A^=D~1/2A~D~1/2\hat A = \tilde D^{-1/2} \tilde A \tilde D^{-1/2}, where D~\tilde D is the diagonal degree matrix.

A stack of LL shared graph convolutional layers encodes low-level, task-agnostic local features: Hl+1=ReLU(A^HlWl)H_{l+1} = \mathrm{ReLU}(\hat A\, H_l\, W_l) with H0H_0 as the initial patch embeddings and WlRd×dW_l \in \mathbb R^{d \times d} trainable.

The network then branches into TT pipelines (e.g., for typing and staging), each consisting of:

  • Task-aware Knowledge Injection: Projects the shared HLH_L into a task-specific space, producing H^(t)\hat H^{(t)}.
  • Domain Knowledge-driven Pooling: Pools H^(t)\hat H^{(t)} to pp graph tokens using pooling strategies aligned with their respective clinical tasks.
  • Transformer Module: Processes pooled tokens, prepending a learnable [CLS] token, using a stack of transformer layers without positional encoding.
  • Prediction Head: A task-specific multilayer perceptron for slide-level classification.

2. Graph-Transformer Composition

At the core, TGT first applies GCN-based locality-aware feature extraction, as described above. After pooling, each task pipeline receives a p×dp \times d matrix H^(t),pool\hat H^{(t), \mathrm{pool}} and a prepended [CLS] token, forming an input to a standard transformer stack: MultiHead(Q,K,V)=[head1;;headh]WO\mathrm{MultiHead}(Q, K, V) = [\mathrm{head}_1; \dots; \mathrm{head}_h]\, W^O

headi=softmax(QWiQ(KWiK))(VWiV)\mathrm{head}_i = \mathrm{softmax}(Q W^Q_i (K W^K_i)^\top)(V W^V_i)

with Q,K,VR(p+1)×dQ, K, V \in \mathbb R^{(p+1) \times d} as linear projections.

The output [CLS] embedding is fed to a task-specific MLP for the final prediction. This modular design allows the architecture to efficiently share computational resources while exploiting task-specific patterns.

3. Task-aware Knowledge Injection

TGT incorporates Task-aware Knowledge Injection (TKI) to transfer the shared, task-agnostic graph features into task-specific feature spaces. For each task tt, a set of mm trainable latent tokens T(t)Rm×dT^{(t)} \in \mathbb R^{m \times d} is learned. The injection is formulated as a cross-attention block: TKI(t)(HL)=MultiHead(Q=HL,K=T(t),V=T(t))RV×d\mathrm{TKI}^{(t)}(H_L) = \mathrm{MultiHead}(Q=H_L,\, K=T^{(t)},\, V=T^{(t)}) \in \mathbb R^{|\mathcal V| \times d} The output undergoes residual connection, layer normalization, and position-wise feed-forward transformation: Z(t)=LayerNorm(HL+TKI(t)(HL))Z^{(t)} = \mathrm{LayerNorm}(H_L + \mathrm{TKI}^{(t)}(H_L))

H^(t)=LayerNorm(Z(t)+rFF(Z(t)))\hat H^{(t)} = \mathrm{LayerNorm}(Z^{(t)} + \mathrm{rFF}(Z^{(t)}))

This method allows for disentangling latent task-knowledge and minimizes feature interference, as each task has independent TKI parameters. Ablation studies demonstrate that cross-attention-based injection provides up to ~1% accuracy improvement over linear projection or sharing (Zhao et al., 2023).

4. Domain Knowledge-driven Graph Pooling

Pooling is explicitly designed to reflect the diagnostic bias of each task:

  • Tumor Typing: A node drop pooling ("DropPool") is used, randomly dropping nodes during training to retain exactly pp nodes, aligning with the expectation that observing only a few distinctive patches suffices.
  • Tumor Staging: A soft spectral clustering mechanism ("GCMinCut") is employed, where an assignment matrix SRV×pS \in \mathbb R^{|\mathcal V| \times p} soft-clusters the nodes: S=ReLU(A^H^(stage)Wpool)S = \mathrm{ReLU}(\hat A\, \hat H^{(\mathrm{stage})} W_\mathrm{pool})

H^(stage),pool=SH^(stage)\hat H^{(\mathrm{stage}),\, \mathrm{pool}} = S^\top\, \hat H^{(\mathrm{stage})}

This method preserves region-level context and is regularized by an unsupervised MinCut loss to sharpen locality: Lmincut=Tr(SA~S)Tr(SD~S)+SSSSFIppF\mathcal L_{\mathrm{mincut}} = -\frac{\mathrm{Tr}(S^\top\,\tilde A\,S)}{\mathrm{Tr}(S^\top\,\tilde D\,S)} + \left\| \frac{S^\top S}{\| S^\top S \|_F} - \frac{I_p}{\sqrt p} \right\|_F Ablation results indicate that matching pooling strategies to diagnostic tasks improves both accuracy and robustness by 1–3% (Zhao et al., 2023).

5. Optimization Objective

TGT is optimized jointly on all tasks using the following objective: Ltotal=wtypeLtype+wstageLstage+wmincutLmincut\mathcal L_{\mathrm{total}} = w_{\mathrm{type}} \mathcal L_{\mathrm{type}} + w_{\mathrm{stage}} \mathcal L_{\mathrm{stage}} + w_{\mathrm{mincut}} \mathcal L_{\mathrm{mincut}} where each task’s cross-entropy loss is defined as

Lt=1Ni=1Nc=1CtYi(t,c)log(Y^i(t,c))\mathcal L_t = -\frac{1}{N} \sum_{i=1}^N \sum_{c=1}^{C_t} Y_i^{(t, c)} \log(\hat Y_i^{(t, c)})

Weights are tuned so that losses have similar magnitudes. The multi-task setting yields small but consistent additional gains, particularly on challenging staging labels.

6. Empirical Evaluation

TGT was evaluated on two WSI datasets from TCGA: KICA (Kidney carcinoma) and ESCA (Esophageal carcinoma). WSIs were tiled into 512×512512 \times 512 tiles, filtered by tissue content, and embedded into d=1024d=1024 vectors with KimiaNet. Graph pooling size was set to p=100p=100, latent tokens to m=150m=150, and Adam optimizer used for 40 epochs with a batch size of 8. Metrics (AUC, Accuracy, F1) were reported as the result of $5$-fold cross-validation averaged over three runs.

Key findings include:

Task & Dataset AUC Accuracy F1 Comparative Gain (ACC)
Typing (KICA) 98.44% 93.89% 93.89% +1.58% over baseline
Staging (KICA) 80.22% 74.98% 72.55% +3.78%
Typing (ESCA) 97.49% 92.81% 92.74% +2.94%
Staging (ESCA) 71.48% 66.63% 65.73% +1.43%

Ablations confirm the efficacy of both the task-aware injection and domain-driven pooling modules. The architecture achieves consistent gains in both typing and staging, substantiating the benefit of TGT for clinical-grade multi-task WSI analysis (Zhao et al., 2023).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Task Graph Transformer (TGT).