From b318e195f6b0b75ce373ce5e35e7d75043265b67 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sat, 3 Feb 2024 19:10:52 +0100 Subject: [PATCH 1/2] bugfix + minkowski mypy --- src/graphnet/models/graphs/edges/minkowski.py | 11 +++--- src/graphnet/models/standard_model.py | 39 +++++++++++++------ 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 5d1134ec5..2526de1cb 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -69,12 +69,13 @@ def _construct_edges(self, graph: Data) -> Data: row = [] col = [] for batch in range(x.shape[0]): + x_masked = x[batch][mask[batch]] distance_mat = compute_minkowski_distance_mat( - x_masked := x[batch][mask[batch]], - x_masked, - self.c, - self.space_coords, - self.time_coord, + x=x_masked, + y=x_masked, + c=self.c, + space_coords=self.space_coords, + time_coord=self.time_coord, ) num_points = x_masked.shape[0] num_edges = min(self.nb_nearest_neighbours, num_points) diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 663664996..098f63d98 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -410,6 +410,9 @@ def predict_as_dataframe( f"number of output columns ({predictions.shape[1]}) don't match." ) + # Check if predictions are on event- or pulse-level + pulse_level_predictions = len(predictions) > len(dataloader.dataset) + # Get additional attributes attributes: Dict[str, List[np.ndarray]] = OrderedDict( [(attr, []) for attr in additional_attributes] @@ -423,25 +426,39 @@ def predict_as_dataframe( # Check if node level predictions # If true, additional attributes are repeated # to make dimensions fit - if len(predictions) != len(dataloader.dataset): + if pulse_level_predictions: if len(attribute) < np.sum( batch.n_pulses.detach().cpu().numpy() ): attribute = np.repeat( attribute, batch.n_pulses.detach().cpu().numpy() ) - try: - assert len(attribute) == len(batch.x) - except AssertionError: - self.warning_once( - "Could not automatically adjust length" - f"of additional attribute {attr} to match length of" - f"predictions. Make sure {attr} is a graph-level or" - "node-level attribute. Attribute skipped." - ) - pass attributes[attr].extend(attribute) + # Confirm that attributes match length of predictions + skip_attributes = [] + for attr in attributes.keys(): + try: + assert len(attributes[attr]) == len(predictions) + except AssertionError: + self.warning_once( + "Could not automatically adjust length" + f" of additional attribute '{attr}' to match length of" + f" predictions.This error can be caused by heavy" + " disagreement between number of examples in the" + " dataset vs. actual events in the dataloader, e.g. " + " heavy filtering of events in `collate_fn` passed to" + " `dataloader`. This can also be caused by requesting" + " pulse-level attributes for `Task`s that produce" + " event-level predictions. Attribute skipped." + ) + skip_attributes.append(attr) + + # Remove bad attributes + for attr in skip_attributes: + attributes.pop(attr) + additional_attributes.remove(attr) + data = np.concatenate( [predictions] + [ From 36a2384ff06afabec9fc143ef3d548672b398761 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sat, 3 Feb 2024 19:52:56 +0100 Subject: [PATCH 2/2] check requirement links --- requirements/torch_cpu.txt | 2 +- requirements/torch_gpu.txt | 2 +- requirements/torch_macos.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/torch_cpu.txt b/requirements/torch_cpu.txt index 59e273288..babb4fb8e 100644 --- a/requirements/torch_cpu.txt +++ b/requirements/torch_cpu.txt @@ -1,2 +1,2 @@ --find-links https://download.pytorch.org/whl/cpu ---find-links https://data.pyg.org/whl/torch-2.1.0+cpu.html \ No newline at end of file +--find-links https://data.pyg.org/whl/torch-2.2.0+cpu.html \ No newline at end of file diff --git a/requirements/torch_gpu.txt b/requirements/torch_gpu.txt index 1f1abba3f..ddcb85038 100644 --- a/requirements/torch_gpu.txt +++ b/requirements/torch_gpu.txt @@ -1,4 +1,4 @@ # Contains packages requirements for GPU installation --find-links https://download.pytorch.org/whl/torch_stable.html torch==2.1.0+cu118 ---find-links https://data.pyg.org/whl/torch-2.1.0+cu118.html +--find-links https://data.pyg.org/whl/torch-2.2.0+cu118.html diff --git a/requirements/torch_macos.txt b/requirements/torch_macos.txt index 3e9d75df4..2b5009a8e 100644 --- a/requirements/torch_macos.txt +++ b/requirements/torch_macos.txt @@ -1,2 +1,2 @@ --find-links https://download.pytorch.org/whl/torch_stable.html ---find-links https://data.pyg.org/whl/torch-2.1.0+cpu.html \ No newline at end of file +--find-links https://data.pyg.org/whl/torch-2.2.0+cpu.html \ No newline at end of file