Behold, My Stuff

[Home] [Writing] [Links] [CV] [Contact]

How To Convert a PyTorch Model to Jax

I wanted a Jax version of DINOv3 which is written in PyTorch. Here’s how I did it.

  1. Download the PyTorch code and weights.
  2. Write the Jax code.
  3. Convert the PyTorch weights.
  4. Test against the PyTorch reference implementation.

Download the PyTorch Code and Weights

This depends on whatever model you’re using. For DINOv3, I followed the instructions in the repo. You want to be able to load the model into memory and see the Torch representation:

vit_pt = torch.hub.load(
    "facebookresearch/dinov3", "dinov3_vits16", source="github", weights=CKPT_PATH
)

Then I can print vit_pt and see:

DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (rope_embed): RopePositionEmbedding()
  (blocks): ModuleList(
    (0-11): 12 x SelfAttentionBlock(
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (attn): SelfAttention(
        (qkv): LinearKMaskedBias(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
    )
  )
  (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
  (head): Identity()
)

So now I can inspect the model and its weights.

Write the Jax Code

Convert the PyTorch Weights

Test Against Reference Implementation

Unstructured Notes

This is a pain in the ass. A huge pain in the ass.


[Relevant link] [Source]

Sam Stevens, 2024