Module biobench.plantnet
Pl@ntNet is a "dataset with high label ambiguity and a long-tailed distribution" from NeurIPS 2021. We fit a ridge classifier from scikit-learn to a backbone's embeddings and evaluate on the validation split.
There are two pieces that make Pl@ntNet more than a simple classification task:
- Because of the long tail, we use
class_weight='balanced'
which adjusts weights based on class frequency. - We use macro F1 both to choose the alpha parameter and to evaluate the final classifier rather than accuracy due to the massive class imbalance.
If you use this task, please cite the original paper:
@inproceedings{plantnet-300k, author={Garcin, Camille and Joly, Alexis and Bonnet, Pierre and Lombardo, Jean-Christophe and Affouard, Antoine and Chouet, Mathias and Servajean, Maximilien and Lorieul, Titouan and Salmon, Joseph}, booktitle={NeurIPS Datasets and Benchmarks 2021}, title={{Pl@ntNet-300K}: a plant image dataset with high label ambiguity and a long-tailed distribution}, year={2021}, }
Sub-modules
biobench.plantnet.download
Functions
def benchmark(args: Args, model_args: ModelArgs) ‑> tuple[ModelArgs, TaskReport]
-
Steps: 1. Get features for all images. 2. Select lambda using cross validation splits. 3. Report score on test data.
def calc_macro_top1(examples: list[Example]) ‑> float
-
Macro top-1 accuracy.
def get_features(args: Args, backbone: VisionBackbone, *, split: str) ‑> Features
def init_clf(args: Args)
Classes
class Args (seed: int = 42, datadir: str = '', device: str = 'cuda', debug: bool = False, batch_size: int = 256, n_workers: int = 4, log_every: int = 10)
-
Args(seed: int = 42, datadir: str = '', device: str = 'cuda', debug: bool = False, batch_size: int = 256, n_workers: int = 4, log_every: int = 10)
Expand source code
@beartype.beartype @dataclasses.dataclass(frozen=True) class Args(interfaces.TaskArgs): batch_size: int = 256 """batch size for deep model.""" n_workers: int = 4 """number of dataloader worker processes.""" log_every: int = 10 """how often (number of batches) to log progress."""
Ancestors
Class variables
var batch_size : int
-
batch size for deep model.
var log_every : int
-
how often (number of batches) to log progress.
var n_workers : int
-
number of dataloader worker processes.
Inherited members
class Dataset (root: str, transform)
-
An abstract class representing a :class:
Dataset
.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:
__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__
, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler
implementations and the default options of :class:~torch.utils.data.DataLoader
. Subclasses could also optionally implement :meth:__getitems__
, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.Note
:class:
~torch.utils.data.DataLoader
by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.Expand source code
@jaxtyped(typechecker=beartype.beartype) class Dataset(torch.utils.data.Dataset): transform: typing.Any | None """Optional function function that transforms an image into a format expected by a neural network.""" samples: list[tuple[str, str, str]] """List of all image ids, image paths, and classnames.""" def __init__(self, root: str, transform): self.transform = transform self.samples = [] if not os.path.exists(root) or not os.path.isdir(root): msg = f"Path '{root}' doesn't exist. Did you download the Pl@ntNet dataset? See the docstring at the top of this file for instructions. If you did download it, pass the path as --dataset-dir PATH" raise RuntimeError(msg) for dirpath, dirnames, filenames in os.walk(root): image_class = os.path.relpath(dirpath, root) for filename in filenames: image_id = filename.removesuffix(".jpg") image_path = os.path.join(dirpath, filename) self.samples.append((image_id, image_path, image_class)) def __getitem__(self, i: int) -> tuple[str, Float[Tensor, "3 width height"], str]: image_id, image_path, image_class = self.samples[i] image = Image.open(image_path) if self.transform is not None: image = self.transform(image) return image_id, image, image_class def __len__(self) -> int: return len(self.samples)
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Class variables
var samples : list[tuple[str, str, str]]
-
List of all image ids, image paths, and classnames.
var transform : typing.Any | None
-
Optional function function that transforms an image into a format expected by a neural network.
class Features (x: jaxtyping.Float[ndarray, 'n dim'], labels: jaxtyping.Shaped[ndarray, 'n'], ids: jaxtyping.Shaped[ndarray, 'n'])
-
Features(x: jaxtyping.Float[ndarray, 'n dim'], labels: jaxtyping.Shaped[ndarray, 'n'], ids: jaxtyping.Shaped[ndarray, 'n'])
Expand source code
@jaxtyped(typechecker=beartype.beartype) @dataclasses.dataclass(frozen=True) class Features: x: Float[np.ndarray, "n dim"] labels: Shaped[np.ndarray, " n"] ids: Shaped[np.ndarray, " n"] def y(self, encoder): return encoder.transform(self.labels.reshape(-1, 1)).reshape(-1)
Class variables
var ids : jaxtyping.Shaped[ndarray, 'n']
var labels : jaxtyping.Shaped[ndarray, 'n']
var x : jaxtyping.Float[ndarray, 'n dim']
Methods
def y(self, encoder)