How To Convert a PyTorch Model to Jax
I wanted a Jax version of DINOv3 which is written in PyTorch. Here’s how I did it.
- Download the PyTorch code and weights.
- Write the Jax code.
- Convert the PyTorch weights.
- Test against the PyTorch reference implementation.
Download the PyTorch Code and Weights
This depends on whatever model you’re using. For DINOv3, I followed the instructions in the repo. You want to be able to load the model into memory and see the Torch representation:
= torch.hub.load(
vit_pt "facebookresearch/dinov3", "dinov3_vits16", source="github", weights=CKPT_PATH
)
Then I can print vit_pt
and see:
DinoVisionTransformer(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
(norm): Identity()
)
(rope_embed): RopePositionEmbedding()
(blocks): ModuleList(
(0-11): 12 x SelfAttentionBlock(
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(attn): SelfAttention(
(qkv): LinearKMaskedBias(in_features=384, out_features=1152, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): LayerScale()
(norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=384, out_features=1536, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1536, out_features=384, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
(ls2): LayerScale()
)
)
(norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(head): Identity()
)
So now I can inspect the model and its weights.
Write the Jax Code
Convert the PyTorch Weights
Test Against Reference Implementation
Unstructured Notes
This is a pain in the ass. A huge pain in the ass.
[Relevant link] [Source]
Sam Stevens, 2024