- The paper introduces TEA-GLM, a framework that aligns GNN representations with frozen LLM token embeddings via contrastive learning to achieve effective zero-shot graph learning.
- It employs a two-stage process using instance-wise and feature-wise contrastive losses followed by a linear projector to generate graph token embeddings for LLM prompting.
- Experimental results demonstrate robust cross-dataset and cross-task performance improvements over baselines, highlighting the efficiency and generalization of the proposed method.
This paper introduces the Token Embedding-Aligned Graph LLM (TEA-GLM), a novel framework designed to enable LLMs to perform zero-shot learning on graph-based tasks across different datasets and task types. The core problem addressed is the limited generalization capability of traditional Graph Neural Networks (GNNs) and existing GNN-LLM methods, which often require task-specific fine-tuning or struggle with unseen data.
TEA-GLM achieves zero-shot learning by aligning the representations learned by a GNN with the token embeddings of a frozen LLM. This alignment allows the LLM to effectively leverage its pre-trained knowledge for graph tasks without needing to be fine-tuned itself.
The methodology involves two main stages:
- Token Embeddings-Aligned Graph Self-Supervised Learning:
* The final pre-training loss for the GNN is L=21(Lins+Lfea). This GNN is then frozen.
- Alignment Tuning (Training the Linear Projector):
- After pre-training the GNN, its parameters are fixed. A simple linear projector is then trained. This projector maps the GNN's output representation for a node (or a pooled representation for an edge/graph) into a fixed number (K) of "graph token embeddings."
Htoken=fLinear(ui), where Htoken∈RK×FL.
- These K graph token embeddings act as soft prompts for the LLM. They are inserted into a unified instruction template designed for various graph tasks (node classification, link prediction).
- Instruction Design:
- Graph Information Provision: Uses a placeholder ⟨graph⟩ (which is replaced by the K graph token embeddings) and minimal textual node information (e.g., only paper titles for citation graphs). The paper notes that reducing input text helps LLMs focus on graph structural information embedded in the graph tokens.
- Task Description: Includes the specific task query and a set of possible answers {ans}. This helps the model generalize by learning to reason from a given set of options rather than memorizing answers for a specific dataset.
- Crucially, during this stage, only the linear projector is trained; the LLM's parameters remain frozen. This makes the process efficient and leverages the LLM's inherent zero-shot capabilities.
Implementation and Application:
- Node Feature Encoding: Raw text features of nodes are initially encoded using a pre-trained BERT model.
- GNN Architecture: GraphSAGE is used as the GNN encoder.
- LLM: Vicuna-7B-v1.5 is used as the base LLM.
- Training Process:
1. Pre-train the GNN (GraphSAGE) on a source dataset (e.g., Arxiv) using the combined instance-wise and feature-wise contrastive loss.
2. Fix the GNN parameters.
3. Train the linear projector on the same source dataset for specific tasks (e.g., node classification on Arxiv) using the designed instructions and the frozen LLM.
4. Evaluate the system (frozen GNN + trained projector + frozen LLM) on unseen target datasets and unseen tasks in a zero-shot manner.
Example Application (Node Classification):
- For a target node in an unseen graph, get its features (e.g., title).
- Pass the graph structure and node features through the pre-trained (frozen) GNN to get the node embedding ui.
- Feed ui into the trained linear projector to get K graph token embeddings Htoken.
- Construct the prompt for the LLM:
1
2
3
|
Given the representation of a paper: <graph_token_1>...<graph_token_K>, with the following information:
Title: First Paper: {title_of_target_node}
Which arXiv CS sub-category does this paper belong to? Please directly give the most likely answer from the following sub-categories: {ans_1, ans_2, ..., ans_M} |
- The LLM generates the predicted category.
Key Experimental Findings:
- Cross-Dataset Zero-Shot (RQ1): TEA-GLM significantly outperformed baseline methods (including other GNNs, LLMs alone, and GNN-LLM combinations like GraphGPT and LLaGA) on unseen datasets for node classification. This demonstrates its strong generalization ability. For example, when trained on Arxiv for node classification, it performed well on Pubmed and Cora without retraining.
- Cross-Task Zero-Shot (RQ2): When trained on a node classification task, TEA-GLM could effectively perform link prediction on unseen datasets without any task-specific fine-tuning, again outperforming baselines. This highlights the flexibility of the unified instruction and graph token embedding approach.
- Ablation Studies (RQ3):
- Removing feature-wise contrastive learning ("w/o FC") hurt performance on unseen datasets and especially on unseen tasks, indicating its importance for aligning GNN representations with the LLM's semantic space for better transferability.
- Removing graph token embeddings ("w/o GT," meaning the LLM only gets text) caused a significant performance drop, confirming that the graph tokens effectively convey crucial structural information from the GNN to the LLM.
Practical Implications and Considerations:
- Efficiency: By keeping the LLM frozen, TEA-GLM is more efficient to train and adapt than methods requiring full LLM fine-tuning. Only the GNN (once) and a small linear projector are trained.
- Generalization: The core strength is improved zero-shot generalization to new graph datasets and tasks.
- Unified Framework: The fixed number of graph tokens and unified instruction design simplify adaptation to different graph tasks (node, edge, potentially graph-level).
- Reduced Text Reliance: The method works well even with minimal textual input (e.g., only titles), allowing the LLM to focus on structural signals from the GNN.
- Computational Requirements: Requires GPUs for GNN pre-training and LLM inference. The paper used 2 NVIDIA A100 80GB GPUs. PCA on LLM token embeddings is a one-time offline cost.
- Hyperparameters: Key parameters include the number of graph token embeddings (K) and the number of principal components (P) for feature-wise contrastive learning. The paper found K=1 or K=3 often sufficient for unseen datasets and P=1000 (capturing 50% variance of LLM token embeddings) worked well.
In summary, TEA-GLM provides a practical and effective framework for leveraging the zero-shot capabilities of LLMs for graph machine learning. It achieves this through a novel self-supervised GNN pre-training strategy that aligns GNN outputs with LLM token embedding space, and a lightweight linear projector that converts graph representations into a fixed number of graph token embeddings for LLM consumption via carefully designed instructions.