- The paper demonstrates that replacing standard MLPs with bilinear layers improves the mathematical tractability of transformer circuits.
- It introduces a tensor representation of the bilinear operation, integrating MLP pathways into established analytical frameworks for transformers.
- The study reveals that bilinear layers enable clearer feature decomposition for mechanistic interpretability while highlighting practical computational challenges.
The paper "A technical note on bilinear layers for interpretability" (2305.03452) explores how replacing standard MLP layers with bilinear layers can significantly improve the mathematical tractability of analyzing neural networks, particularly within the context of transformer circuits. This improved analytical capability is presented as a potential pathway towards deeper mechanistic interpretability and safety insights for LLMs.
The core idea revolves around the mathematical form of a bilinear layer: MLPBilinear(x)=(W1x)⊙(W2x), where W1 and W2 are weight matrices and ⊙ denotes elementwise multiplication. Standard MLPs use an elementwise non-linear activation function like ReLU or GeLU, MLPReLU(x)=σ(Wx), which makes formal analysis difficult. Bilinear layers, while still nonlinear functions of the input x, have a structure that can be expressed using only linear operations and a third-order tensor.
Specifically, the paper shows that (W1x)⊙(W2x) can be written as x⋅12B⋅21x, where B is a third-order tensor with elements Bijk=W1(ij)W2(ik). Here, ⋅jk denotes a tensor inner product that sums over the j-th axis of the first tensor and the k-th axis of the second tensor. This tensor representation is key because linear operations and tensor algebra are much more amenable to formal analysis than elementwise non-linearities.
Extending the Mathematical Framework for Transformer Circuits
A significant practical implication is that this tensor representation allows integrating MLPs into existing analytical frameworks for transformers, which were previously limited to attention-only models [math_framework]. The paper demonstrates how to incorporate the bilinear layer expression into the path expansion for a one-layer transformer with both attention and MLPs.
The resulting expression for the output of a one-layer transformer T(t) processing a token sequence t involves the standard token embedding and unembedding pathways, attention pathways, and new pathways that explicitly show how information flows through the bilinear MLP. These MLP pathways involve terms like WE⊤WI1m⊤⋅12Z⋅21WI2mWE, where WE is the embedding matrix, WI1m and WI2m are the MLP input weights, Z is a specific third-order tensor related to the elementwise product, and WOm is the MLP output weight matrix. The full expression reveals how token embeddings, attention heads, and the MLP input weights interact via tensor operations.
For practitioners, this means it becomes possible to analyze the contributions of different computational paths through the network end-to-end, including those involving the MLP layer, using linear algebraic and tensor tools. This could facilitate the identification and analysis of "circuits" – specific sets of neurons and weights that perform particular computations – which was a major outcome of applying the mathematical framework to attention layers [math_framework, olsson2022incontext, wang2022interpretability].
Understanding Feature Construction
Beyond circuit analysis, the paper suggests bilinear layers offer a promising avenue for understanding feature construction, particularly in models exhibiting superposition. The difficulty with standard MLPs is that the elementwise non-linearity makes it hard to decompose how input features combine to form output features. The non-linearity acts as a complex, input-dependent "modifier" on the linear pre-activation.
In bilinear layers, the structure (W1x)⊙(W2x) can be viewed as W1x acting as a linear "modifier" on the linear transformation W2x. If we represent the input x as a sparse linear combination of input features (x=∑iaidiI), the output of the bilinear layer can be expressed as i∑j∑aiajdiI⋅12B⋅21djI. This shows that the output is a sum of pairwise interactions between input features. This property, termed "additively pairwise nonlinear," is significantly simpler to analyze than the "fully nonlinear" interactions in standard MLPs.
This decomposition suggests practical strategies for interpretability:
- Analyze the tensor B: Study the largest coefficients of the tensor B to identify which pairs of input dimensions contribute most strongly to which output dimensions.
- Tensor Decomposition: Apply techniques like Higher Order Singular Value Decomposition (HOSVD) to B to find dominant interaction patterns.
- Study Modified Features: Identify interpretable bases (e.g., using ICA [ica]) for the input space and the linear transformations W1x and W2x. Then, analyze how specific input features (diI) modify the 'default' output features (e.g., those derived from W2) via the modifier W1x. This could involve quantifying how much diI changes a default output feature direction or activation. Prioritize analyzing interactions between input features that co-occur frequently or cause the most significant modifications.
Implementation Considerations
Implementing a bilinear layer in a neural network framework like PyTorch or TensorFlow is straightforward. It involves two linear layers (W1x and W2x) followed by an elementwise multiplication.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
import torch
import torch.nn as nn
class BilinearLayer(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear1 = nn.Linear(input_dim, output_dim, bias=False) # W1
self.linear2 = nn.Linear(input_dim, output_dim, bias=False) # W2
# Optional: could add biases here if needed, but paper omits for simplicity
def forward(self, x):
# x is input vector (batch_size, input_dim)
# W1 x -> (batch_size, output_dim)
# W2 x -> (batch_size, output_dim)
out1 = self.linear1(x)
out2 = self.linear2(x)
# Elementwise multiplication
output = out1 * out2 # equivalent to torch.mul(out1, out2)
return output
input_dim = 768
output_dim = 3072 # Typical MLP dimensions in transformers
bilinear_mlp = BilinearLayer(input_dim, output_dim)
dummy_input = torch.randn(1, input_dim) # Batch size 1
output = bilinear_mlp(dummy_input)
print(output.shape) # Expected: torch.Size([1, 3072]) |
Replacing a standard MLP with a BilinearLayer
in a transformer block requires modifying the model architecture definition. This might involve removing the original linear layers and activation function and inserting the BilinearLayer
.
For the analytical side, constructing the tensor B from trained W1 and W2 weights is also feasible. If W1 and W2 are PyTorch tensors of shape (output_dim, input_dim)
, B would be a tensor of shape (output_dim, input_dim, input_dim)
. This construction involves outer products and reshaping, which can be computationally intensive for large dimensions.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
def construct_bilinear_tensor_B(W1, W2):
output_dim, input_dim = W1.shape
B = torch.zeros(output_dim, input_dim, input_dim)
for i in range(output_dim):
# Get the i-th row of W1 and W2 (corresponding to the i-th output dimension)
w1_row_i = W1[i, :] # shape (input_dim,)
w2_row_i = W2[i, :] # shape (input_dim,)
# Outer product of the i-th row of W1 and the i-th row of W2
# Resulting tensor element (j, k) is W1[i, j] * W2[i, k]
outer_prod = torch.outer(w1_row_i, w2_row_i) # shape (input_dim, input_dim)
B[i, :, :] = outer_prod
return B
W1_weights = bilinear_mlp.linear1.weight.data # shape (output_dim, input_dim)
W2_weights = bilinear_mlp.linear2.weight.data # shape (output_dim, input_dim)
B_tensor = construct_bilinear_tensor_B(W1_weights, W2_weights)
print(B_tensor.shape) # Expected: torch.Size([3072, 768, 768]) |
Analyzing this large tensor
B (e.g., using HOSVD) or implementing the proposed feature interaction analysis methods would require specialized tools and potentially significant computational resources, although likely less than enumerating features in a non-decomposable system.
Potential Limitations and Trade-offs
- Performance at Scale: The paper notes that competitive performance was shown for a 120M parameter model size. Whether bilinear layers maintain performance competitiveness at much larger scales (billions of parameters) compared to state-of-the-art alternatives like SwiGLU is an open empirical question. SwiGLU layers are closely related ((W1x)⊙σ(W2x) or similar), suggesting bilinear structure might be part of effective large models, but the exact form matters.
- Computational Cost: The tensor B has dimensions
(output_dim, input_dim, input_dim)
. For typical transformer dimensions (e.g., input_dim=768, output_dim=3072), B would be 3072×768×768, which is substantial (1.8×109 elements). Storing and analyzing this tensor requires significant memory and computation, although the paper suggests strategies to prioritize (e.g., focusing on large coefficients or dominant HOSVD components).
- Interpretability is still Hard: While the mathematical structure is simpler, translating the analysis of the B tensor or pairwise interactions into human-understandable explanations of model behavior remains a challenging task requiring further research and tool development. Identifying "interpretable bases" DI and DO (e.g., via ICA) is itself a difficult problem.
Overall, the paper presents bilinear layers not just as a potentially performant architectural choice, but primarily as a research tool that makes the internal workings of neural networks more transparent to mathematical analysis, opening concrete avenues for mechanistic interpretability and safety research.