This paper, "BERT Rediscovers the Classical NLP Pipeline" (Tenney et al., 2019 ), investigates how linguistic information is organized within the layers of the pre-trained BERT model. The central goal is to understand if BERT learns representations that correspond to the processing steps of a traditional NLP pipeline (like POS tagging, parsing, semantic role labeling) and where in the network these representations are located. The findings offer practical insights into leveraging BERT's internal states for various downstream tasks.
The core methodology employed is edge probing, an approach designed to quantify how well information about linguistic structure can be extracted from a pre-trained encoder. The idea is to train simple, task-specific classifiers ("probes") on top of the frozen representations from the pre-trained model. These probes take vectors corresponding to one or two token spans and predict a linguistic label (e.g., the POS tag for a single token span, the dependency relation between two token spans, the entity type for a multi-token span). By keeping the probe simple and freezing the encoder, the experiment aims to measure what information is already present and accessible in the pre-trained representations, rather than what the model can learn to encode during task-specific training.
The paper explores several linguistic tasks covering different levels of analysis:
- Syntax: Part-of-speech (POS) tagging, Constituents (phrase structure), Dependencies (syntactic relations).
- Entities: Named Entity Recognition (NER).
- Semantics: Semantic Role Labeling (SRL), Semantic Proto-roles (SPR), Relation Classification.
- Discourse: Coreference Resolution.
To understand how information emerges across the layers of BERT (which has multiple Transformer layers), the authors introduce two key metrics:
- Scalar Mixing Weights: Inspired by ELMo, this approach trains the probing classifier using a weighted sum of the vectors from all layers. The weights for each layer are learned during training. Higher weights indicate layers that the probing classifier finds most useful for a given task. A summary statistic, the "center-of-gravity", is calculated as the average layer index weighted by these learned coefficients, indicating the overall layer preference for a task. The formula for the weighted sum for token and task is , where are the learned, task-specific weights, and is the vector for token at layer .
- Cumulative Scoring: This metric evaluates the performance gain as more layers are considered. A series of probing classifiers are trained, where each classifier uses representations from layer and all preceding layers (using scalar mixing up to layer ). The "differential score" measures the performance increase when including layer : . This indicates at which point in the network processing hierarchy a particular type of information becomes useful for improving task performance. The "expected layer" is a summary statistic based on these differential scores, indicating the approximate layer where performance gains are typically observed.
Key Findings and Practical Implications:
- The Classical Pipeline Emerges Hierarchically: Both metrics consistently show that linguistic information is represented in a layered structure that mirrors the traditional NLP pipeline order: POS tagging appears earliest, followed by constituents, dependencies, entities, semantic roles, and coreference in higher layers. This means basic syntactic structure is primarily captured in the lower layers, while more complex semantic and discourse relations are represented in the higher layers.
- Implementation Implication: This finding is highly practical for fine-tuning or using BERT as a feature extractor. If your downstream task relies heavily on syntax (e.g., parsing, POS tagging), representations from lower layers might be more beneficial or sufficient. For tasks requiring deeper semantic understanding (e.g., question answering, coreference resolution), using vectors from higher layers, or a combination including higher layers, is likely more effective. When designing a feature-based model on top of BERT, selecting representations from layers identified as important by probes like those in this paper can be a principled approach. For example, for a dependency parsing task, one might extract vectors from layers 3-7 of BERT-large (based on Figure 2) and use them as features for a parser.
- Syntactic vs. Semantic Localization: The paper found that syntactic information tends to be more localized to specific layers (high KL divergence of weight/score distributions), while semantic information is often distributed across many layers.
- Implementation Implication: This suggests that for syntactic tasks, focusing on a few key layers might be efficient and effective. For semantic tasks, combining information from a wider range of layers (as done with scalar mixing) is likely crucial. This justifies using techniques like scalar mixing or concatenating representations from multiple layers in downstream models targeting semantics.
- Comparison of Metrics: Cumulative scores show that many predictions (especially for syntactic tasks) can be resolved early, potentially using heuristic or local patterns. However, scalar mixing weights concentrate on later layers, suggesting that the richer, more discriminative features for successful performance are developed higher in the network.
- Implementation Implication: This reinforces the idea that while lower layers capture basic features, higher layers perform complex transformations that lead to better overall representations for challenging cases. Fine-tuning typically affects these higher layers more, aligning with their importance for achieving state-of-the-art results on many tasks. For performance-critical applications, using the full model or focusing on higher layers is generally advisable, despite some information being present early.
- Dynamic, Inter-Dependent Processing (Per-Example Analysis): While the aggregate trend follows the pipeline, a qualitative analysis of individual examples shows that BERT can dynamically revise earlier decisions based on information from later, higher-level processing. For instance, in the example "he smoked toronto...", an initial incorrect NER tag (GPE) for "Toronto" is corrected to ORG after the model resolves the semantic role that "Toronto" is the thing being "smoked" (implying a sports team).
- Implementation Implication: This highlights that BERT is not just a static pipeline but performs complex, non-linear interactions across layers. This explains its power in handling ambiguity. For tasks requiring fine-grained disambiguation, the ability of higher layers to influence decisions related to lower-level features is a critical factor and a key advantage over traditional, strictly sequential pipelines. This dynamic suggests that approaches that allow interaction between representations from different layers (like standard fine-tuning or multi-layer feature extraction) are important for exploiting BERT's full capabilities.
Implementing Concepts:
The "edge probing" method provides a template for analyzing any pre-trained transformer model. To implement it:
- Choose a pre-trained model: Load BERT, RoBERTa, etc. Freeze its weights.
- Define Probing Tasks: Select linguistic tasks relevant to your analysis or application (e.g., POS, NER, Dependency arcs).
- Prepare Data: Obtain datasets annotated with the chosen linguistic structures (e.g., OntoNotes, Penn Treebank).
- Design Probing Classifiers: For each task, build a simple feed-forward network that takes representations from the pre-trained model corresponding to the target span(s).
- For single-span tasks (POS, NER), the input might be the average pooling or the first token's vector of the span.
- For two-span tasks (Dependencies, SRL, Coreference, Relations), the input could be concatenations of the span representations, element-wise products/differences, or the vector of the token connecting the spans in a tree structure, as explored in the original edge probing paper (Narayan et al., 2018 ).
- The paper's architecture uses features based on endpoint vectors, span representations, and attention mechanisms, which are slightly more complex than a simple average but still limited compared to full fine-tuning. A basic implementation might use average pooling of span token vectors.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch import torch.nn as nn class SpanClassifier(nn.Module): def __init__(self, input_dim, num_labels): super().__init__() self.classifier = nn.Linear(input_dim, num_labels) def forward(self, span_vectors): # span_vectors could be the averaged representations of tokens in a span # or concatenated representations for multi-span tasks return self.classifier(span_vectors) # Example usage within a probing framework # Assume `bert_output` is a list of [batch_size, seq_len, hidden_size] tensors # from each layer of BERT # Assume `span_indices` specify token start and end indices for a span # Select layer representations (e.g., from layer 6 for a syntactic task) layer_idx = 6 layer_output = bert_output[layer_idx] # [batch_size, seq_len, hidden_size] # Extract span vector (e.g., average pooling) # This requires careful indexing based on batch and span_indices span_vectors = [] # Placeholder for batch_idx in range(batch_size): start, end = span_indices[batch_idx] # Simple average pooling for demonstration span_vector = torch.mean(layer_output[batch_idx, start:end], dim=0) span_vectors.append(span_vector) span_vectors = torch.stack(span_vectors) # [batch_size, hidden_size] # Classify the span pos_classifier = SpanClassifier(input_dim=layer_output.shape[-1], num_labels=num_pos_tags) logits = pos_classifier(span_vectors) |
- Implement Layer Combination (Optional but Recommended): To use scalar mixing, add learned weights for combining layer outputs before feeding to the probe.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
# Example of scalar mixing layer class ScalarMix(nn.Module): def __init__(self, num_layers, hidden_size): super().__init__() # Learned scalar weights self.scalar_weights = nn.Parameter(torch.ones(num_layers)) # Learned gamma parameter self.gamma = nn.Parameter(torch.ones(1)) def forward(self, layer_outputs): # layer_outputs is a list of tensors [batch_size, seq_len, hidden_size] # Apply softmax to weights normed_weights = torch.softmax(self.scalar_weights, dim=0) # Weighted sum of layers mixed_output = torch.sum( torch.stack([w * layer_outputs[i] for i, w in enumerate(normed_weights)], dim=0), dim=0 ) * self.gamma return mixed_output # Example usage num_layers = len(bert_output) hidden_size = bert_output[0].shape[-1] scalar_mix_layer = ScalarMix(num_layers, hidden_size) mixed_representations = scalar_mix_layer(bert_output) # [batch_size, seq_len, hidden_size] # Now extract span vectors from mixed_representations and feed to classifier |
- Train Probes: Train the probing classifiers for each task, optimizing the probe parameters (and scalar mixing weights if used) while keeping the BERT weights frozen. Use standard classification loss (e.g., cross-entropy).
- Evaluate and Analyze: Evaluate probes using F1 score or other relevant metrics. Analyze learned mixing weights and cumulative scores across layers to understand the information flow.
Computational Requirements and Limitations:
- Training probes is computationally much cheaper than fine-tuning the full BERT model, as only a small number of parameters (in the probes and scalar mixing layer) are trained.
- However, extracting representations from all layers of BERT can still be memory-intensive, especially for BERT-large or long sequences.
- The edge probing method itself has limitations, as acknowledged by the authors: the inability to extract information via a simple probe doesn't definitively mean the information isn't present, and observing information doesn't guarantee the model actually uses it in downstream tasks in the way the probe does. This motivates combining probing with behavioral analysis.
In essence, this paper provides a valuable map of the linguistic landscape within BERT's layers. For developers and engineers, it translates theoretical model internals into practical guidance on how to effectively access and utilize BERT's learned representations for a range of NLP tasks, suggesting that task characteristics (syntax vs. semantics) should inform which layers are prioritized.