diff --git a/src/graphnet/deployment/i3modules/graphnet_module.py b/src/graphnet/deployment/i3modules/graphnet_module.py index dee0973b8..40fb7b388 100644 --- a/src/graphnet/deployment/i3modules/graphnet_module.py +++ b/src/graphnet/deployment/i3modules/graphnet_module.py @@ -238,7 +238,9 @@ def _inference(self, data: Data) -> np.ndarray: len(task_predictions) == 1 ), f"""This method assumes a single task. \n Got {len(task_predictions)} tasks.""" - return self.model(data)[0].detach().numpy() + return ( + task_predictions[0].detach().numpy() + ) # self.model(data)[0].detach().numpy() class I3PulseCleanerModule(I3InferenceModule):