Module contrib.classification.training

Train a linear probe on [CLS] activations from a ViT.

Functions

def check_cfgs(cfgs: list[Train])
def dump_model(cfg: Train,
model: torch.nn.modules.module.Module)

Save a model checkpoint to disk along with configuration, using the trick from equinox.

def get_dataloader(cfg: Train,
*,
is_train: bool)
def load_acts(cfg: DataLoad) ‑> jaxtyping.Float[Tensor, 'n d_vit']
def load_class_headers(cfg: ImageFolderDataset) ‑> list[str]
def load_model(fpath: str, *, device: str = 'cpu') ‑> torch.nn.modules.module.Module

Loads a linear layer from disk.

def load_targets(cfg: ImageFolderDataset) ‑> jaxtyping.Int[Tensor, 'n']
def main(cfgs: list[Train])
def make_models(cfgs: list[Train],
d_out: int) ‑> tuple[torch.nn.modules.container.ModuleList, list[dict[str, object]]]

Classes

class Dataset (acts_cfg: DataLoad,
imgs_cfg: ImageFolderDataset)

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader. Subclasses could also optionally implement :meth:__getitems__, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

Note

:class:~torch.utils.data.DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

Expand source code
@beartype.beartype
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self, acts_cfg: saev.config.DataLoad, imgs_cfg: saev.config.ImageFolderDataset
    ):
        self.acts = saev.activations.Dataset(acts_cfg)

        img_dataset = saev.activations.ImageFolder(imgs_cfg.root)
        self.targets = [tgt for sample, tgt in img_dataset.samples]
        self.labels = [img_dataset.classes[tgt] for tgt in self.targets]

    @property
    def d_vit(self) -> int:
        return self.acts.metadata.d_vit

    @property
    def n_classes(self) -> int:
        return len(set(self.targets))

    def __getitem__(self, i: int) -> dict[str, object]:
        act_D = self.acts[i]["act"]
        label = self.labels[i]
        target = self.targets[i]

        return {"index": i, "acts": act_D, "labels": label, "targets": target}

    def __len__(self) -> int:
        assert len(self.acts) == len(self.targets)
        return len(self.targets)

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic

Instance variables

prop d_vit : int
Expand source code
@property
def d_vit(self) -> int:
    return self.acts.metadata.d_vit
prop n_classes : int
Expand source code
@property
def n_classes(self) -> int:
    return len(set(self.targets))