Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exporting Model as .pth file having dimensions issues #2591

Open
jamesgwen opened this issue Nov 8, 2024 · 0 comments
Open

Exporting Model as .pth file having dimensions issues #2591

jamesgwen opened this issue Nov 8, 2024 · 0 comments
Assignees

Comments

@jamesgwen
Copy link

Hello!
I am currently trying to save the checkpoint_final.pth file as a .pth file with the model architecture saved with it.

I modified the nnUNet/nnunetv2/inference/predict_from_raw_data.py file to save the trained model as a .pth file.

if __name__ == '__main__':
    # predict a bunch of files
    from nnunetv2.paths import nnUNet_results, nnUNet_raw

    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=True,
        perform_everything_on_device=True,
        device=torch.device('cuda', 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True
    )
    predictor.initialize_from_trained_model_folder(
        join(nnUNet_results, 'Dataset089_data/nnUNetTrainer__nnUNetPlans__3d_fullres'),
        use_folds=(0,),
        checkpoint_name='checkpoint_final.pth',
    )
    model = predictor.network
    param = predictor.list_of_parameters[0] 
    model.load_state_dict(param)
    torch.save(model, '/home/jupyter/nnunet_Dataset089_data.pth')`

I then load the model and try to run a dummy tensor through it:

model = torch.load('nnunet_Dataset089_data.pth', weights_only=False)
model.eval()
data = torch.rand((1, 64, 192, 160))
output = model(data)

The above code gives the following error:
RuntimeError: Given groups=1, weight of size [320, 640, 3, 3, 3], expected input[1, 320, 16, 12, 10] to have 640 channels, but got 320 channels instead

Would really appreciate anyone's help in debugging this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants