diff --git a/_modules/graphnet/deployment/deployment_module.html b/_modules/graphnet/deployment/deployment_module.html
index 812e10b3c..d035ff90b 100644
--- a/_modules/graphnet/deployment/deployment_module.html
+++ b/_modules/graphnet/deployment/deployment_module.html
@@ -327,7 +327,7 @@
Source code fo
from typing import Any, List, Union, Dict
import numpy as np
-from torch import Tensor
+from torch import Tensor, load
from torch_geometric.data import Data, Batch
from graphnet.models import Model
@@ -387,7 +387,11 @@ Source code fo
) -> Model:
"""Load `Model` from config and insert learned weights."""
model = Model.from_config(model_config, trust=True)
- model.load_state_dict(state_dict)
+ if isinstance(state_dict, str) and state_dict.endswith(".ckpt"):
+ ckpt = load(state_dict)
+ model.load_state_dict(ckpt["state_dict"])
+ else:
+ model.load_state_dict(state_dict)
return model
def _resolve_prediction_columns(
diff --git a/_modules/graphnet/models/graphs/nodes/nodes.html b/_modules/graphnet/models/graphs/nodes/nodes.html
index b9e868b76..e11b5d06f 100644
--- a/_modules/graphnet/models/graphs/nodes/nodes.html
+++ b/_modules/graphnet/models/graphs/nodes/nodes.html
@@ -613,6 +613,8 @@ Source code for g
"""Construct nodes from raw node features ´x´."""
# Cast to Numpy
x = x.numpy()
+ if x.shape[0] == 0:
+ return Data(x=torch.tensor(np.column_stack([x, []])))
# if there is no charge column add a dummy column of zeros with the same shape as the time column
if self._charge_index is None:
charge_index: int = len(self._keys)