diff --git a/src/graphnet/data/dataset/dataset.py b/src/graphnet/data/dataset/dataset.py index 110da8a73..49329e4ac 100644 --- a/src/graphnet/data/dataset/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -224,6 +224,7 @@ def __init__( loss_weight_default_value: Optional[float] = None, seed: Optional[int] = None, labels: Optional[Dict[str, Any]] = None, + use_super_selection: bool = False, ): """Construct Dataset. @@ -271,6 +272,10 @@ def __init__( events ~ event_no % 5 > 0"`). graph_definition: Method that defines the graph representation. labels: Dictionary of labels to be added to the dataset. + use_super_selection: If True, the string selection is handled by + the query function of the dataset class, rather than + pd.DataFrame.query. Defaults to False and should + only be used with sqlite. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -297,6 +302,7 @@ def __init__( self._graph_definition = deepcopy(graph_definition) self._labels = labels self._string_column = graph_definition._detector.string_index_name + self._use_super_selection = use_super_selection if node_truth is not None: assert isinstance(node_truth_table, str) @@ -347,6 +353,7 @@ def __init__( self, index_column=index_column, seed=seed, + use_super_selection=self._use_super_selection, ) if self._labels is not None: @@ -609,10 +616,13 @@ def _create_graph( """ # Convert truth to dict if len(truth.shape) == 1: - truth = truth.reshape(1, -1) - truth_dict = { - key: truth[:, index] for index, key in enumerate(self._truth) - } + truth_dict = { + key: truth[0][index] for index, key in enumerate(self._truth) + } + else: + truth_dict = { + key: truth[:, index] for index, key in enumerate(self._truth) + } # Define custom labels labels_dict = self._get_labels(truth_dict) diff --git a/src/graphnet/data/utilities/string_selection_resolver.py b/src/graphnet/data/utilities/string_selection_resolver.py index 8a1c61513..c19311bef 100644 --- a/src/graphnet/data/utilities/string_selection_resolver.py +++ b/src/graphnet/data/utilities/string_selection_resolver.py @@ -53,14 +53,17 @@ def __init__( index_column: str, seed: Optional[int] = None, use_cache: bool = True, + use_super_selection: bool = False, ): """Construct `StringSelectionResolver`.""" self._dataset = dataset self._index_column = index_column self._seed = seed self._use_cache = use_cache - + self._use_super_selection = use_super_selection # Base class constructor + if self._use_super_selection: + self._use_cache = False super().__init__(name=__name__, class_name=self.__class__.__name__) # Public method(s) @@ -214,19 +217,32 @@ def _query_selection_from_dataset(self, selection: str) -> pd.DataFrame: df_values = self._load_values_cache(values_cache_path) else: - df_values = pd.DataFrame( - data=self._dataset.query_table( - self._dataset.truth_table, - list(variables), - ), - columns=list(variables), - ) + if self._use_super_selection: + df_values = pd.DataFrame( + data=self._dataset.query_table( + self._dataset.truth_table, + list(variables), + selection=selection, + ).tolist(), + columns=list(variables), + ) + + else: + df_values = pd.DataFrame( + data=self._dataset.query_table( + self._dataset.truth_table, + list(variables), + ).tolist(), + columns=list(variables), + ) # (Opt.) Cache indices. if self._use_cache and not os.path.exists(values_cache_path): self._save_values_cache(df_values, values_cache_path) - - df_selection = df_values.query(selection) + if not self._use_super_selection: + df_selection = df_values.query(selection) + else: + df_selection = df_values return df_selection def _get_random_state(self, selection: str) -> Optional[int]: