diff --git a/src/graphnet/deployment/deployment_module.py b/src/graphnet/deployment/deployment_module.py index 0083d5dce..13e2b96f7 100644 --- a/src/graphnet/deployment/deployment_module.py +++ b/src/graphnet/deployment/deployment_module.py @@ -3,7 +3,7 @@ from typing import Any, List, Union, Dict import numpy as np -from torch import Tensor +from torch import Tensor, load from torch_geometric.data import Data, Batch from graphnet.models import Model @@ -61,7 +61,11 @@ def _load_model( ) -> Model: """Load `Model` from config and insert learned weights.""" model = Model.from_config(model_config, trust=True) - model.load_state_dict(state_dict) + if isinstance(state_dict, str) and state_dict.endswith(".ckpt"): + ckpt = load(state_dict) + self.model.load_state_dict(ckpt["state_dict"]) + else: + self.model.load_state_dict(state_dict) return model def _resolve_prediction_columns(