Module contrib.classification.config

Functions

def grid(cfg: Train,
sweep_dct: dict[str, object]) ‑> tuple[list[Train], list[str]]

Classes

class Train (learning_rate: float = 0.0001,
weight_decay: float = 0.001,
n_epochs: int = 20,
batch_size: int = 512,
n_workers: int = 32,
train_imgs: ImageFolderDataset = <factory>,
val_imgs: ImageFolderDataset = <factory>,
device: str = 'cuda',
ckpt_path: str = './checkpoints/contrib/classification',
seed: int = 42,
log_to: str = './logs/contrib/classification')

Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 20, batch_size: int = 512, n_workers: int = 32, train_imgs: saev.config.ImageFolderDataset = , val_imgs: saev.config.ImageFolderDataset = , device: str = 'cuda', ckpt_path: str = './checkpoints/contrib/classification', seed: int = 42, log_to: str = './logs/contrib/classification')

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Train:
    learning_rate: float = 1e-4
    """Linear layer learning rate."""
    weight_decay: float = 1e-3
    """Weight decay  for AdamW."""
    n_epochs: int = 20
    """Number of training epochs for linear layer."""
    batch_size: int = 512
    """Training batch size."""
    n_workers: int = 32
    """Number of dataloader workers."""
    train_imgs: saev.config.ImageFolderDataset = dataclasses.field(
        default_factory=saev.config.ImageFolderDataset
    )
    """Configuration for the training images."""
    val_imgs: saev.config.ImageFolderDataset = dataclasses.field(
        default_factory=saev.config.ImageFolderDataset
    )
    """Configuration for the validation images."""
    device: str = "cuda"
    "Hardware to train on."
    ckpt_path: str = os.path.join(".", "checkpoints", "contrib", "classification")
    seed: int = 42
    """Random seed."""
    log_to: str = os.path.join(".", "logs", "contrib", "classification")

Class variables

var batch_size : int

Training batch size.

var ckpt_path : str
var device : str

Hardware to train on.

var learning_rate : float

Linear layer learning rate.

var log_to : str
var n_epochs : int

Number of training epochs for linear layer.

var n_workers : int

Number of dataloader workers.

var seed : int

Random seed.

var train_imgsImageFolderDataset

Configuration for the training images.

var val_imgsImageFolderDataset

Configuration for the validation images.

var weight_decay : float

Weight decay for AdamW.