Module contrib.classification.config
Functions
def grid(cfg: Train,
sweep_dct: dict[str, object]) ‑> tuple[list[Train], list[str]]
Classes
class Train (learning_rate: float = 0.0001,
weight_decay: float = 0.001,
n_epochs: int = 20,
batch_size: int = 512,
n_workers: int = 32,
train_imgs: ImageFolderDataset = <factory>,
val_imgs: ImageFolderDataset = <factory>,
device: str = 'cuda',
ckpt_path: str = './checkpoints/contrib/classification',
seed: int = 42,
log_to: str = './logs/contrib/classification')-
Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 20, batch_size: int = 512, n_workers: int = 32, train_imgs: saev.config.ImageFolderDataset =
, val_imgs: saev.config.ImageFolderDataset = , device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/classification', seed: int = 42, log_to: str = './logs/contrib/classification') Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Train: learning_rate: float = 1e-4 """Linear layer learning rate.""" weight_decay: float = 1e-3 """Weight decay for AdamW.""" n_epochs: int = 20 """Number of training epochs for linear layer.""" batch_size: int = 512 """Training batch size.""" n_workers: int = 32 """Number of dataloader workers.""" train_imgs: saev.config.ImageFolderDataset = dataclasses.field( default_factory=saev.config.ImageFolderDataset ) """Configuration for the training images.""" val_imgs: saev.config.ImageFolderDataset = dataclasses.field( default_factory=saev.config.ImageFolderDataset ) """Configuration for the validation images.""" device: str = "cuda" "Hardware to train on." ckpt_path: str = os.path.join(".", "checkpoints", "contrib", "classification") seed: int = 42 """Random seed.""" log_to: str = os.path.join(".", "logs", "contrib", "classification")
Class variables
var batch_size : int
-
Training batch size.
var ckpt_path : str
var device : str
-
Hardware to train on.
var learning_rate : float
-
Linear layer learning rate.
var log_to : str
var n_epochs : int
-
Number of training epochs for linear layer.
var n_workers : int
-
Number of dataloader workers.
var seed : int
-
Random seed.
var train_imgs : ImageFolderDataset
-
Configuration for the training images.
var val_imgs : ImageFolderDataset
-
Configuration for the validation images.
var weight_decay : float
-
Weight decay for AdamW.