Skip to content

Commit

Permalink
🚧
Browse files Browse the repository at this point in the history
  • Loading branch information
aaarrti committed Apr 28, 2024
1 parent 3022a28 commit 0e4b1ea
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Optional,
Tuple,
Union,
TypeGuard,
TypedDict,
)

Expand All @@ -36,6 +35,12 @@
from quantus.helpers.model.model_interface import ModelInterface


if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard


class PyTorchModel(ModelInterface[nn.Module]):
"""Interface for torch models."""

Expand Down Expand Up @@ -134,9 +139,7 @@ def _obtain_predictions(

elif isinstance(self.model, nn.Module):
pred_model = self.get_softmax_arg_model()
return pred_model(
torch.Tensor(x).to(self.device), **model_predict_kwargs
)
return pred_model(torch.Tensor(x).to(self.device), **model_predict_kwargs)
else:
raise ValueError("Predictions cant be null")

Expand Down

0 comments on commit 0e4b1ea

Please sign in to comment.