Module contrib.mae.modeling

Functions

def load_ckpt(ckpt: str, *, chunk_size_kb: int = 1024) ‑> MaskedAutoencoder

Loads a pre-trained MAE ViT from disk. If it's not on disk, downloads the checkpoint from huggingface and then loads it into the MaskedAutoencoder class.

def random_masking(x_BND: jaxtyping.Float[Tensor, 'batch n d'],
mask_ratio: float,
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> tuple[jaxtyping.Float[Tensor, 'batch m d'], jaxtyping.Float[Tensor, 'batch n'], jaxtyping.Int[Tensor, 'batch n']]

Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsorting random noise.

Classes

class Attention (*, d: int, n_heads: int)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class Attention(torch.nn.Module):
    def __init__(self, *, d: int, n_heads: int):
        super().__init__()
        assert d % n_heads == 0, f"n_heads={n_heads} must evenly divide d={d}"

        self.n_heads = n_heads

        self.query = torch.nn.Linear(d, d)
        self.key = torch.nn.Linear(d, d)
        self.value = torch.nn.Linear(d, d)
        self.output = torch.nn.Linear(d, d)

    @jaxtyped(typechecker=beartype.beartype)
    def split(
        self, x_BND: Float[Tensor, "batch n d"]
    ) -> Float[Tensor, "batch n_heads n d_head"]:
        return einops.rearrange(
            x_BND,
            "batch n (n_heads d_head) -> batch n_heads n d_head",
            n_heads=self.n_heads,
        )

    @jaxtyped(typechecker=beartype.beartype)
    def forward(self, x_BND: Float[Tensor, "batch n d"]) -> Float[Tensor, "batch n d"]:
        q_BHNd = self.split(self.query(x_BND))
        k_BHNd = self.split(self.key(x_BND))
        v_BHNd = self.split(self.value(x_BND))

        x_BHNd = torch.nn.functional.scaled_dot_product_attention(
            q_BHNd, k_BHNd, v_BHNd, dropout_p=0.0, is_causal=False, scale=None
        )

        x_BND = einops.rearrange(
            x_BHNd, "batch n_heads n d_head -> batch n (n_heads d_head)"
        ).contiguous()

        x_BND = self.output(x_BND)
        return x_BND

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']
def split(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n_heads n d_head']
class Decoder (*,
d_in: int,
d: int,
d_hidden: int,
n_layers: int,
n_heads: int,
patch_size_px: tuple[int, int],
image_size_px: tuple[int, int],
ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class Decoder(torch.nn.Module):
    def __init__(
        self,
        *,
        d_in: int,
        d: int,
        d_hidden: int,
        n_layers: int,
        n_heads: int,
        patch_size_px: tuple[int, int],
        image_size_px: tuple[int, int],
        ln_eps: float,
    ):
        super().__init__()

        image_w_px, image_h_px = image_size_px
        patch_w_px, patch_h_px = patch_size_px
        n_patches = (image_w_px // patch_w_px) * (image_h_px // patch_h_px)

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, d))

        self.embd = torch.nn.Linear(d_in, d)

        self.pos_embd = torch.nn.Parameter(
            torch.zeros(1, n_patches + 1, d), requires_grad=False
        )

        self.layers = torch.nn.ModuleList([
            TransformerBlock(d=d, d_hidden=d_hidden, n_heads=n_heads, ln_eps=ln_eps)
            for _ in range(n_layers)
        ])

        self.layernorm = torch.nn.LayerNorm(d, eps=ln_eps)
        self.head = torch.nn.Linear(d, patch_w_px * patch_h_px * 3)

    def forward(
        self,
        x_BMD: Float[Tensor, "batch m d_in"],
        ids_restore_BN: Int[Tensor, "batch n"],
    ) -> Float[Tensor, "batch n patch_pixels"]:
        batch_size, m, _ = x_BMD.shape
        _, n = ids_restore_BN.shape

        # Linear projection from encoder dimension to decoder dimension
        # CHECKED AGAINST REF---WORKS
        x_BMD = self.embd(x_BMD)

        _, _, d_decoder = x_BMD.shape

        # Add the mask tokens back
        n_mask_tokens = n + 1 - m
        masks_BOD = self.mask_token.repeat(batch_size, n_mask_tokens, 1)
        x_BND = torch.cat([x_BMD[:, 1:, :], masks_BOD], dim=1)  # no cls token

        index_BND = ids_restore_BN[..., None].repeat(1, 1, d_decoder).to(x_BND.device)
        x_BND = torch.gather(x_BND, dim=1, index=index_BND)

        x_BND = torch.cat([x_BMD[:, :1, :], x_BND], dim=1)  # append cls token

        # Add positional embeddings again.
        x_BND = x_BND + self.pos_embd

        for layer in self.layers:
            x_BND = layer(x_BND)

        x_BND = self.layernorm(x_BND)
        logits_BNP = self.head(x_BND)

        # Remove cls token
        logits_BNP = logits_BNP[:, 1:, :]

        return logits_BNP

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self,
x_BMD: jaxtyping.Float[Tensor, 'batch m d_in'],
ids_restore_BN: jaxtyping.Int[Tensor, 'batch n']) ‑> jaxtyping.Float[Tensor, 'batch n patch_pixels']

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Embeddings (*,
d: int,
image_size_px: tuple[int, int],
patch_size_px: tuple[int, int],
mask_ratio: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class Embeddings(torch.nn.Module):
    class Output(typing.TypedDict):
        x_BMD: Float[Tensor, "batch m d"]
        mask_BN: Float[Tensor, "batch n"]
        ids_restore_BN: Int[Tensor, "batch n"]

    def __init__(
        self,
        *,
        d: int,
        image_size_px: tuple[int, int],
        patch_size_px: tuple[int, int],
        mask_ratio: float,
    ):
        super().__init__()

        image_w_px, image_h_px = image_size_px
        patch_w_px, patch_h_px = patch_size_px
        n_patches = (image_w_px // patch_w_px) * (image_h_px // patch_h_px)

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, d))
        self.position_embeddings = torch.nn.Parameter(
            torch.zeros(1, n_patches + 1, d), requires_grad=False
        )
        self.patch_embeddings = PatchEmbeddings(d=d, patch_size_px=patch_size_px)

        self.mask_ratio = mask_ratio

    @jaxtyped(typechecker=beartype.beartype)
    def forward(
        self,
        x_BCWH: Float[Tensor, "batch 3 height width"],
        noise_BN: Float[Tensor, "batch n"] | None = None,
    ) -> Output:
        batch_size, _, _, _ = x_BCWH.shape

        x_BND = self.patch_embeddings(x_BCWH) + self.position_embeddings[:, 1:, :]

        x_BMD, mask_BN, ids_restore_BN = random_masking(
            x_BND, self.mask_ratio, noise_BN=noise_BN
        )

        cls_x_11D = self.cls_token + self.position_embeddings[:, :1, :]

        cls_x_B1D = cls_x_11D.expand(batch_size, -1, -1)
        return self.Output(
            x_BMD=torch.cat((cls_x_B1D, x_BMD), dim=1),
            mask_BN=mask_BN,
            ids_restore_BN=ids_restore_BN,
        )

Ancestors

  • torch.nn.modules.module.Module

Class variables

var Output

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods

def forward(self,
x_BCWH: jaxtyping.Float[Tensor, 'batch 3 height width'],
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> Embeddings.Output
class Encoder (*, d: int, d_hidden: int, n_heads: int, n_layers: int, ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class Encoder(torch.nn.Module):
    def __init__(
        self,
        *,
        d: int,
        d_hidden: int,
        n_heads: int,
        n_layers: int,
        ln_eps: float,
    ):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            TransformerBlock(d=d, d_hidden=d_hidden, n_heads=n_heads, ln_eps=ln_eps)
            for _ in range(n_layers)
        ])

    @jaxtyped(typechecker=beartype.beartype)
    def forward(self, x_BMD: Float[Tensor, "batch m d"]) -> Float[Tensor, "batch m d"]:
        for layer in self.layers:
            x_BMD = layer(x_BMD)
        return x_BMD

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x_BMD: jaxtyping.Float[Tensor, 'batch m d']) ‑> jaxtyping.Float[Tensor, 'batch m d']
class Feedforward (*, d: int, d_hidden: int)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class Feedforward(torch.nn.Module):
    def __init__(self, *, d: int, d_hidden: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(d, d_hidden)
        self.linear2 = torch.nn.Linear(d_hidden, d)

    @jaxtyped(typechecker=beartype.beartype)
    def forward(self, x_BND: Float[Tensor, "batch n d"]) -> Float[Tensor, "batch n d"]:
        x_BNF = self.linear1(x_BND)
        x_BNF = torch.nn.functional.gelu(x_BNF)
        x_BND = self.linear2(x_BNF)
        return x_BND

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x_BND: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']
class MaskedAutoencoder (*,
d_encoder: int,
d_hidden_encoder: int,
n_heads_encoder: int,
n_layers_encoder: int,
d_decoder: int,
d_hidden_decoder: int,
n_heads_decoder: int,
n_layers_decoder: int,
image_size_px: tuple[int, int],
patch_size_px: tuple[int, int],
mask_ratio: float,
ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@beartype.beartype
class MaskedAutoencoder(torch.nn.Module):
    class Output(typing.TypedDict):
        latents: Float[Tensor, "batch n d"]
        decoded: Float[Tensor, "batch n d"]
        ids_restore: Int[Tensor, "batch n"]

    def __init__(
        self,
        *,
        d_encoder: int,
        d_hidden_encoder: int,
        n_heads_encoder: int,
        n_layers_encoder: int,
        d_decoder: int,
        d_hidden_decoder: int,
        n_heads_decoder: int,
        n_layers_decoder: int,
        image_size_px: tuple[int, int],
        patch_size_px: tuple[int, int],
        mask_ratio: float,
        ln_eps: float,
    ):
        super().__init__()
        self.vit = VisionTransformer(
            d=d_encoder,
            d_hidden=d_hidden_encoder,
            n_heads=n_heads_encoder,
            n_layers=n_layers_encoder,
            image_size_px=image_size_px,
            patch_size_px=patch_size_px,
            mask_ratio=mask_ratio,
            ln_eps=ln_eps,
        )
        self.decoder = Decoder(
            d_in=d_encoder,
            d=d_decoder,
            d_hidden=d_hidden_decoder,
            n_layers=n_layers_decoder,
            n_heads=n_heads_decoder,
            image_size_px=image_size_px,
            patch_size_px=patch_size_px,
            ln_eps=ln_eps,
        )

    def forward(
        self,
        x_B3WH: Float[Tensor, "batch 3 width height"],
        noise_BN: Float[Tensor, "batch n"] | None = None,
    ) -> Output:
        encoded = self.vit(x_B3WH, noise_BN=noise_BN)
        decoded_BND = self.decoder(encoded["x_BMD"], encoded["ids_restore_BN"])

        return self.Output(
            latents=encoded["x_BMD"],
            decoded=decoded_BND,
            ids_restore=encoded["ids_restore_BN"],
        )

Ancestors

  • torch.nn.modules.module.Module

Class variables

var Output

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods

def forward(self,
x_B3WH: jaxtyping.Float[Tensor, 'batch 3 width height'],
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> MaskedAutoencoder.Output
class PatchEmbeddings (d: int, patch_size_px: tuple[int, int])

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class PatchEmbeddings(torch.nn.Module):
    def __init__(self, d: int, patch_size_px: tuple[int, int]):
        super().__init__()
        self.projection = torch.nn.Conv2d(
            3, d, kernel_size=patch_size_px, stride=patch_size_px
        )

    @jaxtyped(typechecker=beartype.beartype)
    def forward(
        self, x_BCWH: Float[Tensor, "batch 3 width height"]
    ) -> Float[Tensor, "batch n_patches d"]:
        return einops.rearrange(self.projection(x_BCWH), "batch d w h -> batch (w h) d")

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x_BCWH: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> jaxtyping.Float[Tensor, 'batch n_patches d']
class TransformerBlock (*, d: int, d_hidden: int, n_heads: int, ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class TransformerBlock(torch.nn.Module):
    def __init__(self, *, d: int, d_hidden: int, n_heads: int, ln_eps: float):
        super().__init__()
        self.attention = Attention(d=d, n_heads=n_heads)
        self.ffnn = Feedforward(d=d, d_hidden=d_hidden)
        self.layernorm1 = torch.nn.LayerNorm(d, eps=ln_eps)
        self.layernorm2 = torch.nn.LayerNorm(d, eps=ln_eps)

    @jaxtyped(typechecker=beartype.beartype)
    def forward(self, x: Float[Tensor, "batch n d"]) -> Float[Tensor, "batch n d"]:
        x_ = self.attention(self.layernorm1(x))

        x = x_ + x

        x_ = self.ffnn(self.layernorm2(x))
        return x_ + x

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x: jaxtyping.Float[Tensor, 'batch n d']) ‑> jaxtyping.Float[Tensor, 'batch n d']
class VisionTransformer (*,
d: int,
d_hidden: int,
n_heads: int,
n_layers: int,
image_size_px: tuple[int, int],
patch_size_px: tuple[int, int],
mask_ratio: float,
ln_eps: float)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
@jaxtyped(typechecker=beartype.beartype)
class VisionTransformer(torch.nn.Module):
    class Output(typing.TypedDict):
        x_BMD: Float[Tensor, "batch m d"]
        ids_restore_BN: Int[Tensor, "batch n"]

    def __init__(
        self,
        *,
        d: int,
        d_hidden: int,
        n_heads: int,
        n_layers: int,
        image_size_px: tuple[int, int],
        patch_size_px: tuple[int, int],
        mask_ratio: float,
        ln_eps: float,
    ):
        super().__init__()
        self.embeddings = Embeddings(
            d=d,
            image_size_px=image_size_px,
            patch_size_px=patch_size_px,
            mask_ratio=mask_ratio,
        )
        self.encoder = Encoder(
            d=d,
            d_hidden=d_hidden,
            n_heads=n_heads,
            n_layers=n_layers,
            ln_eps=ln_eps,
        )
        self.layernorm = torch.nn.LayerNorm(d, eps=ln_eps)

    def forward(
        self,
        x_B3WH: Float[Tensor, "batch 3 width height"],
        noise_BN: Float[Tensor, "batch n"] | None = None,
    ) -> Float[Tensor, "batch ..."]:
        embedded = self.embeddings(x_B3WH, noise_BN=noise_BN)
        x_BMD = self.encoder(embedded["x_BMD"])
        x_BMD = self.layernorm(x_BMD)

        return self.Output(x_BMD=x_BMD, ids_restore_BN=embedded["ids_restore_BN"])

Ancestors

  • torch.nn.modules.module.Module

Class variables

var Output

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods

def forward(self,
x_B3WH: jaxtyping.Float[Tensor, 'batch 3 width height'],
noise_BN: jaxtyping.Float[Tensor, 'batch n'] | None = None) ‑> jaxtyping.Float[Tensor, 'batch ...']

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.