Module contrib.classification.training
Train a linear probe on [CLS] activations from a ViT.
Functions
def check_cfgs(cfgs: list[Train])
def dump_model(cfg: Train,
model: torch.nn.modules.module.Module)-
Save a model checkpoint to disk along with configuration, using the trick from equinox.
def get_dataloader(cfg: Train,
*,
is_train: bool)def load_acts(cfg: DataLoad) ‑> jaxtyping.Float[Tensor, 'n d_vit']
def load_class_headers(cfg: ImageFolderDataset) ‑> list[str]
def load_model(fpath: str, *, device: str = 'cpu') ‑> torch.nn.modules.module.Module
-
Loads a linear layer from disk.
def load_targets(cfg: ImageFolderDataset) ‑> jaxtyping.Int[Tensor, 'n']
def main(cfgs: list[Train])
def make_models(cfgs: list[Train],
d_out: int) ‑> tuple[torch.nn.modules.container.ModuleList, list[dict[str, object]]]
Classes
class Dataset (acts_cfg: DataLoad,
imgs_cfg: ImageFolderDataset)-
An abstract class representing a :class:
Dataset
.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:
__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__
, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler
implementations and the default options of :class:~torch.utils.data.DataLoader
. Subclasses could also optionally implement :meth:__getitems__
, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.Note
:class:
~torch.utils.data.DataLoader
by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.Expand source code
@beartype.beartype class Dataset(torch.utils.data.Dataset): def __init__( self, acts_cfg: saev.config.DataLoad, imgs_cfg: saev.config.ImageFolderDataset ): self.acts = saev.activations.Dataset(acts_cfg) img_dataset = saev.activations.ImageFolder(imgs_cfg.root) self.targets = [tgt for sample, tgt in img_dataset.samples] self.labels = [img_dataset.classes[tgt] for tgt in self.targets] @property def d_vit(self) -> int: return self.acts.metadata.d_vit @property def n_classes(self) -> int: return len(set(self.targets)) def __getitem__(self, i: int) -> dict[str, object]: act_D = self.acts[i]["act"] label = self.labels[i] target = self.targets[i] return {"index": i, "acts": act_D, "labels": label, "targets": target} def __len__(self) -> int: assert len(self.acts) == len(self.targets) return len(self.targets)
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Instance variables
prop d_vit : int
-
Expand source code
@property def d_vit(self) -> int: return self.acts.metadata.d_vit
prop n_classes : int
-
Expand source code
@property def n_classes(self) -> int: return len(set(self.targets))