Module sweep

Launch script to kick off a hyperparameter sweep using a Slurm cluster (with submitit).

Functions

def expand(config: dict[str, float | int | bool | str | list[float | int | bool | str] | Distribution], *, n_per_discrete: int) ‑> collections.abc.Iterator[dict[str, float | int | bool | str]]
def expand_discrete(config: dict[str, float | int | bool | str | list[float | int | bool | str] | Distribution]) ‑> collections.abc.Iterator[dict[str, float | int | bool | str]]

Expands any list values in config.

def main(config_file: str, /, n_per_discrete: int = 1, override: Args = Args(seed=42, p_dropout=0.2, model_d=128, n_layers=6, init_std=0.02, resize_size=256, crop_size=224, n_classes=1000, v2_dir='', batch_size=256, n_workers=4, p_mixup=0.2, pin_memory=False, learning_rate=0.001, lr_schedule='warmup', n_lr_warmup=10000, beta1=0.9, beta2=0.999, grad_clip=1.0, grad_accum=1, weight_decay=0.0001, n_epochs=90, do_mup=True, mup_base_d=128, log_every=10, track=True, ckpt_dir='./checkpoints', tags=[]), slurm: bool = False, n_cpus: int = 0, n_gpus: int = 0, n_hours: int = 0, sacct: str = '')

Start a hyperparameter sweep of training runs using either a Slurm cluster or a local GPU. Results are written to a sqlite file, which can be queried for final metrics to make plots like those you see in SAE papers (comparing sparsity and reconstruction loss).

Args

configs
list of config filepaths.
n_per_discrete
number of random samples to draw for each discrete config.
override
individual arguments that you want to override for all jobs.
slurm
whether to use a slurm cluster for running jobs or a local GPU.
n_cpus
(slurm only) how many cpus to use; should be at least as many as Args.n_workers.
n_gpus
(slurm only) how many gpus to use.
n_hours
(slurm only) how many hours to run a slurm job for.
sacct
(slurm only) the slurm account.
def overwrite(args: Args, override: Args) ‑> Args

If there are any non-default values in override, returns a copy of args with all those values included.

Arguments

args: sweep args override: incoming args with zero or more non-default values.

Returns

frx.train.Args

def roberts_sequence(num_points: int, dim: int, root_iters: int = 10000, complement_basis: bool = True, perturb: bool = True, key: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, ForwardRef(None)] = None, dtype=builtins.float)

Returns the Roberts sequence, a low-discrepancy quasi-random sequence: Low-discrepancy sequences are useful for quasi-Monte Carlo methods. Reference: Martin Roberts. The Unreasonable Effectiveness of Quasirandom Sequences. extremelearning.com.au/unreasonable-effectiveness-of-quasirandom-sequences

Args

num_points
Number of points to return.
dim
The dimensionality of each point in the sequence.
root_iters
Number of iterations to use to find the root.
complement_basis
Complement the basis to improve precision, as described in https://www.martysmods.com/a-better-r2-sequence.
key
a PRNG key.
dtype
optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns

An array of shape (num_points, dim) containing the sequence. From https://github.com/jax-ml/jax/pull/23808

def sample_from(config: dict[str, float | int | bool | str | Distribution], *, n: int) ‑> collections.abc.Iterator[dict[str, float | int | bool | str]]

Classes

class Distribution (*args, **kwargs)

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Ancestors

  • builtins.dict

Class variables

var dist : Literal['loguniform']
var max : float
var min : float