Module saev =========== saev is a Python package for training sparse autoencoders (SAEs) on vision transformers (ViTs) in PyTorch. The main entrypoint to the package is in `__main__`; use `python -m saev --help` to see the options and documentation for the script. # Guide to Training SAEs on Vision Models 1. Record ViT activations and save them to disk. 2. Train SAEs on the activations. 3. Visualize the learned features from the trained SAEs. 4. (your job) Propose trends and patterns in the visualized features. 5. (your job, supported by code) Construct datasets to test your hypothesized trends. 6. Confirm/reject hypotheses using `probing` package. `saev` helps with steps 1, 2 and 3. .. note:: `saev` assumes you are running on NVIDIA GPUs. On a multi-GPU system, prefix your commands with `CUDA_VISIBLE_DEVICES=X` to run on GPU X. ## Record ViT Activations to Disk To save activations to disk, we need to specify: 1. Which model we would like to use 2. Which layers we would like to save. 3. Where on disk and how we would like to save activations. 4. Which images we want to save activations for. The `saev.activations` module does all of this for us. Run `uv run python -m saev activations --help` to see all the configuration. In practice, you might run: ```sh uv run python -m saev activations \ --model-group clip \ --model-ckpt ViT-B-32/openai \ --d-vit 768 \ --n-patches-per-img 49 \ --layers -2 \ --dump-to /local/scratch/$USER/cache/saev \ --n-patches-per-shard 2_4000_000 \ data:imagenet-dataset ``` This will save activations for the CLIP-pretrained model ViT-B/32, which has a residual stream dimension of 768, and has 49 patches per image (224 / 32 = 7; 7 x 7 = 49). It will save the second-to-last layer (`--layer -2`). It will write 2.4M patches per shard, and save shards to a new directory `/local/scratch$USER/cache/saev`. .. note:: A note on storage space: A ViT-B/16 will save 1.2M images x 197 patches/layer/image x 1 layer = ~240M activations, each of which take up 768 floats x 4 bytes/float = 3072 bytes, for a **total of 723GB** for the entire dataset. As you scale to larger models (ViT-L has 1024 dimensions, 14x14 patches are 224 patches/layer/image), recorded activations will grow even larger. This script will also save a `metadata.json` file that will record the relevant metadata for these activations, which will be read by future steps. The activations will be in `.bin` files, numbered starting from 000000. To add your own models, see the guide to extending in `saev.activations`. ## Train SAEs on Activations To train an SAE, we need to specify: 1. Which activations to use as input. 2. SAE architectural stuff. 3. Optimization-related stuff. `The `saev.training` module handles this. Run `uv run python -m saev train --help` to see all the configuration. Continuing on from our example before, you might want to run something like: ```sh uv run python -m saev train \ --data.shard-root /local/scratch/$USER/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8 \ --data.layer -2 \ --data.patches patches \ --data.no-scale-mean \ --data.no-scale-norm \ --sae.d-vit 768 \ --lr 5e-4 ``` `--data.*` flags describe which activations to use. `--data.shard-root` should point to a directory with `*.bin` files and the `metadata.json` file. `--data.layer` specifies the layer, and `--data.patches` says that want to train on individual patch activations, rather than the [CLS] token activation. `--data.no-scale-mean` and `--data.no-scale-norm` mean not to scale the activation mean or L2 norm. Anthropic's and OpenAI's papers suggest normalizing these factors, but `saev` still has a bug with this, so I suggest not scaling these factors. `--sae.*` flags are about the SAE itself. `--sae.d-vit` is the only one you need to change; the dimension of our ViT was 768 for a ViT-B, rather than the default of 1024 for a ViT-L. Finally, choose a slightly larger learning rate than the default with `--lr 5e-4`. This will train one (1) sparse autoencoder on the data. See the section on sweeps to learn how to train multiple SAEs in parallel using only a single GPU. ## Visualize the Learned Features Now that you've trained an SAE, you probably want to look at its learned features. One way to visualize an individual learned feature \(f\) is by picking out images that maximize the activation of feature \(f\). Since we train SAEs on patch-level activations, we try to find the top *patches* for each feature \(f\). Then, we pick out the images those patches correspond to and create a heatmap based on SAE activation values. .. note:: More advanced forms of visualization are possible (and valuable!), but should not be included in `saev` unless they can be applied to every SAE/dataset combination. If you have specific visualizations, please add them to `contrib/` or another location. `saev.visuals` records these maximally activating images for us. You can see all the options with `uv run python -m saev visuals --help`. So you might run: ```sh uv run python -m saev visuals \ --ckpt checkpoints/abcdefg/sae.pt \ --dump-to /nfs/$USER/saev/webapp/abcdefg \ --data.shard-root /local/scratch/$USER/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8 \ --data.layer -2 \ --data.patches patches \ --data.no-scale-mean \ --data.no-scale-norm \ images:imagenet-dataset ``` This will record the top 128 patches, and then save the unique images among those top 128 patches for each feature in the trained SAE. It will cache these best activations to disk, then start saving images to visualize later on. `saev.interactive.features` is a small web application based on [marimo](https://marimo.io/) to interactively look at these images. You can run it with `uv run marimo edit saev/interactive/features.py`. ## Sweeps .. todo:: Explain how to run grid sweeps. ## Training Metrics and Visualizations .. todo:: Explain how to use the `saev.interactive.metrics` notebook. # Related Work Various papers and internet posts on training SAEs for vision. ## Preprints [An X-Ray Is Worth 15 Features: Sparse Autoencoders for Interpretable Radiology Report Generation](https://arxiv.org/pdf/2410.03334) * Haven't read this yet, but Hugo Fry is an author. ## LessWrong [Towards Multimodal Interpretability: Learning Sparse Interpretable Features in Vision Transformers](https://www.lesswrong.com/posts/bCtbuWraqYTDtuARg/towards-multimodal-interpretability-learning-sparse-2) * Trains a sparse autoencoder on the 22nd layer of a CLIP ViT-L/14. First public work training an SAE on a ViT. Finds interesting features, demonstrating that SAEs work with ViTs. [Interpreting and Steering Features in Images](https://www.lesswrong.com/posts/Quqekpvx8BGMMcaem/interpreting-and-steering-features-in-images) * Havne't read it yet. [Case Study: Interpreting, Manipulating, and Controlling CLIP With Sparse Autoencoders](https://www.lesswrong.com/posts/iYFuZo9BMvr6GgMs5/case-study-interpreting-manipulating-and-controlling-clip) * Followup to the above work; haven't read it yet. [A Suite of Vision Sparse Autoencoders](https://www.lesswrong.com/posts/wrznNDMRmbQABAEMH/a-suite-of-vision-sparse-autoencoders) * Train a sparse autoencoder on various layers using the TopK with k=32 on a CLIP ViT-L/14 trained on LAION-2B. The SAE is trained on 1.2B tokens including patch (not just [CLS]). Limited evaluation. Sub-modules ----------- * saev.activations * saev.config * saev.helpers * saev.imaging * saev.interactive * saev.nn * saev.test_activations * saev.test_config * saev.test_nn * saev.test_training * saev.test_visuals * saev.training * saev.visuals Module saev.activations ======================= To save lots of activations, we want to do things in parallel, with lots of slurm jobs, and save multiple files, rather than just one. This module handles that additional complexity. Conceptually, activations are either thought of as 1. A single [n_imgs x n_layers x (n_patches + 1), d_vit] tensor. This is a *dataset* 2. Multiple [n_imgs_per_shard, n_layers, (n_patches + 1), d_vit] tensors. This is a set of sharded activations. Functions --------- `get_acts_dir(cfg: saev.config.Activations) ‑> str` : Return the activations directory based on the relevant values of a config. Also saves a metadata.json file to that directory for human reference. Args: cfg: Config for experiment. Returns: Directory to where activations should be dumped/loaded from. `get_dataloader(cfg: saev.config.Activations, *, img_transform=None)` : Gets the dataloader for the current experiment; delegates dataloader construction to dataset-specific functions. Args: cfg: Experiment config. img_transform: Image transform to be applied to each image. Returns: A PyTorch Dataloader that yields dictionaries with `'image'` keys containing image batches. `get_dataset(cfg: saev.config.ImagenetDataset | saev.config.ImageFolderDataset | saev.config.Ade20kDataset, *, img_transform)` : Gets the dataset for the current experiment; delegates construction to dataset-specific functions. Args: cfg: Experiment config. img_transform: Image transform to be applied to each image. Returns: A dataset that has dictionaries with `'image'`, `'index'`, `'target'`, and `'label'` keys containing examples. `get_default_dataloader(cfg: saev.config.Activations, *, img_transform: ) ‑> torch.utils.data.dataloader.DataLoader` : Get a dataloader for a default map-style dataset. Args: cfg: Config. img_transform: Image transform to be applied to each image. Returns: A PyTorch Dataloader that yields dictionaries with `'image'` keys containing image batches, `'index'` keys containing original dataset indices and `'label'` keys containing label batches. `main(cfg: saev.config.Activations)` : Args: cfg: Config for activations. `make_img_transform(model_family: str, model_ckpt: str) ‑> ` : `make_vit(cfg: saev.config.Activations)` : `setup(cfg: saev.config.Activations)` : Run dataset-specific setup. These setup functions can assume they are the only job running, but they should be idempotent; they should be safe (and ideally cheap) to run multiple times in a row. `setup_ade20k(cfg: saev.config.Activations)` : `setup_imagefolder(cfg: saev.config.Activations)` : `setup_imagenet(cfg: saev.config.Activations)` : `worker_fn(cfg: saev.config.Activations)` : Args: cfg: Config for activations. Classes ------- `Ade20k(cfg: saev.config.Ade20kDataset, *, img_transform: collections.abc.Callable | None = None, seg_transform: collections.abc.Callable | None = >)` : An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. ### Ancestors (in MRO) * torch.utils.data.dataset.Dataset * typing.Generic ### Class variables `Sample` : `samples: list[saev.activations.Ade20k.Sample]` : `Clip(cfg: saev.config.Activations)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> Callable[..., Any]` : Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:`Module` instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. `Dataset(cfg: saev.config.DataLoad)` : Dataset of activations from disk. ### Ancestors (in MRO) * torch.utils.data.dataset.Dataset * typing.Generic ### Class variables `Example` : Individual example. `act_mean: jaxtyping.Float[Tensor, 'd_vit']` : Mean activation. `cfg: saev.config.DataLoad` : Configuration; set via CLI args. `layer_index: int` : Layer index into the shards if we are choosing a specific layer. `metadata: saev.activations.Metadata` : Activations metadata; automatically loaded from disk. `scalar: float` : Normalizing scalar such that ||x / scalar ||_2 ~= sqrt(d_vit). ### Instance variables `d_vit: int` : Dimension of the underlying vision transformer's embedding space. ### Methods `get_img_patches(self, i: int) ‑> jaxtyping.Float[ndarray, 'n_layers all_patches d_vit']` : `get_shard_patches(self)` : `transform(self, act: jaxtyping.Float[ndarray, 'd_vit']) ‑> jaxtyping.Float[Tensor, 'd_vit']` : Apply a scalar normalization so the mean squared L2 norm is same as d_vit. This is from 'Scaling Monosemanticity': > As a preprocessing step we apply a scalar normalization to the model activations so their average squared L2 norm is the residual stream dimension So we divide by self.scalar which is the datasets (approximate) L2 mean before normalization divided by sqrt(d_vit). `DinoV2(cfg: saev.config.Activations)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> Callable[..., Any]` : Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:`Module` instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. `ImageFolder(root: str | pathlib.Path, transform: Callable | None = None, target_transform: Callable | None = None, loader: Callable[[str], Any] = , is_valid_file: Callable[[str], bool] | None = None, allow_empty: bool = False)` : A generic data loader where the images are arranged in this way by default: :: root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png This class inherits from :class:`~torchvision.datasets.DatasetFolder` so the same methods can be overridden to customize the dataset. Args: root (str or ``pathlib.Path``): Root directory path. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) allow_empty(bool, optional): If True, empty folders are considered to be valid classes. An error is raised on empty folders if False (default). Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples ### Ancestors (in MRO) * torchvision.datasets.folder.ImageFolder * torchvision.datasets.folder.DatasetFolder * torchvision.datasets.vision.VisionDataset * torch.utils.data.dataset.Dataset * typing.Generic `Imagenet(cfg: saev.config.ImagenetDataset, *, img_transform=None)` : An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. ### Ancestors (in MRO) * torch.utils.data.dataset.Dataset * typing.Generic `MaskedAutoencoder(cfg: saev.config.Activations)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> Callable[..., Any]` : Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:`Module` instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. `Metadata(model_family: str, model_ckpt: str, layers: tuple[int, ...], n_patches_per_img: int, cls_token: bool, d_vit: int, seed: int, n_imgs: int, n_patches_per_shard: int, data: str)` : Metadata(model_family: str, model_ckpt: str, layers: tuple[int, ...], n_patches_per_img: int, cls_token: bool, d_vit: int, seed: int, n_imgs: int, n_patches_per_shard: int, data: str) ### Class variables `cls_token: bool` : `d_vit: int` : `data: str` : `layers: tuple[int, ...]` : `model_ckpt: str` : `model_family: str` : `n_imgs: int` : `n_patches_per_img: int` : `n_patches_per_shard: int` : `seed: int` : ### Static methods `from_cfg(cls, cfg: saev.config.Activations) ‑> saev.activations.Metadata` : `load(cls, fpath) ‑> saev.activations.Metadata` : ### Instance variables `hash: str` : ### Methods `dump(self, fpath)` : `ShardWriter(cfg: saev.config.Activations)` : ShardWriter is a stateful object that handles sharded activation writing to disk. ### Class variables `acts: jaxtyping.Float[ndarray, 'n_imgs_per_shard n_layers all_patches d_vit'] | None` : `acts_path: str` : `filled: int` : `root: str` : `shape: tuple[int, int, int, int]` : `shard: int` : ### Methods `flush(self) ‑> None` : `next_shard(self) ‑> None` : `Siglip(cfg: saev.config.Activations)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> Callable[..., Any]` : Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:`Module` instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. `VitRecorder(cfg: saev.config.Activations, patches: slice = slice(None, None, None))` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Class variables `cfg: saev.config.Activations` : ### Instance variables `activations: jaxtyping.Float[Tensor, 'batch n_layers all_patches dim']` : ### Methods `hook(self, module, args: tuple, output: jaxtyping.Float[Tensor, 'batch n_layers dim']) ‑> None` : `register(self, modules: list[torch.nn.modules.module.Module])` : `reset(self)` : Module saev.config ================== All configs for all saev jobs. ## Import Times This module should be very fast to import so that `python main.py --help` is fast. This means that the top-level imports should not include big packages like numpy, torch, etc. For example, `TreeOfLife.n_imgs` imports numpy when it's needed, rather than importing it at the top level. Also contains code for expanding configs with lists into lists of configs (grid search). Might be expanded in the future to support pseudo-random sampling from distributions to support random hyperparameter search, as in [this file](https://github.com/samuelstevens/sax/blob/main/sax/sweep.py). Functions --------- `expand(config: dict[str, object]) ‑> Iterator[dict[str, object]]` : Expands dicts with (nested) lists into a list of (nested) dicts. `grid(cfg: saev.config.Train, sweep_dct: dict[str, object]) ‑> tuple[list[saev.config.Train], list[str]]` : Classes ------- `Activations(data: saev.config.ImagenetDataset | saev.config.ImageFolderDataset | saev.config.Ade20kDataset = , dump_to: str = './shards', model_family: Literal['clip', 'siglip', 'dinov2', 'mae'] = 'clip', model_ckpt: str = 'ViT-L-14/openai', vit_batch_size: int = 1024, n_workers: int = 8, d_vit: int = 1024, layers: list[int] = , n_patches_per_img: int = 256, cls_token: bool = True, n_patches_per_shard: int = 2400000, seed: int = 42, ssl: bool = True, device: str = 'cuda', slurm: bool = False, slurm_acct: str = 'PAS2136', log_to: str = './logs')` : Configuration for calculating and saving ViT activations. ### Class variables `cls_token: bool` : Whether the model has a [CLS] token. `d_vit: int` : Dimension of the ViT activations (depends on model). `data: saev.config.ImagenetDataset | saev.config.ImageFolderDataset | saev.config.Ade20kDataset` : Which dataset to use. `device: str` : Which device to use. `dump_to: str` : Where to write shards. `layers: list[int]` : Which layers to save. By default, the second-to-last layer. `log_to: str` : Where to log Slurm job stdout/stderr. `model_ckpt: str` : Specific model checkpoint. `model_family: Literal['clip', 'siglip', 'dinov2', 'mae']` : Which model family. `n_patches_per_img: int` : Number of ViT patches per image (depends on model). `n_patches_per_shard: int` : Number of activations per shard; 2.4M is approximately 10GB for 1024-dimensional 4-byte activations. `n_workers: int` : Number of dataloader workers. `seed: int` : Random seed. `slurm: bool` : Whether to use `submitit` to run jobs on a Slurm cluster. `slurm_acct: str` : Slurm account string. `ssl: bool` : Whether to use SSL. `vit_batch_size: int` : Batch size for ViT inference. `Ade20kDataset(root: str = './data/ade20k', split: Literal['training', 'validation'] = 'training')` : ### Class variables `root: str` : Where the class folders with images are stored. `split: Literal['training', 'validation']` : Data split. ### Instance variables `n_imgs: int` : `DataLoad(shard_root: str = './shards', patches: Literal['cls', 'patches', 'meanpool'] = 'patches', layer: int | Literal['all', 'meanpool'] = -2, clamp: float = 100000.0, n_random_samples: int = 524288, scale_mean: bool = True, scale_norm: bool = True)` : Configuration for loading activation data from disk. ### Class variables `clamp: float` : Maximum value for activations; activations will be clamped to within [-clamp, clamp]`. `layer: int | Literal['all', 'meanpool']` : .. todo: document this field. `n_random_samples: int` : Number of random samples used to calculate approximate dataset means at startup. `patches: Literal['cls', 'patches', 'meanpool']` : Which kinds of patches to use. 'cls' indicates just the [CLS] token (if any). 'patches' indicates it will return all patches. 'meanpool' returns the mean of all image patches. `scale_mean: bool` : Whether to subtract approximate dataset means from examples. `scale_norm: bool` : Whether to scale average dataset norm to sqrt(d_vit). `shard_root: str` : Directory with .bin shards and a metadata.json file. `ImageFolderDataset(root: str = './data/split')` : Configuration for a generic image folder dataset. ### Class variables `root: str` : Where the class folders with images are stored. ### Instance variables `n_imgs: int` : Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires walking the directory structure. If you need to reference this number very often, cache it in a local variable. `ImagenetDataset(name: str = 'ILSVRC/imagenet-1k', split: str = 'train')` : Configuration for HuggingFace Imagenet. ### Class variables `name: str` : Dataset name on HuggingFace. Don't need to change this.. `split: str` : Dataset split. For the default ImageNet-1K dataset, can either be 'train', 'validation' or 'test'. ### Instance variables `n_imgs: int` : Number of images in the dataset. Calculated on the fly, but is non-trivial to calculate because it requires loading the dataset. If you need to reference this number very often, cache it in a local variable. `SparseAutoencoder(d_vit: int = 1024, exp_factor: int = 16, sparsity_coeff: float = 0.0004, n_reinit_samples: int = 524288, ghost_grads: bool = False, remove_parallel_grads: bool = True, normalize_w_dec: bool = True, seed: int = 0)` : SparseAutoencoder(d_vit: int = 1024, exp_factor: int = 16, sparsity_coeff: float = 0.0004, n_reinit_samples: int = 524288, ghost_grads: bool = False, remove_parallel_grads: bool = True, normalize_w_dec: bool = True, seed: int = 0) ### Class variables `d_vit: int` : `exp_factor: int` : Expansion factor for SAE. `ghost_grads: bool` : Whether to use ghost grads. `n_reinit_samples: int` : Number of samples to use for SAE re-init. Anthropic proposes initializing b_dec to the geometric median of the dataset here: https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-bias. We use the regular mean. `normalize_w_dec: bool` : Whether to make sure W_dec has unit norm columns. See https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder for original citation. `remove_parallel_grads: bool` : Whether to remove gradients parallel to W_dec columns (which will be ignored because we force the columns to have unit norm). See https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-optimization for the original discussion from Anthropic. `seed: int` : Random seed. `sparsity_coeff: float` : How much to weight sparsity loss term. ### Instance variables `d_sae: int` : `Train(data: saev.config.DataLoad = , n_workers: int = 32, n_patches: int = 100000000, sae: saev.config.SparseAutoencoder = , n_sparsity_warmup: int = 0, lr: float = 0.0004, n_lr_warmup: int = 500, sae_batch_size: int = 16384, track: bool = True, wandb_project: str = 'saev', tag: str = '', log_every: int = 25, ckpt_path: str = './checkpoints', device: Literal['cuda', 'cpu'] = 'cuda', seed: int = 42, slurm: bool = False, slurm_acct: str = 'PAS2136', log_to: str = './logs')` : Configuration for training a sparse autoencoder on a vision transformer. ### Class variables `ckpt_path: str` : Where to save checkpoints. `data: saev.config.DataLoad` : Data configuration `device: Literal['cuda', 'cpu']` : Hardware device. `log_every: int` : How often to log to WandB. `log_to: str` : Where to log Slurm job stdout/stderr. `lr: float` : Learning rate. `n_lr_warmup: int` : Number of learning rate warmup steps. `n_patches: int` : Number of SAE training examples. `n_sparsity_warmup: int` : Number of sparsity coefficient warmup steps. `n_workers: int` : Number of dataloader workers. `sae: saev.config.SparseAutoencoder` : SAE configuration. `sae_batch_size: int` : Batch size for SAE training. `seed: int` : Random seed. `slurm: bool` : Whether to use `submitit` to run jobs on a Slurm cluster. `slurm_acct: str` : Slurm account string. `tag: str` : Tag to add to WandB run. `track: bool` : Whether to track with WandB. `wandb_project: str` : WandB project name. `Visuals(ckpt: str = './checkpoints/sae.pt', data: saev.config.DataLoad = , images: saev.config.ImagenetDataset | saev.config.ImageFolderDataset | saev.config.Ade20kDataset = , top_k: int = 128, n_workers: int = 16, topk_batch_size: int = 16384, sae_batch_size: int = 16384, epsilon: float = 1e-09, sort_by: Literal['cls', 'img', 'patch'] = 'patch', device: str = 'cuda', dump_to: str = './data', log_freq_range: tuple[float, float] = (-6.0, -2.0), log_value_range: tuple[float, float] = (-1.0, 1.0), include_latents: list[int] = , n_distributions: int = 25, percentile: int = 99, n_latents: int = 400, seed: int = 42)` : Configuration for generating visuals from trained SAEs. ### Class variables `ckpt: str` : Path to the sae.pt file. `data: saev.config.DataLoad` : Data configuration. `device: str` : Which accelerator to use. `dump_to: str` : Where to save data. `epsilon: float` : Value to add to avoid log(0). `images: saev.config.ImagenetDataset | saev.config.ImageFolderDataset | saev.config.Ade20kDataset` : Which images to use. `include_latents: list[int]` : Latents to always include, no matter what. `log_freq_range: tuple[float, float]` : Log10 frequency range for which to save images. `log_value_range: tuple[float, float]` : Log10 frequency range for which to save images. `n_distributions: int` : Number of features to save distributions for. `n_latents: int` : Maximum number of latents to save images for. `n_workers: int` : Number of dataloader workers. `percentile: int` : Percentile to estimate for outlier detection. `sae_batch_size: int` : Batch size for SAE inference. `seed: int` : Random seed. `sort_by: Literal['cls', 'img', 'patch']` : How to find the top k images. 'cls' picks images where the SAE latents of the ViT's [CLS] token are maximized without any patch highligting. 'img' picks images that maximize the sum of an SAE latent over all patches in the image, highlighting the patches. 'patch' pickes images that maximize an SAE latent over all patches (not summed), highlighting the patches and only showing unique images. `top_k: int` : How many images per SAE feature to store. `topk_batch_size: int` : Number of examples to apply top-k op to. ### Instance variables `distributions_fpath: str` : `mean_values_fpath: str` : `percentiles_fpath: str` : `root: str` : `sparsity_fpath: str` : `top_img_i_fpath: str` : `top_patch_i_fpath: str` : `top_values_fpath: str` : Module saev.helpers =================== Useful helpers for `saev`. Functions --------- `flattened(dct: dict[str, object], *, sep: str = '.') ‑> dict[str, str | int | float | bool | None]` : Flatten a potentially nested dict to a single-level dict with `.`-separated keys. `get(dct: dict[str, object], key: str, *, sep: str = '.') ‑> object` : `get_cache_dir() ‑> str` : Get cache directory from environment variables, defaulting to the current working directory (.) Returns: A path to a cache directory (might not exist yet). Classes ------- `progress(it, *, every: int = 10, desc: str = 'progress', total: int = 0)` : Wraps an iterable with a logger like tqdm but doesn't use any control codes to manipulate a progress bar, which doesn't work well when your output is redirected to a file. Instead, simple logging statements are used, but it includes quality-of-life features like iteration speed and predicted time to finish. Args: it: Iterable to wrap. every: How many iterations between logging progress. desc: What to name the logger. total: If non-zero, how long the iterable is. Module saev.imaging =================== Functions --------- `add_highlights(img: PIL.Image.Image, patches: jaxtyping.Float[ndarray, 'n_patches'], *, upper: float | None = None) ‑> PIL.Image.Image` : Namespace saev.interactive ========================== Sub-modules ----------- * saev.interactive.features * saev.interactive.metrics Module saev.interactive.features ================================ Module saev.interactive.metrics =============================== Module saev.nn ============== Neural network architectures for sparse autoencoders. Functions --------- `dump(fpath: str, sae: saev.nn.SparseAutoencoder)` : Save an SAE checkpoint to disk along with configuration, using the [trick from equinox](https://docs.kidger.site/equinox/examples/serialisation). Arguments: fpath: filepath to save checkpoint to. sae: sparse autoencoder checkpoint to save. `load(fpath: str, *, device: str = 'cpu') ‑> saev.nn.SparseAutoencoder` : Loads a sparse autoencoder from disk. `ref_mse(x_hat: jaxtyping.Float[Tensor, '*d'], x: jaxtyping.Float[Tensor, '*d'], norm: bool = True) ‑> jaxtyping.Float[Tensor, '*d']` : `safe_mse(x_hat: jaxtyping.Float[Tensor, '*batch d'], x: jaxtyping.Float[Tensor, '*batch d'], norm: bool = False) ‑> jaxtyping.Float[Tensor, '*batch d']` : Classes ------- `Loss(mse: jaxtyping.Float[Tensor, ''], sparsity: jaxtyping.Float[Tensor, ''], ghost_grad: jaxtyping.Float[Tensor, ''], l0: jaxtyping.Float[Tensor, ''], l1: jaxtyping.Float[Tensor, ''])` : The composite loss terms for an autoencoder training batch. ### Ancestors (in MRO) * builtins.tuple ### Instance variables `ghost_grad: jaxtyping.Float[Tensor, '']` : Ghost gradient loss, if any. `l0: jaxtyping.Float[Tensor, '']` : L0 magnitude of hidden activations. `l1: jaxtyping.Float[Tensor, '']` : L1 magnitude of hidden activations. `loss: jaxtyping.Float[Tensor, '']` : Total loss. `mse: jaxtyping.Float[Tensor, '']` : Reconstruction loss (mean squared error). `sparsity: jaxtyping.Float[Tensor, '']` : Sparsity loss, typically lambda * L1. `SparseAutoencoder(cfg: saev.config.SparseAutoencoder)` : Sparse auto-encoder (SAE) using L1 sparsity penalty. Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Class variables `cfg: saev.config.SparseAutoencoder` : ### Methods `forward(self, x: jaxtyping.Float[Tensor, 'batch d_model']) ‑> tuple[jaxtyping.Float[Tensor, 'batch d_model'], jaxtyping.Float[Tensor, 'batch d_sae'], saev.nn.Loss]` : Given x, calculates the reconstructed x_hat, the intermediate activations f_x, and the loss. Arguments: x: a batch of ViT activations. `init_b_dec(self, vit_acts: jaxtyping.Float[Tensor, 'n d_vit'])` : `normalize_w_dec(self)` : Set W_dec to unit-norm columns. `remove_parallel_grads(self)` : Update grads so that they remove the parallel component (d_sae, d_vit) shape Module saev.test_activations ============================ Test that the cached activations are actually correct. These tests are quite slow Functions --------- `test_dataloader_batches()` : `test_shard_writer_and_dataset_e2e()` : Module saev.test_config ======================= Functions --------- `test_expand()` : `test_expand_multiple()` : `test_expand_nested()` : `test_expand_nested_and_unnested()` : `test_expand_nested_and_unnested_backwards()` : `test_expand_two_fields()` : Module saev.test_nn =================== Uses [hypothesis]() and [hypothesis-torch](https://hypothesis-torch.readthedocs.io/en/stable/compatability/) to generate test cases to compare our normalized MSE implementation to a reference MSE implementation. Functions --------- `test_safe_mse_hypothesis() ‑> None` : `test_safe_mse_large_x()` : `test_safe_mse_nonzero()` : `test_safe_mse_same()` : `test_safe_mse_zero_x_hat()` : Module saev.test_training ========================= Functions --------- `test_split_cfgs_no_bad_keys()` : `test_split_cfgs_on_multiple_keys_with_multiple_per_key()` : `test_split_cfgs_on_single_key()` : `test_split_cfgs_on_single_key_with_multiple_per_key()` : Module saev.test_visuals ======================== Functions --------- `test_gather_batched_small()` : Module saev.training ==================== Trains many SAEs in parallel to amortize the cost of loading a single batch of data over many SAE training runs. Functions --------- `evaluate(cfgs: list[saev.config.Train], saes: torch.nn.modules.container.ModuleList) ‑> list[saev.training.EvalMetrics]` : Evaluates SAE quality by counting the number of dead features and the number of dense features. Also makes histogram plots to help human qualitative comparison. .. todo:: Develop automatic methods to use histogram and feature frequencies to evaluate quality with a single number. `init_b_dec_batched(saes: torch.nn.modules.container.ModuleList, dataset: saev.activations.Dataset)` : `main(cfgs: list[saev.config.Train]) ‑> list[str]` : `make_hashable(obj)` : `make_saes(cfgs: list[saev.config.SparseAutoencoder]) ‑> tuple[torch.nn.modules.container.ModuleList, list[dict[str, object]]]` : `split_cfgs(cfgs: list[saev.config.Train]) ‑> list[list[saev.config.Train]]` : Splits configs into groups that can be parallelized. Arguments: A list of configs from a sweep file. Returns: A list of lists, where the configs in each sublist do not differ in any keys that are in `CANNOT_PARALLELIZE`. This means that each sublist is a valid "parallel" set of configs for `train`. `train(cfgs: list[saev.config.Train]) ‑> tuple[torch.nn.modules.container.ModuleList, saev.training.ParallelWandbRun, int]` : Explicitly declare the optimizer, schedulers, dataloader, etc outside of `main` so that all the variables are dropped from scope and can be garbage collected. Classes ------- `BatchLimiter(dataloader: torch.utils.data.dataloader.DataLoader, n_samples: int)` : Limits the number of batches to only return `n_samples` total samples. `EvalMetrics(l0: float, l1: float, mse: float, n_dead: int, n_almost_dead: int, n_dense: int, freqs: jaxtyping.Float[Tensor, 'd_sae'], mean_values: jaxtyping.Float[Tensor, 'd_sae'], almost_dead_threshold: float, dense_threshold: float)` : Results of evaluating a trained SAE on a datset. ### Class variables `almost_dead_threshold: float` : Threshold for an "almost dead" neuron. `dense_threshold: float` : Threshold for a dense neuron. `freqs: jaxtyping.Float[Tensor, 'd_sae']` : How often each feature fired. `l0: float` : Mean L0 across all examples. `l1: float` : Mean L1 across all examples. `mean_values: jaxtyping.Float[Tensor, 'd_sae']` : The mean value for each feature when it did fire. `mse: float` : Mean MSE across all examples. `n_almost_dead: int` : Number of neurons that fired on fewer than `almost_dead_threshold` of examples. `n_dead: int` : Number of neurons that never fired on any example. `n_dense: int` : Number of neurons that fired on more than `dense_threshold` of examples. ### Methods `for_wandb(self) ‑> dict[str, int | float]` : `ParallelWandbRun(project: str, cfgs: list[saev.config.Train], mode: str, tags: list[str])` : Inspired by https://community.wandb.ai/t/is-it-possible-to-log-to-multiple-runs-simultaneously/4387/3. ### Methods `finish(self) ‑> list[str]` : `log(self, metrics: list[dict[str, object]], *, step: int)` : `Scheduler()` : ### Descendants * saev.training.Warmup ### Methods `step(self) ‑> float` : `Warmup(init: float, final: float, n_steps: int)` : Linearly increases from `init` to `final` over `n_warmup_steps` steps. ### Ancestors (in MRO) * saev.training.Scheduler ### Methods `step(self) ‑> float` : Module saev.visuals =================== There is some important notation used only in this file to dramatically shorten variable names. Variables suffixed with `_im` refer to entire images, and variables suffixed with `_p` refer to patches. Functions --------- `batched_idx(total_size: int, batch_size: int) ‑> Iterator[tuple[int, int]]` : Iterate over (start, end) indices for total_size examples, where end - start is at most batch_size. Args: total_size: total number of examples batch_size: maximum distance between the generated indices. Returns: A generator of (int, int) tuples that can slice up a list or a tensor. `dump_activations(cfg: saev.config.Visuals)` : For each SAE latent, we want to know which images have the most total "activation". That is, we keep track of each patch `gather_batched(value: jaxtyping.Float[Tensor, 'batch n dim'], i: jaxtyping.Int[Tensor, 'batch k']) ‑> jaxtyping.Float[Tensor, 'batch k dim']` : `get_new_topk(val1: jaxtyping.Float[Tensor, 'd_sae k'], i1: jaxtyping.Int[Tensor, 'd_sae k'], val2: jaxtyping.Float[Tensor, 'd_sae k'], i2: jaxtyping.Int[Tensor, 'd_sae k'], k: int) ‑> tuple[jaxtyping.Float[Tensor, 'd_sae k'], jaxtyping.Int[Tensor, 'd_sae k']]` : Picks out the new top k values among val1 and val2. Also keeps track of i1 and i2, then indices of the values in the original dataset. Args: val1: top k original SAE values. i1: the patch indices of those original top k values. val2: top k incoming SAE values. i2: the patch indices of those incoming top k values. k: k. Returns: The new top k values and their patch indices. `get_sae_acts(vit_acts: jaxtyping.Float[Tensor, 'n d_vit'], sae: saev.nn.SparseAutoencoder, cfg: saev.config.Visuals) ‑> jaxtyping.Float[Tensor, 'n d_sae']` : Get SAE hidden layer activations for a batch of ViT activations. Args: vit_acts: Batch of ViT activations sae: Sparse autoencder. cfg: Experimental config. `get_topk_img(cfg: saev.config.Visuals) ‑> saev.visuals.TopKImg` : Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by max over all images: f_x(cls)[i] Thus, we will never have duplicate images for a given latent. But we also will not have patch-level activations (a nice heatmap). Args: cfg: Config. Returns: A tuple of TopKImg and the first m features' activation distributions. `get_topk_patch(cfg: saev.config.Visuals) ‑> saev.visuals.TopKPatch` : Gets the top k images for each latent in the SAE. The top k images are for latent i are sorted by max over all patches: f_x(patch)[i] Thus, we could end up with duplicate images in the top k, if an image has more than one patch that maximally activates an SAE latent. Args: cfg: Config. Returns: A tuple of TopKPatch and m randomly sampled activation distributions. `main(cfg: saev.config.Visuals)` : .. todo:: document this function. Dump top-k images to a directory. Args: cfg: Configuration object. `make_img(elem: saev.visuals.GridElement, *, upper: float | None = None) ‑> PIL.Image.Image` : `plot_activation_distributions(cfg: saev.config.Visuals, distributions: jaxtyping.Float[Tensor, 'm n'])` : `safe_load(path: str) ‑> object` : `test_online_quantile_estimation(true: float, percentile: float)` : Classes ------- `GridElement(img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches'])` : GridElement(img: PIL.Image.Image, label: str, patches: jaxtyping.Float[Tensor, 'n_patches']) ### Class variables `img: PIL.Image.Image` : `label: str` : `patches: jaxtyping.Float[Tensor, 'n_patches']` : `PercentileEstimator(percentile: float | int, total: int, lr: float = 0.001, shape: tuple[int, ...] = ())` : ### Instance variables `estimate` : ### Methods `update(self, x)` : Update the estimator with a new value. This method maintains the marker positions using the P2 algorithm rules. When a new value arrives, it's placed in the appropriate position relative to existing markers, and marker positions are adjusted to maintain their desired percentile positions. Arguments: x: The new value to incorporate into the estimation `TopKImg(top_values: jaxtyping.Float[Tensor, 'd_sae k'], top_i: jaxtyping.Int[Tensor, 'd_sae k'], mean_values: jaxtyping.Float[Tensor, 'd_sae'], sparsity: jaxtyping.Float[Tensor, 'd_sae'], distributions: jaxtyping.Float[Tensor, 'm n'], percentiles: jaxtyping.Float[Tensor, 'd_sae'])` : .. todo:: Document this class. ### Class variables `distributions: jaxtyping.Float[Tensor, 'm n']` : `mean_values: jaxtyping.Float[Tensor, 'd_sae']` : `percentiles: jaxtyping.Float[Tensor, 'd_sae']` : `sparsity: jaxtyping.Float[Tensor, 'd_sae']` : `top_i: jaxtyping.Int[Tensor, 'd_sae k']` : `top_values: jaxtyping.Float[Tensor, 'd_sae k']` : `TopKPatch(top_values: jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img'], top_i: jaxtyping.Int[Tensor, 'd_sae k'], mean_values: jaxtyping.Float[Tensor, 'd_sae'], sparsity: jaxtyping.Float[Tensor, 'd_sae'], distributions: jaxtyping.Float[Tensor, 'm n'], percentiles: jaxtyping.Float[Tensor, 'd_sae'])` : .. todo:: Document this class. ### Class variables `distributions: jaxtyping.Float[Tensor, 'm n']` : `mean_values: jaxtyping.Float[Tensor, 'd_sae']` : `percentiles: jaxtyping.Float[Tensor, 'd_sae']` : `sparsity: jaxtyping.Float[Tensor, 'd_sae']` : `top_i: jaxtyping.Int[Tensor, 'd_sae k']` : `top_values: jaxtyping.Float[Tensor, 'd_sae k n_patches_per_img']` : Namespace contrib ================= Sub-modules ----------- * contrib.classification * contrib.mae * contrib.semseg Module contrib.classification ============================= # Reproduce You can reproduce our classification control experiments from our preprint by following these instructions. The big overview (as described in our paper) is: 1. Train an SAE on the ImageNet-1K [CLS] token activations from a CLIP ViT-B/16, from the 11th (second-to-last) layer. 2. Show that you get meaningful features, through visualizations. 3. Train a linear probe on the [CLS] token activations from a CLIP ViT-B/16, from the 11th layer, on the Caltech-101 dataset. We use an arbitrary random train/test split. 4. Show that we get good accuracy. 5. Manipulate the activations using the proposed SAE features. 6. Be amazed. :) To do these steps: ## Record ImageNet-1K activations ## Train an SAE on [CLS] Activations ```sh uv run python -m saev train \ --sweep configs/preprint/classification.toml \ --data.shard-root /local/scratch/$USER/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8/ \ --data.patches cls \ --sae.d-vit 768 ``` ## Visualize the SAE Features `bd97z80b` was the best checkpoint from my sweep. ```sh uv run python -m saev visuals \ --ckpt checkpoints/bd97z80b/sae.pt \ --dump-to /research/nfs_su_809/workspace/stevens.994/saev/features/bd97z80b \ --sort-byt cls \ --data.shard-root /local/scratch/stevens.994/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8/ \ --data.layer -2 \ --data.patches cls \ --log-freq-range -2.5 -1.5 \ --log-value-range 0.0 1.0 \ images:imagenet-dataset ``` You can see some neat features in here by using `saev.interactive.features` with `marimo`. ## Record Caltech-101 Activations For each `$SPLIT` in "train" and "test": ```sh uv run python -m saev activations \ --model-family clip \ --model-ckpt ViT-B-16/openai \ --d-vit 768 \ --n-patches-per-img 196 \ --layers -2 \ --dump-to /local/scratch/$USER/cache/saev \ --n-patches-per-shard 2_4000_000 \ data:image-folder-dataset \ --data.root /nfs/datasets/caltech-101/$SPLIT ``` ## Train a Linear Probe ```sh uv run python -m contrib.classification train \ --n-workers 32 \ --train-acts.shard-root /local/scratch/$USER/cache/saev/$TRAIN \ --val-acts.shard-root /local/scratch/$USER/cache/saev/$TEST \ --train-imgs.root /nfs/$USER/datasets/flowers102/train \ --val-imgs.root /nfs/$USER/datasets/flowers102/val \ --sweep contrib/classification/sweep.toml ``` Then look at `logs/contrib/classification/hparam-sweeps.png`. It probably works for any of the learning rates above 1e-5 or so. ## Manipulate Now we will manipulate the inputs to the probe by using the directions proposed by the SAE trained on ImageNet-1K and observe the changes in the linear model's predictions. There are two ways to do this: 1. The marimo dashboard, which requires that you run your own inference. 2. The online dashboard, which is more polished but offers less control. Since you have gone through the effort of training all this stuff, you probably want more control and have the hardware for inference. Run the marimo dashboard with: ```sh uv run marimo edit contrib/classification/interactive.py ``` These screenshots show the kinds of findings you can uncover with this dashboard. First, when you open the dashboard and configure the options, you will eventually see something like this: ![Default dashbaord view of a sunflower example.](/assets/contrib/classification/sunflower-unchanged.png) The main parts of the dashboard: 1. Example selector: choose which test image to classify. The image is shown on the bottom left. 2. The top SAE latents for the test image's class (in purple below). The latent values of $h$ are also shown. Many will be 0 because SAE latents fire very rarely (*sparse* autoencoder). 3. The top SAE latents for another, user-selected class (in orange below). Choose the class on the top right dropdown. 4. The top classes as predicted by the pre-trained classification model (a linear probe; shown in green below). 5. The top classes as predicted by the *same* pre-trained classification model, *after* modifying the dense vector representation with the SAE's vectors. These predictions are updated as you change the sliders on the screen. ![Annotated dashbaord view of a sunflower example.](/saev/assets/contrib/classification/sunflower-unchanged-annotated.png) As an example, you can scale *up* the top bonsai features. As you do, the most likely class will be a bonsai. See below. ![A sunflower changed to look like a bonsai.](/saev/assets/contrib/classification/class-manipulation.png) Here's another example. With another sunflower, you can manipulate turn up the SAE feature that fires strongly on pagodas and other traditionally Asian architectural structures. If you do, the most likley classification is a lotus, which is popular in Japanese and other Asian cultures. ![A sunflower changed to be a lotus (a culturally Asian flower).](/saev/assets/contrib/classification/japanese-culture.png) Only once you turn up the SAE feature that fires strongly on potted plants does the classification change to bonsai (which are typically potted). ![A sunflower changed to "bonsai".](/saev/assets/contrib/classification/bonsai.png) I encourage you to look at other test images and manipulate the predictions! Sub-modules ----------- * contrib.classification.config * contrib.classification.download_caltech101 * contrib.classification.download_flowers * contrib.classification.interactive * contrib.classification.plot_logits * contrib.classification.training Module contrib.classification.config ==================================== Functions --------- `grid(cfg: contrib.classification.config.Train, sweep_dct: dict[str, object]) ‑> tuple[list[contrib.classification.config.Train], list[str]]` : Classes ------- `Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_steps: int = 400, batch_size: int = 1024, n_workers: int = 32, train_acts: saev.config.DataLoad = , val_acts: saev.config.DataLoad = , train_imgs: saev.config.ImageFolderDataset = , val_imgs: saev.config.ImageFolderDataset = , eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/classification', seed: int = 42, log_to: str = './logs/contrib/classification')` : Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_steps: int = 400, batch_size: int = 1024, n_workers: int = 32, train_acts: saev.config.DataLoad = , val_acts: saev.config.DataLoad = , train_imgs: saev.config.ImageFolderDataset = , val_imgs: saev.config.ImageFolderDataset = , eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/classification', seed: int = 42, log_to: str = './logs/contrib/classification') ### Class variables `batch_size: int` : Training batch size for linear layer. `ckpt_path: str` : `device: str` : Hardware to train on. `eval_every: int` : How many epochs between evaluations. `learning_rate: float` : Linear layer learning rate. `log_to: str` : `n_steps: int` : Number of training steps for linear layer. `n_workers: int` : Number of dataloader workers. `seed: int` : Random seed. `train_acts: saev.config.DataLoad` : Configuration for the saved Flowers102 training ViT activations. `train_imgs: saev.config.ImageFolderDataset` : Configuration for the Flowers102 training images. `val_acts: saev.config.DataLoad` : Configuration for the saved Flowers102 validation ViT activations. `val_imgs: saev.config.ImageFolderDataset` : Configuration for the Flowers102 validation images. `weight_decay: float` : Weight decay for AdamW. Module contrib.classification.download_caltech101 ================================================= A script to download the Caltech101 dataset for use as an saev.activations.ImageFolderDataset. ```sh uv run contrib/classification/download_flowers.py --help ``` Functions --------- `main(args: contrib.classification.download_caltech101.Args)` : Download NeWT. Classes ------- `Args(dir: str = '.', chunk_size_kb: int = 1, seed: int = 42)` : Configure download options. ### Class variables `chunk_size_kb: int` : How many KB to download at a time before writing to file. `dir: str` : Where to save data. `seed: int` : Random seed used to generate split. Module contrib.classification.download_flowers ============================================== A script to download the Flowers102 dataset. ```sh uv run contrib/classification/download_flowers.py --help ``` Functions --------- `main(args: contrib.classification.download_flowers.Args)` : Download NeWT. Classes ------- `Args(dir: str = '.', chunk_size_kb: int = 1)` : Configure download options. ### Class variables `chunk_size_kb: int` : How many KB to download at a time before writing to file. `dir: str` : Where to save data. Module contrib.classification.interactive ========================================= Module contrib.classification.plot_logits ========================================= Generates plots demonstrating SAE feature specificity in image classification. This module creates visualizations showing how different feature interventions affect class logits in a controlled manner. It plots the relationship between intervention magnitudes and their effects on class predictions, demonstrating that features are semantically meaningful and independent. The main plotting function generates a three-panel figure: - Left panel: Effect of feature A intervention on classes A, B and C - Middle panel: Effect of feature B intervention on classes A, B and C - Right panel: Effect of feature C intervention on classes A, B and C Classes ------- `Config(magnitude_range: tuple[float, float] = (-10.0, 10.0), n_points: int = 50, figsize: tuple[float, float] = (18.0, 5.0), class_colors: dict[str, str] = , show_confidence: bool = True, dpi: int = 300)` : Config(magnitude_range: tuple[float, float] = (-10.0, 10.0), n_points: int = 50, figsize: tuple[float, float] = (18.0, 5.0), class_colors: dict[str, str] = , show_confidence: bool = True, dpi: int = 300) ### Class variables `class_colors: dict[str, str]` : Color mapping for different classes. Uses hex color codes for consistency across plotting backends. `dpi: int` : Dots per inch for saved figures. Higher values create larger files but better resolution. `figsize: tuple[float, float]` : Figure dimensions in inches (width, height). Default size is optimized for a three-panel figure. `magnitude_range: tuple[float, float]` : Range for intervention magnitudes, from minimum to maximum value. Usually kept within [-10, 10] following Anthropic's work. Values outside this range may create artifacts. `n_points: int` : Number of evenly spaced points to sample within magnitude range. Higher values create smoother plots but increase computation time. `show_confidence: bool` : Whether to show confidence intervals around trend lines. Module contrib.classification.training ====================================== Train a linear probe on [CLS] activations from a ViT. This assumes the training and evaluation dataset are very small and can fit in GPU memory. If this is not true, look at contrib/semseg/training.py for some inspiration. Functions --------- `check_cfgs(cfgs: list[contrib.classification.config.Train])` : `dump_model(cfg: contrib.classification.config.Train, model: torch.nn.modules.module.Module)` : Save a model checkpoint to disk along with configuration, using the [trick from equinox](https://docs.kidger.site/equinox/examples/serialisation). `load_acts(cfg: saev.config.DataLoad) ‑> jaxtyping.Float[Tensor, 'n d_vit']` : `load_class_headers(cfg: saev.config.ImageFolderDataset) ‑> list[str]` : `load_model(fpath: str, *, device: str = 'cpu') ‑> torch.nn.modules.module.Module` : Loads a linear layer from disk. `load_targets(cfg: saev.config.ImageFolderDataset) ‑> jaxtyping.Int[Tensor, 'n']` : `main(cfgs: list[contrib.classification.config.Train])` : `make_models(cfgs: list[contrib.classification.config.Train], d_in: int, d_out: int) ‑> tuple[torch.nn.modules.container.ModuleList, list[dict[str, object]]]` : Classes ------- `Dataset(acts_cfg: saev.config.DataLoad, imgs_cfg: saev.config.ImageFolderDataset)` : An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. ### Ancestors (in MRO) * torch.utils.data.dataset.Dataset * typing.Generic ### Instance variables `d_vit: int` : `n_classes: int` : Module contrib.mae ================== Sub-modules ----------- * contrib.mae.example * contrib.mae.modeling Functions --------- `load_ckpt(ckpt: str, *, chunk_size_kb: int = 1024) ‑> contrib.mae.modeling.MaskedAutoencoder` : Loads a pre-trained MAE ViT from disk. If it's not on disk, downloads the checkpoint from huggingface and then loads it into the `MaskedAutoencoder` class. Module contrib.mae.example ========================== Module contrib.mae.modeling =========================== Functions --------- `load_ckpt(ckpt: str, *, chunk_size_kb: int = 1024) ‑> contrib.mae.modeling.MaskedAutoencoder` : Loads a pre-trained MAE ViT from disk. If it's not on disk, downloads the checkpoint from huggingface and then loads it into the `MaskedAutoencoder` class. `random_masking(x_BND: jaxtyping.Float[Tensor, 'batch n d'], mask_ratio: float, noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> tuple[jaxtyping.Float[Tensor, 'batch m d'], jaxtyping.Float[Tensor, 'batch n'], jaxtyping.Int[Tensor, 'batch n']]` : Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsorting random noise. Classes ------- `Attention(*, d: int, n_heads: int)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']` : `split(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n_heads n d_head']` : `Decoder(*, d_in: int, d: int, d_hidden: int, n_layers: int, n_heads: int, patch_size_px: tuple[int, int], image_size_px: tuple[int, int], ln_eps: float)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, x_BMD: jaxtyping.Float[Tensor, 'batch m d_in'], ids_restore_BN: jaxtyping.Int[Tensor, 'batch n']) ‑> jaxtyping.Float[Tensor, 'batch n patch_pixels']` : Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:`Module` instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. `Embeddings(*, d: int, image_size_px: tuple[int, int], patch_size_px: tuple[int, int], mask_ratio: float)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Class variables `Output` : dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2) ### Methods `forward(self, x_BCWH: jaxtyping.Float[Tensor, 'batch 3 height width'], noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> contrib.mae.modeling.Embeddings.Output` : `Encoder(*, d: int, d_hidden: int, n_heads: int, n_layers: int, ln_eps: float)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, x_BMD: jaxtyping.Float[Tensor, 'batch m d']) ‑> jaxtyping.Float[Tensor, 'batch m d']` : `Feedforward(*, d: int, d_hidden: int)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']` : `MaskedAutoencoder(*, d_encoder: int, d_hidden_encoder: int, n_heads_encoder: int, n_layers_encoder: int, d_decoder: int, d_hidden_decoder: int, n_heads_decoder: int, n_layers_decoder: int, image_size_px: tuple[int, int], patch_size_px: tuple[int, int], mask_ratio: float, ln_eps: float)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Class variables `Output` : dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2) ### Methods `forward(self, x_B3WH: jaxtyping.Float[Tensor, 'batch 3 width height'], noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> contrib.mae.modeling.MaskedAutoencoder.Output` : `PatchEmbeddings(d: int, patch_size_px: tuple[int, int])` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, x_BCWH: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch n_patches d']` : `TransformerBlock(*, d: int, d_hidden: int, n_heads: int, ln_eps: float)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Methods `forward(self, x: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']` : `VisionTransformer(*, d: int, d_hidden: int, n_heads: int, n_layers: int, image_size_px: tuple[int, int], patch_size_px: tuple[int, int], mask_ratio: float, ln_eps: float)` : Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. ### Ancestors (in MRO) * torch.nn.modules.module.Module ### Class variables `Output` : dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2) ### Methods `forward(self, x_B3WH: jaxtyping.Float[Tensor, 'batch 3 width height'], noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> jaxtyping.Float[Tensor, 'batch ...']` : Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:`Module` instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. Module contrib.semseg ===================== You can reproduce our semantic segmentation control experiments from our preprint by following these instructions. As an overview: 1. Record ViT activations for ADE20K. 2. Train a linear probe on semantic segmentation task using ADE20K. 3. Establish baseline metrics for the linear probe. 4. Automatically identify feature vectors in the SAE's \(W_dec\) matrix for each class in ADE20K. 5. Suppress those features in the vision transformer's activations before applying the linear probe. 6. Record class-specific metrics before and after suppression. Details can be found below. # Record ViT Activations for Linear Probe Training and SAE Inference # Train a Linear Probe on Semantic Segmentation Now train a linear probe on the activations. ```sh uv run python -m contrib.semseg train \ --train-acts.shard-root $TRAIN_SHARDS \ --train-acts.layer -1 \ --val-acts.shard-root $VAL_SHARDS \ --val-acts.layer -1 \ --imgs.root /nfs/$USER/datasets/ade20k/ \ --sweep contrib/semseg/sweep.toml ``` # Establish Linear Probe Baseline Metrics ```sh uv run python -m contrib.semseg validate \ --imgs.root /nfs/$USER/datasets/ade20k/ \ --acts.shard-root $VAL_SHARDS ``` Then you can look in `./logs/contrib/semseg` for `hparam-sweeps.png` to see what learning rate/weight decay combination is best. Sub-modules ----------- * contrib.semseg.config * contrib.semseg.dashboard * contrib.semseg.dashboard2 * contrib.semseg.interactive * contrib.semseg.manipulation * contrib.semseg.training * contrib.semseg.validation * contrib.semseg.visuals Module contrib.semseg.config ============================ Functions --------- `grid(cfg: contrib.semseg.config.Train, sweep_dct: dict[str, object]) ‑> tuple[list[contrib.semseg.config.Train], list[str]]` : Classes ------- `Manipulation(probe_ckpt: str = './checkpoints/semseg/lr_0_001__wd_0_1/model.pt', sae_ckpt: str = './checkpoints/abcdef/sae.pt', ade20k_classes: list[int] = , sae_latents: list[int] = , acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, device: str = 'cuda')` : Manipulation(probe_ckpt: str = './checkpoints/semseg/lr_0_001__wd_0_1/model.pt', sae_ckpt: str = './checkpoints/abcdef/sae.pt', ade20k_classes: list[int] = , sae_latents: list[int] = , acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, device: str = 'cuda') ### Class variables `acts: saev.config.DataLoad` : Configuration for the saved ADE20K validation ViT activations. `ade20k_classes: list[int]` : One or more ADE20K classes to track. `batch_size: int` : Batch size for both linear probe and SAE. `device: str` : Hardware for linear probe and SAE inference. `imgs: saev.config.Ade20kDataset` : Configuration for the ADE20K validation dataset. `n_workers: int` : Number of dataloader workers. `probe_ckpt: str` : Linear probe checkpoint. `sae_ckpt: str` : SAE checkpoint. `sae_latents: list[int]` : one or more SAE latents to manipulate. `Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, train_acts: saev.config.DataLoad = , val_acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , patch_size_px: tuple[int, int] = (14, 14), eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/semseg', seed: int = 42, log_to: str = './logs/contrib/semseg')` : Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, train_acts: saev.config.DataLoad = , val_acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , patch_size_px: tuple[int, int] = (14, 14), eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/semseg', seed: int = 42, log_to: str = './logs/contrib/semseg') ### Class variables `batch_size: int` : Training batch size for linear layer. `ckpt_path: str` : `device: str` : Hardware to train on. `eval_every: int` : How many epochs between evaluations. `imgs: saev.config.Ade20kDataset` : Configuration for the ADE20K dataset. `learning_rate: float` : Linear layer learning rate. `log_to: str` : `n_epochs: int` : Number of training epochs for linear layer. `n_workers: int` : Number of dataloader workers. `patch_size_px: tuple[int, int]` : Patch size in pixels. `seed: int` : Random seed. `train_acts: saev.config.DataLoad` : Configuration for the saved ADE20K training ViT activations. `val_acts: saev.config.DataLoad` : Configuration for the saved ADE20K validation ViT activations. `weight_decay: float` : Weight decay for AdamW. `Validation(ckpt_root: str = './checkpoints/contrib/semseg', dump_to: str = './logs/contrib/semseg', acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , patch_size_px: tuple[int, int] = (14, 14), batch_size: int = 128, n_workers: int = 32, device: str = 'cuda')` : Validation(ckpt_root: str = './checkpoints/contrib/semseg', dump_to: str = './logs/contrib/semseg', acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , patch_size_px: tuple[int, int] = (14, 14), batch_size: int = 128, n_workers: int = 32, device: str = 'cuda') ### Class variables `acts: saev.config.DataLoad` : Configuration for the saved ADE20K validation ViT activations. `batch_size: int` : Batch size for calculating F1 scores. `ckpt_root: str` : Root to all checkpoints to evaluate. `device: str` : Hardware for linear probe inference. `dump_to: str` : Directory to dump results to. `imgs: saev.config.Ade20kDataset` : Configuration for the ADE20K validation dataset. `n_workers: int` : Number of dataloader workers. `patch_size_px: tuple[int, int]` : Patch size in pixels. `Visuals(sae_ckpt: str = './checkpoints/sae.pt', ade20k_cls: int = 29, k: int = 32, acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda')` : Visuals(sae_ckpt: str = './checkpoints/sae.pt', ade20k_cls: int = 29, k: int = 32, acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda') ### Class variables `acts: saev.config.DataLoad` : Configuration for the saved ADE20K training ViT activations. `ade20k_cls: int` : ADE20K class to probe for. `batch_size: int` : Batch size for calculating F1 scores. `device: str` : Hardware for SAE inference. `imgs: saev.config.Ade20kDataset` : Configuration for the ADE20K training dataset. `k: int` : Top K features to save. `label_threshold: float` : `n_workers: int` : Number of dataloader workers. `sae_ckpt: str` : Path to the sae.pt file. Module contrib.semseg.dashboard =============================== Module contrib.semseg.dashboard2 ================================ Module contrib.semseg.interactive ================================= Module contrib.semseg.manipulation ================================== Manipulate representations by increasing or decreasing the presence of a feature in a ViT activation, then use the linear probe for inference. Record class-specific scores before and after manipulation to see that you can directly manipulate abilities to complete downstream tasks. Functions --------- `main(cfg: contrib.semseg.config.Manipulation)` : `manipulate(cfg: contrib.semseg.config.Manipulation, sae: saev.nn.SparseAutoencoder, acts_BWHD: jaxtyping.Float[Tensor, 'batch width height d_vit']) ‑> tuple[jaxtyping.Float[Tensor, 'batch width height d_vit'], jaxtyping.Float[Tensor, 'batch width height d_vit']]` : Module contrib.semseg.training ============================== Functions --------- `batched_idx(total_size: int, batch_size: int) ‑> Iterator[tuple[int, int]]` : Iterate over (start, end) indices for total_size examples, where end - start is at most batch_size. Args: total_size: total number of examples batch_size: maximum distance between the generated indices. Returns: A generator of (int, int) tuples that can slice up a list or a tensor. `batched_upsample_and_pred(tensor: jaxtyping.Float[Tensor, 'n channels width height'], *, size: tuple[int, int], mode: str, batch_size: int = 128) ‑> jaxtyping.Int[Tensor, 'n {size[0]} {size[1]}']` : `check_cfgs(cfgs: list[contrib.semseg.config.Train])` : `count_patches(ade20k: saev.config.Ade20kDataset, patch_size_px: tuple[int, int] = (14, 14), threshold: float = 0.9, n_workers: int = 8)` : Count the number of patches in the data that meets `dump(cfg: contrib.semseg.config.Train, model: torch.nn.modules.module.Module, *, step: int | None = None)` : Save a model checkpoint to disk along with configuration, using the [trick from equinox](https://docs.kidger.site/equinox/examples/serialisation). `get_class_ious(y_pred: jaxtyping.Int[Tensor, 'models batch width height'], y_true: jaxtyping.Int[Tensor, 'models batch width height'], n_classes: int, ignore_class: int | None = 0) ‑> jaxtyping.Float[Tensor, 'models n_classes']` : Calculate mean IoU for predicted masks. Arguments: y_pred: y_true: n_classes: Number of classes. Returns: Mean IoU as a float tensor. `get_dataloader(cfg: contrib.semseg.config.Train, *, is_train: bool)` : `load(fpath: str, *, device: str = 'cpu') ‑> torch.nn.modules.module.Module` : Loads a sparse autoencoder from disk. `load_latest(dpath: str, *, device: str = 'cpu') ‑> torch.nn.modules.module.Module` : Loads the latest checkpoint by picking out the checkpoint file in dpath with the largest _step{step} suffix. Arguments: dpath: Directory to search. device: optional torch device to pass to load. `main(cfgs: list[contrib.semseg.config.Train])` : `make_models(cfgs: list[contrib.semseg.config.Train], d_vit: int) ‑> tuple[torch.nn.modules.container.ModuleList, list[dict[str, object]]]` : Classes ------- `Dataset(acts_cfg: saev.config.DataLoad, imgs_cfg: saev.config.Ade20kDataset, patch_size_px: tuple[int, int])` : An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. ### Ancestors (in MRO) * torch.utils.data.dataset.Dataset * typing.Generic ### Instance variables `d_vit: int` : Module contrib.semseg.validation ================================ Make some predictions for a bunch of checkpoints. See which checkpoints have the best validation loss, mean IoU, class-specific IoU, validation accuracy, and qualitative results. Writes results to CSV files and hparam graphs (in-progress). Functions --------- `load_ckpts(root: str, *, device: str = 'cpu') ‑> list[tuple[contrib.semseg.config.Train, torch.nn.modules.module.Module]]` : Loads the latest checkpoints for each directory within root. Arguments: root: directory containing other directories with cfg.json and model_step{step}.pt files. device: where to load models. Returns: List of cfg, model pairs. `main(cfg: contrib.semseg.config.Validation)` : Module contrib.semseg.visuals ============================= Propose features for manual verification. Functions --------- `axis_unique(a: jaxtyping.Shaped[ndarray, '*axes'], axis: int = -1, return_counts: bool = True, *, null_value: int = -1) ‑> jaxtyping.Shaped[ndarray, '*axes'] | tuple[jaxtyping.Shaped[ndarray, '*axes'], jaxtyping.Int[ndarray, '*axes']]` : Calculate unique values and their counts along any axis of a matrix. Arguments: a: Input array axis: The axis along which to find unique values. return_counts: If true, also return the count of each unique value Returns: unique: Array of unique values, with zeros replacing duplicates counts: (optional) Count of each unique value (only if return_counts=True) `main(cfg: contrib.semseg.config.Visuals)` :