Module biobench.newt

NeWT: Natural World Tasks

NeWT is a collection of 164 binary classification tasks related to visual understanding of the natural world (CVPR 2021 paper, code).

We evaluate a vision model by extracting visual features for each image, fitting a linear SVM to the training examples, and evaluating on the test data. We aggregate scores across all 164 tasks.

If you use this evaluation, be sure to cite the original work:

@inproceedings{van2021benchmarking,
  title={Benchmarking Representation Learning for Natural World Image Collections},
  author={Van Horn, Grant and Cole, Elijah and Beery, Sara and Wilber, Kimberly and Belongie, Serge and Mac Aodha, Oisin},
  booktitle={Computer Vision and Pattern Recognition},
  year={2021}
}

Sub-modules

biobench.newt.download

A script to download the NeWT dataset …

Functions

def benchmark(cfg: Experiment) ‑> Report
Expand source code
@beartype.beartype
def benchmark(cfg: config.Experiment) -> reporting.Report:
    """
    The NeWT benchmark.
    First, get features for all images.
    Second, select the subsets of features that correspond to different tasks and train an SVM.
    Third, evaluate the SVM and report results.
    """

    # Fit SVMs.
    all_preds = []
    for task in get_all_tasks(cfg):
        (x_train, y_train), (x_test, y_test) = task.splits

        x_mean = x_train.mean(axis=0, keepdims=True)

        x_train = x_train - x_mean
        x_train = l2_normalize(x_train)

        x_test = x_test - x_mean
        x_test = l2_normalize(x_test)

        svc = init_svc(cfg.n_train)

        svc.fit(x_train, y_train)
        y_pred = svc.predict(x_test)
        info = {
            "task": task.name,
            "cluster": task.cluster,
            "subcluster": task.subcluster,
        }
        preds = [
            reporting.Prediction(str(id), float(pred == true), info)
            for id, pred, true in zip(task.example_ids, y_pred, y_test)
        ]

        all_preds.extend(preds)

    return reporting.Report("newt", all_preds, cfg)

The NeWT benchmark. First, get features for all images. Second, select the subsets of features that correspond to different tasks and train an SVM. Third, evaluate the SVM and report results.

def bootstrap_scores(df: polars.dataframe.frame.DataFrame,
*,
b: int = 0,
rng: numpy.random._generator.Generator | None = None) ‑> dict[str, jaxtyping.Float[ndarray, 'b']]
Expand source code
@jaxtyped(typechecker=beartype.beartype)
def bootstrap_scores(
    df: pl.DataFrame, *, b: int = 0, rng: np.random.Generator | None = None
) -> dict[str, Float[np.ndarray, " b"]]:
    assert df.get_column("task_name").unique().to_list() == ["newt"]

    n, *rest = df.group_by("model_ckpt").agg(n=pl.len()).get_column("n").to_list()
    assert all(n == i for i in rest)

    if b > 0:
        assert rng is not None, "must provide rng argument"
        i_bs = rng.integers(0, n, size=(b, n), dtype=np.int32)

    scores = {}

    scores_buf = np.empty((b, n), dtype=np.float32)

    for model_ckpt in df.get_column("model_ckpt").unique().sort().to_list():
        # pull y_true and y_pred for *one* model
        scores_ = (
            df.filter(pl.col("model_ckpt") == model_ckpt)
            .select("img_id", "score")
            .unique()
            .sort("img_id")
            .get_column("score")
            .cast(pl.Float32)
            .to_numpy()
        )

        if len(scores_) == 0:
            continue

        if b > 0:
            # bootstrap resample into pre-allocated buffers
            np.take(scores_, i_bs, axis=0, out=scores_buf)
            scores[model_ckpt] = scores_buf.mean(axis=1)
        else:
            scores[model_ckpt] = np.array([scores_.mean()])

    return scores
def get_all_tasks(cfg: Experiment) ‑> Iterator[Task]
Expand source code
@jaxtyped(typechecker=beartype.beartype)
@torch.no_grad()
def get_all_tasks(cfg: config.Experiment) -> collections.abc.Iterator[Task]:
    """ """
    rng = np.random.default_rng(seed=cfg.seed)

    # Load model
    backbone = registry.load_vision_backbone(cfg.model)
    img_transform = backbone.make_img_transform()
    backbone = torch.compile(backbone.to(cfg.device))

    labels_csv_name = "newt2021_labels.csv"
    labels_csv_path = os.path.join(cfg.data.newt, labels_csv_name)
    imgs_dir_name = "newt2021_images"
    imgs_dir_path = os.path.join(cfg.data.newt, imgs_dir_name)

    if not os.path.isfile(labels_csv_path):
        msg = f"Path '{labels_csv_path}' doesn't exist. Did you download the Newt dataset? See the docstring at the top of this file for instructions. If you did download it, pass the path with '--data'; see --help for more."
        raise RuntimeError(msg)

    # Read the CSV and add row indices
    df = pl.read_csv(labels_csv_path).with_row_index(name="original_index")

    # Sample balanced training data for each task
    df = sample(rng, df, cfg.n_train).with_row_index(name="sampled_index")

    # Get all image IDs and labels
    all_data = df.select("id", "label").to_numpy(structured=True)
    all_ids, all_labels = all_data["id"], all_data["label"]

    # Create dataset with all samples
    dataset = Dataset(
        imgs_dir_path,
        all_ids,
        all_labels,
        img_transform,
    )

    dataloader = torch.utils.data.DataLoader(
        dataset,
        num_workers=cfg.n_workers,
        drop_last=False,
        shuffle=False,
        pin_memory=False,
        persistent_workers=False,
    )

    def probe(batch):
        imgs = batch["img"].to(cfg.device, non_blocking=True)
        with torch.amp.autocast(cfg.device):
            _ = backbone.img_encode(imgs).img_features  # forward only

    all_features, all_ids = [], []

    with helpers.auto_batch_size(dataloader, probe=probe):
        total = len(dataloader) if not cfg.debug else 2
        it = iter(dataloader)
        for b in helpers.progress(range(total), every=10, desc="newt"):
            batch = next(it)
            imgs = batch["img"].to(cfg.device)

            with torch.amp.autocast("cuda"):
                features = backbone.img_encode(imgs).img_features
                features = torch.nn.functional.normalize(features, dim=-1)
                all_features.append(features.cpu())

            all_ids.extend(batch["img_id"])

    all_features = torch.cat(all_features, dim=0).cpu()
    all_ids = np.array(all_ids)

    for task in df.get_column("task").unique():
        task_df = df.filter(pl.col("task") == task)

        task_idx = task_df.get_column("sampled_index").to_numpy()
        features = all_features[task_idx].numpy()
        ids = all_ids[task_idx]

        labels = task_df.get_column("label").to_numpy()
        is_train = task_df.select(pl.col("split") == "train").get_column("split")

        cluster = task_df.item(row=0, column="task_cluster")
        subcluster = task_df.item(row=0, column="task_subcluster")
        yield Task(
            task, cluster, subcluster, features, labels, is_train.to_numpy(), ids
        )
def init_svc(n_train: int)
Expand source code
def init_svc(n_train: int):
    """Create a new, randomly initialized SVM with a random hyperparameter search over kernel, C and gamma. It uses only 16 jobs in parallel to prevent overloading the CPUs on a shared machine."""
    if n_train < 10:
        return sklearn.pipeline.make_pipeline(
            sklearn.svm.SVC(kernel="linear"),
        )

    return sklearn.model_selection.RandomizedSearchCV(
        sklearn.pipeline.make_pipeline(
            sklearn.preprocessing.StandardScaler(),
            sklearn.svm.SVC(C=1.0, kernel="rbf"),
        ),
        {
            "svc__C": scipy.stats.loguniform(a=1e-3, b=1e1),
            "svc__kernel": ["rbf", "linear", "sigmoid", "poly"],
            "svc__gamma": scipy.stats.loguniform(a=1e-4, b=1e-3),
        },
        n_iter=100,
        n_jobs=16,
        random_state=42,
    )

Create a new, randomly initialized SVM with a random hyperparameter search over kernel, C and gamma. It uses only 16 jobs in parallel to prevent overloading the CPUs on a shared machine.

def l2_normalize(features: jaxtyping.Float[ndarray, 'batch dim']) ‑> jaxtyping.Float[ndarray, 'batch dim']
Expand source code
@jaxtyped(typechecker=beartype.beartype)
def l2_normalize(
    features: Float[np.ndarray, "batch dim"],
) -> Float[np.ndarray, "batch dim"]:
    """Normalizes a batch of vectors to have L2 unit norm."""
    norms = np.linalg.norm(features, ord=2, axis=1, keepdims=True)
    return features / norms

Normalizes a batch of vectors to have L2 unit norm.

def sample(rng: numpy.random._generator.Generator,
df: polars.dataframe.frame.DataFrame,
n_train: int) ‑> polars.dataframe.frame.DataFrame
Expand source code
@jaxtyped(typechecker=beartype.beartype)
def sample(rng: np.random.Generator, df: pl.DataFrame, n_train: int) -> pl.DataFrame:
    """Sample a balanced subset of training data points for each task.

    Args:
        rng: Random number generator.
        df: NeWT dataframe.
        n_train: Number of training samples per task to return.

    Returns:
        A DataFrame with balanced training samples and all test samples.
    """
    if n_train <= 0:
        return df  # Return all data if n_train is not positive

    # Create a new dataframe to store the results
    result_dfs = []

    # Keep all test samples
    test_df = df.filter(pl.col("split") != "train")
    result_dfs.append(test_df)

    # Process each task separately
    for task in df.get_column("task").unique():
        task_df = df.filter((pl.col("task") == task) & (pl.col("split") == "train"))

        # Skip if the task has no training samples
        if task_df.height == 0:
            continue

        # Get samples for each class
        class0_df = task_df.filter(pl.col("label") == 0)
        class1_df = task_df.filter(pl.col("label") == 1)

        n0 = n_train // 2
        n1 = n_train - n0

        assert n0 > 0
        assert n1 > 0

        # Sample from each class
        if n0 < class0_df.height:
            indices0 = rng.choice(class0_df.height, size=n0, replace=False)
            result_dfs.append(
                class0_df.with_row_index(name="tmp")
                .filter(pl.col("tmp").is_in(indices0))
                .drop("tmp")
            )
        else:
            result_dfs.append(class0_df)

        if n1 < class1_df.height:
            indices1 = rng.choice(class1_df.height, size=n1, replace=False)
            result_dfs.append(
                class1_df.with_row_index(name="tmp")
                .filter(pl.col("tmp").is_in(indices1))
                .drop("tmp")
            )
        else:
            result_dfs.append(class1_df)

    # Combine all dataframes
    return pl.concat(result_dfs)

Sample a balanced subset of training data points for each task.

Args

rng
Random number generator.
df
NeWT dataframe.
n_train
Number of training samples per task to return.

Returns

A DataFrame with balanced training samples and all test samples.

Classes

class Dataset (root: str,
img_ids: jaxtyping.Shaped[ndarray, 'n'],
labels: jaxtyping.Int[ndarray, 'n'],
transform=None)
Expand source code
@jaxtyped(typechecker=beartype.beartype)
class Dataset(torch.utils.data.Dataset):
    """A dataset that returns ImageSample dictionaries."""

    def __init__(
        self,
        root: str,
        img_ids: Shaped[np.ndarray, " n"],
        labels: Int[np.ndarray, " n"],
        transform=None,
    ):
        """Initialize the dataset with image paths and labels.

        Args:
            root: Root directory containing the images.
            img_ids: Array of image IDs.
            labels: Array of binary labels corresponding to the images.
            transform: Optional transform to apply to the images.
        """
        self.transform = transform
        self.root = root
        self.img_ids = img_ids
        self.labels = labels

    def __getitem__(self, i: int) -> Sample:
        """Get a sample by its index.

        Args:
            i: Index of the sample to retrieve.

        Returns:
            A dictionary containing the image ID, image tensor, and label.
        """
        img_id = self.img_ids[i]
        img = Image.open(os.path.join(self.root, f"{img_id}.jpg"))
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[i]
        return {"img_id": img_id, "img": img, "label": label}

    def __len__(self) -> int:
        """Return the number of samples in the dataset.

        Returns:
            The number of samples.
        """
        return len(self.img_ids)

A dataset that returns ImageSample dictionaries.

Initialize the dataset with image paths and labels.

Args

root
Root directory containing the images.
img_ids
Array of image IDs.
labels
Array of binary labels corresponding to the images.
transform
Optional transform to apply to the images.

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic
class Sample (*args, **kwargs)
Expand source code
@jaxtyped(typechecker=beartype.beartype)
class Sample(typing.TypedDict):
    """A dictionary representing a single image sample with its metadata.

    Attributes:
        img_id: Unique identifier for the image.
        img: The image tensor with shape [3, width, height] (RGB channels first).
        label: Binary class label (0 or 1) for the image.
    """

    img_id: str
    img: Float[Tensor, "3 width height"]
    label: Int[Tensor, ""]

A dictionary representing a single image sample with its metadata.

Attributes

img_id
Unique identifier for the image.
img
The image tensor with shape [3, width, height] (RGB channels first).
label
Binary class label (0 or 1) for the image.

Ancestors

  • builtins.dict

Class variables

var img : jaxtyping.Float[Tensor, '3 width height']
var img_id : str
var label : jaxtyping.Int[Tensor, '']
class Task (name: str,
cluster: str,
subcluster: str | None,
features: jaxtyping.Float[ndarray, 'batch dim'],
labels: jaxtyping.Int[ndarray, 'batch'],
is_train: jaxtyping.Bool[ndarray, 'batch'],
example_ids: jaxtyping.Shaped[ndarray, 'batch'])
Expand source code
@jaxtyped(typechecker=beartype.beartype)
@dataclasses.dataclass(frozen=True)
class Task:
    """
    Task is a group of features and labels for an SVM + a train/test split.
    """

    name: str
    cluster: str
    subcluster: str | None
    features: Float[np.ndarray, "batch dim"]
    labels: Int[np.ndarray, " batch"]
    is_train: Bool[np.ndarray, " batch"]
    example_ids: Shaped[np.ndarray, " batch"]  # Should be String[...]

    def __repr__(self) -> str:
        return f"Task(task={self.name}, cluster={self.cluster}, features={self.features.shape})"

    @property
    def splits(
        self,
    ) -> tuple[
        tuple[Float[np.ndarray, "n_train dim"], Int[np.ndarray, " n_train"]],
        tuple[Float[np.ndarray, "n_test dim"], Int[np.ndarray, " n_test"]],
    ]:
        """
        The features and labels for train and test splits.

        Returned as `(x_train, y_train), (x_test, y_test)`.
        """
        x_train = self.features[self.is_train]
        y_train = self.labels[self.is_train]
        x_test = self.features[~self.is_train]
        y_test = self.labels[~self.is_train]

        return (x_train, y_train), (x_test, y_test)

Task is a group of features and labels for an SVM + a train/test split.

Instance variables

var cluster : str
var example_ids : jaxtyping.Shaped[ndarray, 'batch']
var features : jaxtyping.Float[ndarray, 'batch dim']
var is_train : jaxtyping.Bool[ndarray, 'batch']
var labels : jaxtyping.Int[ndarray, 'batch']
var name : str
prop splits : tuple[tuple[jaxtyping.Float[ndarray, 'n_train dim'], jaxtyping.Int[ndarray, 'n_train']], tuple[jaxtyping.Float[ndarray, 'n_test dim'], jaxtyping.Int[ndarray, 'n_test']]]
Expand source code
@property
def splits(
    self,
) -> tuple[
    tuple[Float[np.ndarray, "n_train dim"], Int[np.ndarray, " n_train"]],
    tuple[Float[np.ndarray, "n_test dim"], Int[np.ndarray, " n_test"]],
]:
    """
    The features and labels for train and test splits.

    Returned as `(x_train, y_train), (x_test, y_test)`.
    """
    x_train = self.features[self.is_train]
    y_train = self.labels[self.is_train]
    x_test = self.features[~self.is_train]
    y_test = self.labels[~self.is_train]

    return (x_train, y_train), (x_test, y_test)

The features and labels for train and test splits.

Returned as (x_train, y_train), (x_test, y_test).

var subcluster : str | None