Module benchmark

Entrypoint for running all tasks in biobench.

Most of this script is self documenting. Run python benchmark.py --help to see all the options.

Note that you will have to download all the datasets, but each dataset includes its own download script with instructions. For example, see biobench.newt.download for an example.

Examples

Run Everything

Suppose you want to run all the tasks for all the default models to get started using a local GPU (device 4, for example). You need to specify all the --TASK-run flags and --TASK-args.datadir so that each task knows where to load data from.

CUDA_VISIBLE_DEVICES=4 python benchmark.py \
  --kabr-run --kabr-args.datadir /local/scratch/stevens.994/datasets/kabr \
  --iwildcam-run --iwildcam-args.datadir /local/scratch/stevens.994/datasets/iwildcam \
  --plantnet-run --plantnet-args.datadir /local/scratch/stevens.994/datasets/plantnet \
  --birds525-run --birds525-args.datadir /local/scratch/stevens.994/datasets/birds525 \
  --newt-run --newt-args.datadir /local/scratch/stevens.994/datasets/newt \
  --beluga-run --beluga-args.datadir /local/scratch/stevens.994/datasets/beluga \
  --ages-run --ages-args.datadir /local/scratch/stevens.994/datasets/newt \
  --fishnet-run --fishnet-args.datadir /local/scratch/stevens.994/datasets/fishnet

More generally, you can configure options for individual tasks using --TASK-args.<OPTION>, which are all documented python benchmark.py --help.

Just One Task

Suppose you just want to run one task (NeWT).

CUDA_VISIBLE_DEVICES=4 python benchmark.py \
  --newt-run --newt-args.datadir /local/scratch/stevens.994/datasets/newt

Just One Model

Suppose you only want to run the SigLIP SO400M ViT from Open CLIP, but you want to run it on all tasks. Since that model is a checkpoint in Open CLIP, we can use the OpenClip class to load the checkpoint.

CUDA_VISIBLE_DEVICES=4 python benchmark.py \
  --kabr-run --kabr-args.datadir /local/scratch/stevens.994/datasets/kabr \
  --iwildcam-run --iwildcam-args.datadir /local/scratch/stevens.994/datasets/iwildcam \
  --plantnet-run --plantnet-args.datadir /local/scratch/stevens.994/datasets/plantnet \
  --birds525-run --birds525-args.datadir /local/scratch/stevens.994/datasets/birds525 \
  --newt-run --newt-args.datadir /local/scratch/stevens.994/datasets/newt \
  --beluga-run --beluga-args.datadir /local/scratch/stevens.994/datasets/beluga \
  --ages-run --ages-args.datadir /local/scratch/stevens.994/datasets/newt \
  --fishnet-run --fishnet-args.datadir /local/scratch/stevens.994/datasets/fishnet \
  --model open-clip ViT-SO400M-14-SigLIP/webli  # <- This is the new line!

Use Slurm

Slurm clusters with lots of GPUs can be used to run lots of tasks in parallel. It's really easy with biobench.

python benchmark.py \
  --kabr-run --kabr-args.datadir /local/scratch/stevens.994/datasets/kabr \
  --iwildcam-run --iwildcam-args.datadir /local/scratch/stevens.994/datasets/iwildcam \
  --plantnet-run --plantnet-args.datadir /local/scratch/stevens.994/datasets/plantnet \
  --birds525-run --birds525-args.datadir /local/scratch/stevens.994/datasets/birds525 \
  --newt-run --newt-args.datadir /local/scratch/stevens.994/datasets/newt \
  --beluga-run --beluga-args.datadir /local/scratch/stevens.994/datasets/beluga \
  --ages-run --ages-args.datadir /local/scratch/stevens.994/datasets/newt \
  --fishnet-run --fishnet-args.datadir /local/scratch/stevens.994/datasets/fishnet \
  --slurm  # <- Just add --slurm to use slurm!

Note that you don't need to specify CUDA_VISIBLE_DEVICES anymore because you're not running on the local machine anymore.

Design

biobench is designed to make it easy to add both models and tasks that work with other models and tasks.

To add a new model, look at biobench.registry's documentation, which includes a tutorial for adding a new model.

Functions

def export_to_csv(args: Args) ‑> set[str]

Exports (and writes) to a wide table format for viewing (long table formats are better for additional manipulation/graphing, but wide is easy for viewing).

def main(args: Args)

Launch all jobs, using either a local GPU or a Slurm cluster. Then report results and save to disk.

def plot_task(conn: sqlite3.Connection, task: str)

Plots the most recent result for each model on given task, including confidence intervals. Returns the figure so the caller can save or display it.

Args

conn
connection to database.
task
which task to run.

Returns

matplotlib.pyplot.Figure

def save(args: Args, model_args: ModelArgs, report: TaskReport) ‑> None

Saves the report to disk in a machine-readable SQLite format.

Args

args
launch script arguments.
model_args
a pair of model_org, model_ckpt strings.
report
the task report from the model_args.

Classes

class Args (slurm: bool = False, slurm_acct: str = 'PAS2136', model_args: typing.Annotated[list[ModelArgs], _ArgConfiguration(name='model', metavar=None, help=None, help_behavior_hint=None, aliases=None, prefix_name=None, constructor_factory=None)] = <factory>, device: Literal['cpu', 'cuda'] = 'cuda', debug: bool = False, ssl: bool = True, ages_run: bool = False, ages_args: Args = <factory>, beluga_run: bool = False, beluga_args: Args = <factory>, birds525_run: bool = False, birds525_args: Args = <factory>, fishnet_run: bool = False, fishnet_args: Args = <factory>, imagenet_run: bool = False, imagenet_args: Args = <factory>, inat21_run: bool = False, inat21_args: Args = <factory>, iwildcam_run: bool = False, iwildcam_args: Args = <factory>, kabr_run: bool = False, kabr_args: Args = <factory>, leopard_run: bool = False, leopard_args: Args = <factory>, newt_run: bool = False, newt_args: Args = <factory>, plankton_run: bool = False, plankton_args: Args = <factory>, plantnet_run: bool = False, plantnet_args: Args = <factory>, rarespecies_run: bool = False, rarespecies_args: Args = <factory>, report_to: str = './reports', graph: bool = True, graph_to: str = './graphs', log_to: str = './logs')

Params to run one or more benchmarks in a parallel setting.

Expand source code
@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Args:
    """Params to run one or more benchmarks in a parallel setting."""

    slurm: bool = False
    """whether to use submitit to run jobs on a slurm cluster."""
    slurm_acct: str = "PAS2136"
    """slurm account string."""

    model_args: typing.Annotated[
        list[interfaces.ModelArgs], tyro.conf.arg(name="model")
    ] = dataclasses.field(
        default_factory=lambda: [
            interfaces.ModelArgs("open-clip", "RN50/openai"),
            interfaces.ModelArgs("open-clip", "ViT-B-16/openai"),
            interfaces.ModelArgs("open-clip", "ViT-B-16/laion400m_e32"),
            interfaces.ModelArgs("open-clip", "hf-hub:imageomics/bioclip"),
            interfaces.ModelArgs("open-clip", "ViT-B-16-SigLIP/webli"),
            interfaces.ModelArgs("timm-vit", "vit_base_patch14_reg4_dinov2.lvd142m"),
        ]
    )
    """model; a pair of model org (interface) and checkpoint."""
    device: typing.Literal["cpu", "cuda"] = "cuda"
    """which kind of accelerator to use."""
    debug: bool = False
    """whether to run in debug mode."""
    ssl: bool = True
    """Use SSL when connecting to remote servers to download checkpoints; use --no-ssl if your machine has certificate issues. See `biobench.third_party_models.get_ssl()` for a discussion of how this works."""

    # Individual benchmarks.
    ages_run: bool = False
    """Whether to run the bird age benchmark."""
    ages_args: ages.Args = dataclasses.field(default_factory=ages.Args)
    """Arguments for the bird age benchmark."""
    beluga_run: bool = False
    """Whether to run the Beluga whale re-ID benchmark."""
    beluga_args: beluga.Args = dataclasses.field(default_factory=beluga.Args)
    """Arguments for the Beluga whale re-ID benchmark."""
    birds525_run: bool = False
    """whether to run the Birds 525 benchmark."""
    birds525_args: birds525.Args = dataclasses.field(default_factory=birds525.Args)
    """arguments for the Birds 525 benchmark."""
    fishnet_run: bool = False
    """Whether to run the FishNet benchmark."""
    fishnet_args: fishnet.Args = dataclasses.field(default_factory=fishnet.Args)
    """Arguments for the FishNet benchmark."""
    imagenet_run: bool = False
    """Whether to run the ImageNet-1K benchmark."""
    imagenet_args: imagenet.Args = dataclasses.field(default_factory=imagenet.Args)
    """Arguments for the ImageNet-1K benchmark."""
    inat21_run: bool = False
    """Whether to run the iNat21 benchmark."""
    inat21_args: inat21.Args = dataclasses.field(default_factory=inat21.Args)
    """Arguments for the iNat21 benchmark."""
    iwildcam_run: bool = False
    """whether to run the iWildCam benchmark."""
    iwildcam_args: iwildcam.Args = dataclasses.field(default_factory=iwildcam.Args)
    """arguments for the iWildCam benchmark."""
    kabr_run: bool = False
    """whether to run the KABR benchmark."""
    kabr_args: kabr.Args = dataclasses.field(default_factory=kabr.Args)
    """arguments for the KABR benchmark."""
    leopard_run: bool = False
    """Whether to run the leopard re-ID benchmark."""
    leopard_args: leopard.Args = dataclasses.field(default_factory=leopard.Args)
    """Arguments for the leopard re-ID benchmark."""
    newt_run: bool = False
    """whether to run the NeWT benchmark."""
    newt_args: newt.Args = dataclasses.field(default_factory=newt.Args)
    """arguments for the NeWT benchmark."""
    plankton_run: bool = False
    """Whether to run the Plankton benchmark."""
    plankton_args: plankton.Args = dataclasses.field(default_factory=plankton.Args)
    """Arguments for the Plankton benchmark."""
    plantnet_run: bool = False
    """whether to run the Pl@ntNet benchmark."""
    plantnet_args: plantnet.Args = dataclasses.field(default_factory=plantnet.Args)
    """arguments for the Pl@ntNet benchmark."""
    rarespecies_run: bool = False
    rarespecies_args: rarespecies.Args = dataclasses.field(
        default_factory=rarespecies.Args
    )
    """Arguments for the Rare Species benchmark."""

    # Reporting and graphing.
    report_to: str = os.path.join(".", "reports")
    """where to save reports to."""
    graph: bool = True
    """whether to make graphs."""
    graph_to: str = os.path.join(".", "graphs")
    """where to save graphs to."""
    log_to: str = os.path.join(".", "logs")
    """where to save logs to."""

    def to_dict(self) -> dict[str, object]:
        return dataclasses.asdict(self)

    def get_sqlite_connection(self) -> sqlite3.Connection:
        """Get a connection to the reports database.
        Returns:
            a connection to a sqlite3 database.
        """
        return sqlite3.connect(os.path.join(self.report_to, "reports.sqlite"))

Class variables

var ages_argsArgs

Arguments for the bird age benchmark.

var ages_run : bool

Whether to run the bird age benchmark.

var beluga_argsArgs

Arguments for the Beluga whale re-ID benchmark.

var beluga_run : bool

Whether to run the Beluga whale re-ID benchmark.

var birds525_argsArgs

arguments for the Birds 525 benchmark.

var birds525_run : bool

whether to run the Birds 525 benchmark.

var debug : bool

whether to run in debug mode.

var device : Literal['cpu', 'cuda']

which kind of accelerator to use.

var fishnet_argsArgs

Arguments for the FishNet benchmark.

var fishnet_run : bool

Whether to run the FishNet benchmark.

var graph : bool

whether to make graphs.

var graph_to : str

where to save graphs to.

var imagenet_argsArgs

Arguments for the ImageNet-1K benchmark.

var imagenet_run : bool

Whether to run the ImageNet-1K benchmark.

var inat21_argsArgs

Arguments for the iNat21 benchmark.

var inat21_run : bool

Whether to run the iNat21 benchmark.

var iwildcam_argsArgs

arguments for the iWildCam benchmark.

var iwildcam_run : bool

whether to run the iWildCam benchmark.

var kabr_argsArgs

arguments for the KABR benchmark.

var kabr_run : bool

whether to run the KABR benchmark.

var leopard_argsArgs

Arguments for the leopard re-ID benchmark.

var leopard_run : bool

Whether to run the leopard re-ID benchmark.

var log_to : str

where to save logs to.

var model_args : list[ModelArgs]

model; a pair of model org (interface) and checkpoint.

var newt_argsArgs

arguments for the NeWT benchmark.

var newt_run : bool

whether to run the NeWT benchmark.

var plankton_argsArgs

Arguments for the Plankton benchmark.

var plankton_run : bool

Whether to run the Plankton benchmark.

var plantnet_argsArgs

arguments for the Pl@ntNet benchmark.

var plantnet_run : bool

whether to run the Pl@ntNet benchmark.

var rarespecies_argsArgs

Arguments for the Rare Species benchmark.

var rarespecies_run : bool
var report_to : str

where to save reports to.

var slurm : bool

whether to use submitit to run jobs on a slurm cluster.

var slurm_acct : str

slurm account string.

var ssl : bool

Use SSL when connecting to remote servers to download checkpoints; use –no-ssl if your machine has certificate issues. See get_ssl() for a discussion of how this works.

Methods

def get_sqlite_connection(self) ‑> sqlite3.Connection

Get a connection to the reports database.

Returns

a connection to a sqlite3 database.

def to_dict(self) ‑> dict[str, object]