From 2582c7b7ad92d953d0ecfae314d6236e10ab8ca0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Mon, 18 Mar 2024 01:06:31 +0000 Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20@=20graphnet?= =?UTF-8?q?-team/graphnet@6f7db39f909f0a5c10334caa446aac01675befa6=20?= =?UTF-8?q?=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- _modules/graphnet/deployment/deployment_module.html | 8 ++++++-- _modules/graphnet/models/graphs/nodes/nodes.html | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/_modules/graphnet/deployment/deployment_module.html b/_modules/graphnet/deployment/deployment_module.html index 812e10b3c..d035ff90b 100644 --- a/_modules/graphnet/deployment/deployment_module.html +++ b/_modules/graphnet/deployment/deployment_module.html @@ -327,7 +327,7 @@

Source code fo from typing import Any, List, Union, Dict import numpy as np -from torch import Tensor +from torch import Tensor, load from torch_geometric.data import Data, Batch from graphnet.models import Model @@ -387,7 +387,11 @@

Source code fo ) -> Model: """Load `Model` from config and insert learned weights.""" model = Model.from_config(model_config, trust=True) - model.load_state_dict(state_dict) + if isinstance(state_dict, str) and state_dict.endswith(".ckpt"): + ckpt = load(state_dict) + model.load_state_dict(ckpt["state_dict"]) + else: + model.load_state_dict(state_dict) return model def _resolve_prediction_columns( diff --git a/_modules/graphnet/models/graphs/nodes/nodes.html b/_modules/graphnet/models/graphs/nodes/nodes.html index b9e868b76..e11b5d06f 100644 --- a/_modules/graphnet/models/graphs/nodes/nodes.html +++ b/_modules/graphnet/models/graphs/nodes/nodes.html @@ -613,6 +613,8 @@

Source code for g """Construct nodes from raw node features ´x´.""" # Cast to Numpy x = x.numpy() + if x.shape[0] == 0: + return Data(x=torch.tensor(np.column_stack([x, []]))) # if there is no charge column add a dummy column of zeros with the same shape as the time column if self._charge_index is None: charge_index: int = len(self._keys)