Skip to content

Commit

Permalink
Merge pull request #277 from RasmusOrsoe/parquet_to_sqlite_converter
Browse files Browse the repository at this point in the history
Parquet to sqlite converter
  • Loading branch information
RasmusOrsoe authored Oct 6, 2022
2 parents d109faf + 0849615 commit 30f338c
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 7 deletions.
14 changes: 14 additions & 0 deletions examples/parquet_to_sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from graphnet.data.utilities.parquet_to_sqlite import ParquetToSQLiteConverter

if __name__ == "__main__":
# path to parquet file or directory containing parquet files
parquet_path = "/my_file.parquet"
# path to where you want the database to be stored
outdir = "/home/my_databases/"
# name of the database. Will be saved in outdir/database_name/data/database_name.db
database_name = "my_database_from_parquet"

converter = ParquetToSQLiteConverter(
mc_truth_table="mc_truth", parquet_path=parquet_path
)
converter.run(outdir=outdir, database_name=database_name)
39 changes: 32 additions & 7 deletions src/graphnet/data/sqlite/sqlite_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,42 @@ def save_to_sql(df: pd.DataFrame, table_name: str, database: str):
engine = sqlalchemy.create_engine("sqlite:///" + database)
df.to_sql(table_name, con=engine, index=False, if_exists="append")
engine.dispose()


def attach_index(database: str, table_name: str):
"""Attaches the table index. Important for query times!"""
code = (
"PRAGMA foreign_keys=off;\n"
"BEGIN TRANSACTION;\n"
f"CREATE INDEX event_no_{table_name} ON {table_name} (event_no);\n"
"COMMIT TRANSACTION;\n"
"PRAGMA foreign_keys=on;"
)
run_sql_code(database, code)
return


def create_table(database, table_name, df):
def create_table(
df: pd.DataFrame,
table_name: str,
database_path: str,
is_pulse_map: bool = False,
):
"""Creates a table.
Args:
pipeline_database (str): path to the pipeline database
df (str): pandas.DataFrame of combined predictions
database (str): path to the database
table_name (str): name of the table
columns (str): the names of the columns of the table
is_pulse_map (bool, optional): whether or not this is a pulse map table. Defaults to False.
"""
query_columns = list()
for column in df.columns:
if column == "event_no":
type_ = "INTEGER PRIMARY KEY NOT NULL"
else:
type_ = "FLOAT"
if not is_pulse_map:
type_ = "INTEGER PRIMARY KEY NOT NULL"
else:
type_ = "NOT NULL"
query_columns.append(f"{column} {type_}")
query_columns = ", ".join(query_columns)

Expand All @@ -54,5 +75,9 @@ def create_table(database, table_name, df):
f"CREATE TABLE {table_name} ({query_columns});\n"
"PRAGMA foreign_keys=on;"
)
run_sql_code(database, code)
run_sql_code(
database_path,
code,
)

return
168 changes: 168 additions & 0 deletions src/graphnet/data/utilities/parquet_to_sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import pandas as pd
import os
import sqlite3
import awkward as ak

import glob
from typing import List, Optional, Union
from tqdm.auto import trange
import numpy as np
import sqlalchemy
from graphnet.data.sqlite.sqlite_utilities import (
run_sql_code,
save_to_sql,
attach_index,
create_table,
)

from graphnet.utilities.logging import LoggerMixin


class ParquetToSQLiteConverter(LoggerMixin):
"""Converts Parquet files to a SQLite database. Each event in the parquet file(s) are assigned a unique event id.
By default, every field in the parquet file(s) are extracted. One can choose to exclude certain fields by using the argument exclude_fields.
"""

def __init__(
self,
parquet_path: Union[str, List[str]],
mc_truth_table: str = "mc_truth",
excluded_fields: Optional[Union[str, List[str]]] = None,
):
# checks
if isinstance(parquet_path, str):
pass
elif isinstance(parquet_path, list):
assert isinstance(
parquet_path[0], str
), "Argument `parquet_path` must be a string or list of strings"
else:
assert isinstance(
parquet_path, str
), "Argument `parquet_path` must be a string or list of strings"

assert isinstance(
mc_truth_table, str
), "Argument `mc_truth_table` must be a string"
self._parquet_files = self._find_parquet_files(parquet_path)
if excluded_fields is not None:
self._excluded_fields = excluded_fields
else:
self._excluded_fields = []
self._mc_truth_table = mc_truth_table
self._event_counter = 0
self._created_tables = []

def _find_parquet_files(self, paths: Union[str, List[str]]) -> List[str]:
if isinstance(paths, str):
if paths.endswith(".parquet"):
files = [paths]
else:
files = glob.glob(f"{paths}/*.parquet")
elif isinstance(paths, list):
files = []
for path in paths:
files.extend(self._find_parquet_files(path))
assert len(files) > 0, f"No files found in {paths}"
return files

def run(self, outdir: str, database_name: str):
self._create_output_directories(outdir, database_name)
database_path = os.path.join(
outdir, database_name, "data", database_name + ".db"
)
for i in trange(
len(self._parquet_files), desc="Main", colour="#0000ff", position=0
):
parquet_file = ak.from_parquet(self._parquet_files[i])
n_events_in_file = self._count_events(parquet_file)
for j in trange(
len(parquet_file.fields),
desc="%s" % (self._parquet_files[i].split("/")[-1]),
colour="#ffa500",
position=1,
leave=False,
):
if parquet_file.fields[j] not in self._excluded_fields:
self._save_to_sql(
database_path,
parquet_file,
parquet_file.fields[j],
n_events_in_file,
)
self._event_counter += n_events_in_file
self._save_config(outdir, database_name)
print(
f"Database saved at: \n{outdir}/{database_name}/data/{database_name}.db"
)

def _count_events(self, open_parquet_file: ak.Array) -> int:
return len(open_parquet_file[self._mc_truth_table])

def _save_to_sql(
self,
database_path: str,
ak_array: ak.Array = None,
field_name: str = None,
n_events_in_file: int = None,
):
df = self._make_df(ak_array, field_name, n_events_in_file)
if field_name in self._created_tables:
save_to_sql(
database_path,
field_name,
df,
)
else:
if len(df) > n_events_in_file:
is_pulse_map = True
else:
is_pulse_map = False
create_table(df, field_name, database_path, is_pulse_map)
if is_pulse_map:
attach_index(database_path, table_name=field_name)
self._created_tables.append(field_name)
save_to_sql(
database_path,
field_name,
df,
)

def _convert_to_dataframe(
self,
ak_array: ak.Array,
field_name: str,
n_events_in_file: int,
) -> pd.DataFrame:
df = pd.DataFrame(ak.to_pandas(ak_array[field_name]))
if len(df.columns) == 1:
if df.columns == ["values"]:
df.columns = [field_name]
if (
len(df) != n_events_in_file
): # if true, the dataframe contains more than 1 row pr. event (e.g. Pulsemap).
event_nos = []
c = 0
for event_no in range(
self._event_counter, self._event_counter + n_events_in_file, 1
):
try:
event_nos.extend(
np.repeat(event_no, len(df[df.columns[0]][c])).tolist()
)
except KeyError: # KeyError indicates that this df has no entry for event_no (e.g. an event with no detector response)
pass
c += 1
else:
event_nos = np.arange(0, n_events_in_file, 1) + self._event_counter
df["event_no"] = event_nos
return df

def _create_output_directories(self, outdir: str, database_name: str):
os.makedirs(outdir + "/" + database_name + "/data", exist_ok=True)
os.makedirs(outdir + "/" + database_name + "/config", exist_ok=True)

def _save_config(self, outdir: str, database_name: str):
"""Save the list of converted Parquet files to a CSV file."""
df = pd.DataFrame(data=self._parquet_files, columns=["files"])
df.to_csv(outdir + "/" + database_name + "/config/files.csv")

0 comments on commit 30f338c

Please sign in to comment.