I am looking for a way to use a custom patch embedding layer for vanilla ViT. I want the rest of the ViT from the pre-trained model. Is there a way to do it using Pytorch?
I can load a pre-trained model or can write the ViT code from scratch. But I want something where I can use the weights of the model for the layers after the patch embeddings.
Thanks,
It depends a lot on the architecture of the two models, but you can do something like this:
In this example,
embeddingwould be your custom embedding, andvit_modelwould be all the trained layers of the vit model except the embedding. Depending on how the vit model is structured, you may need to hack into it to extract the non-embedding layers in a way that allows you to simply pass an input to them.