Claude
Skills
Sign in
Back

transformer-lens-interpretability

Included with Lifetime
$97 forever

Provides guidance for mechanistic interpretability research using TransformerLens to inspect and manipulate transformer internals via HookPoints and activation caching. Use when reverse-engineering model algorithms, studying attention patterns, or performing activation patching experiments.

GeneralMechanistic InterpretabilityTransformerLensActivation PatchingCircuit Analysis

What this skill does


# TransformerLens: Mechanistic Interpretability for Transformers

TransformerLens is the de facto standard library for mechanistic interpretability research on GPT-style language models. Created by Neel Nanda and maintained by Bryce Meyer, it provides clean interfaces to inspect and manipulate model internals via HookPoints on every activation.

**GitHub**: [TransformerLensOrg/TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (2,900+ stars)

## When to Use TransformerLens

**Use TransformerLens when you need to:**
- Reverse-engineer algorithms learned during training
- Perform activation patching / causal tracing experiments
- Study attention patterns and information flow
- Analyze circuits (e.g., induction heads, IOI circuit)
- Cache and inspect intermediate activations
- Apply direct logit attribution

**Consider alternatives when:**
- You need to work with non-transformer architectures → Use **nnsight** or **pyvene**
- You want to train/analyze Sparse Autoencoders → Use **SAELens**
- You need remote execution on massive models → Use **nnsight** with NDIF
- You want higher-level causal intervention abstractions → Use **pyvene**

## Installation

```bash
pip install transformer-lens
```

For development version:
```bash
pip install git+https://github.com/TransformerLensOrg/TransformerLens
```

## Core Concepts

### HookedTransformer

The main class that wraps transformer models with HookPoints on every activation:

```python
from transformer_lens import HookedTransformer

# Load a model
model = HookedTransformer.from_pretrained("gpt2-small")

# For gated models (LLaMA, Mistral)
import os
os.environ["HF_TOKEN"] = "your_token"
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
```

### Supported Models (50+)

| Family | Models |
|--------|--------|
| GPT-2 | gpt2, gpt2-medium, gpt2-large, gpt2-xl |
| LLaMA | llama-7b, llama-13b, llama-2-7b, llama-2-13b |
| EleutherAI | pythia-70m to pythia-12b, gpt-neo, gpt-j-6b |
| Mistral | mistral-7b, mixtral-8x7b |
| Others | phi, qwen, opt, gemma |

### Activation Caching

Run the model and cache all intermediate activations:

```python
# Get all activations
tokens = model.to_tokens("The Eiffel Tower is in")
logits, cache = model.run_with_cache(tokens)

# Access specific activations
residual = cache["resid_post", 5]  # Layer 5 residual stream
attn_pattern = cache["pattern", 3]  # Layer 3 attention pattern
mlp_out = cache["mlp_out", 7]  # Layer 7 MLP output

# Filter which activations to cache (saves memory)
logits, cache = model.run_with_cache(
    tokens,
    names_filter=lambda name: "resid_post" in name
)
```

### ActivationCache Keys

| Key Pattern | Shape | Description |
|-------------|-------|-------------|
| `resid_pre, layer` | [batch, pos, d_model] | Residual before attention |
| `resid_mid, layer` | [batch, pos, d_model] | Residual after attention |
| `resid_post, layer` | [batch, pos, d_model] | Residual after MLP |
| `attn_out, layer` | [batch, pos, d_model] | Attention output |
| `mlp_out, layer` | [batch, pos, d_model] | MLP output |
| `pattern, layer` | [batch, head, q_pos, k_pos] | Attention pattern (post-softmax) |
| `q, layer` | [batch, pos, head, d_head] | Query vectors |
| `k, layer` | [batch, pos, head, d_head] | Key vectors |
| `v, layer` | [batch, pos, head, d_head] | Value vectors |

## Workflow 1: Activation Patching (Causal Tracing)

Identify which activations causally affect model output by patching clean activations into corrupted runs.

### Step-by-Step

```python
from transformer_lens import HookedTransformer, patching
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

# 1. Define clean and corrupted prompts
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

# 2. Get clean activations
_, clean_cache = model.run_with_cache(clean_tokens)

# 3. Define metric (e.g., logit difference)
paris_token = model.to_single_token(" Paris")
rome_token = model.to_single_token(" Rome")

def metric(logits):
    return logits[0, -1, paris_token] - logits[0, -1, rome_token]

# 4. Patch each position and layer
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])

for layer in range(model.cfg.n_layers):
    for pos in range(clean_tokens.shape[1]):
        def patch_hook(activation, hook):
            activation[0, pos] = clean_cache[hook.name][0, pos]
            return activation

        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
        )
        results[layer, pos] = metric(patched_logits)

# 5. Visualize results (layer x position heatmap)
```

### Checklist
- [ ] Define clean and corrupted inputs that differ minimally
- [ ] Choose metric that captures behavior difference
- [ ] Cache clean activations
- [ ] Systematically patch each (layer, position) combination
- [ ] Visualize results as heatmap
- [ ] Identify causal hotspots

## Workflow 2: Circuit Analysis (Indirect Object Identification)

Replicate the IOI circuit discovery from "Interpretability in the Wild".

### Step-by-Step

```python
from transformer_lens import HookedTransformer
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

# IOI task: "When John and Mary went to the store, Mary gave a bottle to"
# Model should predict "John" (indirect object)

prompt = "When John and Mary went to the store, Mary gave a bottle to"
tokens = model.to_tokens(prompt)

# 1. Get baseline logits
logits, cache = model.run_with_cache(tokens)

john_token = model.to_single_token(" John")
mary_token = model.to_single_token(" Mary")

# 2. Compute logit difference (IO - S)
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token]
print(f"Logit difference: {logit_diff.item():.3f}")

# 3. Direct logit attribution by head
def get_head_contribution(layer, head):
    # Project head output to logits
    head_out = cache["z", layer][0, :, head, :]  # [pos, d_head]
    W_O = model.W_O[layer, head]  # [d_head, d_model]
    W_U = model.W_U  # [d_model, vocab]

    # Head contribution to logits at final position
    contribution = head_out[-1] @ W_O @ W_U
    return contribution[john_token] - contribution[mary_token]

# 4. Map all heads
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        head_contributions[layer, head] = get_head_contribution(layer, head)

# 5. Identify top contributing heads (name movers, backup name movers)
```

### Checklist
- [ ] Set up task with clear IO/S tokens
- [ ] Compute baseline logit difference
- [ ] Decompose by attention head contributions
- [ ] Identify key circuit components (name movers, S-inhibition, induction)
- [ ] Validate with ablation experiments

## Workflow 3: Induction Head Detection

Find induction heads that implement [A][B]...[A] → [B] pattern.

```python
from transformer_lens import HookedTransformer
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

# Create repeated sequence: [A][B][A] should predict [B]
repeated_tokens = torch.tensor([[1000, 2000, 1000]])  # Arbitrary tokens

_, cache = model.run_with_cache(repeated_tokens)

# Induction heads attend from final [A] back to first [B]
# Check attention from position 2 to position 1
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

for layer in range(model.cfg.n_layers):
    pattern = cache["pattern", layer][0]  # [head, q_pos, k_pos]
    # Attention from pos 2 to pos 1
    induction_scores[layer] = pattern[:, 2, 1]

# Heads with high scores are induction heads
top_heads = torch.topk(induction_scores.flatten(), k=5)
```

## Common Issues & Solutions

### Issue: Hooks persist after debugging
```python
# WRONG: Old hooks remain active
model.run_with_hooks(tokens, fwd_hooks=[...])  # Deb

Related in General