Module contrib.semseg.config
Functions
def grid(cfg: Train,
sweep_dct: dict[str, object]) ‑> tuple[list[Train], list[str]]
Classes
class Manipulation (probe_ckpt: str = './checkpoints/semseg/lr_0_001__wd_0_1/model.pt',
sae_ckpt: str = './checkpoints/abcdef/sae.pt',
ade20k_classes: list[int] = <factory>,
sae_latents: list[int] = <factory>,
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
device: str = 'cuda')-
Manipulation(probe_ckpt: str = './checkpoints/semseg/lr_0_001__wd_0_1/model.pt', sae_ckpt: str = './checkpoints/abcdef/sae.pt', ade20k_classes: list[int] =
, sae_latents: list[int] = , acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, device: str = 'cuda') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Manipulation: probe_ckpt: str = os.path.join( ".", "checkpoints", "semseg", "lr_0_001__wd_0_1", "model.pt" ) """Linear probe checkpoint.""" sae_ckpt: str = os.path.join(".", "checkpoints", "abcdef", "sae.pt") """SAE checkpoint.""" ade20k_classes: list[int] = dataclasses.field(default_factory=lambda: [29]) """One or more ADE20K classes to track.""" sae_latents: list[int] = dataclasses.field(default_factory=lambda: [0, 1, 2]) """one or more SAE latents to manipulate.""" acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad) """Configuration for the saved ADE20K validation ViT activations.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=lambda: saev.config.Ade20kDataset(split="validation") ) """Configuration for the ADE20K validation dataset.""" batch_size: int = 128 """Batch size for both linear probe and SAE.""" n_workers: int = 32 """Number of dataloader workers.""" device: str = "cuda" "Hardware for linear probe and SAE inference."
Class variables
var acts : DataLoad
-
Configuration for the saved ADE20K validation ViT activations.
var ade20k_classes : list[int]
-
One or more ADE20K classes to track.
var batch_size : int
-
Batch size for both linear probe and SAE.
var device : str
-
Hardware for linear probe and SAE inference.
var imgs : Ade20kDataset
-
Configuration for the ADE20K validation dataset.
var n_workers : int
-
Number of dataloader workers.
var probe_ckpt : str
-
Linear probe checkpoint.
var sae_ckpt : str
-
SAE checkpoint.
var sae_latents : list[int]
-
one or more SAE latents to manipulate.
class Train (learning_rate: float = 0.0001,
weight_decay: float = 0.001,
n_epochs: int = 400,
batch_size: int = 1024,
n_workers: int = 32,
imgs: Ade20kDataset = <factory>,
patch_size_px: tuple[int, int] = (14, 14),
eval_every: int = 100,
device: str = 'cuda',
ckpt_path: str = './checkpoints/contrib/semseg',
seed: int = 42,
log_to: str = './logs/contrib/semseg')-
Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, imgs: saev.config.Ade20kDataset =
, patch_size_px: tuple[int, int] = (14, 14), eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/semseg', seed: int = 42, log_to: str = './logs/contrib/semseg') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Train: learning_rate: float = 1e-4 """Linear layer learning rate.""" weight_decay: float = 1e-3 """Weight decay for AdamW.""" n_epochs: int = 400 """Number of training epochs for linear layer.""" batch_size: int = 1024 """Training batch size for linear layer.""" n_workers: int = 32 """Number of dataloader workers.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=saev.config.Ade20kDataset ) """Configuration for the ADE20K dataset.""" patch_size_px: tuple[int, int] = (14, 14) """Patch size in pixels.""" eval_every: int = 100 """How many epochs between evaluations.""" device: str = "cuda" "Hardware to train on." ckpt_path: str = os.path.join(".", "checkpoints", "contrib", "semseg") seed: int = 42 """Random seed.""" log_to: str = os.path.join(".", "logs", "contrib", "semseg")
Class variables
var batch_size : int
-
Training batch size for linear layer.
var ckpt_path : str
var device : str
-
Hardware to train on.
var eval_every : int
-
How many epochs between evaluations.
var imgs : Ade20kDataset
-
Configuration for the ADE20K dataset.
var learning_rate : float
-
Linear layer learning rate.
var log_to : str
var n_epochs : int
-
Number of training epochs for linear layer.
var n_workers : int
-
Number of dataloader workers.
var patch_size_px : tuple[int, int]
-
Patch size in pixels.
var seed : int
-
Random seed.
var weight_decay : float
-
Weight decay for AdamW.
class Validation (ckpt_root: str = './checkpoints/contrib/semseg',
dump_to: str = './logs/contrib/semseg',
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
patch_size_px: tuple[int, int] = (14, 14),
batch_size: int = 128,
n_workers: int = 32,
device: str = 'cuda')-
Validation(ckpt_root: str = './checkpoints/contrib/semseg', dump_to: str = './logs/contrib/semseg', acts: saev.config.DataLoad =
, imgs: saev.config.Ade20kDataset = , patch_size_px: tuple[int, int] = (14, 14), batch_size: int = 128, n_workers: int = 32, device: str = 'cuda') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Validation: ckpt_root: str = os.path.join(".", "checkpoints", "contrib", "semseg") """Root to all checkpoints to evaluate.""" dump_to: str = os.path.join(".", "logs", "contrib", "semseg") """Directory to dump results to.""" acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad) """Configuration for the saved ADE20K validation ViT activations.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=lambda: saev.config.Ade20kDataset(split="validation") ) """Configuration for the ADE20K validation dataset.""" patch_size_px: tuple[int, int] = (14, 14) """Patch size in pixels.""" batch_size: int = 128 """Batch size for calculating F1 scores.""" n_workers: int = 32 """Number of dataloader workers.""" device: str = "cuda" "Hardware for linear probe inference."
Class variables
var acts : DataLoad
-
Configuration for the saved ADE20K validation ViT activations.
var batch_size : int
-
Batch size for calculating F1 scores.
var ckpt_root : str
-
Root to all checkpoints to evaluate.
var device : str
-
Hardware for linear probe inference.
var dump_to : str
-
Directory to dump results to.
var imgs : Ade20kDataset
-
Configuration for the ADE20K validation dataset.
var n_workers : int
-
Number of dataloader workers.
var patch_size_px : tuple[int, int]
-
Patch size in pixels.
class Visuals (sae_ckpt: str = './checkpoints/sae.pt',
ade20k_cls: int = 29,
k: int = 32,
acts: DataLoad = <factory>,
imgs: Ade20kDataset = <factory>,
batch_size: int = 128,
n_workers: int = 32,
label_threshold: float = 0.9,
device: str = 'cuda')-
Visuals(sae_ckpt: str = './checkpoints/sae.pt', ade20k_cls: int = 29, k: int = 32, acts: saev.config.DataLoad =
, imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Visuals: sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt") """Path to the sae.pt file.""" ade20k_cls: int = 29 """ADE20K class to probe for.""" k: int = 32 """Top K features to save.""" acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad) """Configuration for the saved ADE20K training ViT activations.""" imgs: saev.config.Ade20kDataset = dataclasses.field( default_factory=lambda: saev.config.Ade20kDataset(split="training") ) """Configuration for the ADE20K training dataset.""" batch_size: int = 128 """Batch size for calculating F1 scores.""" n_workers: int = 32 """Number of dataloader workers.""" label_threshold: float = 0.9 device: str = "cuda" "Hardware for SAE inference."
Class variables
var acts : DataLoad
-
Configuration for the saved ADE20K training ViT activations.
var ade20k_cls : int
-
ADE20K class to probe for.
var batch_size : int
-
Batch size for calculating F1 scores.
var device : str
-
Hardware for SAE inference.
var imgs : Ade20kDataset
-
Configuration for the ADE20K training dataset.
var k : int
-
Top K features to save.
var label_threshold : float
var n_workers : int
-
Number of dataloader workers.
var sae_ckpt : str
-
Path to the sae.pt file.