diff --git a/_modules/graphnet/deployment/deployment_module.html b/_modules/graphnet/deployment/deployment_module.html index 812e10b3c..d035ff90b 100644 --- a/_modules/graphnet/deployment/deployment_module.html +++ b/_modules/graphnet/deployment/deployment_module.html @@ -327,7 +327,7 @@

Source code fo from typing import Any, List, Union, Dict import numpy as np -from torch import Tensor +from torch import Tensor, load from torch_geometric.data import Data, Batch from graphnet.models import Model @@ -387,7 +387,11 @@

Source code fo ) -> Model: """Load `Model` from config and insert learned weights.""" model = Model.from_config(model_config, trust=True) - model.load_state_dict(state_dict) + if isinstance(state_dict, str) and state_dict.endswith(".ckpt"): + ckpt = load(state_dict) + model.load_state_dict(ckpt["state_dict"]) + else: + model.load_state_dict(state_dict) return model def _resolve_prediction_columns( diff --git a/_modules/graphnet/models/graphs/nodes/nodes.html b/_modules/graphnet/models/graphs/nodes/nodes.html index b9e868b76..e11b5d06f 100644 --- a/_modules/graphnet/models/graphs/nodes/nodes.html +++ b/_modules/graphnet/models/graphs/nodes/nodes.html @@ -613,6 +613,8 @@

Source code for g """Construct nodes from raw node features ´x´.""" # Cast to Numpy x = x.numpy() + if x.shape[0] == 0: + return Data(x=torch.tensor(np.column_stack([x, []]))) # if there is no charge column add a dummy column of zeros with the same shape as the time column if self._charge_index is None: charge_index: int = len(self._keys)