Skip to content

Commit

Permalink
add_checkpoint_handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Mar 14, 2024
1 parent 4ffa0f1 commit 275da0e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/graphnet/deployment/deployment_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, List, Union, Dict

import numpy as np
from torch import Tensor
from torch import Tensor, load
from torch_geometric.data import Data, Batch

from graphnet.models import Model
Expand Down Expand Up @@ -61,7 +61,11 @@ def _load_model(
) -> Model:
"""Load `Model` from config and insert learned weights."""
model = Model.from_config(model_config, trust=True)
model.load_state_dict(state_dict)
if isinstance(state_dict, str) and state_dict.endswith(".ckpt"):
ckpt = load(state_dict)
self.model.load_state_dict(ckpt["state_dict"])
else:
self.model.load_state_dict(state_dict)
return model

def _resolve_prediction_columns(
Expand Down

0 comments on commit 275da0e

Please sign in to comment.