Module contrib.mae.modeling
Functions
def load_ckpt(ckpt: str, *, chunk_size_kb: int = 1024) ‑> MaskedAutoencoder
-
Loads a pre-trained MAE ViT from disk. If it's not on disk, downloads the checkpoint from huggingface and then loads it into the
MaskedAutoencoder
class. def random_masking(x_BND: jaxtyping.Float[Tensor, 'batch n d'],
mask_ratio: float,
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> tuple[jaxtyping.Float[Tensor, 'batch m d'], jaxtyping.Float[Tensor, 'batch n'], jaxtyping.Int[Tensor, 'batch n']]-
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsorting random noise.
Classes
class Attention (*, d: int, n_heads: int)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Attention(torch.nn.Module): def __init__(self, *, d: int, n_heads: int): super().__init__() assert d % n_heads == 0, f"n_heads={n_heads} must evenly divide d={d}" self.n_heads = n_heads self.query = torch.nn.Linear(d, d) self.key = torch.nn.Linear(d, d) self.value = torch.nn.Linear(d, d) self.output = torch.nn.Linear(d, d) @jaxtyped(typechecker=beartype.beartype) def split( self, x_BND: Float[Tensor, "batch n d"] ) -> Float[Tensor, "batch n_heads n d_head"]: return einops.rearrange( x_BND, "batch n (n_heads d_head) -> batch n_heads n d_head", n_heads=self.n_heads, ) @jaxtyped(typechecker=beartype.beartype) def forward(self, x_BND: Float[Tensor, "batch n d"]) -> Float[Tensor, "batch n d"]: q_BHNd = self.split(self.query(x_BND)) k_BHNd = self.split(self.key(x_BND)) v_BHNd = self.split(self.value(x_BND)) x_BHNd = torch.nn.functional.scaled_dot_product_attention( q_BHNd, k_BHNd, v_BHNd, dropout_p=0.0, is_causal=False, scale=None ) x_BND = einops.rearrange( x_BHNd, "batch n_heads n d_head -> batch n (n_heads d_head)" ).contiguous() x_BND = self.output(x_BND) return x_BND
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']
def split(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n_heads n d_head']
class Decoder (*,
d_in: int,
d: int,
d_hidden: int,
n_layers: int,
n_heads: int,
patch_size_px: tuple[int, int],
image_size_px: tuple[int, int],
ln_eps: float)-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Decoder(torch.nn.Module): def __init__( self, *, d_in: int, d: int, d_hidden: int, n_layers: int, n_heads: int, patch_size_px: tuple[int, int], image_size_px: tuple[int, int], ln_eps: float, ): super().__init__() image_w_px, image_h_px = image_size_px patch_w_px, patch_h_px = patch_size_px n_patches = (image_w_px // patch_w_px) * (image_h_px // patch_h_px) self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, d)) self.embd = torch.nn.Linear(d_in, d) self.pos_embd = torch.nn.Parameter( torch.zeros(1, n_patches + 1, d), requires_grad=False ) self.layers = torch.nn.ModuleList([ TransformerBlock(d=d, d_hidden=d_hidden, n_heads=n_heads, ln_eps=ln_eps) for _ in range(n_layers) ]) self.layernorm = torch.nn.LayerNorm(d, eps=ln_eps) self.head = torch.nn.Linear(d, patch_w_px * patch_h_px * 3) def forward( self, x_BMD: Float[Tensor, "batch m d_in"], ids_restore_BN: Int[Tensor, "batch n"], ) -> Float[Tensor, "batch n patch_pixels"]: batch_size, m, _ = x_BMD.shape _, n = ids_restore_BN.shape # Linear projection from encoder dimension to decoder dimension # CHECKED AGAINST REF---WORKS x_BMD = self.embd(x_BMD) _, _, d_decoder = x_BMD.shape # Add the mask tokens back n_mask_tokens = n + 1 - m masks_BOD = self.mask_token.repeat(batch_size, n_mask_tokens, 1) x_BND = torch.cat([x_BMD[:, 1:, :], masks_BOD], dim=1) # no cls token index_BND = ids_restore_BN[..., None].repeat(1, 1, d_decoder).to(x_BND.device) x_BND = torch.gather(x_BND, dim=1, index=index_BND) x_BND = torch.cat([x_BMD[:, :1, :], x_BND], dim=1) # append cls token # Add positional embeddings again. x_BND = x_BND + self.pos_embd for layer in self.layers: x_BND = layer(x_BND) x_BND = self.layernorm(x_BND) logits_BNP = self.head(x_BND) # Remove cls token logits_BNP = logits_BNP[:, 1:, :] return logits_BNP
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self,
x_BMD: jaxtyping.Float[Tensor, 'batch m d_in'],
ids_restore_BN: jaxtyping.Int[Tensor, 'batch n']) ‑> jaxtyping.Float[Tensor, 'batch n patch_pixels']-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
class Embeddings (*,
d: int,
image_size_px: tuple[int, int],
patch_size_px: tuple[int, int],
mask_ratio: float)-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Embeddings(torch.nn.Module): class Output(typing.TypedDict): x_BMD: Float[Tensor, "batch m d"] mask_BN: Float[Tensor, "batch n"] ids_restore_BN: Int[Tensor, "batch n"] def __init__( self, *, d: int, image_size_px: tuple[int, int], patch_size_px: tuple[int, int], mask_ratio: float, ): super().__init__() image_w_px, image_h_px = image_size_px patch_w_px, patch_h_px = patch_size_px n_patches = (image_w_px // patch_w_px) * (image_h_px // patch_h_px) self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, d)) self.position_embeddings = torch.nn.Parameter( torch.zeros(1, n_patches + 1, d), requires_grad=False ) self.patch_embeddings = PatchEmbeddings(d=d, patch_size_px=patch_size_px) self.mask_ratio = mask_ratio @jaxtyped(typechecker=beartype.beartype) def forward( self, x_BCWH: Float[Tensor, "batch 3 height width"], noise_BN: Float[Tensor, "batch n"] | None = None, ) -> Output: batch_size, _, _, _ = x_BCWH.shape x_BND = self.patch_embeddings(x_BCWH) + self.position_embeddings[:, 1:, :] x_BMD, mask_BN, ids_restore_BN = random_masking( x_BND, self.mask_ratio, noise_BN=noise_BN ) cls_x_11D = self.cls_token + self.position_embeddings[:, :1, :] cls_x_B1D = cls_x_11D.expand(batch_size, -1, -1) return self.Output( x_BMD=torch.cat((cls_x_B1D, x_BMD), dim=1), mask_BN=mask_BN, ids_restore_BN=ids_restore_BN, )
Ancestors
- torch.nn.modules.module.Module
Class variables
var Output
-
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
Methods
def forward(self,
x_BCWH: jaxtyping.Float[Tensor, 'batch 3 height width'],
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> Embeddings.Output
class Encoder (*, d: int, d_hidden: int, n_heads: int, n_layers: int, ln_eps: float)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Encoder(torch.nn.Module): def __init__( self, *, d: int, d_hidden: int, n_heads: int, n_layers: int, ln_eps: float, ): super().__init__() self.layers = torch.nn.ModuleList([ TransformerBlock(d=d, d_hidden=d_hidden, n_heads=n_heads, ln_eps=ln_eps) for _ in range(n_layers) ]) @jaxtyped(typechecker=beartype.beartype) def forward(self, x_BMD: Float[Tensor, "batch m d"]) -> Float[Tensor, "batch m d"]: for layer in self.layers: x_BMD = layer(x_BMD) return x_BMD
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, x_BMD: jaxtyping.Float[Tensor, 'batch m d']) ‑> jaxtyping.Float[Tensor, 'batch m d']
class Feedforward (*, d: int, d_hidden: int)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class Feedforward(torch.nn.Module): def __init__(self, *, d: int, d_hidden: int): super().__init__() self.linear1 = torch.nn.Linear(d, d_hidden) self.linear2 = torch.nn.Linear(d_hidden, d) @jaxtyped(typechecker=beartype.beartype) def forward(self, x_BND: Float[Tensor, "batch n d"]) -> Float[Tensor, "batch n d"]: x_BNF = self.linear1(x_BND) x_BNF = torch.nn.functional.gelu(x_BNF) x_BND = self.linear2(x_BNF) return x_BND
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']
class MaskedAutoencoder (*,
d_encoder: int,
d_hidden_encoder: int,
n_heads_encoder: int,
n_layers_encoder: int,
d_decoder: int,
d_hidden_decoder: int,
n_heads_decoder: int,
n_layers_decoder: int,
image_size_px: tuple[int, int],
patch_size_px: tuple[int, int],
mask_ratio: float,
ln_eps: float)-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@beartype.beartype class MaskedAutoencoder(torch.nn.Module): class Output(typing.TypedDict): latents: Float[Tensor, "batch n d"] decoded: Float[Tensor, "batch n d"] ids_restore: Int[Tensor, "batch n"] def __init__( self, *, d_encoder: int, d_hidden_encoder: int, n_heads_encoder: int, n_layers_encoder: int, d_decoder: int, d_hidden_decoder: int, n_heads_decoder: int, n_layers_decoder: int, image_size_px: tuple[int, int], patch_size_px: tuple[int, int], mask_ratio: float, ln_eps: float, ): super().__init__() self.vit = VisionTransformer( d=d_encoder, d_hidden=d_hidden_encoder, n_heads=n_heads_encoder, n_layers=n_layers_encoder, image_size_px=image_size_px, patch_size_px=patch_size_px, mask_ratio=mask_ratio, ln_eps=ln_eps, ) self.decoder = Decoder( d_in=d_encoder, d=d_decoder, d_hidden=d_hidden_decoder, n_layers=n_layers_decoder, n_heads=n_heads_decoder, image_size_px=image_size_px, patch_size_px=patch_size_px, ln_eps=ln_eps, ) def forward( self, x_B3WH: Float[Tensor, "batch 3 width height"], noise_BN: Float[Tensor, "batch n"] | None = None, ) -> Output: encoded = self.vit(x_B3WH, noise_BN=noise_BN) decoded_BND = self.decoder(encoded["x_BMD"], encoded["ids_restore_BN"]) return self.Output( latents=encoded["x_BMD"], decoded=decoded_BND, ids_restore=encoded["ids_restore_BN"], )
Ancestors
- torch.nn.modules.module.Module
Class variables
var Output
-
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
Methods
def forward(self,
x_B3WH: jaxtyping.Float[Tensor, 'batch 3 width height'],
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> MaskedAutoencoder.Output
class PatchEmbeddings (d: int, patch_size_px: tuple[int, int])
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class PatchEmbeddings(torch.nn.Module): def __init__(self, d: int, patch_size_px: tuple[int, int]): super().__init__() self.projection = torch.nn.Conv2d( 3, d, kernel_size=patch_size_px, stride=patch_size_px ) @jaxtyped(typechecker=beartype.beartype) def forward( self, x_BCWH: Float[Tensor, "batch 3 width height"] ) -> Float[Tensor, "batch n_patches d"]: return einops.rearrange(self.projection(x_BCWH), "batch d w h -> batch (w h) d")
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, x_BCWH: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch n_patches d']
class TransformerBlock (*, d: int, d_hidden: int, n_heads: int, ln_eps: float)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class TransformerBlock(torch.nn.Module): def __init__(self, *, d: int, d_hidden: int, n_heads: int, ln_eps: float): super().__init__() self.attention = Attention(d=d, n_heads=n_heads) self.ffnn = Feedforward(d=d, d_hidden=d_hidden) self.layernorm1 = torch.nn.LayerNorm(d, eps=ln_eps) self.layernorm2 = torch.nn.LayerNorm(d, eps=ln_eps) @jaxtyped(typechecker=beartype.beartype) def forward(self, x: Float[Tensor, "batch n d"]) -> Float[Tensor, "batch n d"]: x_ = self.attention(self.layernorm1(x)) x = x_ + x x_ = self.ffnn(self.layernorm2(x)) return x_ + x
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, x: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']
class VisionTransformer (*,
d: int,
d_hidden: int,
n_heads: int,
n_layers: int,
image_size_px: tuple[int, int],
patch_size_px: tuple[int, int],
mask_ratio: float,
ln_eps: float)-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
@jaxtyped(typechecker=beartype.beartype) class VisionTransformer(torch.nn.Module): class Output(typing.TypedDict): x_BMD: Float[Tensor, "batch m d"] ids_restore_BN: Int[Tensor, "batch n"] def __init__( self, *, d: int, d_hidden: int, n_heads: int, n_layers: int, image_size_px: tuple[int, int], patch_size_px: tuple[int, int], mask_ratio: float, ln_eps: float, ): super().__init__() self.embeddings = Embeddings( d=d, image_size_px=image_size_px, patch_size_px=patch_size_px, mask_ratio=mask_ratio, ) self.encoder = Encoder( d=d, d_hidden=d_hidden, n_heads=n_heads, n_layers=n_layers, ln_eps=ln_eps, ) self.layernorm = torch.nn.LayerNorm(d, eps=ln_eps) def forward( self, x_B3WH: Float[Tensor, "batch 3 width height"], noise_BN: Float[Tensor, "batch n"] | None = None, ) -> Float[Tensor, "batch ..."]: embedded = self.embeddings(x_B3WH, noise_BN=noise_BN) x_BMD = self.encoder(embedded["x_BMD"]) x_BMD = self.layernorm(x_BMD) return self.Output(x_BMD=x_BMD, ids_restore_BN=embedded["ids_restore_BN"])
Ancestors
- torch.nn.modules.module.Module
Class variables
var Output
-
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
Methods
def forward(self,
x_B3WH: jaxtyping.Float[Tensor, 'batch 3 width height'],
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> jaxtyping.Float[Tensor, 'batch ...']-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.