Pre-trained ViT (Vision Transformer) models are usually trained on 224x224 images or 384x384 images. But I have to fine-tune a custom ViT model (all the layers of ViT plus some additional layers) on 640x640 images. How to handle the positional embeddings in this case? When I load the pre-trained weights from ViT-base-patch16-384 (available on huggingface) to my model using load_state_dict(), I am getting the following error:
size mismatch for pos_embed: copying a param with shape torch.Size([1, 577, 768]) from checkpoint, the shape in current model is torch.Size([1, 1601, 768]).
This error is expected, as the original model had 24x24=576 positional embeddings while my model has 40x40=1600 positional embeddings (as the size of the image is 640x640 now).
As far as I have studied, researchers use interpolation to solve this problem. My question is, if I use bicubic interpolation, does the following code resolve the issue I am facing? Will it properly interpolate the positional embeddings?
import torch
a = torch.rand(1, 577, 768) # say it is the original pos_embed
a_temp = a.unsqueeze(0) # shape of a_temp is (1,1,577,768)
# making it 4D because bicubic interpolation of torch works for only 4D and 5D
b = torch.nn.functional.interpolate(a_temp, [1601,768], mode='bicubic')
# shape of b is now (1,1,1601,768)
b = torch.squeeze(b,0) # transferring b from 4D to 3D again
# shape of b is now (1,1601,768)
N.B. I have followed the video to implement the original ViT, then I added some additional layers.
Downsampling it into 224x224 or 384x384 is one option but you will deteriorate your image quality due to the bicubic (or other method) interpolation.
Another option is to customly write you ViT model that it will be able to process images with the full 640x640 resolution. You'll only have to change the patch embedding part of the model, the other blocks will stay the same. For example if your input resolution is 640x640 than you can embed it into 16x16 patches of 40x40 pixels each. You'll then get [256+1cls, 768] sequences (you can also embed it into [257,384] which are shorter sequences, it further reduces the complexity of the model).