Training SAEs
Some empirical tips and tricks for training sparse autoencoders (SAEs).
Last Updated: 11/17/2025
The best references, in my experience, are Scaling and Evaluating Sparse Autoencoders by OpenAI and Gemma Scope by Google.
Background
I assume an encoder-decoder SAE with input vectors \(x \in \mathbb{R}^{d}\), a sparse representation \(f(x) \in \mathbb{R}^{n}\) and a reconstructed input \(\hat{x} \in \mathbb{R}^d\). The SAE contains a linear encoder \(W_\text{enc} \in \mathbb{R}^{n \times d}\) and \(b_\text{enc} \in \mathbb{R}^n\), a linear decoder \(W_\text{dec} \in \mathbb{R}^{d \times n}\) and \(b_\text{dec} \in \mathbb{R}^d\) and a nonlinear activation function \(a : \mathbb{R}^n \rightarrow \mathbb{R}^n\). The sparse representation is \(f(x) = a(W_\text{enc} \cdot (x - b_\text{dec}) + b_\text{enc})\) and the reconstructed input is \(\hat{x} = W_\text{dec} \cdot f(x) + b_\text{dec}\). Subtracting \(b_\text{dec}\) is called pre-encoder bias.
Initialization
Anthropic and Pierre Peigne suggest using data-point initialization.
- Select \(n\) random data points from your training data.
- Compute the mean \(\mu\) and zero-center the data: \(x_0 = x - \mu\).
- Linearly blend each zero-centered datapoint with Kaiming initialization: \(w = p \cdot (x - \mu) + (1 - p) \cdot r\) where \(p\) is your blend probability and \(r\) is a randomly sampled Kaiming initalization vector.
- Initialize \(W_\text{enc}\) as a concatenation of \(n\) blended vectors.
- Initialize \(W_\text{dec}\) as \(W_\text{enc}^T\).
Anthropic suggests \(p = 0.8\) for SAEs and 0.4 for “weakly causal crosscoders”. I interpret this that there is no universally appropriate \(p\).
The intuition behind why this is an effective way to seed dictionaries is that model activations are not isotropic, so we initialize the parameters to be in the higher density region of model activations. This might lead to an initial boost in both sparsity and reconstruction. More importantly, this works empirically.
Gao et al. also suggests initializing \(W_\text{enc}\) to the transpose of \(W_\text{dec}\) to solve dead latents.
Initialize \(b_\text{dec}\) to all zeros.
Anthropic recommends initializing \(b_\text{enc}\) to “a constant per feature such that each feature activates \(\frac{10K}{m}\) of the time.” This means that “in aggregate roughly 10,000 features will fire per datapoint” and they “think this initialization is important for avoiding dead features.” I just use Kaiming initialization.
Activation
The choice of activation function \(a\) is a hot area. It seems that TopK (introduced by Gao et al.) and its variants are very strong and posess useful properties: \[f(x) = \text{TopK}(W_\text{enc} \cdot x + b_\text{enc})\] where TopK zeroes out all but the largest \(k\) largest activations in the input. This is nice: you get to pick your L\(_0\) sparsity directly by choosing \(k\) instead of tuning some \(\lambda\) hyperparameter.
There is an important variant, BatchTopK, which picks the top \(k \times \text{bsz}\) values across the entire batch. This enables more flexibility: “BatchTopK adaptively allocates more or fewer latents depending on the sample, improving reconstruction without sacrificing average sparsity” but you must learn a threshold value for inference to prevent within-batch effects.
Here is some example code in PyTorch. Gradients are trivial.
class TopK(nn.Module):
def __init__(self, k: int):
super().__init__()
self.k = k
def forward(self, x ):
bsz, d_sae = x.shape
k = min(self.k, d_sae)
_, idxs = torch.topk(x, k, dim=-1, sorted=False)
mask = torch.zeros_like(x).scatter(-1, idxs, 1.0)
return torch.mul(mask, x)Here’s some code for BatchTopK, which includes the learned inference threshold buffer.
class BatchTopK(nn.Module):
def __init__(self, k, momentum):
super().__init__()
self.k = k
self.momentum = momentum
self.register_buffer("threshold", torch.tensor(0.0))
def forward(self, x):
if not self.training:
# Fallback: if θ is still 0 (e.g. never trained), just do ReLU.
if self.threshold <= 0:
return torch.where(x > 0, x, torch.zeros_like(x))
return torch.where(x > self.threshold, x, torch.zeros_like(x))
bsz, d_sae = x.shape
x_flat = x.flatten()
bsz, d_sae = x.shape
k = min(self.k * bsz, d_sae * bsz)
_, idxs = torch.topk(x_flat, k, sorted=False)
mask = torch.zeros_like(x_flat).scatter(-1, idxs, 1.0).reshape(x.shape)
x = torch.mul(mask, x)
with torch.no_grad():
# smallest positive activation in this batch (i.e. the effective threshold)
pos = x[x > 0]
if pos.numel() >= 0:
# EMA update, like BatchNorm
self.threshold.mul_(1 - self.momentum).add_(self.momentum * pos.min())
return xObjective
With TopK, the main objective is just reconstruction and is quite straightforward: \[\mathcal{L} = ||x - \hat{x}||_2^2\] You do not need to minimize L\(_1\) sparsity.
However, you should probably use the Matryoshka objective, which learns much better features in my experience.
There are also many different auxiliary losses that try to minimize dead features and dense features. I haven’t used any of them, so I cannot speak to them.
Data
Scale your data so that the average L2 norm is \(\sqrt{d}\).
The goal of this change is for the same value of \(k\) to mean the same thing across datasets generated by different size transformers.
Don’t worry about subtracting the mean; subtracting \(b_\text{dec}\) lets the model learn that.
Optimization
Basically everyone uses Adam (no weight decay). There is some discussion of Adam betas, but no clear consensus.
Anthropic linearly decays learning rate to 0 over the last 20% of training, OpenAI didn’t use learning rate decay. Google used cosine learning rate warmup over 1K training steps. I use linear learning rate warmup and cosine decay to 0 for the rest of training.
Use gradient clipping. Anthropic recommends 1. I use that without any issues.
There are a couple SAE-specific tricks to help with training
Constraining \(W_\text{dec}\) to have unit norm columns. At every step, simply rescale \(W_\text{dec}\):
with torch.no_grad():
W_dec.data /= torch.norm(W_dec.data, dim=1, keepdim=True)This makes the values of \(f(x)\) comparable. If a particular column \(i\) of \(W_\text{dec}\) had a much smaller norm, \(f_i(x)\) would have to be much larger.
Removing parallel gradients. Because we normalize the columns of \(W_\text{dec}\), we want to remove the parallel component of the gradient so that Adam doesn’t do anything funny with moments. This was originally described in Towards Monosemanticity.
with torch.no_grad():
parallel = W_dec.grad @ W_dec.data
norm_sq = torch.sum(W_dec.data * W_dec.data, dim=1)
scales = torch.zeros_like(parallel)
nonzero = norm_sq > 0
scales[nonzero] = parallel[nonzero] / norm_sq[nonzero]
W_dec.grad -= scales * W_dec.dataHardware
This section is mostly taken from How to Scale Your Model.
While SAEs are smaller than the foundation model, training faster is always an advantage. Furthermore, SAEs have different compute-bandwidth tradeoffs compared to most neural network training. So it’s worth thinking training efficiently.
Sam Stevens, 2024