Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error while Loading vit_small weights #26

Open
shubhaminnani opened this issue May 31, 2023 · 6 comments
Open

Error while Loading vit_small weights #26

shubhaminnani opened this issue May 31, 2023 · 6 comments

Comments

@shubhaminnani
Copy link

shubhaminnani commented May 31, 2023

Hi @Xiyue-Wang ,
Thank you for the amazing repo.
I am trying to load the weights for MoCoV3 vit_small pretrained weights.

model = moco.builder_infence.MoCo_ViT(partial(vits.__dict__['vit_small'], stop_grad_conv1=True))
pretext_model = torch.load('TransPath/vit_small.pth.tar')['state_dict']
model = nn.DataParallel(model).cuda()
model.load_state_dict(pretext_model,strict=False)

But facing an error as below
image

I checked the keys for model and weights and seems to be same, but still above error.

Looking forward.
Thanks!

@Xiyue-Wang
Copy link
Owner

Xiyue-Wang commented Jun 1, 2023

you can try
model = moco.builder_infence.MoCo_ViT(
partial(vits.dict[args.arch], stop_grad_conv1=True))

pretext_model = torch.load(r'./vit_small.pth.tar')['state_dict']
model = nn.DataParallel(model).cuda()
model.load_state_dict(pretext_model, strict=True)?
Many people use is no problem, I do not know why you have an error

@shubhaminnani
Copy link
Author

thank you, but dict can't be defined like above as you mentioned.

Even though I tried, it gave below Error AttributeError: module 'vits' has no attribute 'dict'
image

Thanks

@Xiyue-Wang
Copy link
Owner

image
may be remove model = nn.DataParallel(model).cuda()?

@shubhaminnani
Copy link
Author

Tried by removing that code, complete keys are a mismatch here. Not able to load any keys for that.

image

@shubhaminnani
Copy link
Author

Do I need to use model.module.online_encoder.net.head = nn.Identity() similar kind of code from TransPath to extract the features?

@shubhaminnani
Copy link
Author

Hi @Xiyue-Wang ,
I think I found the problem. The model you trained was with moco.builder and while doing inference you are trying to call the module from moco.builder_infence where the in forward function below code is commented out. Training weights have keys from the former module and they dont match in the moco.builder_infence.
image

As understanding the code, it seems feature extraction should be done from base encoder only rather the complete architecture, so for that need to find a solution. Better load the complete weights and truncate the model after that. Whats your view?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants