Module contrib.mae.modeling


def load_ckpt(ckpt: str, *, chunk_size_kb: int = 1024) ‑> MaskedAutoencoder

Loads a pre-trained MAE ViT from disk. If it's not on disk, downloads the checkpoint from huggingface and then loads it into the MaskedAutoencoder class.

def random_masking(x_BND: jaxtyping.Float[Tensor, 'batch n d'],
mask_ratio: float,
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> tuple[jaxtyping.Float[Tensor, 'batch m d'], jaxtyping.Float[Tensor, 'batch n'], jaxtyping.Int[Tensor, 'batch n']]

Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsorting random noise.


class Attention (*, d: int, n_heads: int)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Attention(torch.nn.Module):
    def __init__(self, *, d: int, n_heads: int):
        assert d % n_heads == 0, f"n_heads={n_heads} must evenly divide d={d}"

        self.n_heads = n_heads

        self.query = torch.nn.Linear(d, d)
        self.key = torch.nn.Linear(d, d)
        self.value = torch.nn.Linear(d, d)
        self.output = torch.nn.Linear(d, d)

    def split(
        self, x_BND: Float[Tensor, "batch n d"]
    ) -> Float[Tensor, "batch n_heads n d_head"]:
        return einops.rearrange(
            "batch n (n_heads d_head) -> batch n_heads n d_head",

    def forward(self, x_BND: Float[Tensor, "batch n d"]) -> Float[Tensor, "batch n d"]:
        q_BHNd = self.split(self.query(x_BND))
        k_BHNd = self.split(self.key(x_BND))
        v_BHNd = self.split(self.value(x_BND))

        x_BHNd = torch.nn.functional.scaled_dot_product_attention(
            q_BHNd, k_BHNd, v_BHNd, dropout_p=0.0, is_causal=False, scale=None

        x_BND = einops.rearrange(
            x_BHNd, "batch n_heads n d_head -> batch n (n_heads d_head)"

        x_BND = self.output(x_BND)
        return x_BND


  • torch.nn.modules.module.Module


def forward(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']
def split(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n_heads n d_head']
class Decoder (*,
d_in: int,
d: int,
d_hidden: int,
n_layers: int,
n_heads: int,
patch_size_px: tuple[int, int],
image_size_px: tuple[int, int],
ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Decoder(torch.nn.Module):
    def __init__(
        d_in: int,
        d: int,
        d_hidden: int,
        n_layers: int,
        n_heads: int,
        patch_size_px: tuple[int, int],
        image_size_px: tuple[int, int],
        ln_eps: float,

        image_w_px, image_h_px = image_size_px
        patch_w_px, patch_h_px = patch_size_px
        n_patches = (image_w_px // patch_w_px) * (image_h_px // patch_h_px)

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, d))

        self.embd = torch.nn.Linear(d_in, d)

        self.pos_embd = torch.nn.Parameter(
            torch.zeros(1, n_patches + 1, d), requires_grad=False

        self.layers = torch.nn.ModuleList([
            TransformerBlock(d=d, d_hidden=d_hidden, n_heads=n_heads, ln_eps=ln_eps)
            for _ in range(n_layers)

        self.layernorm = torch.nn.LayerNorm(d, eps=ln_eps)
        self.head = torch.nn.Linear(d, patch_w_px * patch_h_px * 3)

    def forward(
        x_BMD: Float[Tensor, "batch m d_in"],
        ids_restore_BN: Int[Tensor, "batch n"],
    ) -> Float[Tensor, "batch n patch_pixels"]:
        batch_size, m, _ = x_BMD.shape
        _, n = ids_restore_BN.shape

        # Linear projection from encoder dimension to decoder dimension
        x_BMD = self.embd(x_BMD)

        _, _, d_decoder = x_BMD.shape

        # Add the mask tokens back
        n_mask_tokens = n + 1 - m
        masks_BOD = self.mask_token.repeat(batch_size, n_mask_tokens, 1)
        x_BND =[x_BMD[:, 1:, :], masks_BOD], dim=1)  # no cls token

        index_BND = ids_restore_BN[..., None].repeat(1, 1, d_decoder).to(x_BND.device)
        x_BND = torch.gather(x_BND, dim=1, index=index_BND)

        x_BND =[x_BMD[:, :1, :], x_BND], dim=1)  # append cls token

        # Add positional embeddings again.
        x_BND = x_BND + self.pos_embd

        for layer in self.layers:
            x_BND = layer(x_BND)

        x_BND = self.layernorm(x_BND)
        logits_BNP = self.head(x_BND)

        # Remove cls token
        logits_BNP = logits_BNP[:, 1:, :]

        return logits_BNP


  • torch.nn.modules.module.Module


def forward(self,
x_BMD: jaxtyping.Float[Tensor, 'batch m d_in'],
ids_restore_BN: jaxtyping.Int[Tensor, 'batch n']) ‑> jaxtyping.Float[Tensor, 'batch n patch_pixels']

Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Embeddings (*,
d: int,
image_size_px: tuple[int, int],
patch_size_px: tuple[int, int],
mask_ratio: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Embeddings(torch.nn.Module):
    class Output(typing.TypedDict):
        x_BMD: Float[Tensor, "batch m d"]
        mask_BN: Float[Tensor, "batch n"]
        ids_restore_BN: Int[Tensor, "batch n"]

    def __init__(
        d: int,
        image_size_px: tuple[int, int],
        patch_size_px: tuple[int, int],
        mask_ratio: float,

        image_w_px, image_h_px = image_size_px
        patch_w_px, patch_h_px = patch_size_px
        n_patches = (image_w_px // patch_w_px) * (image_h_px // patch_h_px)

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, d))
        self.position_embeddings = torch.nn.Parameter(
            torch.zeros(1, n_patches + 1, d), requires_grad=False
        self.patch_embeddings = PatchEmbeddings(d=d, patch_size_px=patch_size_px)

        self.mask_ratio = mask_ratio

    def forward(
        x_BCWH: Float[Tensor, "batch 3 height width"],
        noise_BN: Float[Tensor, "batch n"] | None = None,
    ) -> Output:
        batch_size, _, _, _ = x_BCWH.shape

        x_BND = self.patch_embeddings(x_BCWH) + self.position_embeddings[:, 1:, :]

        x_BMD, mask_BN, ids_restore_BN = random_masking(
            x_BND, self.mask_ratio, noise_BN=noise_BN

        cls_x_11D = self.cls_token + self.position_embeddings[:, :1, :]

        cls_x_B1D = cls_x_11D.expand(batch_size, -1, -1)
        return self.Output(
  , x_BMD), dim=1),


  • torch.nn.modules.module.Module

Class variables

var Output

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)


def forward(self,
x_BCWH: jaxtyping.Float[Tensor, 'batch 3 height width'],
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> Embeddings.Output
class Encoder (*, d: int, d_hidden: int, n_heads: int, n_layers: int, ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Encoder(torch.nn.Module):
    def __init__(
        d: int,
        d_hidden: int,
        n_heads: int,
        n_layers: int,
        ln_eps: float,
        self.layers = torch.nn.ModuleList([
            TransformerBlock(d=d, d_hidden=d_hidden, n_heads=n_heads, ln_eps=ln_eps)
            for _ in range(n_layers)

    def forward(self, x_BMD: Float[Tensor, "batch m d"]) -> Float[Tensor, "batch m d"]:
        for layer in self.layers:
            x_BMD = layer(x_BMD)
        return x_BMD


  • torch.nn.modules.module.Module


def forward(self, x_BMD: jaxtyping.Float[Tensor, 'batch m d']) ‑> jaxtyping.Float[Tensor, 'batch m d']
class Feedforward (*, d: int, d_hidden: int)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Feedforward(torch.nn.Module):
    def __init__(self, *, d: int, d_hidden: int):
        self.linear1 = torch.nn.Linear(d, d_hidden)
        self.linear2 = torch.nn.Linear(d_hidden, d)

    def forward(self, x_BND: Float[Tensor, "batch n d"]) -> Float[Tensor, "batch n d"]:
        x_BNF = self.linear1(x_BND)
        x_BNF = torch.nn.functional.gelu(x_BNF)
        x_BND = self.linear2(x_BNF)
        return x_BND


  • torch.nn.modules.module.Module


def forward(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']
class MaskedAutoencoder (*,
d_encoder: int,
d_hidden_encoder: int,
n_heads_encoder: int,
n_layers_encoder: int,
d_decoder: int,
d_hidden_decoder: int,
n_heads_decoder: int,
n_layers_decoder: int,
image_size_px: tuple[int, int],
patch_size_px: tuple[int, int],
mask_ratio: float,
ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class MaskedAutoencoder(torch.nn.Module):
    class Output(typing.TypedDict):
        latents: Float[Tensor, "batch n d"]
        decoded: Float[Tensor, "batch n d"]
        ids_restore: Int[Tensor, "batch n"]

    def __init__(
        d_encoder: int,
        d_hidden_encoder: int,
        n_heads_encoder: int,
        n_layers_encoder: int,
        d_decoder: int,
        d_hidden_decoder: int,
        n_heads_decoder: int,
        n_layers_decoder: int,
        image_size_px: tuple[int, int],
        patch_size_px: tuple[int, int],
        mask_ratio: float,
        ln_eps: float,
        self.vit = VisionTransformer(
        self.decoder = Decoder(

    def forward(
        x_B3WH: Float[Tensor, "batch 3 width height"],
        noise_BN: Float[Tensor, "batch n"] | None = None,
    ) -> Output:
        encoded = self.vit(x_B3WH, noise_BN=noise_BN)
        decoded_BND = self.decoder(encoded["x_BMD"], encoded["ids_restore_BN"])

        return self.Output(


  • torch.nn.modules.module.Module

Class variables

var Output

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)


def forward(self,
x_B3WH: jaxtyping.Float[Tensor, 'batch 3 width height'],
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> MaskedAutoencoder.Output
class PatchEmbeddings (d: int, patch_size_px: tuple[int, int])

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class PatchEmbeddings(torch.nn.Module):
    def __init__(self, d: int, patch_size_px: tuple[int, int]):
        self.projection = torch.nn.Conv2d(
            3, d, kernel_size=patch_size_px, stride=patch_size_px

    def forward(
        self, x_BCWH: Float[Tensor, "batch 3 width height"]
    ) -> Float[Tensor, "batch n_patches d"]:
        return einops.rearrange(self.projection(x_BCWH), "batch d w h -> batch (w h) d")


  • torch.nn.modules.module.Module


def forward(self, x_BCWH: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch n_patches d']
class TransformerBlock (*, d: int, d_hidden: int, n_heads: int, ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class TransformerBlock(torch.nn.Module):
    def __init__(self, *, d: int, d_hidden: int, n_heads: int, ln_eps: float):
        self.attention = Attention(d=d, n_heads=n_heads)
        self.ffnn = Feedforward(d=d, d_hidden=d_hidden)
        self.layernorm1 = torch.nn.LayerNorm(d, eps=ln_eps)
        self.layernorm2 = torch.nn.LayerNorm(d, eps=ln_eps)

    def forward(self, x: Float[Tensor, "batch n d"]) -> Float[Tensor, "batch n d"]:
        x_ = self.attention(self.layernorm1(x))

        x = x_ + x

        x_ = self.ffnn(self.layernorm2(x))
        return x_ + x


  • torch.nn.modules.module.Module


def forward(self, x: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']
class VisionTransformer (*,
d: int,
d_hidden: int,
n_heads: int,
n_layers: int,
image_size_px: tuple[int, int],
patch_size_px: tuple[int, int],
mask_ratio: float,
ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.


As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class VisionTransformer(torch.nn.Module):
    class Output(typing.TypedDict):
        x_BMD: Float[Tensor, "batch m d"]
        ids_restore_BN: Int[Tensor, "batch n"]

    def __init__(
        d: int,
        d_hidden: int,
        n_heads: int,
        n_layers: int,
        image_size_px: tuple[int, int],
        patch_size_px: tuple[int, int],
        mask_ratio: float,
        ln_eps: float,
        self.embeddings = Embeddings(
        self.encoder = Encoder(
        self.layernorm = torch.nn.LayerNorm(d, eps=ln_eps)

    def forward(
        x_B3WH: Float[Tensor, "batch 3 width height"],
        noise_BN: Float[Tensor, "batch n"] | None = None,
    ) -> Float[Tensor, "batch ..."]:
        embedded = self.embeddings(x_B3WH, noise_BN=noise_BN)
        x_BMD = self.encoder(embedded["x_BMD"])
        x_BMD = self.layernorm(x_BMD)

        return self.Output(x_BMD=x_BMD, ids_restore_BN=embedded["ids_restore_BN"])


  • torch.nn.modules.module.Module

Class variables

var Output

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)


def forward(self,
x_B3WH: jaxtyping.Float[Tensor, 'batch 3 width height'],
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> jaxtyping.Float[Tensor, 'batch ...']

Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.