From 4c6142f6740a6c4f55141ff00feafb1da901850b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Tue, 6 Dec 2022 14:27:32 +0100 Subject: [PATCH 01/16] Remove duplication of, standardise create_table, save_to_sql, and attach_index --- src/graphnet/data/pipeline.py | 35 +------- .../data/sqlite/sqlite_dataconverter.py | 81 ++++--------------- src/graphnet/data/sqlite/sqlite_utilities.py | 47 +++++++---- .../data/utilities/parquet_to_sqlite.py | 17 ++-- src/graphnet/pisa/fitting.py | 32 +------- src/graphnet/training/weight_fitting.py | 4 +- .../data/test_dataconverters_and_datasets.py | 14 ++-- 7 files changed, 73 insertions(+), 157 deletions(-) diff --git a/src/graphnet/data/pipeline.py b/src/graphnet/data/pipeline.py index 0f6923e14..296e66041 100644 --- a/src/graphnet/data/pipeline.py +++ b/src/graphnet/data/pipeline.py @@ -13,7 +13,7 @@ import torch from torch.utils.data import DataLoader -from graphnet.data.sqlite.sqlite_utilities import run_sql_code, save_to_sql +from graphnet.data.sqlite.sqlite_utilities import save_to_sql, create_table from graphnet.training.utils import get_predictions, make_dataloader from graphnet.utilities.logging import get_logger @@ -216,38 +216,11 @@ def _append_to_pipeline( pipeline_database = outdir + "/%s.db" % self._pipeline_name if i == 0: # Only setup table schemes if its the first time appending - self._create_table(pipeline_database, "reconstruction", df) - self._create_table(pipeline_database, "truth", truth) + create_table(df.columns, "reconstruction", pipeline_database) + create_table(truth.columns, "truth", pipeline_database) save_to_sql(df, "reconstruction", pipeline_database) save_to_sql(truth, "truth", pipeline_database) if isinstance(retro, pd.DataFrame): if i == 0: - self._create_table(pipeline_database, "retro", retro) + create_table(retro.columns, "retro", pipeline_database) save_to_sql(retro, self._retro_table_name, pipeline_database) - - # @FIXME: Duplicate. - def _create_table( - self, pipeline_database: str, table_name: str, df: pd.DataFrame - ) -> None: - """Create a table. - - Args: - pipeline_database: Path to the pipeline database. - table_name: Name of the table in pipeline database. - df: DataFrame of combined predictions. - """ - query_columns_list = list() - for column in df.columns: - if column == "event_no": - type_ = "INTEGER PRIMARY KEY NOT NULL" - else: - type_ = "FLOAT" - query_columns_list.append(f"{column} {type_}") - query_columns = ", ".join(query_columns_list) - - code = ( - "PRAGMA foreign_keys=off;\n" - f"CREATE TABLE {table_name} ({query_columns});\n" - "PRAGMA foreign_keys=on;" - ) - run_sql_code(pipeline_database, code) diff --git a/src/graphnet/data/sqlite/sqlite_dataconverter.py b/src/graphnet/data/sqlite/sqlite_dataconverter.py index b19acf788..7f3fdfdfb 100644 --- a/src/graphnet/data/sqlite/sqlite_dataconverter.py +++ b/src/graphnet/data/sqlite/sqlite_dataconverter.py @@ -10,7 +10,7 @@ from tqdm import tqdm from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined] -from graphnet.data.sqlite.sqlite_utilities import run_sql_code, save_to_sql +from graphnet.data.sqlite.sqlite_utilities import save_to_sql, create_table class SQLiteDataConverter(DataConverter): @@ -92,12 +92,13 @@ def merge_files( input_files, table_name ) if len(column_names) > 1: - is_pulse_map = is_pulsemap_check(table_name) - self._create_table( - output_file, - table_name, + create_table( column_names, - is_pulse_map=is_pulse_map, + table_name, + output_file, + integer_primary_key=not ( + is_pulse_map(table_name) or is_mc_tree(table_name) + ), ) # Merge temporary databases into newly created one @@ -157,60 +158,6 @@ def any_pulsemap_is_non_empty(self, data_dict: Dict[str, Dict]) -> bool: pulsemap_dicts = [data_dict[pulsemap] for pulsemap in self._pulsemaps] return any(d["dom_x"] for d in pulsemap_dicts) - def _attach_index(self, database: str, table_name: str) -> None: - """Attach the table index. - - Important for query times! - """ - code = ( - "PRAGMA foreign_keys=off;\n" - "BEGIN TRANSACTION;\n" - f"CREATE INDEX event_no_{table_name} ON {table_name} (event_no);\n" - "COMMIT TRANSACTION;\n" - "PRAGMA foreign_keys=on;" - ) - run_sql_code(database, code) - - def _create_table( - self, - database: str, - table_name: str, - columns: List[str], - is_pulse_map: bool = False, - ) -> None: - """Create a table. - - Args: - database: Path to the database. - table_name: Name of the table. - columns: The names of the columns of the table. - is_pulse_map: Whether or not this is a pulse map table. - """ - query_columns = list() - for column in columns: - if column == "event_no": - if not is_pulse_map: - type_ = "INTEGER PRIMARY KEY NOT NULL" - else: - type_ = "NOT NULL" - else: - type_ = "FLOAT" - query_columns.append(f"{column} {type_}") - query_columns_string = ", ".join(query_columns) - - code = ( - "PRAGMA foreign_keys=off;\n" - f"CREATE TABLE {table_name} ({query_columns_string});\n" - "PRAGMA foreign_keys=on;" - ) - run_sql_code(database, code) - - if is_pulse_map: - self.debug(table_name) - self.debug("Attaching indices") - self._attach_index(database, table_name) - return - def _submit_to_database( self, database: str, key: str, data: pd.DataFrame ) -> None: @@ -280,9 +227,11 @@ def construct_dataframe(extraction: Dict[str, Any]) -> pd.DataFrame: return out -def is_pulsemap_check(table_name: str) -> bool: - """Check whether `table_name` corresponds to a pulsemap.""" - if "pulse" in table_name.lower(): - return True - else: - return False +def is_pulse_map(table_name: str) -> bool: + """Check whether `table_name` corresponds to a pulse map.""" + return "pulse" in table_name.lower() or "series" in table_name.lower() + + +def is_mc_tree(table_name: str) -> bool: + """Check whether `table_name` corresponds to an MC tree.""" + return "I3MCTree" in table_name diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/sqlite/sqlite_utilities.py index 696d6ccde..0691bc3e0 100644 --- a/src/graphnet/data/sqlite/sqlite_utilities.py +++ b/src/graphnet/data/sqlite/sqlite_utilities.py @@ -1,5 +1,7 @@ """SQLite-specific utility functions for use in `graphnet.data`.""" +from typing import List + import pandas as pd import sqlalchemy import sqlite3 @@ -33,47 +35,59 @@ def save_to_sql(df: pd.DataFrame, table_name: str, database: str) -> None: engine.dispose() -def attach_index(database: str, table_name: str) -> None: - """Attaches the table index. +def attach_index( + database_path: str, table_name: str, index_column: str = "event_no" +) -> None: + """Attach the table (i.e., event) index. Important for query times! """ code = ( "PRAGMA foreign_keys=off;\n" "BEGIN TRANSACTION;\n" - f"CREATE INDEX event_no_{table_name} ON {table_name} (event_no);\n" + f"CREATE INDEX {index_column}_{table_name} " + f"ON {table_name} ({index_column});\n" "COMMIT TRANSACTION;\n" "PRAGMA foreign_keys=on;" ) - run_sql_code(database, code) + run_sql_code(database_path, code) def create_table( - df: pd.DataFrame, + columns: List[str], table_name: str, database_path: str, - is_pulse_map: bool = False, + index_column: str = "event_no", + integer_primary_key: bool = True, ) -> None: """Create a table. Args: - df: Data to be saved to table + columns: Column names to be created in table. table_name: Name of the table. database_path: Path to the database. - is_pulse_map: Whether or not this is a pulse map table. + index_column: Name of the index column. + integer_primary_key: Whether or not to create the `index_column` with + the `INTEGER PRIMARY KEY` type. Such a column is required to have + unique, integer values for each row. This is appropriate when the + table has one row per event, e.g., event-level MC truth. It is not + appropriate for pulse map series, particle-level MC truth, and + other such data that is expected to have more that one row per + event (i.e., with the same index). """ - query_columns = list() - for column in df.columns: - if column == "event_no": - if not is_pulse_map: - type_ = "INTEGER PRIMARY KEY NOT NULL" - else: - type_ = "NOT NULL" + # Prepare column names and types + query_columns = [] + for column in columns: + if column == index_column and integer_primary_key: + type_ = "INTEGER PRIMARY KEY NOT NULL" else: type_ = "NOT NULL" + query_columns.append(f"{column} {type_}") + query_columns_string = ", ".join(query_columns) + # Run SQL code code = ( "PRAGMA foreign_keys=off;\n" f"CREATE TABLE {table_name} ({query_columns_string});\n" @@ -83,3 +97,6 @@ def create_table( database_path, code, ) + + if not integer_primary_key: + attach_index(database_path, table_name) diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py index 71961ce7a..7e62a78de 100644 --- a/src/graphnet/data/utilities/parquet_to_sqlite.py +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -118,23 +118,26 @@ def _save_to_sql( df = self._convert_to_dataframe(ak_array, field_name, n_events_in_file) if field_name in self._created_tables: save_to_sql( - database_path, - field_name, df, + field_name, + database_path, ) else: if len(df) > n_events_in_file: is_pulse_map = True else: is_pulse_map = False - create_table(df, field_name, database_path, is_pulse_map) - if is_pulse_map: - attach_index(database_path, table_name=field_name) + create_table( + df.columns, + field_name, + database_path, + integer_primary_key=not is_pulse_map, + ) self._created_tables.append(field_name) save_to_sql( - database_path, - field_name, df, + field_name, + database_path, ) def _convert_to_dataframe( diff --git a/src/graphnet/pisa/fitting.py b/src/graphnet/pisa/fitting.py index 210795c9c..daf93370b 100644 --- a/src/graphnet/pisa/fitting.py +++ b/src/graphnet/pisa/fitting.py @@ -20,7 +20,7 @@ from pisa.analysis.analysis import Analysis from pisa import ureg -from graphnet.data.sqlite import run_sql_code, save_to_sql +from graphnet.data.sqlite import save_to_sql, create_table mpl.use("pdf") plt.rc("font", family="serif") @@ -157,38 +157,10 @@ def fit_weights( results = results.append(data) if add_to_database: - self._create_table(self._database_path, weight_name, results) + create_table(results.columns, weight_name, self._database_path) save_to_sql(results, weight_name, self._database_path) return results.sort_values("event_no").reset_index(drop=True) - # @TODO: Remove duplication wrt. method with same name in - # `src/graphnet/data/sqlite/sqlite_dataconverter.py` - def _create_table( - self, database: str, table_name: str, df: pd.DataFrame - ) -> None: - """Create a table. - - Args: - database: Path to the pipeline database. - table_name: Name of the table to be created. - df: DataFrame of combined predictions. - """ - query_columns_list = list() - for column in df.columns: - if column == "event_no": - type_ = "INTEGER PRIMARY KEY NOT NULL" - else: - type_ = "FLOAT" - query_columns_list.append(f"{column} {type_}") - query_columns = ", ".join(query_columns_list) - - code = ( - "PRAGMA foreign_keys=off;\n" - f"CREATE TABLE {table_name} ({query_columns});\n" - "PRAGMA foreign_keys=on;" - ) - run_sql_code(database, code) - def _make_config( self, config_outdir: str, diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index c6c23f353..bd7559575 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -92,7 +92,9 @@ def fit( weights = self._fit_weights(truth, **kwargs) if add_to_database: - create_table(weights, self._weight_name, self._database_path) + create_table( + weights.columns, self._weight_name, self._database_path + ) save_to_sql(weights, self._weight_name, self._database_path) return weights.sort_values(self._index_column).reset_index(drop=True) diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index 208494e60..e2adac9ba 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -23,7 +23,7 @@ SQLiteDataConverter, ) from graphnet.data.sqlite.sqlite_dataconverter import ( - is_pulsemap_check, + is_pulse_map, ) from graphnet.utilities.imports import has_icecube_package @@ -57,12 +57,12 @@ def get_file_path(backend: str) -> str: # Unit test(s) def test_is_pulsemap_check() -> None: """Test behaviour of `is_pulsemap_check`.""" - assert is_pulsemap_check("SplitInIcePulses") is True - assert is_pulsemap_check("SRTInIcePulses") is True - assert is_pulsemap_check("InIceDSTPulses") is True - assert is_pulsemap_check("RTTWOfflinePulses") is True - assert is_pulsemap_check("truth") is False - assert is_pulsemap_check("retro") is False + assert is_pulse_map("SplitInIcePulses") is True + assert is_pulse_map("SRTInIcePulses") is True + assert is_pulse_map("InIceDSTPulses") is True + assert is_pulse_map("RTTWOfflinePulses") is True + assert is_pulse_map("truth") is False + assert is_pulse_map("retro") is False @pytest.mark.order(1) From 8ac1ff89d8eb217539cf47af965cbadc4f037dd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Tue, 6 Dec 2022 14:38:18 +0100 Subject: [PATCH 02/16] Allow for specifying the non-index column types --- src/graphnet/data/sqlite/sqlite_utilities.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/sqlite/sqlite_utilities.py index 0691bc3e0..feb9d725f 100644 --- a/src/graphnet/data/sqlite/sqlite_utilities.py +++ b/src/graphnet/data/sqlite/sqlite_utilities.py @@ -57,7 +57,9 @@ def create_table( columns: List[str], table_name: str, database_path: str, + *, index_column: str = "event_no", + default_type: str = "NOT NULL", integer_primary_key: bool = True, ) -> None: """Create a table. @@ -67,6 +69,7 @@ def create_table( table_name: Name of the table. database_path: Path to the database. index_column: Name of the index column. + default_type: The type used for all non-index columns. integer_primary_key: Whether or not to create the `index_column` with the `INTEGER PRIMARY KEY` type. Such a column is required to have unique, integer values for each row. This is appropriate when the @@ -78,10 +81,12 @@ def create_table( # Prepare column names and types query_columns = [] for column in columns: - if column == index_column and integer_primary_key: - type_ = "INTEGER PRIMARY KEY NOT NULL" - else: - type_ = "NOT NULL" + type_ = default_type + if column == index_column: + if integer_primary_key: + type_ = "INTEGER PRIMARY KEY NOT NULL" + else: + type_ = "NOT NULL" query_columns.append(f"{column} {type_}") From 340458291380067738d9f12616d97c34745bca61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Tue, 6 Dec 2022 14:38:41 +0100 Subject: [PATCH 03/16] Specify non-index column type --- src/graphnet/data/sqlite/sqlite_dataconverter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/graphnet/data/sqlite/sqlite_dataconverter.py b/src/graphnet/data/sqlite/sqlite_dataconverter.py index 7f3fdfdfb..30fc3ce26 100644 --- a/src/graphnet/data/sqlite/sqlite_dataconverter.py +++ b/src/graphnet/data/sqlite/sqlite_dataconverter.py @@ -96,6 +96,7 @@ def merge_files( column_names, table_name, output_file, + default_type="FLOAT", integer_primary_key=not ( is_pulse_map(table_name) or is_mc_tree(table_name) ), From d574846af7802a34894be496fd055401ed829070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 09:29:41 +0100 Subject: [PATCH 04/16] Remove if-statement that breaks order of arguments --- src/graphnet/utilities/config/base_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/utilities/config/base_config.py b/src/graphnet/utilities/config/base_config.py index 15f6b078b..6daecc05a 100644 --- a/src/graphnet/utilities/config/base_config.py +++ b/src/graphnet/utilities/config/base_config.py @@ -57,7 +57,7 @@ def get_all_argument_values( # Get all default argument values cfg = OrderedDict() for key, parameter in inspect.signature(fn).parameters.items(): - if key == "self" or parameter.default == inspect._empty: + if key == "self": continue cfg[key] = parameter.default From eb714262a747fd77464840cf8c80df3d30d3e034 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 09:30:12 +0100 Subject: [PATCH 05/16] Add unit test of ParquetToSQLiteConverter --- .../data/test_dataconverters_and_datasets.py | 56 +++++++++++++------ 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index e2adac9ba..3bbd4180c 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -2,8 +2,8 @@ import os -import numpy as np import pytest +import torch import graphnet.constants from graphnet.data.constants import FEATURES, TRUTH @@ -12,19 +12,11 @@ I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, - I3GenericExtractor, -) -from graphnet.data.parquet import ( - ParquetDataset, - ParquetDataConverter, -) -from graphnet.data.sqlite import ( - SQLiteDataset, - SQLiteDataConverter, -) -from graphnet.data.sqlite.sqlite_dataconverter import ( - is_pulse_map, ) +from graphnet.data.parquet import ParquetDataset, ParquetDataConverter +from graphnet.data.sqlite import SQLiteDataset, SQLiteDataConverter +from graphnet.data.sqlite.sqlite_dataconverter import is_pulse_map +from graphnet.data.utilities.parquet_to_sqlite import ParquetToSQLiteConverter from graphnet.utilities.imports import has_icecube_package if has_icecube_package(): @@ -102,7 +94,7 @@ def test_dataconverter( assert os.path.exists(path), path -@pytest.mark.order(3) +@pytest.mark.order(2) @pytest.mark.parametrize("backend", ["sqlite", "parquet"]) def test_dataset(backend: str) -> None: """Test the implementation of `Dataset` for `backend`.""" @@ -147,7 +139,7 @@ def test_dataset(backend: str) -> None: assert len(event.features) == len(opt["features"]) -@pytest.mark.order(4) +@pytest.mark.order(3) @pytest.mark.parametrize("backend", ["sqlite", "parquet"]) def test_dataset_query_table(backend: str) -> None: """Test the implementation of `Dataset._query_table` for `backend`.""" @@ -190,5 +182,37 @@ def test_dataset_query_table(backend: str) -> None: assert results_all_subset == results_single +@pytest.mark.order(4) +def test_parquet_to_sqlite_converter() -> None: + """Test the implementation of `ParquetToSQLiteConverter`.""" + # Constructor ParquetToSQLiteConverter instance + converter = ParquetToSQLiteConverter( + parquet_path=get_file_path("parquet"), + mc_truth_table="truth", + ) + + # Perform conversion from I3 to `backend` + database_name = FILE_NAME + "_from_parquet" + converter.run(OUTPUT_DATA_DIR, database_name) + + # Check that output exists + path = f"{OUTPUT_DATA_DIR}/{database_name}/data/{database_name}.db" + assert os.path.exists(path), path + + # Check that datasets agree + opt = dict( + pulsemaps="SRTInIcePulses", + features=FEATURES.DEEPCORE, + truth=TRUTH.DEEPCORE, + ) + + dataset_from_parquet = SQLiteDataset(path, **opt) # type: ignore[arg-type] + dataset = SQLiteDataset(get_file_path("sqlite"), **opt) # type: ignore[arg-type] + + assert len(dataset_from_parquet) == len(dataset) + for ix in range(len(dataset)): + assert torch.allclose(dataset_from_parquet[ix].x, dataset[ix].x) + + if __name__ == "__main__": - test_dataset_query_table("parquet") + test_parquet_to_sqlite_converter() From 9980b09d85c9111da1a0c65f623566146096e2c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 10:43:00 +0100 Subject: [PATCH 06/16] Add check for kwargs-like argument --- src/graphnet/utilities/config/base_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/graphnet/utilities/config/base_config.py b/src/graphnet/utilities/config/base_config.py index 6daecc05a..19427b35b 100644 --- a/src/graphnet/utilities/config/base_config.py +++ b/src/graphnet/utilities/config/base_config.py @@ -57,7 +57,9 @@ def get_all_argument_values( # Get all default argument values cfg = OrderedDict() for key, parameter in inspect.signature(fn).parameters.items(): - if key == "self": + if key == "self" or ( + key in ["kwargs", "kwds"] and parameter.default == inspect._empty + ): continue cfg[key] = parameter.default From 482b97fc3d5fd5c3c809b8df1d9bd9746666df70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 10:57:59 +0100 Subject: [PATCH 07/16] More robust checking --- src/graphnet/utilities/config/base_config.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/graphnet/utilities/config/base_config.py b/src/graphnet/utilities/config/base_config.py index 19427b35b..658a816dc 100644 --- a/src/graphnet/utilities/config/base_config.py +++ b/src/graphnet/utilities/config/base_config.py @@ -56,12 +56,14 @@ def get_all_argument_values( """Return dict of all argument values to `fn`, including defaults.""" # Get all default argument values cfg = OrderedDict() - for key, parameter in inspect.signature(fn).parameters.items(): - if key == "self" or ( - key in ["kwargs", "kwds"] and parameter.default == inspect._empty - ): + for key, param in inspect.signature(fn).parameters.items(): + # Don't save `self`, `*args`, or `**kwargs` + if key == "self" or param.kind in [ + param.VAR_POSITIONAL, + param.VAR_KEYWORD, + ]: continue - cfg[key] = parameter.default + cfg[key] = param.default # Add positional arguments for key, val in zip(cfg.keys(), args): From cd5a7aecb141c78c0cae409a544d946f91b6e3b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 13:12:23 +0100 Subject: [PATCH 08/16] Add create_table_and_save_to_sql utility function for simplifying optimised table creation --- src/graphnet/data/sqlite/sqlite_utilities.py | 58 ++++++++++++++++++-- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/sqlite/sqlite_utilities.py index feb9d725f..997023752 100644 --- a/src/graphnet/data/sqlite/sqlite_utilities.py +++ b/src/graphnet/data/sqlite/sqlite_utilities.py @@ -1,5 +1,6 @@ """SQLite-specific utility functions for use in `graphnet.data`.""" +import os.path from typing import List import pandas as pd @@ -7,20 +8,38 @@ import sqlite3 -def run_sql_code(database: str, code: str) -> None: +def database_exists(database_path: str) -> bool: + """Check whether database exists at `database_path`.""" + assert database_path.endswith( + ".db" + ), "Provided database path does not end in `.db`." + return os.path.exists(database_path) + + +def database_table_exists(database_path: str, table_name: str) -> bool: + """Check whether `table_name` exists in database at `database_path`.""" + if not database_exists(database_path): + return False + query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';" + with sqlite3.connect(database_path) as conn: + result = pd.read_sql(query, conn) + return len(result) == 1 + + +def run_sql_code(database_path: str, code: str) -> None: """Execute SQLite code. Args: - database: Path to databases + database_path: Path to databases code: SQLite code """ - conn = sqlite3.connect(database) + conn = sqlite3.connect(database_path) c = conn.cursor() c.executescript(code) c.close() -def save_to_sql(df: pd.DataFrame, table_name: str, database: str) -> None: +def save_to_sql(df: pd.DataFrame, table_name: str, database_path: str) -> None: """Save a dataframe `df` to a table `table_name` in SQLite `database`. Table must exist already. @@ -28,9 +47,9 @@ def save_to_sql(df: pd.DataFrame, table_name: str, database: str) -> None: Args: df: Dataframe with data to be stored in sqlite table table_name: Name of table. Must exist already - database: Path to SQLite database + database_path: Path to SQLite database """ - engine = sqlalchemy.create_engine("sqlite:///" + database) + engine = sqlalchemy.create_engine("sqlite:///" + database_path) df.to_sql(table_name, con=engine, index=False, if_exists="append") engine.dispose() @@ -78,6 +97,9 @@ def create_table( other such data that is expected to have more that one row per event (i.e., with the same index). """ + print( + f"!! {table_name} in {database_path} has integer_primary_key = {integer_primary_key}" + ) # Prepare column names and types query_columns = [] for column in columns: @@ -103,5 +125,29 @@ def create_table( code, ) + # Attaching index to all non-truth-like tables (e.g., pulse maps). if not integer_primary_key: + print(f"!! Attaching index for {table_name} in {database_path}") attach_index(database_path, table_name) + + +def create_table_and_save_to_sql( + df: pd.DataFrame, + table_name: str, + database_path: str, + *, + index_column: str = "event_no", + default_type: str = "NOT NULL", + integer_primary_key: bool = True, +) -> None: + """Create table if it doesn't exist and save dataframe to it.""" + if not database_table_exists(database_path, table_name): + create_table( + df.columns, + table_name, + database_path, + index_column=index_column, + default_type=default_type, + integer_primary_key=integer_primary_key, + ) + save_to_sql(df, table_name=table_name, database_path=database_path) From 78d367613508d8ca82735df6c23eb0a23fbfad54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 13:12:51 +0100 Subject: [PATCH 09/16] Switch to create_table_and_save_to_sql --- src/graphnet/data/pipeline.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/graphnet/data/pipeline.py b/src/graphnet/data/pipeline.py index 296e66041..946570785 100644 --- a/src/graphnet/data/pipeline.py +++ b/src/graphnet/data/pipeline.py @@ -13,7 +13,7 @@ import torch from torch.utils.data import DataLoader -from graphnet.data.sqlite.sqlite_utilities import save_to_sql, create_table +from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql from graphnet.training.utils import get_predictions, make_dataloader from graphnet.utilities.logging import get_logger @@ -97,7 +97,7 @@ def __call__( df = self._inference(device, dataloader) truth = self._get_truth(database, event_batches[i].tolist()) retro = self._get_retro(database, event_batches[i].tolist()) - self._append_to_pipeline(outdir, truth, retro, df, i) + self._append_to_pipeline(outdir, truth, retro, df) i += 1 else: logger.info(outdir) @@ -210,17 +210,12 @@ def _append_to_pipeline( truth: pd.DataFrame, retro: pd.DataFrame, df: pd.DataFrame, - i: int, ) -> None: os.makedirs(outdir, exist_ok=True) pipeline_database = outdir + "/%s.db" % self._pipeline_name - if i == 0: - # Only setup table schemes if its the first time appending - create_table(df.columns, "reconstruction", pipeline_database) - create_table(truth.columns, "truth", pipeline_database) - save_to_sql(df, "reconstruction", pipeline_database) - save_to_sql(truth, "truth", pipeline_database) + create_table_and_save_to_sql(df, "reconstruction", pipeline_database) + create_table_and_save_to_sql(truth, "truth", pipeline_database) if isinstance(retro, pd.DataFrame): - if i == 0: - create_table(retro.columns, "retro", pipeline_database) - save_to_sql(retro, self._retro_table_name, pipeline_database) + create_table_and_save_to_sql( + retro, self._retro_table_name, pipeline_database + ) From 699b1e0cac99158e6d20ecb0e1cd79f2f02c5045 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 13:14:12 +0100 Subject: [PATCH 10/16] Switch to create_table_and_save_to_sql --- src/graphnet/data/sqlite/__init__.py | 7 +++- .../data/sqlite/sqlite_dataconverter.py | 15 ++++++- .../data/utilities/parquet_to_sqlite.py | 40 ++++++------------- src/graphnet/training/weight_fitting.py | 16 ++++---- 4 files changed, 38 insertions(+), 40 deletions(-) diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py index d914a2c11..d9201e6e3 100644 --- a/src/graphnet/data/sqlite/__init__.py +++ b/src/graphnet/data/sqlite/__init__.py @@ -3,7 +3,12 @@ from graphnet.utilities.imports import has_torch_package from .sqlite_dataconverter import SQLiteDataConverter -from .sqlite_utilities import run_sql_code, save_to_sql, create_table +from .sqlite_utilities import ( + run_sql_code, + save_to_sql, + create_table, + create_table_and_save_to_sql, +) if has_torch_package(): from .sqlite_dataset import SQLiteDataset diff --git a/src/graphnet/data/sqlite/sqlite_dataconverter.py b/src/graphnet/data/sqlite/sqlite_dataconverter.py index 30fc3ce26..b520ee583 100644 --- a/src/graphnet/data/sqlite/sqlite_dataconverter.py +++ b/src/graphnet/data/sqlite/sqlite_dataconverter.py @@ -10,7 +10,10 @@ from tqdm import tqdm from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined] -from graphnet.data.sqlite.sqlite_utilities import save_to_sql, create_table +from graphnet.data.sqlite.sqlite_utilities import ( + create_table, + create_table_and_save_to_sql, +) class SQLiteDataConverter(DataConverter): @@ -51,7 +54,15 @@ def save_data(self, data: List[OrderedDict], output_file: str) -> None: saved_any = False for table, df in dataframe.items(): if len(df) > 0: - save_to_sql(df, table, output_file) + create_table_and_save_to_sql( + df, + table, + output_file, + default_type="FLOAT", + integer_primary_key=not ( + is_pulse_map(table) or is_mc_tree(table) + ), + ) saved_any = True if saved_any: diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py index 7e62a78de..139acfe91 100644 --- a/src/graphnet/data/utilities/parquet_to_sqlite.py +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -9,11 +9,7 @@ import pandas as pd from tqdm.auto import trange -from graphnet.data.sqlite.sqlite_utilities import ( - save_to_sql, - attach_index, - create_table, -) +from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql from graphnet.utilities.logging import LoggerMixin @@ -54,7 +50,6 @@ def __init__( self._excluded_fields = [] self._mc_truth_table = mc_truth_table self._event_counter = 0 - self._created_tables: List[str] = [] def _find_parquet_files(self, paths: Union[str, List[str]]) -> List[str]: if isinstance(paths, str): @@ -116,29 +111,18 @@ def _save_to_sql( n_events_in_file: int, ) -> None: df = self._convert_to_dataframe(ak_array, field_name, n_events_in_file) - if field_name in self._created_tables: - save_to_sql( - df, - field_name, - database_path, - ) + + if len(df) > n_events_in_file: + is_pulse_map = True else: - if len(df) > n_events_in_file: - is_pulse_map = True - else: - is_pulse_map = False - create_table( - df.columns, - field_name, - database_path, - integer_primary_key=not is_pulse_map, - ) - self._created_tables.append(field_name) - save_to_sql( - df, - field_name, - database_path, - ) + is_pulse_map = False + + create_table_and_save_to_sql( + df, + field_name, + database_path, + integer_primary_key=not is_pulse_map, + ) def _convert_to_dataframe( self, diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index bd7559575..404f740d7 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -1,14 +1,13 @@ """Classes for fitting per-event weights for training.""" +from abc import ABC, abstractmethod +from typing import Any, Optional, List, Callable + import numpy as np import pandas as pd import sqlite3 -from typing import Any, Optional, List, Callable -from graphnet.data.sqlite.sqlite_utilities import ( - save_to_sql, - create_table, -) -from abc import ABC, abstractmethod + +from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql from graphnet.utilities.logging import LoggerMixin @@ -92,10 +91,9 @@ def fit( weights = self._fit_weights(truth, **kwargs) if add_to_database: - create_table( - weights.columns, self._weight_name, self._database_path + create_table_and_save_to_sql( + weights, self._weight_name, self._database_path ) - save_to_sql(weights, self._weight_name, self._database_path) return weights.sort_values(self._index_column).reset_index(drop=True) @abstractmethod From c63d742edec2fbe27aaef6e227480486d8dd2471 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 13:18:31 +0100 Subject: [PATCH 11/16] Clean-up --- src/graphnet/data/sqlite/sqlite_utilities.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/sqlite/sqlite_utilities.py index 997023752..a0e88c0ab 100644 --- a/src/graphnet/data/sqlite/sqlite_utilities.py +++ b/src/graphnet/data/sqlite/sqlite_utilities.py @@ -97,9 +97,6 @@ def create_table( other such data that is expected to have more that one row per event (i.e., with the same index). """ - print( - f"!! {table_name} in {database_path} has integer_primary_key = {integer_primary_key}" - ) # Prepare column names and types query_columns = [] for column in columns: @@ -127,7 +124,6 @@ def create_table( # Attaching index to all non-truth-like tables (e.g., pulse maps). if not integer_primary_key: - print(f"!! Attaching index for {table_name} in {database_path}") attach_index(database_path, table_name) From 929f9c6e4f5aff298ccec7c38ff659d5805243d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 13:18:54 +0100 Subject: [PATCH 12/16] Standardise progress bar --- src/graphnet/data/utilities/parquet_to_sqlite.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py index 139acfe91..1e3d5df3f 100644 --- a/src/graphnet/data/utilities/parquet_to_sqlite.py +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -75,8 +75,12 @@ def run(self, outdir: str, database_name: str) -> None: database_path = os.path.join( outdir, database_name, "data", database_name + ".db" ) + self.info(f"Processing {len(self._parquet_files)} Parquet file(s)") for i in trange( - len(self._parquet_files), desc="Main", colour="#0000ff", position=0 + len(self._parquet_files), + unit="file(s)", + colour="green", + position=0, ): parquet_file = ak.from_parquet(self._parquet_files[i]) n_events_in_file = self._count_events(parquet_file) From e02452bd864a0d6c801e37402eb6c631bfcc5a4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 13:20:07 +0100 Subject: [PATCH 13/16] Add unit test for SQLite table query plans --- .../data/test_dataconverters_and_datasets.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index 3bbd4180c..dc45f6655 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -2,7 +2,9 @@ import os +import pandas as pd import pytest +import sqlite3 import torch import graphnet.constants @@ -214,5 +216,35 @@ def test_parquet_to_sqlite_converter() -> None: assert torch.allclose(dataset_from_parquet[ix].x, dataset[ix].x) +@pytest.mark.order(5) +@pytest.mark.parametrize("pulsemap", ["SRTInIcePulses"]) +@pytest.mark.parametrize("event_no", [1]) +def test_database_query_plan(pulsemap: str, event_no: int) -> None: + """Test query plan agreement in original and parquet-converted database.""" + # Configure paths to databases to compare + database_name = FILE_NAME + "_from_parquet" + parquet_converted_database = ( + f"{OUTPUT_DATA_DIR}/{database_name}/data/{database_name}.db" + ) + sqlite_database = get_file_path("sqlite") + + # Get query plans + query = f"EXPLAIN QUERY PLAN SELECT * FROM {pulsemap} WHERE event_no={event_no}" + with sqlite3.connect(sqlite_database) as conn: + sqlite_plan = pd.read_sql(query, conn) + + with sqlite3.connect(parquet_converted_database) as conn: + parquet_plan = pd.read_sql(query, conn) + + # Compare + assert "USING INDEX event_no" in sqlite_plan["detail"].iloc[0] + assert "USING INDEX event_no" in parquet_plan["detail"].iloc[0] + + assert (sqlite_plan["detail"] == parquet_plan["detail"]).all() + + if __name__ == "__main__": + test_dataconverter("sqlite") + test_dataconverter("parquet") test_parquet_to_sqlite_converter() + test_database_query_plan("SRTInIcePulses", 1) From 98c5066714639c6403e43cb974b435794b84bf5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 13:22:15 +0100 Subject: [PATCH 14/16] Wrap line --- tests/data/test_dataconverters_and_datasets.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index dc45f6655..d505f41d8 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -229,7 +229,11 @@ def test_database_query_plan(pulsemap: str, event_no: int) -> None: sqlite_database = get_file_path("sqlite") # Get query plans - query = f"EXPLAIN QUERY PLAN SELECT * FROM {pulsemap} WHERE event_no={event_no}" + query = f""" + EXPLAIN QUERY PLAN + SELECT * FROM {pulsemap} + WHERE event_no={event_no} + """ with sqlite3.connect(sqlite_database) as conn: sqlite_plan = pd.read_sql(query, conn) From dbbf2b5b8a62fe8facfcc8fc3a2930b305222293 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 13:27:15 +0100 Subject: [PATCH 15/16] Wrap lines --- .../data/utilities/parquet_to_sqlite.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py index 1e3d5df3f..f16275846 100644 --- a/src/graphnet/data/utilities/parquet_to_sqlite.py +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -100,8 +100,9 @@ def run(self, outdir: str, database_name: str) -> None: ) self._event_counter += n_events_in_file self._save_config(outdir, database_name) - print( - f"Database saved at: \n{outdir}/{database_name}/data/{database_name}.db" + self.info( + "Database saved at: \n" + f"{outdir}/{database_name}/data/{database_name}.db" ) def _count_events(self, open_parquet_file: ak.Array) -> int: @@ -138,9 +139,10 @@ def _convert_to_dataframe( if len(df.columns) == 1: if df.columns == ["values"]: df.columns = [field_name] - if ( - len(df) != n_events_in_file - ): # if true, the dataframe contains more than 1 row pr. event (e.g. Pulsemap). + + # If true, the dataframe contains more than 1 row pr. event (i.e., + # pulsemap). + if len(df) != n_events_in_file: event_nos = [] c = 0 for event_no in range( @@ -150,7 +152,10 @@ def _convert_to_dataframe( event_nos.extend( np.repeat(event_no, len(df[df.columns[0]][c])).tolist() ) - except KeyError: # KeyError indicates that this df has no entry for event_no (e.g. an event with no detector response) + + # KeyError indicates that this df has no entry for event_no + # (e.g., an event with no detector response). + except KeyError: pass c += 1 else: From 6aef937e0acc6923d1d13ba0c925b8711fdd4014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 8 Dec 2022 13:33:32 +0100 Subject: [PATCH 16/16] Switch to create_table_and_save_to_sql --- src/graphnet/data/sqlite/__init__.py | 7 +------ src/graphnet/pisa/fitting.py | 7 ++++--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py index d9201e6e3..db9d66616 100644 --- a/src/graphnet/data/sqlite/__init__.py +++ b/src/graphnet/data/sqlite/__init__.py @@ -3,12 +3,7 @@ from graphnet.utilities.imports import has_torch_package from .sqlite_dataconverter import SQLiteDataConverter -from .sqlite_utilities import ( - run_sql_code, - save_to_sql, - create_table, - create_table_and_save_to_sql, -) +from .sqlite_utilities import create_table_and_save_to_sql if has_torch_package(): from .sqlite_dataset import SQLiteDataset diff --git a/src/graphnet/pisa/fitting.py b/src/graphnet/pisa/fitting.py index daf93370b..ae3bc41c9 100644 --- a/src/graphnet/pisa/fitting.py +++ b/src/graphnet/pisa/fitting.py @@ -20,7 +20,7 @@ from pisa.analysis.analysis import Analysis from pisa import ureg -from graphnet.data.sqlite import save_to_sql, create_table +from graphnet.data.sqlite import create_table_and_save_to_sql mpl.use("pdf") plt.rc("font", family="serif") @@ -157,8 +157,9 @@ def fit_weights( results = results.append(data) if add_to_database: - create_table(results.columns, weight_name, self._database_path) - save_to_sql(results, weight_name, self._database_path) + create_table_and_save_to_sql( + results.columns, weight_name, self._database_path + ) return results.sort_values("event_no").reset_index(drop=True) def _make_config(