Module sweep
Launch script to kick off a hyperparameter sweep using a Slurm cluster (with submitit).
Functions
def expand(config: dict[str, float | int | bool | str | list[float | int | bool | str] | Distribution], *, n_per_discrete: int) ‑> collections.abc.Iterator[dict[str, float | int | bool | str]]
def expand_discrete(config: dict[str, float | int | bool | str | list[float | int | bool | str] | Distribution]) ‑> collections.abc.Iterator[dict[str, float | int | bool | str]]
-
Expands any list values in
config
. def main(config_file: str, /, n_per_discrete: int = 1, override: Args = Args(seed=42, p_dropout=0.2, model_d=128, n_layers=6, init_std=0.02, resize_size=256, crop_size=224, n_classes=1000, v2_dir='', batch_size=256, n_workers=4, p_mixup=0.2, pin_memory=False, learning_rate=0.001, lr_schedule='warmup', n_lr_warmup=10000, beta1=0.9, beta2=0.999, grad_clip=1.0, grad_accum=1, weight_decay=0.0001, n_epochs=90, do_mup=True, mup_base_d=128, log_every=10, track=True, ckpt_dir='./checkpoints', tags=[]), slurm: bool = False, n_cpus: int = 0, n_gpus: int = 0, n_hours: int = 0, sacct: str = '')
-
Start a hyperparameter sweep of training runs using either a Slurm cluster or a local GPU. Results are written to a sqlite file, which can be queried for final metrics to make plots like those you see in SAE papers (comparing sparsity and reconstruction loss).
Args
configs
- list of config filepaths.
n_per_discrete
- number of random samples to draw for each discrete config.
override
- individual arguments that you want to override for all jobs.
slurm
- whether to use a slurm cluster for running jobs or a local GPU.
n_cpus
- (slurm only) how many cpus to use; should be at least as many as
Args.n_workers
. n_gpus
- (slurm only) how many gpus to use.
n_hours
- (slurm only) how many hours to run a slurm job for.
sacct
- (slurm only) the slurm account.
def overwrite(args: Args, override: Args) ‑> Args
-
If there are any non-default values in override, returns a copy of
args
with all those values included.Arguments
args: sweep args override: incoming args with zero or more non-default values.
Returns
frx.train.Args
def roberts_sequence(num_points: int, dim: int, root_iters: int = 10000, complement_basis: bool = True, perturb: bool = True, key: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, ForwardRef(None)] = None, dtype=builtins.float)
-
Returns the Roberts sequence, a low-discrepancy quasi-random sequence: Low-discrepancy sequences are useful for quasi-Monte Carlo methods. Reference: Martin Roberts. The Unreasonable Effectiveness of Quasirandom Sequences. extremelearning.com.au/unreasonable-effectiveness-of-quasirandom-sequences
Args
num_points
- Number of points to return.
dim
- The dimensionality of each point in the sequence.
root_iters
- Number of iterations to use to find the root.
complement_basis
- Complement the basis to improve precision, as described in https://www.martysmods.com/a-better-r2-sequence.
key
- a PRNG key.
dtype
- optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns
An array of shape (num_points, dim) containing the sequence. From https://github.com/jax-ml/jax/pull/23808
def sample_from(config: dict[str, float | int | bool | str | Distribution], *, n: int) ‑> collections.abc.Iterator[dict[str, float | int | bool | str]]
Classes
class Distribution (*args, **kwargs)
-
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
Ancestors
- builtins.dict
Class variables
var dist : Literal['loguniform']
var max : float
var min : float