PredFT: Dual-Network fMRI Decoding
- PredFT is a dual-network model that reconstructs natural language from spatial–temporal fMRI signals by jointly modeling neural decoding and brain predictive coding.
- It incorporates a Main Decoding Network with 3D-CNN, FIR compensation, and Transformer layers alongside a Side Network that extracts ROI predictive coding signals.
- Experiments show significant gains over previous models using BLEU and ROUGE metrics, with ablation analyses highlighting the essential role of the predictive coding pathway.
PredFT is a dual-network model designed for reconstructing natural language from spatial–temporal fMRI signals, jointly modeling neural decoding and brain predictive coding. It leverages a Main Decoding Network to generate language from fMRI activity and a Side Network to extract predictive representations from specific brain regions of interest (ROIs) and integrate these into the decoding process via cross-attention. PredFT achieves state-of-the-art performance on the Narratives fMRI dataset, with ablation studies demonstrating the contribution of predictive coding representations to decoding accuracy (Yin et al., 2024).
1. Architectural Overview
PredFT implements a multi-component architecture:
- Main Decoding Network: Processes full-volume fMRI data through a sequence of modules: (i) 3D convolutional neural network (3D‐CNN) for spatial feature extraction, (ii) finite impulse response (FIR) model for temporal alignment (accounting for BOLD delay), (iii) Transformer encoder for contextual processing, and (iv) Transformer decoder with masked self-attention and fMRI encoder-decoder attention, augmented by cross-attention over the predictive-coding signal.
- Predictive Coding Side Network: Extracts activations from six predefined ROIs, applies FIR delay compensation, encodes fused ROI activity with a Transformer encoder, and decodes this representation to predict future word tokens.
Both networks process input in temporally aligned windows and share architectural motifs, with integration occurring via a cross-attention mechanism imposed in the main decoder.
2. Input Representation and Preprocessing
The input to PredFT consists of fMRI volumes , where is subject, is segment, and is time. Voxel-wise normalization is applied across time. For the predictive coding pathway, mean activation vectors are extracted at each TR from six ROIs: left superior temporal sulcus, left angular gyrus, left supramarginal gyrus, and left inferior frontal gyrus (opercular, triangular, orbital).
Textual targets are tokenized using the BART tokenizer, with alignment to the fMRI segment timing, and segments are processed to match the temporal resolution provided by the fMRI acquisition protocol.
3. Network Operations and Fusion Mechanisms
Within the main network, the fMRI encoder chain is:
- Multi-layer 3D-CNN with GroupNorm, ReLU, and residual pathways, outputting a spatial feature tensor.
- The FIR temporal compensation module concatenates feature vectors across temporally adjacent frames to mitigate BOLD-induced lag, with a learned linear transformation to restore dimensionality.
- Transformer layers contextualize the temporally compensated representations.
The decoder takes shifted-right text embeddings and, at each layer, computes:
- Masked self-attention for sequential language modeling,
- Encoder-decoder attention over the main network's encoded fMRI state,
- Cross-attention over the Side Network's encoded predictive-coding signal, using a causally masked attention matrix (), constraining each token's attention to predictive-coding representations at equal or later times.
The Side Network independently learns a word-prediction task from ROI time series, with identical temporal alignment schemes and Transformer-based sequence modeling.
4. Training Objectives and Joint Optimization
PredFT is optimized under a multi-objective paradigm:
where
- : Negative log-likelihood for reconstructing actual text conditioned on fMRI signals,
- : Negative log-likelihood for predicting future tokens from the predictive coding pathway,
- : Scalar (set to 1.0 for 10-frame, 0.5 for 20/40-frame experiments) to balance joint training.
regularization and standard Transformer dropout (rate 0.1) are used for regularization. Optimization uses Adam with a learning rate schedule decaying from to over 20 epochs.
5. Dataset, Evaluation Protocol, and Metrics
PredFT is evaluated on the Narratives dataset: 230 subjects, TR=1.5s, 64x64x27 voxel volumes collected during listening to spoken stories (60 min per subject). Preprocessing involves fMRIPrep pipeline, voxel z-scoring, and Destrieux parcellation. Text is tokenized by BART and aligned to fMRI.
Model selection and reporting use cross-subject splits to prevent data leakage. Evaluation uses BLEU-N (geometric mean of n-gram precisions with brevity penalty) and ROUGE-1 (unigram precision, recall, F1).
6. Quantitative Performance and Ablation Analysis
With 40-frame (60s) windows:
- BLEU-1: 27.8%
- BLEU-2: 8.3%
- ROUGE-1-F1: 25.96%
Relative to the previous state-of-the-art UniCoRN model (BLEU-1: 21.8%, BLEU-2: 5.4%, ROUGE-1-F1: 25.30%), PredFT achieves substantial gains.
Ablations show that removing the Side Network drops BLEU-1 to 18.0%. Substituting predictive-coding ROIs with random or whole-cortex features underperforms using the specifically selected six ROIs. Varying reveals optimal joint training near 0.5 for most window settings.
7. Predictive Coding Effects and Error Analysis
BLEU-1 as a function of future prediction distance () in the Side Network peaks at , consistent with encoding-side findings. Error analyses reveal that PredFT substantially reduces the elevated error rates typically observed for "last-heard" words per TR, supporting the role of predictive coding in mitigating TR-related information loss. This suggests that incorporating representations from predictive-coding ROIs provides temporally anticipatory signals that enhance linguistic decoding beyond what is attainable from conventional fMRI data pathways.
In summary, PredFT combines spatial–temporal fMRI-to-text decoding with predictive-coding representations extracted from preselected ROIs using a dual-Transformer architecture and cross-attention fusion, establishing a new empirical benchmark for naturalistic language reconstruction from brain activity (Yin et al., 2024).