Module biobench.kabr
Kenyan Animal Behavior Recognition (KABR)
KABR is a video recognition task (paper, website, Huggingface) where the model predicts Kenyan animal behavior in short video segments.
This can be framed as a classification task: given a short video segment of a single animal, which behavior is most common within the segment?
While specialized architectures exist, we train a simple nearest-centroid classifier which works well with few-shot tasks over video representations. We get video representations by embedding each frame of the video and taking the mean over the batch dimension.
Data
To download the data, you need to use the dataset download script:
- Copy-paste the download script to your data directory, like
/scratch/KABR/download.py
. - Run
python download.py
. It doesn't have any requirements beyond the Python standard library.
Sub-modules
biobench.kabr.download
-
Download the Kenyan Animal Behavior Recognition (KABR) dataset …
biobench.kabr.test_download
biobench.kabr.test_kabr
Functions
def aggregate_frames(features: jaxtyping.Float[Tensor, 'n_frames n_examples dim']) ‑> jaxtyping.Float[Tensor, 'n_examples dim']
-
Expand source code
@jaxtyped(typechecker=beartype.beartype) def aggregate_frames( features: Float[Tensor, "n_frames n_examples dim"], ) -> Float[Tensor, "n_examples dim"]: return torch.max(features, dim=0).values
def aggregate_labels(labels: jaxtyping.Int[Tensor, 'n_frames n_examples']) ‑> jaxtyping.Int[Tensor, 'n_examples']
-
Expand source code
@jaxtyped(typechecker=beartype.beartype) def aggregate_labels( labels: Int[Tensor, "n_frames n_examples"], ) -> Int[Tensor, " n_examples"]: """Aggregate per-frame labels to a per-video label. Uses the most common label (mode).""" return torch.mode(labels, dim=0).values
Aggregate per-frame labels to a per-video label. Uses the most common label (mode).
def benchmark(cfg: Experiment) ‑> Report
-
Expand source code
@beartype.beartype def benchmark(cfg: config.Experiment) -> reporting.Report: """Runs KABR benchmark.""" # 1. Load model backbone = registry.load_vision_backbone(cfg.model) backbone = backbone.to(cfg.device) # 2. Load data. train_features = get_features(cfg, backbone, is_train=True) test_features = get_features(cfg, backbone, is_train=False) # 4. Do simpleshot. clf = helpers.init_logreg_clf(cfg) clf.fit(train_features.x, train_features.y) true_labels = test_features.y pred_labels = clf.predict(test_features.x) # Return benchmark report. preds = [ reporting.Prediction( str(video_id), float(pred == true), {"y_pred": pred.item(), "y_true": true.item()}, ) for video_id, pred, true in zip(test_features.ids, pred_labels, true_labels) ] return reporting.Report("kabr", preds, cfg)
Runs KABR benchmark.
def bootstrap_scores(df: polars.dataframe.frame.DataFrame,
*,
b: int = 0,
rng: numpy.random._generator.Generator | None = None) ‑> dict[str, jaxtyping.Float[ndarray, 'b']]-
Expand source code
@jaxtyped(typechecker=beartype.beartype) def bootstrap_scores( df: pl.DataFrame, *, b: int = 0, rng: np.random.Generator | None = None ) -> dict[str, Float[np.ndarray, " b"]]: assert df.get_column("task_name").unique().to_list() == ["kabr"] return reporting.bootstrap_scores_macro_f1(df, b=b, rng=rng)
def get_features(cfg: Experiment,
backbone: VisionBackbone,
*,
is_train: bool) ‑> Features-
Expand source code
@jaxtyped(typechecker=beartype.beartype) @torch.no_grad() def get_features( cfg: config.Experiment, backbone: registry.VisionBackbone, *, is_train: bool ) -> Features: img_transform = backbone.make_img_transform() backbone = torch.compile(backbone) split = "train" if is_train else "val" dataset = Dataset(cfg.data.kabr, split, transform=img_transform) if is_train and cfg.n_train > 0: i = helpers.balanced_random_sample(dataset.labels, cfg.n_train) assert len(i) == cfg.n_train dataset = torch.utils.data.Subset(dataset, i) dataloader = torch.utils.data.DataLoader( dataset, batch_size=max(1, cfg.batch_size // 32), num_workers=cfg.n_workers, drop_last=False, shuffle=False, pin_memory=True, ) all_feats, all_labels, all_ids = [], [], [] def probe(batch): frames, _, _ = batch frames = torch.stack(frames, dim=0) frames = frames.to(cfg.device, non_blocking=True) with torch.amp.autocast(cfg.device): # conv2d doesn't support multiple batch dimensions, so we have to view() before and after the model.img_encode() call. n_frames, bsz, c, h, w = frames.shape frames = frames.view(bsz * n_frames, c, h, w) outputs = backbone.img_encode(frames) features = outputs.img_features.view(n_frames, bsz, -1) features = aggregate_frames(features) with helpers.auto_batch_size(dataloader, probe=probe, backoff=1): total = len(dataloader) if not cfg.debug else 2 it = iter(dataloader) for b in helpers.progress(range(total), desc=f"kabr/{split}"): frames, labels, ids = next(it) frames = torch.stack(frames, dim=0) labels = torch.stack(labels, dim=0) frames = frames.to(cfg.device, non_blocking=True) with torch.amp.autocast(cfg.device): # conv2d doesn't support multiple batch dimensions, so we have to view() before and after the model.img_encode() call. n_frames, bsz, c, h, w = frames.shape frames = frames.view(bsz * n_frames, c, h, w) outputs = backbone.img_encode(frames) features = outputs.img_features.view(n_frames, bsz, -1) features = aggregate_frames(features) all_feats.append(features.cpu()) labels = aggregate_labels(labels) all_labels.append(labels.cpu()) logger.debug("Embedded batch %d/%d", b + 1, total) all_ids.extend(ids) all_feats = torch.cat(all_feats, dim=0).cpu().numpy() all_labels = torch.cat(all_labels, dim=0).cpu().numpy() all_ids = np.array(all_ids) return Features(all_feats, all_labels, all_ids)
Classes
class Dataset (path, split: str, transform=None, seed: int = 42)
-
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Dataset(torch.utils.data.Dataset): """ Clips of at most 90 frames in Charades format with each frame stored as an image. """ def __init__(self, path, split: str, transform=None, seed: int = 42): self.path = path self.split = split self.transform = transform self.seed = seed self.rng = np.random.default_rng(seed=seed) self.n_frames = 16 self.n_every = 5 # Load videos ############# frames: dict[int, list[str]] = {} labels: dict[int, list[int]] = {} if not os.path.exists(self.path) or not os.path.isdir(self.path): msg = f"Path '{self.path}' doesn't exist. Did you download the KABR dataset? See the docstring at the top of this file for instructions." raise RuntimeError(msg) with open(os.path.join(self.path, "annotation", f"{split}.csv")) as fd: reader = csv.reader(fd, delimiter=" ") next(reader) # skip headers for _, video_id, frame_id, path, label in reader: video_id = int(video_id) frame_id = int(frame_id) label = int(label) if video_id not in frames: frames[video_id] = [] if video_id not in labels: labels[video_id] = [] if frame_id > len(frames[video_id]) + 1: raise ValueError(f"Video {video_id} is missing a frame.") path = os.path.join(self.path, "dataset", "image", path) frames[video_id].append(path) labels[video_id].append(label) self.videos = [ Video(video_id, frames[video_id], labels[video_id]) for video_id in frames.keys() if len(frames[video_id]) >= self.n_frames ] def __getitem__( self, i: int ) -> tuple[list[Float[Tensor, "3 width height"]], list[int], str]: """ Returns 16 frames and their labels sampled every 5 frames from a clip. The start of the clip is uniformly sampled. If there are fewer """ n_every = self.n_every video = self.videos[i] while len(video.frames) < ((self.n_frames - 1) * n_every + 1): n_every -= 1 if n_every <= 0: print(n_every, len(video.frames), ((self.n_frames - 1) * n_every + 1)) assert n_every >= 1 # margin is the number of extra frames on either size of the 16x5 sampled frames. margin = len(video.frames) - ((self.n_frames - 1) * n_every + 1) # Pick a random start, then pick n_frames frames every n_every frames. # (sam) This is likely not clear and there are probably better ways to express this in Python that is more clear to other video ML devs. Please open a PR if you know a better way! start = self.rng.integers(0, margin + 1) frames = video.frames[start:None:n_every][: self.n_frames] labels = video.labels[start:None:n_every][: self.n_frames] images = [Image.open(frame) for frame in frames] if self.transform is not None: images = [self.transform(image) for image in images] return images, labels, str(i) def __len__(self) -> int: return len(self.videos)
Clips of at most 90 frames in Charades format with each frame stored as an image.
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
class Features (x: jaxtyping.Float[ndarray, 'n dim'],
y: jaxtyping.Int[ndarray, 'n'],
ids: jaxtyping.Shaped[ndarray, 'n'])-
Expand source code
@jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True) class Features: x: Float[np.ndarray, "n dim"] y: Int[np.ndarray, " n"] ids: Shaped[np.ndarray, " n"]
Features(x: jaxtyping.Float[ndarray, 'n dim'], y: jaxtyping.Int[ndarray, 'n'], ids: jaxtyping.Shaped[ndarray, 'n'])
Instance variables
var ids : jaxtyping.Shaped[ndarray, 'n']
var x : jaxtyping.Float[ndarray, 'n dim']
var y : jaxtyping.Int[ndarray, 'n']
class Video (video_id: int, frames: list[str], labels: list[int])
-
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Video: """A single video instance as a sequence of frames.""" video_id: int """Video ID.""" frames: list[str] """Paths to actual frame images.""" labels: list[int] """Frame-level labels.""" def __post_init__(self): err_msg = f"Video {self.video_id} has a different number of frames ({len(self.frames)} and labels ({len(self.labels)})." assert len(self.frames) == len(self.labels), err_msg
A single video instance as a sequence of frames.
Instance variables
var frames : list[str]
-
Paths to actual frame images.
var labels : list[int]
-
Frame-level labels.
var video_id : int
-
Video ID.