Module contrib.semseg.config

Functions

def grid(cfg: Train,
sweep_dct: dict[str, object]) ‑> tuple[list[Train], list[str]]

Classes

class Manipulation (probe_ckpt: str = './checkpoints/semseg/lr_0_001__wd_0_1/model.pt',
sae_ckpt: str = './checkpoints/abcdef/sae.pt',
ade20k_classes: list[int] = <factory>,
sae_latents: list[int] = <factory>,
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
device: str = 'cuda')

Manipulation(probe_ckpt: str = './checkpoints/semseg/lr_0_001__wd_0_1/model.pt', sae_ckpt: str = './checkpoints/abcdef/sae.pt', ade20k_classes: list[int] = , sae_latents: list[int] = , acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, device: str = 'cuda')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Manipulation:
    probe_ckpt: str = os.path.join(
        ".", "checkpoints", "semseg", "lr_0_001__wd_0_1", "model.pt"
    )
    """Linear probe checkpoint."""
    sae_ckpt: str = os.path.join(".", "checkpoints", "abcdef", "sae.pt")
    """SAE checkpoint."""
    ade20k_classes: list[int] = dataclasses.field(default_factory=lambda: [29])
    """One or more ADE20K classes to track."""
    sae_latents: list[int] = dataclasses.field(default_factory=lambda: [0, 1, 2])
    """one or more SAE latents to manipulate."""
    acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
    """Configuration for the saved ADE20K validation ViT activations."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=lambda: saev.config.Ade20kDataset(split="validation")
    )
    """Configuration for the ADE20K validation dataset."""
    batch_size: int = 128
    """Batch size for both linear probe and SAE."""
    n_workers: int = 32
    """Number of dataloader workers."""
    device: str = "cuda"
    "Hardware for linear probe and SAE inference."

Class variables

var actsDataLoad

Configuration for the saved ADE20K validation ViT activations.

var ade20k_classes : list[int]

One or more ADE20K classes to track.

var batch_size : int

Batch size for both linear probe and SAE.

var device : str

Hardware for linear probe and SAE inference.

var imgsAde20kDataset

Configuration for the ADE20K validation dataset.

var n_workers : int

Number of dataloader workers.

var probe_ckpt : str

Linear probe checkpoint.

var sae_ckpt : str

SAE checkpoint.

var sae_latents : list[int]

one or more SAE latents to manipulate.

class Train (learning_rate: float = 0.0001,
weight_decay: float = 0.001,
n_epochs: int = 400,
batch_size: int = 1024,
n_workers: int = 32,
imgs: Ade20kDataset = <factory>,
patch_size_px: tuple[int, int] = (14, 14),
eval_every: int = 100,
device: str = 'cuda',
ckpt_path: str = './checkpoints/contrib/semseg',
seed: int = 42,
log_to: str = './logs/contrib/semseg')

Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, imgs: saev.config.Ade20kDataset = , patch_size_px: tuple[int, int] = (14, 14), eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/semseg', seed: int = 42, log_to: str = './logs/contrib/semseg')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Train:
    learning_rate: float = 1e-4
    """Linear layer learning rate."""
    weight_decay: float = 1e-3
    """Weight decay  for AdamW."""
    n_epochs: int = 400
    """Number of training epochs for linear layer."""
    batch_size: int = 1024
    """Training batch size for linear layer."""
    n_workers: int = 32
    """Number of dataloader workers."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=saev.config.Ade20kDataset
    )
    """Configuration for the ADE20K dataset."""
    patch_size_px: tuple[int, int] = (14, 14)
    """Patch size in pixels."""
    eval_every: int = 100
    """How many epochs between evaluations."""
    device: str = "cuda"
    "Hardware to train on."
    ckpt_path: str = os.path.join(".", "checkpoints", "contrib", "semseg")
    seed: int = 42
    """Random seed."""
    log_to: str = os.path.join(".", "logs", "contrib", "semseg")

Class variables

var batch_size : int

Training batch size for linear layer.

var ckpt_path : str
var device : str

Hardware to train on.

var eval_every : int

How many epochs between evaluations.

var imgsAde20kDataset

Configuration for the ADE20K dataset.

var learning_rate : float

Linear layer learning rate.

var log_to : str
var n_epochs : int

Number of training epochs for linear layer.

var n_workers : int

Number of dataloader workers.

var patch_size_px : tuple[int, int]

Patch size in pixels.

var seed : int

Random seed.

var weight_decay : float

Weight decay for AdamW.

class Validation (ckpt_root: str = './checkpoints/contrib/semseg',
dump_to: str = './logs/contrib/semseg',
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
patch_size_px: tuple[int, int] = (14, 14),
batch_size: int = 128,
n_workers: int = 32,
device: str = 'cuda')

Validation(ckpt_root: str = './checkpoints/contrib/semseg', dump_to: str = './logs/contrib/semseg', acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , patch_size_px: tuple[int, int] = (14, 14), batch_size: int = 128, n_workers: int = 32, device: str = 'cuda')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Validation:
    ckpt_root: str = os.path.join(".", "checkpoints", "contrib", "semseg")
    """Root to all checkpoints to evaluate."""
    dump_to: str = os.path.join(".", "logs", "contrib", "semseg")
    """Directory to dump results to."""
    acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
    """Configuration for the saved ADE20K validation ViT activations."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=lambda: saev.config.Ade20kDataset(split="validation")
    )
    """Configuration for the ADE20K validation dataset."""
    patch_size_px: tuple[int, int] = (14, 14)
    """Patch size in pixels."""
    batch_size: int = 128
    """Batch size for calculating F1 scores."""
    n_workers: int = 32
    """Number of dataloader workers."""
    device: str = "cuda"
    "Hardware for linear probe inference."

Class variables

var actsDataLoad

Configuration for the saved ADE20K validation ViT activations.

var batch_size : int

Batch size for calculating F1 scores.

var ckpt_root : str

Root to all checkpoints to evaluate.

var device : str

Hardware for linear probe inference.

var dump_to : str

Directory to dump results to.

var imgsAde20kDataset

Configuration for the ADE20K validation dataset.

var n_workers : int

Number of dataloader workers.

var patch_size_px : tuple[int, int]

Patch size in pixels.

class Visuals (sae_ckpt: str = './checkpoints/sae.pt',
ade20k_cls: int = 29,
k: int = 32,
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
label_threshold: float = 0.9,
device: str = 'cuda')

Visuals(sae_ckpt: str = './checkpoints/sae.pt', ade20k_cls: int = 29, k: int = 32, acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Visuals:
    sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
    """Path to the sae.pt file."""
    ade20k_cls: int = 29
    """ADE20K class to probe for."""
    k: int = 32
    """Top K features to save."""
    acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
    """Configuration for the saved ADE20K training ViT activations."""
    imgs: saev.config.Ade20kDataset = dataclasses.field(
        default_factory=lambda: saev.config.Ade20kDataset(split="training")
    )
    """Configuration for the ADE20K training dataset."""
    batch_size: int = 128
    """Batch size for calculating F1 scores."""
    n_workers: int = 32
    """Number of dataloader workers."""
    label_threshold: float = 0.9
    device: str = "cuda"
    "Hardware for SAE inference."

Class variables

var actsDataLoad

Configuration for the saved ADE20K training ViT activations.

var ade20k_cls : int

ADE20K class to probe for.

var batch_size : int

Batch size for calculating F1 scores.

var device : str

Hardware for SAE inference.

var imgsAde20kDataset

Configuration for the ADE20K training dataset.

var k : int

Top K features to save.

var label_threshold : float
var n_workers : int

Number of dataloader workers.

var sae_ckpt : str

Path to the sae.pt file.