Module benchmark
Entrypoint for running all tasks in biobench
.
Most of this script is self documenting.
Run python benchmark.py --help
to see all the options.
Note that you will have to download all the datasets, but each dataset includes its own download script with instructions.
For example, see biobench.newt.download
for an example.
Examples
Run Everything
Suppose you want to run all the tasks for all the default models to get started using a local GPU (device 4, for example).
You need to specify all the --TASK-run
flags and --TASK-args.datadir
so that each task knows where to load data from.
CUDA_VISIBLE_DEVICES=4 python benchmark.py \
--kabr-run --kabr-args.datadir /local/scratch/stevens.994/datasets/kabr \
--iwildcam-run --iwildcam-args.datadir /local/scratch/stevens.994/datasets/iwildcam \
--plantnet-run --plantnet-args.datadir /local/scratch/stevens.994/datasets/plantnet \
--birds525-run --birds525-args.datadir /local/scratch/stevens.994/datasets/birds525 \
--newt-run --newt-args.datadir /local/scratch/stevens.994/datasets/newt \
--beluga-run --beluga-args.datadir /local/scratch/stevens.994/datasets/beluga \
--ages-run --ages-args.datadir /local/scratch/stevens.994/datasets/newt \
--fishnet-run --fishnet-args.datadir /local/scratch/stevens.994/datasets/fishnet
More generally, you can configure options for individual tasks using --TASK-args.<OPTION>
, which are all documented python benchmark.py --help
.
Just One Task
Suppose you just want to run one task (NeWT).
CUDA_VISIBLE_DEVICES=4 python benchmark.py \
--newt-run --newt-args.datadir /local/scratch/stevens.994/datasets/newt
Just One Model
Suppose you only want to run the SigLIP SO400M ViT from Open CLIP, but you want to run it on all tasks.
Since that model is a checkpoint in Open CLIP, we can use the OpenClip
class to load the checkpoint.
CUDA_VISIBLE_DEVICES=4 python benchmark.py \
--kabr-run --kabr-args.datadir /local/scratch/stevens.994/datasets/kabr \
--iwildcam-run --iwildcam-args.datadir /local/scratch/stevens.994/datasets/iwildcam \
--plantnet-run --plantnet-args.datadir /local/scratch/stevens.994/datasets/plantnet \
--birds525-run --birds525-args.datadir /local/scratch/stevens.994/datasets/birds525 \
--newt-run --newt-args.datadir /local/scratch/stevens.994/datasets/newt \
--beluga-run --beluga-args.datadir /local/scratch/stevens.994/datasets/beluga \
--ages-run --ages-args.datadir /local/scratch/stevens.994/datasets/newt \
--fishnet-run --fishnet-args.datadir /local/scratch/stevens.994/datasets/fishnet \
--model open-clip ViT-SO400M-14-SigLIP/webli # <- This is the new line!
Use Slurm
Slurm clusters with lots of GPUs can be used to run lots of tasks in parallel.
It's really easy with biobench
.
python benchmark.py \
--kabr-run --kabr-args.datadir /local/scratch/stevens.994/datasets/kabr \
--iwildcam-run --iwildcam-args.datadir /local/scratch/stevens.994/datasets/iwildcam \
--plantnet-run --plantnet-args.datadir /local/scratch/stevens.994/datasets/plantnet \
--birds525-run --birds525-args.datadir /local/scratch/stevens.994/datasets/birds525 \
--newt-run --newt-args.datadir /local/scratch/stevens.994/datasets/newt \
--beluga-run --beluga-args.datadir /local/scratch/stevens.994/datasets/beluga \
--ages-run --ages-args.datadir /local/scratch/stevens.994/datasets/newt \
--fishnet-run --fishnet-args.datadir /local/scratch/stevens.994/datasets/fishnet \
--slurm # <- Just add --slurm to use slurm!
Note that you don't need to specify CUDA_VISIBLE_DEVICES
anymore because you're not running on the local machine anymore.
Design
biobench is designed to make it easy to add both models and tasks that work with other models and tasks.
To add a new model, look at biobench.registry
's documentation, which includes a tutorial for adding a new model.
Functions
def export_to_csv(args: Args) ‑> set[str]
-
Exports (and writes) to a wide table format for viewing (long table formats are better for additional manipulation/graphing, but wide is easy for viewing).
def main(args: Args)
-
Launch all jobs, using either a local GPU or a Slurm cluster. Then report results and save to disk.
def plot_task(conn: sqlite3.Connection, task: str)
-
Plots the most recent result for each model on given task, including confidence intervals. Returns the figure so the caller can save or display it.
Args
conn
- connection to database.
task
- which task to run.
Returns
matplotlib.pyplot.Figure
def save(args: Args, model_args: ModelArgs, report: TaskReport) ‑> None
-
Saves the report to disk in a machine-readable SQLite format.
Args
args
- launch script arguments.
model_args
- a pair of model_org, model_ckpt strings.
report
- the task report from the model_args.
Classes
class Args (slurm: bool = False, slurm_acct: str = 'PAS2136', model_args: typing.Annotated[list[ModelArgs], _ArgConfiguration(name='model', metavar=None, help=None, help_behavior_hint=None, aliases=None, prefix_name=None, constructor_factory=None)] = <factory>, device: Literal['cpu', 'cuda'] = 'cuda', debug: bool = False, ssl: bool = True, ages_run: bool = False, ages_args: Args = <factory>, beluga_run: bool = False, beluga_args: Args = <factory>, birds525_run: bool = False, birds525_args: Args = <factory>, fishnet_run: bool = False, fishnet_args: Args = <factory>, imagenet_run: bool = False, imagenet_args: Args = <factory>, inat21_run: bool = False, inat21_args: Args = <factory>, iwildcam_run: bool = False, iwildcam_args: Args = <factory>, kabr_run: bool = False, kabr_args: Args = <factory>, leopard_run: bool = False, leopard_args: Args = <factory>, newt_run: bool = False, newt_args: Args = <factory>, plankton_run: bool = False, plankton_args: Args = <factory>, plantnet_run: bool = False, plantnet_args: Args = <factory>, rarespecies_run: bool = False, rarespecies_args: Args = <factory>, report_to: str = './reports', graph: bool = True, graph_to: str = './graphs', log_to: str = './logs')
-
Params to run one or more benchmarks in a parallel setting.
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Args: """Params to run one or more benchmarks in a parallel setting.""" slurm: bool = False """whether to use submitit to run jobs on a slurm cluster.""" slurm_acct: str = "PAS2136" """slurm account string.""" model_args: typing.Annotated[ list[interfaces.ModelArgs], tyro.conf.arg(name="model") ] = dataclasses.field( default_factory=lambda: [ interfaces.ModelArgs("open-clip", "RN50/openai"), interfaces.ModelArgs("open-clip", "ViT-B-16/openai"), interfaces.ModelArgs("open-clip", "ViT-B-16/laion400m_e32"), interfaces.ModelArgs("open-clip", "hf-hub:imageomics/bioclip"), interfaces.ModelArgs("open-clip", "ViT-B-16-SigLIP/webli"), interfaces.ModelArgs("timm-vit", "vit_base_patch14_reg4_dinov2.lvd142m"), ] ) """model; a pair of model org (interface) and checkpoint.""" device: typing.Literal["cpu", "cuda"] = "cuda" """which kind of accelerator to use.""" debug: bool = False """whether to run in debug mode.""" ssl: bool = True """Use SSL when connecting to remote servers to download checkpoints; use --no-ssl if your machine has certificate issues. See `biobench.third_party_models.get_ssl()` for a discussion of how this works.""" # Individual benchmarks. ages_run: bool = False """Whether to run the bird age benchmark.""" ages_args: ages.Args = dataclasses.field(default_factory=ages.Args) """Arguments for the bird age benchmark.""" beluga_run: bool = False """Whether to run the Beluga whale re-ID benchmark.""" beluga_args: beluga.Args = dataclasses.field(default_factory=beluga.Args) """Arguments for the Beluga whale re-ID benchmark.""" birds525_run: bool = False """whether to run the Birds 525 benchmark.""" birds525_args: birds525.Args = dataclasses.field(default_factory=birds525.Args) """arguments for the Birds 525 benchmark.""" fishnet_run: bool = False """Whether to run the FishNet benchmark.""" fishnet_args: fishnet.Args = dataclasses.field(default_factory=fishnet.Args) """Arguments for the FishNet benchmark.""" imagenet_run: bool = False """Whether to run the ImageNet-1K benchmark.""" imagenet_args: imagenet.Args = dataclasses.field(default_factory=imagenet.Args) """Arguments for the ImageNet-1K benchmark.""" inat21_run: bool = False """Whether to run the iNat21 benchmark.""" inat21_args: inat21.Args = dataclasses.field(default_factory=inat21.Args) """Arguments for the iNat21 benchmark.""" iwildcam_run: bool = False """whether to run the iWildCam benchmark.""" iwildcam_args: iwildcam.Args = dataclasses.field(default_factory=iwildcam.Args) """arguments for the iWildCam benchmark.""" kabr_run: bool = False """whether to run the KABR benchmark.""" kabr_args: kabr.Args = dataclasses.field(default_factory=kabr.Args) """arguments for the KABR benchmark.""" leopard_run: bool = False """Whether to run the leopard re-ID benchmark.""" leopard_args: leopard.Args = dataclasses.field(default_factory=leopard.Args) """Arguments for the leopard re-ID benchmark.""" newt_run: bool = False """whether to run the NeWT benchmark.""" newt_args: newt.Args = dataclasses.field(default_factory=newt.Args) """arguments for the NeWT benchmark.""" plankton_run: bool = False """Whether to run the Plankton benchmark.""" plankton_args: plankton.Args = dataclasses.field(default_factory=plankton.Args) """Arguments for the Plankton benchmark.""" plantnet_run: bool = False """whether to run the Pl@ntNet benchmark.""" plantnet_args: plantnet.Args = dataclasses.field(default_factory=plantnet.Args) """arguments for the Pl@ntNet benchmark.""" rarespecies_run: bool = False rarespecies_args: rarespecies.Args = dataclasses.field( default_factory=rarespecies.Args ) """Arguments for the Rare Species benchmark.""" # Reporting and graphing. report_to: str = os.path.join(".", "reports") """where to save reports to.""" graph: bool = True """whether to make graphs.""" graph_to: str = os.path.join(".", "graphs") """where to save graphs to.""" log_to: str = os.path.join(".", "logs") """where to save logs to.""" def to_dict(self) -> dict[str, object]: return dataclasses.asdict(self) def get_sqlite_connection(self) -> sqlite3.Connection: """Get a connection to the reports database. Returns: a connection to a sqlite3 database. """ return sqlite3.connect(os.path.join(self.report_to, "reports.sqlite"))
Class variables
var ages_args : Args
-
Arguments for the bird age benchmark.
var ages_run : bool
-
Whether to run the bird age benchmark.
var beluga_args : Args
-
Arguments for the Beluga whale re-ID benchmark.
var beluga_run : bool
-
Whether to run the Beluga whale re-ID benchmark.
var birds525_args : Args
-
arguments for the Birds 525 benchmark.
var birds525_run : bool
-
whether to run the Birds 525 benchmark.
var debug : bool
-
whether to run in debug mode.
var device : Literal['cpu', 'cuda']
-
which kind of accelerator to use.
var fishnet_args : Args
-
Arguments for the FishNet benchmark.
var fishnet_run : bool
-
Whether to run the FishNet benchmark.
var graph : bool
-
whether to make graphs.
var graph_to : str
-
where to save graphs to.
var imagenet_args : Args
-
Arguments for the ImageNet-1K benchmark.
var imagenet_run : bool
-
Whether to run the ImageNet-1K benchmark.
var inat21_args : Args
-
Arguments for the iNat21 benchmark.
var inat21_run : bool
-
Whether to run the iNat21 benchmark.
var iwildcam_args : Args
-
arguments for the iWildCam benchmark.
var iwildcam_run : bool
-
whether to run the iWildCam benchmark.
var kabr_args : Args
-
arguments for the KABR benchmark.
var kabr_run : bool
-
whether to run the KABR benchmark.
var leopard_args : Args
-
Arguments for the leopard re-ID benchmark.
var leopard_run : bool
-
Whether to run the leopard re-ID benchmark.
var log_to : str
-
where to save logs to.
var model_args : list[ModelArgs]
-
model; a pair of model org (interface) and checkpoint.
var newt_args : Args
-
arguments for the NeWT benchmark.
var newt_run : bool
-
whether to run the NeWT benchmark.
var plankton_args : Args
-
Arguments for the Plankton benchmark.
var plankton_run : bool
-
Whether to run the Plankton benchmark.
var plantnet_args : Args
-
arguments for the Pl@ntNet benchmark.
var plantnet_run : bool
-
whether to run the Pl@ntNet benchmark.
var rarespecies_args : Args
-
Arguments for the Rare Species benchmark.
var rarespecies_run : bool
var report_to : str
-
where to save reports to.
var slurm : bool
-
whether to use submitit to run jobs on a slurm cluster.
var slurm_acct : str
-
slurm account string.
var ssl : bool
-
Use SSL when connecting to remote servers to download checkpoints; use –no-ssl if your machine has certificate issues. See
get_ssl()
for a discussion of how this works.
Methods
def get_sqlite_connection(self) ‑> sqlite3.Connection
-
Get a connection to the reports database.
Returns
a connection to a sqlite3 database.
def to_dict(self) ‑> dict[str, object]