Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
162 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

A technical note on bilinear layers for interpretability (2305.03452v1)

Published 5 May 2023 in cs.LG and cs.NE

Abstract: The ability of neural networks to represent more features than neurons makes interpreting them challenging. This phenomenon, known as superposition, has spurred efforts to find architectures that are more interpretable than standard multilayer perceptrons (MLPs) with elementwise activation functions. In this note, I examine bilinear layers, which are a type of MLP layer that are mathematically much easier to analyze while simultaneously performing better than standard MLPs. Although they are nonlinear functions of their input, I demonstrate that bilinear layers can be expressed using only linear operations and third order tensors. We can integrate this expression for bilinear layers into a mathematical framework for transformer circuits, which was previously limited to attention-only transformers. These results suggest that bilinear layers are easier to analyze mathematically than current architectures and thus may lend themselves to deeper safety insights by allowing us to talk more formally about circuits in neural networks. Additionally, bilinear layers may offer an alternative path for mechanistic interpretability through understanding the mechanisms of feature construction instead of enumerating a (potentially exponentially) large number of features in large models.

Citations (4)

Summary

  • The paper demonstrates that replacing standard MLPs with bilinear layers improves the mathematical tractability of transformer circuits.
  • It introduces a tensor representation of the bilinear operation, integrating MLP pathways into established analytical frameworks for transformers.
  • The study reveals that bilinear layers enable clearer feature decomposition for mechanistic interpretability while highlighting practical computational challenges.

The paper "A technical note on bilinear layers for interpretability" (2305.03452) explores how replacing standard MLP layers with bilinear layers can significantly improve the mathematical tractability of analyzing neural networks, particularly within the context of transformer circuits. This improved analytical capability is presented as a potential pathway towards deeper mechanistic interpretability and safety insights for LLMs.

The core idea revolves around the mathematical form of a bilinear layer: MLPBilinear(x)=(W1x)(W2x)MLP_{Bilinear}(x) = (W_1 x) \odot (W_2 x), where W1W_1 and W2W_2 are weight matrices and \odot denotes elementwise multiplication. Standard MLPs use an elementwise non-linear activation function like ReLU or GeLU, MLPReLU(x)=σ(Wx)MLP_{ReLU}(x) = \sigma(W x), which makes formal analysis difficult. Bilinear layers, while still nonlinear functions of the input xx, have a structure that can be expressed using only linear operations and a third-order tensor.

Specifically, the paper shows that (W1x)(W2x)(W_1 x) \odot (W_2 x) can be written as x12B21xx \boldsymbol{\cdot}_{1 2} B \boldsymbol{\cdot}_{2 1} x, where BB is a third-order tensor with elements Bijk=W1(ij)W2(ik)B_{i j k} = W_{1 (i j)} W_{2 (i k)}. Here, jk\boldsymbol{\cdot}_{jk} denotes a tensor inner product that sums over the jj-th axis of the first tensor and the kk-th axis of the second tensor. This tensor representation is key because linear operations and tensor algebra are much more amenable to formal analysis than elementwise non-linearities.

Extending the Mathematical Framework for Transformer Circuits

A significant practical implication is that this tensor representation allows integrating MLPs into existing analytical frameworks for transformers, which were previously limited to attention-only models [math_framework]. The paper demonstrates how to incorporate the bilinear layer expression into the path expansion for a one-layer transformer with both attention and MLPs.

The resulting expression for the output of a one-layer transformer T(t)T(t) processing a token sequence tt involves the standard token embedding and unembedding pathways, attention pathways, and new pathways that explicitly show how information flows through the bilinear MLP. These MLP pathways involve terms like WEWI1m12Z21WI2mWEW_E^\top W_{I_1}^{m \top} \boldsymbol{\cdot}_{1 2} Z \boldsymbol{\cdot}_{ 2 1 } W_{I_2}^{m} W_E, where WEW_E is the embedding matrix, WI1mW_{I_1}^m and WI2mW_{I_2}^m are the MLP input weights, ZZ is a specific third-order tensor related to the elementwise product, and WOmW_O^m is the MLP output weight matrix. The full expression reveals how token embeddings, attention heads, and the MLP input weights interact via tensor operations.

For practitioners, this means it becomes possible to analyze the contributions of different computational paths through the network end-to-end, including those involving the MLP layer, using linear algebraic and tensor tools. This could facilitate the identification and analysis of "circuits" – specific sets of neurons and weights that perform particular computations – which was a major outcome of applying the mathematical framework to attention layers [math_framework, olsson2022incontext, wang2022interpretability].

Understanding Feature Construction

Beyond circuit analysis, the paper suggests bilinear layers offer a promising avenue for understanding feature construction, particularly in models exhibiting superposition. The difficulty with standard MLPs is that the elementwise non-linearity makes it hard to decompose how input features combine to form output features. The non-linearity acts as a complex, input-dependent "modifier" on the linear pre-activation.

In bilinear layers, the structure (W1x)(W2x)(W_1 x) \odot (W_2 x) can be viewed as W1xW_1 x acting as a linear "modifier" on the linear transformation W2xW_2 x. If we represent the input xx as a sparse linear combination of input features (x=iaidiIx = \sum_i a_i d_i^I), the output of the bilinear layer can be expressed as ijaiajdiI12B21djI\sum_i \sum_j a_i a_j d_i^I \boldsymbol{\cdot}_{1 2} B \boldsymbol{\cdot}_{2 1} d_j^I. This shows that the output is a sum of pairwise interactions between input features. This property, termed "additively pairwise nonlinear," is significantly simpler to analyze than the "fully nonlinear" interactions in standard MLPs.

This decomposition suggests practical strategies for interpretability:

  1. Analyze the tensor B: Study the largest coefficients of the tensor BB to identify which pairs of input dimensions contribute most strongly to which output dimensions.
  2. Tensor Decomposition: Apply techniques like Higher Order Singular Value Decomposition (HOSVD) to BB to find dominant interaction patterns.
  3. Study Modified Features: Identify interpretable bases (e.g., using ICA [ica]) for the input space and the linear transformations W1xW_1 x and W2xW_2 x. Then, analyze how specific input features (diId_i^I) modify the 'default' output features (e.g., those derived from W2W_2) via the modifier W1xW_1 x. This could involve quantifying how much diId_i^I changes a default output feature direction or activation. Prioritize analyzing interactions between input features that co-occur frequently or cause the most significant modifications.

Implementation Considerations

Implementing a bilinear layer in a neural network framework like PyTorch or TensorFlow is straightforward. It involves two linear layers (W1xW_1 x and W2xW_2 x) followed by an elementwise multiplication.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn

class BilinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, output_dim, bias=False) # W1
        self.linear2 = nn.Linear(input_dim, output_dim, bias=False) # W2
        # Optional: could add biases here if needed, but paper omits for simplicity

    def forward(self, x):
        # x is input vector (batch_size, input_dim)
        # W1 x -> (batch_size, output_dim)
        # W2 x -> (batch_size, output_dim)
        out1 = self.linear1(x)
        out2 = self.linear2(x)
        # Elementwise multiplication
        output = out1 * out2 # equivalent to torch.mul(out1, out2)
        return output

input_dim = 768
output_dim = 3072 # Typical MLP dimensions in transformers
bilinear_mlp = BilinearLayer(input_dim, output_dim)
dummy_input = torch.randn(1, input_dim) # Batch size 1
output = bilinear_mlp(dummy_input)
print(output.shape) # Expected: torch.Size([1, 3072])

Replacing a standard MLP with a BilinearLayer in a transformer block requires modifying the model architecture definition. This might involve removing the original linear layers and activation function and inserting the BilinearLayer.

For the analytical side, constructing the tensor BB from trained W1W_1 and W2W_2 weights is also feasible. If W1W_1 and W2W_2 are PyTorch tensors of shape (output_dim, input_dim), BB would be a tensor of shape (output_dim, input_dim, input_dim). This construction involves outer products and reshaping, which can be computationally intensive for large dimensions.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def construct_bilinear_tensor_B(W1, W2):
    output_dim, input_dim = W1.shape
    B = torch.zeros(output_dim, input_dim, input_dim)
    for i in range(output_dim):
        # Get the i-th row of W1 and W2 (corresponding to the i-th output dimension)
        w1_row_i = W1[i, :] # shape (input_dim,)
        w2_row_i = W2[i, :] # shape (input_dim,)
        # Outer product of the i-th row of W1 and the i-th row of W2
        # Resulting tensor element (j, k) is W1[i, j] * W2[i, k]
        outer_prod = torch.outer(w1_row_i, w2_row_i) # shape (input_dim, input_dim)
        B[i, :, :] = outer_prod
    return B

W1_weights = bilinear_mlp.linear1.weight.data # shape (output_dim, input_dim)
W2_weights = bilinear_mlp.linear2.weight.data # shape (output_dim, input_dim)

B_tensor = construct_bilinear_tensor_B(W1_weights, W2_weights)
print(B_tensor.shape) # Expected: torch.Size([3072, 768, 768])
Analyzing this large tensor BB (e.g., using HOSVD) or implementing the proposed feature interaction analysis methods would require specialized tools and potentially significant computational resources, although likely less than enumerating features in a non-decomposable system.

Potential Limitations and Trade-offs

  • Performance at Scale: The paper notes that competitive performance was shown for a 120M parameter model size. Whether bilinear layers maintain performance competitiveness at much larger scales (billions of parameters) compared to state-of-the-art alternatives like SwiGLU is an open empirical question. SwiGLU layers are closely related ((W1x)σ(W2x)(W_1 x) \odot \sigma(W_2 x) or similar), suggesting bilinear structure might be part of effective large models, but the exact form matters.
  • Computational Cost: The tensor BB has dimensions (output_dim, input_dim, input_dim). For typical transformer dimensions (e.g., input_dim=768, output_dim=3072), BB would be 3072×768×7683072 \times 768 \times 768, which is substantial (1.8×1091.8 \times 10^9 elements). Storing and analyzing this tensor requires significant memory and computation, although the paper suggests strategies to prioritize (e.g., focusing on large coefficients or dominant HOSVD components).
  • Interpretability is still Hard: While the mathematical structure is simpler, translating the analysis of the BB tensor or pairwise interactions into human-understandable explanations of model behavior remains a challenging task requiring further research and tool development. Identifying "interpretable bases" DID^I and DOD^O (e.g., via ICA) is itself a difficult problem.

Overall, the paper presents bilinear layers not just as a potentially performant architectural choice, but primarily as a research tool that makes the internal workings of neural networks more transparent to mathematical analysis, opening concrete avenues for mechanistic interpretability and safety research.