Skip to content

Commit

Permalink
Merge pull request graphnet-team#679 from Aske-Rosted/inference_from_…
Browse files Browse the repository at this point in the history
…checkpoint

add_checkpoint_handling
  • Loading branch information
Aske-Rosted authored Mar 18, 2024
2 parents 45befd9 + dcdbe3e commit 6f7db39
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/graphnet/deployment/deployment_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, List, Union, Dict

import numpy as np
from torch import Tensor
from torch import Tensor, load
from torch_geometric.data import Data, Batch

from graphnet.models import Model
Expand Down Expand Up @@ -61,7 +61,11 @@ def _load_model(
) -> Model:
"""Load `Model` from config and insert learned weights."""
model = Model.from_config(model_config, trust=True)
model.load_state_dict(state_dict)
if isinstance(state_dict, str) and state_dict.endswith(".ckpt"):
ckpt = load(state_dict)
model.load_state_dict(ckpt["state_dict"])
else:
model.load_state_dict(state_dict)
return model

def _resolve_prediction_columns(
Expand Down

0 comments on commit 6f7db39

Please sign in to comment.