diff --git a/examples/parquet_to_sqlite.py b/examples/parquet_to_sqlite.py new file mode 100644 index 000000000..fff623050 --- /dev/null +++ b/examples/parquet_to_sqlite.py @@ -0,0 +1,14 @@ +from graphnet.data.utilities.parquet_to_sqlite import ParquetToSQLiteConverter + +if __name__ == "__main__": + # path to parquet file or directory containing parquet files + parquet_path = "/my_file.parquet" + # path to where you want the database to be stored + outdir = "/home/my_databases/" + # name of the database. Will be saved in outdir/database_name/data/database_name.db + database_name = "my_database_from_parquet" + + converter = ParquetToSQLiteConverter( + mc_truth_table="mc_truth", parquet_path=parquet_path + ) + converter.run(outdir=outdir, database_name=database_name) diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/sqlite/sqlite_utilities.py index 0823a9a64..06547b91e 100644 --- a/src/graphnet/data/sqlite/sqlite_utilities.py +++ b/src/graphnet/data/sqlite/sqlite_utilities.py @@ -31,21 +31,42 @@ def save_to_sql(df: pd.DataFrame, table_name: str, database: str): engine = sqlalchemy.create_engine("sqlite:///" + database) df.to_sql(table_name, con=engine, index=False, if_exists="append") engine.dispose() + + +def attach_index(database: str, table_name: str): + """Attaches the table index. Important for query times!""" + code = ( + "PRAGMA foreign_keys=off;\n" + "BEGIN TRANSACTION;\n" + f"CREATE INDEX event_no_{table_name} ON {table_name} (event_no);\n" + "COMMIT TRANSACTION;\n" + "PRAGMA foreign_keys=on;" + ) + run_sql_code(database, code) return -def create_table(database, table_name, df): +def create_table( + df: pd.DataFrame, + table_name: str, + database_path: str, + is_pulse_map: bool = False, +): """Creates a table. + Args: - pipeline_database (str): path to the pipeline database - df (str): pandas.DataFrame of combined predictions + database (str): path to the database + table_name (str): name of the table + columns (str): the names of the columns of the table + is_pulse_map (bool, optional): whether or not this is a pulse map table. Defaults to False. """ query_columns = list() for column in df.columns: if column == "event_no": - type_ = "INTEGER PRIMARY KEY NOT NULL" - else: - type_ = "FLOAT" + if not is_pulse_map: + type_ = "INTEGER PRIMARY KEY NOT NULL" + else: + type_ = "NOT NULL" query_columns.append(f"{column} {type_}") query_columns = ", ".join(query_columns) @@ -54,5 +75,9 @@ def create_table(database, table_name, df): f"CREATE TABLE {table_name} ({query_columns});\n" "PRAGMA foreign_keys=on;" ) - run_sql_code(database, code) + run_sql_code( + database_path, + code, + ) + return diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py new file mode 100644 index 000000000..1b3dbc8dc --- /dev/null +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -0,0 +1,168 @@ +import pandas as pd +import os +import sqlite3 +import awkward as ak + +import glob +from typing import List, Optional, Union +from tqdm.auto import trange +import numpy as np +import sqlalchemy +from graphnet.data.sqlite.sqlite_utilities import ( + run_sql_code, + save_to_sql, + attach_index, + create_table, +) + +from graphnet.utilities.logging import LoggerMixin + + +class ParquetToSQLiteConverter(LoggerMixin): + """Converts Parquet files to a SQLite database. Each event in the parquet file(s) are assigned a unique event id. + By default, every field in the parquet file(s) are extracted. One can choose to exclude certain fields by using the argument exclude_fields. + """ + + def __init__( + self, + parquet_path: Union[str, List[str]], + mc_truth_table: str = "mc_truth", + excluded_fields: Optional[Union[str, List[str]]] = None, + ): + # checks + if isinstance(parquet_path, str): + pass + elif isinstance(parquet_path, list): + assert isinstance( + parquet_path[0], str + ), "Argument `parquet_path` must be a string or list of strings" + else: + assert isinstance( + parquet_path, str + ), "Argument `parquet_path` must be a string or list of strings" + + assert isinstance( + mc_truth_table, str + ), "Argument `mc_truth_table` must be a string" + self._parquet_files = self._find_parquet_files(parquet_path) + if excluded_fields is not None: + self._excluded_fields = excluded_fields + else: + self._excluded_fields = [] + self._mc_truth_table = mc_truth_table + self._event_counter = 0 + self._created_tables = [] + + def _find_parquet_files(self, paths: Union[str, List[str]]) -> List[str]: + if isinstance(paths, str): + if paths.endswith(".parquet"): + files = [paths] + else: + files = glob.glob(f"{paths}/*.parquet") + elif isinstance(paths, list): + files = [] + for path in paths: + files.extend(self._find_parquet_files(path)) + assert len(files) > 0, f"No files found in {paths}" + return files + + def run(self, outdir: str, database_name: str): + self._create_output_directories(outdir, database_name) + database_path = os.path.join( + outdir, database_name, "data", database_name + ".db" + ) + for i in trange( + len(self._parquet_files), desc="Main", colour="#0000ff", position=0 + ): + parquet_file = ak.from_parquet(self._parquet_files[i]) + n_events_in_file = self._count_events(parquet_file) + for j in trange( + len(parquet_file.fields), + desc="%s" % (self._parquet_files[i].split("/")[-1]), + colour="#ffa500", + position=1, + leave=False, + ): + if parquet_file.fields[j] not in self._excluded_fields: + self._save_to_sql( + database_path, + parquet_file, + parquet_file.fields[j], + n_events_in_file, + ) + self._event_counter += n_events_in_file + self._save_config(outdir, database_name) + print( + f"Database saved at: \n{outdir}/{database_name}/data/{database_name}.db" + ) + + def _count_events(self, open_parquet_file: ak.Array) -> int: + return len(open_parquet_file[self._mc_truth_table]) + + def _save_to_sql( + self, + database_path: str, + ak_array: ak.Array = None, + field_name: str = None, + n_events_in_file: int = None, + ): + df = self._make_df(ak_array, field_name, n_events_in_file) + if field_name in self._created_tables: + save_to_sql( + database_path, + field_name, + df, + ) + else: + if len(df) > n_events_in_file: + is_pulse_map = True + else: + is_pulse_map = False + create_table(df, field_name, database_path, is_pulse_map) + if is_pulse_map: + attach_index(database_path, table_name=field_name) + self._created_tables.append(field_name) + save_to_sql( + database_path, + field_name, + df, + ) + + def _convert_to_dataframe( + self, + ak_array: ak.Array, + field_name: str, + n_events_in_file: int, + ) -> pd.DataFrame: + df = pd.DataFrame(ak.to_pandas(ak_array[field_name])) + if len(df.columns) == 1: + if df.columns == ["values"]: + df.columns = [field_name] + if ( + len(df) != n_events_in_file + ): # if true, the dataframe contains more than 1 row pr. event (e.g. Pulsemap). + event_nos = [] + c = 0 + for event_no in range( + self._event_counter, self._event_counter + n_events_in_file, 1 + ): + try: + event_nos.extend( + np.repeat(event_no, len(df[df.columns[0]][c])).tolist() + ) + except KeyError: # KeyError indicates that this df has no entry for event_no (e.g. an event with no detector response) + pass + c += 1 + else: + event_nos = np.arange(0, n_events_in_file, 1) + self._event_counter + df["event_no"] = event_nos + return df + + def _create_output_directories(self, outdir: str, database_name: str): + os.makedirs(outdir + "/" + database_name + "/data", exist_ok=True) + os.makedirs(outdir + "/" + database_name + "/config", exist_ok=True) + + def _save_config(self, outdir: str, database_name: str): + """Save the list of converted Parquet files to a CSV file.""" + df = pd.DataFrame(data=self._parquet_files, columns=["files"]) + df.to_csv(outdir + "/" + database_name + "/config/files.csv")