+"""Contains a Generic class for curated DataModules/Datasets.
+
+Inheriting subclasses are data-specific implementations that allow the user to
+import and download pre-converteddatasets for training of deep learning based
+methods in GraphNeT.
+"""
+
+fromtypingimportDict,Any,Optional,List,Tuple,Union
+fromabcimportabstractmethod
+importos
+
+from.datamoduleimportGraphNeTDataModule
+fromgraphnet.models.graphsimportGraphDefinition
+fromgraphnet.data.datasetimportParquetDataset,SQLiteDataset
+
+
+
+[docs]
+classCuratedDataset(GraphNeTDataModule):
+"""Generic base class for curated datasets.
+
+ Curated Datasets in GraphNeT are pre-converted datasets that have been
+ prepared for training and evaluation of deep learning models. On these
+ Datasets, graphnet users can train and benchmark their models against SOTA
+ methods.
+ """
+
+ def__init__(
+ self,
+ graph_definition:GraphDefinition,
+ download_dir:str,
+ truth:Optional[List[str]]=None,
+ features:Optional[List[str]]=None,
+ backend:str="parquet",
+ train_dataloader_kwargs:Optional[Dict[str,Any]]=None,
+ validation_dataloader_kwargs:Dict[str,Any]=None,
+ test_dataloader_kwargs:Dict[str,Any]=None,
+ )->None:
+"""Construct CuratedDataset.
+
+ Args:
+ graph_definition: Method that defines the data representation.
+ download_dir: Directory to download dataset to.
+ truth (Optional): List of event-level truth to include. Will
+ include all available information if not given.
+ features (Optional): List of input features from pulsemap to use.
+ If not given, all available features will be
+ used.
+ backend (Optional): data backend to use. Either "parquet" or
+ "sqlite". Defaults to "parquet".
+ train_dataloader_kwargs (Optional): Arguments for the training
+ DataLoader. Default None.
+ validation_dataloader_kwargs (Optional): Arguments for the
+ validation DataLoader, Default None.
+ test_dataloader_kwargs (Optional): Arguments for the test
+ DataLoader. Default None.
+ """
+ # From user
+ self._download_dir=download_dir
+ self._graph_definition=graph_definition
+ self._backend=backend.lower()
+
+ # Checks
+ assertbackend.lower()inself.available_backends
+ assertbackend.lower()in["sqlite","parquet"]# Double-check
+ ifbackend.lower()=="parquet":
+ dataset_ref=ParquetDataset# type: ignore
+ elifbackend.lower()=="sqlite":
+ dataset_ref=SQLiteDataset# type: ignore
+
+ # Methods:
+ features,truth=self._verify_args(features=features,truth=truth)
+ self.prepare_data()
+ self._check_properties()
+ dataset_args,selec,test_selec=self._prepare_args(
+ backend=backend,features=features,truth=truth
+ )
+ # Instantiate
+ super().__init__(
+ dataset_reference=dataset_ref,
+ dataset_args=dataset_args,
+ train_dataloader_kwargs=train_dataloader_kwargs,
+ validation_dataloader_kwargs=validation_dataloader_kwargs,
+ test_dataloader_kwargs=test_dataloader_kwargs,
+ selection=selec,
+ test_selection=test_selec,
+ )
+
+
+[docs]
+ @abstractmethod
+ defprepare_data(self)->None:
+"""Download and prepare data."""
+
+
+ @abstractmethod
+ def_prepare_args(
+ self,backend:str,features:List[str],truth:List[str]
+ )->Tuple[Dict[str,Any],Union[List[int],None],Union[List[int],None]]:
+"""Prepare arguments to DataModule.
+
+ Args:
+ backend: backend of dataset. Either "parquet" or "sqlite"
+ features: List of features from user to use as input.
+ truth: List of event-level truth form user.
+
+ This method should return three outputs in the following order:
+
+ A) `dataset_args`
+ B) `selection` if wanted, else None
+ C) ``test_selection` if wanted, else None.
+
+ See documentation on GraphNeTDataModule for details on these
+ arguments:
+ https://graphnet-team.github.io/graphnet/api/graphnet.data.datamodule.html
+ """
+
+ def_verify_args(
+ self,features:Union[List[str],None],truth:Union[List[str],None]
+ )->Tuple[List[str],List[str]]:
+"""Check arguments for truth and features from the user.
+
+ Will check to make sure that the given args are available. If not
+ available, and AssertError is thrown.
+ """
+ iffeaturesisNone:
+ features=self._features
+ else:
+ self._assert_isin(given=features,available=self._features)
+ iftruthisNone:
+ truth=self._event_truth
+ else:
+ self._assert_isin(given=truth,available=self._event_truth)
+
+ returnfeatures,truth
+
+ def_assert_isin(self,given:List[str],available:List[str])->None:
+ forkeyingiven:
+ assertkeyinavailable
+
+
+[docs]
+ defdescription(self)->None:
+"""Print details on the Dataset."""
+ event_counts=self.events
+ print(
+ "\n",
+ f"{self.__class__.__name__} contains data from",
+ f"{self.experiment} and was added to GraphNeT by",
+ f"{self.creator}.",
+ "\n\n",
+ "COMMENTS ON USAGE: \n",
+ f"{self.creator}: {self.comments}\n",
+ "\n",
+ "DATASET DETAILS: \n",
+ f"pulsemaps: {self.pulsemaps}\n",
+ f"truth table: {self.truth_table}\n",
+ f"input features: {self.features}\n",
+ f"pulse truth: {self.pulse_truth}\n",
+ f"event truth: {self.event_truth}\n",
+ f"Number of training events: {event_counts['train']}\n",
+ f"Number of validation events: {event_counts['val']}\n",
+ f"Number of test events: {event_counts['test']}\n",
+ "\n",
+ "CITATION:\n",
+ f"{self.citation}",
+ )
+
+
+ def_check_properties(self)->None:
+"""Check that fields have been filled out."""
+ attr=[
+ "pulsemaps",
+ "truth_table",
+ "event_truth",
+ "pulse_truth",
+ "features",
+ "experiment",
+ "citation",
+ "creator",
+ "available_backends",
+ ]
+ forattributeinattr:
+ asserthasattr(self,"_"+attribute),f"missing {attribute}"
+
+ @property
+ defpulsemaps(self)->List[str]:
+"""Produce a list of available pulsemaps in Dataset."""
+ returnself._pulsemaps
+
+ @property
+ deftruth_table(self)->List[str]:
+"""Produce name of table containing event-level truth in Dataset."""
+ returnself._truth_table
+
+ @property
+ defevent_truth(self)->List[str]:
+"""Produce a list of available event-level truth in Dataset."""
+ returnself._event_truth
+
+ @property
+ defpulse_truth(self)->Union[List[str],None]:
+"""Produce a list of available pulse-level truth in Dataset."""
+ returnself._pulse_truth
+
+ @property
+ deffeatures(self)->List[str]:
+"""Produce a list of available input features in Dataset."""
+ returnself._features
+
+ @property
+ defexperiment(self)->str:
+"""Produce the name of the experiment that the data comes from."""
+ returnself._experiment
+
+ @property
+ defcitation(self)->str:
+"""Produce a string that describes how to cite this Dataset."""
+ returnself._citation
+
+ @property
+ defcomments(self)->str:
+"""Produce comments on the dataset from the creator."""
+ returnself._comments
+
+ @property
+ defcreator(self)->str:
+"""Produce name of person who created the Dataset."""
+ returnself._creator
+
+ @property
+ defevents(self)->Dict[str,int]:
+"""Produce a dict that contains number events in each selection."""
+ n_train=len(self._train_dataset)
+ ifhasattr(self,"_val_dataset"):
+ n_val=len(self._val_dataset)
+ else:
+ n_val=0
+ ifhasattr(self,"_test_dataset"):
+ n_test=len(self._test_dataset)
+ else:
+ n_test=0
+
+ return{"train":n_train,"val":n_val,"test":n_test}
+
+ @property
+ defavailable_backends(self)->List[str]:
+"""Produce a list of available data formats that the data comes in."""
+ returnself._available_backends
+
+ @property
+ defdataset_dir(self)->str:
+"""Produce path directory that contains dataset files."""
+ dataset_dir=os.path.join(
+ self._download_dir,self.__class__.__name__,self._backend
+ )
+ returndataset_dir
+
+
+
+
+[docs]
+classERDAHostedDataset(CuratedDataset):
+"""A base class for dataset/datamodule hosted at ERDA.
+
+ Inheriting subclasses will just need to fill out the `_file_hashes`
+ attribute, which points to the file-id of a ERDA-hosted sharelink. It
+ is assumed that sharelinks point to a single compressed file that has
+ been compressed using `tar` with extension ".tar.gz".
+
+ E.g. suppose that the sharelink below
+ https://sid.erda.dk/share_redirect/FbEEzAbg5A
+ points to a compressed sqlite database. Then:
+ _file_hashes = {'sqlite' : "FbEEzAbg5A"}
+ """
+
+ # Member variables
+ _mirror="https://sid.erda.dk/share_redirect"
+ _file_hashes:Dict[str,str]={}# Must be filled out by you!
+
+
Source code for graphnet
num_workers: The number of CPUs used for parallel processing. Defaults to 1 (no multiprocessing). """
+ # Base class constructor
+ super().__init__(name=__name__,class_name=self.__class__.__name__)
+
# Member Variable Assignmentself._file_reader=file_readerself._save_method=save_method
@@ -400,10 +439,8 @@
Source code for graphnet
# with reader.ifnotisinstance(extractors,list):extractors=[extractors]
- self._file_reader.set_extractors(extractors=extractors)
- # Base class constructor
- super().__init__(name=__name__,class_name=self.__class__.__name__)
+ self._file_reader.set_extractors(extractors=extractors)@finaldef__call__(self,input_dir:Union[str,List[str]])->None:
@@ -447,10 +484,9 @@
Source code for graphnet
# Iterate over filesfor_inmap_fn(self._process_file,
- tqdm(input_files,unit="file(s)",colour="green"),
+ tqdm(input_files,unit=" file(s)",colour="green"),):self.debug("processing file.")
-
self._update_shared_variables(pool)@final
@@ -463,13 +499,27 @@
Source code for graphnet
This function is called in parallel. """# Read and apply extractors
- data:List[OrderedDict]=self._file_reader(file_path=file_path)
-
- # Count number of events
- n_events=len(data)
-
- # Assign event_no's to each event in data and transform to pd.DataFrame
- dataframes=self._assign_event_no(data=data)
+ data=self._file_reader(file_path=file_path)
+
+ #
+ ifisinstance(data,list):
+ # Assign event_no's to each event in data
+ # and transform to pd.DataFrame
+ n_events=len(data)
+ dataframes=self._assign_event_no(data=data)
+ elifisinstance(data,dict):
+ keys=[keyforkeyindata.keys()]
+ counter=[]
+ forkeyinkeys:
+ assertisinstance(data[key],pd.DataFrame)
+ assertself._index_columnindata[key].columns
+ counter.append(len(data[key][self._index_column]))
+ dataframes=data
+ n_events=len(
+ pd.unique(data[keys[np.argmin(counter)]][self._index_column])
+ )
+ else:
+ assert1==2,"should not reach here."# Delete `data` to save memorydeldata
@@ -539,7 +589,6 @@
Source code for graphnet
)->int:"""Count number of rows that features from `extractor_name` have."""extractor_dict=event_dict[extractor_name]
-
try:# If all features in extractor_name have the same length# this line of code will execute without error and result
@@ -632,7 +681,9 @@
Source code for graphnet
[docs]@final
- defmerge_files(self,files:Optional[List[str]]=None)->None:
+ defmerge_files(
+ self,files:Optional[List[str]]=None,**kwargs:Any
+ )->None:"""Merge converted files. `DataConverter` will call the `.merge_files` method in the
@@ -660,8 +711,7 @@
Source code for graphnet
merge_path=os.path.join(self._output_dir,"merged")self.info(f"Merging files to {merge_path}")self._save_method.merge_files(
- files=files_to_merge,
- output_dir=merge_path,
+ files=files_to_merge,output_dir=merge_path,**kwargs)
@@ -689,7 +739,7 @@
Source code for graphnet
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/data/dataloader.html b/_modules/graphnet/data/dataloader.html
index 6aa911562..c39bee23e 100644
--- a/_modules/graphnet/data/dataloader.html
+++ b/_modules/graphnet/data/dataloader.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for graphnet.da
selection: (Optional) a list of event id's used for training and validation, Default None. test_selection: (Optional) a list of event id's used for testing,
- Default None.
+ Defaults to None. train_dataloader_kwargs: Arguments for the training DataLoader,
- Default None.
+ Defaults{"batch_size": 2, "num_workers": 1}. validation_dataloader_kwargs: Arguments for the validation
- DataLoader, Default None.
+ DataLoader. Defaults to
+ `train_dataloader_kwargs`. test_dataloader_kwargs: Arguments for the test DataLoader,
- Default None.
+ Defaults to `train_dataloader_kwargs`. train_val_split (Optional): Split ratio for training and validation sets. Default is [0.9, 0.10]. split_seed: seed used for shuffling and splitting selections into
@@ -387,17 +417,101 @@
Source code for graphnet.da
self._train_val_split=train_val_splitor[0.0]self._rng=split_seed
- self._train_dataloader_kwargs=train_dataloader_kwargsor{}
- self._validation_dataloader_kwargs=validation_dataloader_kwargsor{}
- self._test_dataloader_kwargs=test_dataloader_kwargsor{}
+ iftrain_dataloader_kwargsisNone:
+ train_dataloader_kwargs={"batch_size":2,"num_workers":1}
+
+ self._set_dataloader_kwargs(
+ train_dataloader_kwargs,
+ validation_dataloader_kwargs,
+ test_dataloader_kwargs,
+ )# If multiple dataset paths are given, we should use EnsembleDatasetself._use_ensemble_dataset=isinstance(self._dataset_args["path"],list)
+ # Create Dataloadersself.setup("fit")
+ def_set_dataloader_kwargs(
+ self,
+ train_dl_args:Dict[str,Any],
+ val_dl_args:Union[Dict[str,Any],None],
+ test_dl_args:Union[Dict[str,Any],None],
+ )->None:
+"""Copy train dataloader args to other dataloaders if not given.
+
+ Also checks that ParquetDataset dataloaders have multiprocessing
+ context set to "spawn" as this is strictly required.
+
+ See: https://docs.pola.rs/user-guide/misc/multiprocessing/
+ """
+ ifval_dl_argsisNone:
+ self.info(
+ "No `val_dataloader_kwargs` given. This arg has "
+ "been set to `train_dataloader_kwargs` with `shuffle` = False."
+ )
+ val_dl_args=deepcopy(train_dl_args)
+ val_dl_args["shuffle"]=False# Important for inference
+ if(test_dl_argsisNone)&(self._test_selectionisnotNone):
+ test_dl_args=deepcopy(train_dl_args)
+ test_dl_args["shuffle"]=False# Important for inference
+ self.info(
+ "No `test_dataloader_kwargs` given. This arg has "
+ "been set to `train_dataloader_kwargs` with `shuffle` = False."
+ )
+
+ ifself._dataset==ParquetDataset:
+ train_dl_args=self._add_context(train_dl_args,"training")
+ val_dl_args=self._add_context(val_dl_args,"validation")
+ ifself._test_selectionisnotNone:
+ asserttest_dl_argsisnotNone
+ test_dl_args=self._add_context(test_dl_args,"test")
+
+ self._train_dataloader_kwargs=train_dl_args
+ self._validation_dataloader_kwargs=val_dl_args
+ self._test_dataloader_kwargs=test_dl_argsor{}
+
+ def_add_context(
+ self,dataloader_args:Dict[str,Any],dataloader_type:str
+ )->Dict[str,Any]:
+"""Handle assignment of `multiprocessing_context` arg to loaders.
+
+ Datasets relying on threaded libraries often require the
+ multiprocessing context to be set to "spawn" if "num_workers" > 0. This
+ method will check the arguments for this entry and throw an error if
+ the field is already assigned to a wrong value. If the value is not
+ specified, it is added automatically with a log entry.
+ """
+ arg="multiprocessing_context"
+ ifdataloader_args["num_workers"]!=0:
+ # If using multiprocessing
+ ifargindataloader_args:
+ ifdataloader_args[arg]!="spawn":
+ # Wrongly assigned by user
+ self.error(
+ "DataLoaders using `ParquetDataset` must have "
+ "multiprocessing_context = 'spawn'. "
+ f" Found '{dataloader_args[arg]}' in ",
+ f"{dataloader_type} dataloader.",
+ )
+ raiseValueError("multiprocessing_context must be 'spawn'")
+ else:
+ # Correctly assigned by user
+ returndataloader_args
+ else:
+ # Forgotten assignment by user
+ dataloader_args[arg]="spawn"
+ self.warning_once(
+ f"{self.__class__.__name__} has automatically "
+ "set multiprocessing_context = 'spawn' in "
+ f"{dataloader_type} dataloader."
+ )
+ returndataloader_args
+ else:
+ returndataloader_args
+
Source code for graphn
"""Return a list of all unique values in `self._index_column`."""@abstractmethod
- def_get_event_index(
- self,sequential_index:Optional[int]
- )->Optional[int]:
+ def_get_event_index(self,sequential_index:int)->int:"""Return the event index corresponding to a `sequential_index`."""
@@ -756,7 +774,7 @@
Source code for graphn
columns:Union[List[str],str],sequential_index:Optional[int]=None,selection:Optional[str]=None,
- )->List[Tuple[Any,...]]:
+ )->np.ndarray:"""Query a table at a specific index, optionally with some selection. Args:
@@ -878,7 +896,9 @@
Source code for graphn
"""Return a list missing columns in `table`."""forcolumnincolumns:try:
- self.query_table(table,[column],0)
+ self.query_table(
+ table=table,columns=[column],sequential_index=0
+ )exceptColumnMissingException:iftablenotinself._missing_variables:self._missing_variables[table]=[]
@@ -890,12 +910,7 @@
Source code for graphn
def_query(self,sequential_index:int
- )->Tuple[
- List[Tuple[float,...]],
- Tuple[Any,...],
- Optional[List[Tuple[Any,...]]],
- Optional[float],
- ]:
+ )->Tuple[np.ndarray,np.ndarray,Optional[np.ndarray],Optional[float]]:"""Query file for event features and truth information. The returned lists have lengths corresponding to the number of pulses
@@ -917,11 +932,14 @@
Source code for graphn
Returns: Graph object. """
- # Convert nested list to simple dict
+ # Convert truth to dict
+ iflen(truth.shape)==1:
+ truth=truth.reshape(1,-1)truth_dict={
- key:truth[index]forindex,keyinenumerate(self._truth)
+ key:truth[:,index]forindex,keyinenumerate(self._truth)}# Define custom labels
@@ -977,10 +993,9 @@
Source code for graphn
# Convert nested list to simple dictifnode_truthisnotNone:
- node_truth_array=np.asarray(node_truth)assertself._node_truthisnotNonenode_truth_dict={
- key:node_truth_array[:,index]
+ key:node_truth[:,index]forindex,keyinenumerate(self._node_truth)}
@@ -991,19 +1006,16 @@
Source code for graphn
# Catch cases with no reconstructed pulsesiflen(features):
- node_features=np.asarray(features)[
- :,1:
- ]# first entry is index column
+ node_features=featureselse:
- node_features=np.array([]).reshape((0,len(self._features)-1))
+ node_features=np.array([]).reshape((0,len(self._features)))
+ assertisinstance(features,np.ndarray)# Construct graph data objectassertself._graph_definitionisnotNonegraph=self._graph_definition(input_features=node_features,
- input_feature_names=self._features[
- 1:
- ],# first entry is index column
+ input_feature_names=self._features,truth_dicts=truth_dicts,custom_label_functions=self._label_fns,loss_weight_column=self._loss_weight_column,
@@ -1017,13 +1029,11 @@
Source code for graphn
"""Return dictionary of labels, to be added as graph attributes."""if"pid"intruth_dict.keys():abs_pid=abs(truth_dict["pid"])
- sim_type=truth_dict["sim_type"]labels_dict={self._index_column:truth_dict[self._index_column],"muon":int(abs_pid==13),"muon_stopped":int(truth_dict.get("stopped_muon")==1),
- "noise":int((abs_pid==1)&(sim_type!="data")),"neutrino":int((abs_pid!=13)&(abs_pid!=1)),# @TODO: `abs_pid in [12,14,16]`?
@@ -1031,7 +1041,7 @@
[docs]classParquetDataset(Dataset):
-"""Pytorch dataset for reading from Parquet files."""
+"""Dataset class for Parquet-files converted with `ParquetWriter`."""
- # Implementing abstract method(s)
- def_init(self)->None:
- # Check(s)
- ifnotisinstance(self._path,list):
-
- assertisinstance(self._path,str)
-
- assertself._path.endswith(
- ".parquet"
- ),f"Format of input file `{self._path}` is not supported"
-
- assert(
- self._node_truthisNone
- ),"Argument `node_truth` is currently not supported."
- assert(
- self._node_truth_tableisNone
- ),"Argument `node_truth_table` is currently not supported."
- assert(
- self._string_selectionisNone
- ),"Argument `string_selection` is currently not supported"
-
- # Set custom member variable(s)
- ifnotisinstance(self._path,list):
- self._parquet_hook=ak.from_parquet(self._path,lazy=False)
- else:
- self._parquet_hook=ak.concatenate(
- ak.from_parquet(file)forfileinself._path
- )
+ def__init__(
+ self,
+ path:str,
+ graph_definition:GraphDefinition,
+ pulsemaps:Union[str,List[str]],
+ features:List[str],
+ truth:List[str],
+ *,
+ node_truth:Optional[List[str]]=None,
+ index_column:str="event_no",
+ truth_table:str="truth",
+ node_truth_table:Optional[str]=None,
+ string_selection:Optional[List[int]]=None,
+ selection:Optional[Union[str,List[int],List[List[int]]]]=None,
+ dtype:torch.dtype=torch.float32,
+ loss_weight_table:Optional[str]=None,
+ loss_weight_column:Optional[str]=None,
+ loss_weight_default_value:Optional[float]=None,
+ seed:Optional[int]=None,
+ cache_size:int=1,
+ ):
+"""Construct Dataset.
+
+ NOTE: DataLoaders using this Dataset should have
+ "multiprocessing_context = 'spawn'" set to avoid thread locking.
+
+ Args:
+ path: Path to the file(s) from which this `Dataset` should read.
+ pulsemaps: Name(s) of the pulse map series that should be used to
+ construct the nodes on the individual graph objects, and their
+ features. Multiple pulse series maps can be used, e.g., when
+ different DOM types are stored in different maps.
+ features: List of columns in the input files that should be used as
+ node features on the graph objects.
+ truth: List of event-level columns in the input files that should
+ be used added as attributes on the graph objects.
+ node_truth: List of node-level columns in the input files that
+ should be used added as attributes on the graph objects.
+ index_column: Name of the column in the input files that contains
+ unique indicies to identify and map events across tables.
+ truth_table: Name of the table containing event-level truth
+ information.
+ node_truth_table: Name of the table containing node-level truth
+ information.
+ string_selection: Subset of strings for which data should be read
+ and used to construct graph objects. Defaults to None, meaning
+ all strings for which data exists are used.
+ selection: The batch ids to include in the dataset.
+ Defaults to None, meaning that batches are read.
+ dtype: Type of the feature tensor on the graph objects returned.
+ loss_weight_table: Name of the table containing per-event loss
+ weights.
+ loss_weight_column: Name of the column in `loss_weight_table`
+ containing per-event loss weights. This is also the name of the
+ corresponding attribute assigned to the graph object.
+ loss_weight_default_value: Default per-event loss weight.
+ NOTE: This default value is only applied when
+ `loss_weight_table` and `loss_weight_column` are specified, and
+ in this case to events with no value in the corresponding
+ table/column. That is, if no per-event loss weight table/column
+ is provided, this value is ignored. Defaults to None.
+ seed: Random number generator seed, used for selecting a random
+ subset of events when resolving a string-based selection (e.g.,
+ `"10000 random events ~ event_no % 5 > 0"` or `"20% random
+ events ~ event_no % 5 > 0"`).
+ graph_definition: Method that defines the graph representation.
+ cache_size: Number of batches to cache in memory.
+ Must be at least 1. Defaults to 1.
+ """
+ self._validate_selection(selection)
+ # Base class constructor
+ super().__init__(
+ path=path,
+ pulsemaps=pulsemaps,
+ features=features,
+ truth=truth,
+ node_truth=node_truth,
+ index_column=index_column,
+ truth_table=truth_table,
+ node_truth_table=node_truth_table,
+ string_selection=string_selection,
+ selection=selection,
+ dtype=dtype,
+ loss_weight_table=loss_weight_table,
+ loss_weight_column=loss_weight_column,
+ loss_weight_default_value=loss_weight_default_value,
+ seed=seed,
+ graph_definition=graph_definition,
+ )
- def_get_all_indices(self)->List[int]:
- returnnp.arange(
- len(
- ak.to_numpy(
- self._parquet_hook[self._truth_table][self._index_column]
- ).tolist()
- )
- ).tolist()
+ # mypy..
+ assertisinstance(self._path,str)
+ self._path:str=self._path
+ # Member Variables
+ self._cache_size=cache_size
+ self._batch_sizes=self._calculate_sizes()
+ self._batch_cumsum=np.cumsum(self._batch_sizes)
+ self._file_cache=self._initialize_file_cache(
+ truth_table=truth_table,
+ node_truth_table=node_truth_table,
+ pulsemaps=pulsemaps,
+ )
+ self._string_selection=string_selection
+ # Purely internal member variables
+ self._missing_variables:Dict[str,List[str]]={}
+ self._remove_missing_columns()
- def_get_event_index(
- self,sequential_index:Optional[int]
- )->Optional[int]:
- index:Optional[int]
- ifsequential_indexisNone:
- index=None
- else:
- index=cast(List[int],self._indices)[sequential_index]
-
- returnindex
-
- def_format_dictionary_result(
- self,dictionary:Dict
- )->List[Tuple[Any,...]]:
-"""Convert the output of `ak.to_list()` into a list of tuples."""
- # All scalar values
- ifall(map(np.isscalar,dictionary.values())):
- return[tuple(dictionary.values())]
-
- # All arrays should have same length
- array_lengths=[
- len(values)
- forvaluesindictionary.values()
- ifnotnp.isscalar(values)
- ]
- assertlen(set(array_lengths))==1,(
- f"Arrays in {dictionary} have differing lengths "
- f"({set(array_lengths)})."
+ def_initialize_file_cache(
+ self,
+ truth_table:str,
+ node_truth_table:Optional[str],
+ pulsemaps:Union[str,List[str]],
+ )->Dict[str,OrderedDict]:
+ tables=[truth_table]
+ ifnode_truth_tableisnotNone:
+ tables.append(node_truth_table)
+ ifisinstance(pulsemaps,str):
+ tables.append(pulsemaps)
+ elifisinstance(pulsemaps,list):
+ tables.extend(pulsemaps)
+
+ cache:Dict[str,OrderedDict]={}
+ fortableintables:
+ cache[table]=OrderedDict()
+ returncache
+
+ def_validate_selection(
+ self,
+ selection:Optional[Union[str,List[int],List[List[int]]]]=None,
+ )->None:
+ ifselectionisnotNone:
+ try:
+ assertnotisinstance(selection,str)
+ exceptAssertionError:
+ e=AssertionError(
+ f"{self.__class__.__name__} does not support "
+ "str-selections."
+ )
+ raisee
+
+ def_init(self)->None:
+ return
+
+ def_get_event_index(self,sequential_index:int)->int:
+ event_index=self.query_table(
+ table=self._truth_table,
+ sequential_index=sequential_index,
+ columns=[self._index_column],)
- nb_elements=array_lengths[0]
+ returnevent_index
- # Broadcast scalars
- forkeyindictionary:
- value=dictionary[key]
- ifnp.isscalar(value):
- dictionary[key]=np.repeat(
- value,repeats=nb_elements
- ).tolist()
+ def__len__(self)->int:
+"""Return length of dataset, i.e. number of training examples."""
+ returnsum(self._batch_sizes)
- returnlist(map(tuple,list(zip(*dictionary.values()))))
+ def_get_all_indices(self)->List[int]:
+"""Return a list of all unique values in `self._index_column`."""
+ files=glob(os.path.join(self._path,self._truth_table,"*.parquet"))
+ returnnp.arange(0,len(files),1)
+
+ def_calculate_sizes(self)->List[int]:
+"""Calculate the number of events in each batch."""
+ sizes=[]
+ forbatch_idinself._indices:
+ path=os.path.join(
+ self._path,
+ self._truth_table,
+ f"{self.truth_table}_{batch_id}.parquet",
+ )
+ sizes.append(len(pol.read_parquet(path)))
+ returnsizes
+
+ def_get_row_idx(self,sequential_index:int)->int:
+"""Return the row index corresponding to a `sequential_index`."""
+ file_idx=bisect_right(self._batch_cumsum,sequential_index)
+ iffile_idx>0:
+ idx=int(sequential_index-self._batch_cumsum[file_idx-1])
+ else:
+ idx=sequential_index
+ returnidx
[docs]
- defquery_table(
+ defquery_table(# type: ignoreself,table:str,columns:Union[List[str],str],sequential_index:Optional[int]=None,selection:Optional[str]=None,
- )->List[Tuple[Any,...]]:
-"""Query table at a specific index, optionally with some selection."""
- # Check(s)
- assert(
- selectionisNone
- ),"Argument `selection` is currently not supported"
-
- index=self._get_event_index(sequential_index)
-
- try:
- ifindexisNone:
- ak_array=self._parquet_hook[table][columns][:]
- else:
- ak_array=self._parquet_hook[table][columns][index]
- exceptValueErrorase:
- if"does not exist (not in record)"instr(e):
- raiseColumnMissingException(str(e))
- else:
- raisee
+ )->np.ndarray:
+"""Query a table at a specific index, optionally with some selection.
+
+ Args:
+ table: Table to be queried.
+ columns: Columns to read out.
+ sequential_index: Sequentially numbered index
+ (i.e. in [0,len(self))) of the event to query. This _may_
+ differ from the indexation used in `self._indices`. If no value
+ is provided, the entire column is returned.
+ selection: Selection to be imposed before reading out data.
+ Defaults to None.
+
+ Returns:
+ List of tuples containing the values in `columns`. If the `table`
+ contains only scalar data for `columns`, a list of length 1 is
+ returned
+
+ Raises:
+ ColumnMissingException: If one or more element in `columns` is not
+ present in `table`.
+ """
+ ifisinstance(columns,str):
+ columns=[columns]
- output=ak_array.to_list()
+ ifsequential_indexisNone:
+ file_idx=np.arange(0,len(self._batch_cumsum),1)
+ else:
+ file_idx=[bisect_right(self._batch_cumsum,sequential_index)]
+
+ file_indices=[self._indices[idx]foridxinfile_idx]
+
+ arrays=[]
+ forfile_idxinfile_indices:
+ array=self._query_table(
+ table=table,
+ columns=columns,
+ file_idx=file_idx,
+ sequential_index=sequential_index,
+ selection=selection,
+ )
+ arrays.append(array)
+ returnnp.concatenate(arrays,axis=0)
Source code for graphnet.data.dataset.sqlite.sqlite_dataset
"""`Dataset` class(es) for reading data from SQLite databases."""
-fromtypingimportAny,List,Optional,Tuple,Union
+fromtypingimportAny,List,Optional,Tuple,Union,Dictimportpandasaspdimportsqlite3
+importnumpyasnpfromgraphnet.data.dataset.datasetimportDataset,ColumnMissingException
@@ -363,6 +391,9 @@
Source c
self._conn:Optional[sqlite3.Connection]=Nonedef_post_init(self)->None:
+ # Purely internal member variables
+ self._missing_variables:Dict[str,List[str]]={}
+ self._remove_missing_columns()self._close_connection()
@@ -405,7 +436,7 @@
Source c
raiseColumnMissingException(str(e))else:raisee
- returnresult
Source code for graphnet.data.extractors.combine_extractors
+"""Module for combining multiple extractors into a single extractor."""
+fromtypingimportTYPE_CHECKING
+
+fromgraphnet.utilities.importsimporthas_icecube_package
+fromgraphnet.data.extractors.icecube.i3extractorimportI3Extractor
+fromtypingimportList,Dict
+
+ifhas_icecube_package()orTYPE_CHECKING:
+ fromicecubeimporticetray# pyright: reportMissingImports=false
+
+
+
+[docs]
+classCombinedExtractor(I3Extractor):
+"""Class for combining multiple extractors.
+
+ This class is used to combine multiple extractors into a single extractor
+ with a new name.
+ """
+
+ def__init__(self,extractors:List[I3Extractor],extractor_name:str):
+"""Construct CombinedExtractor.
+
+ Args:
+ extractors: List of extractors to combine. The extractors must all return data on the same level; e.g. all event-level data or pulse-level data. Mixing tables that contain event-level and pulse-level information will fail.
+ extractor_name: Name of the new extractor.
+ """
+ super().__init__(extractor_name=extractor_name)
+ self._extractors=extractors
+
+ def__call__(self,frame:"icetray.I3Frame")->Dict[str,float]:
+"""Extract data from frame using all extractors.
+
+ Args:
+ frame: I3Frame to extract data from.
+ """
+ output={}
+ forextractorinself._extractors:
+ output.update(extractor(frame))
+ returnoutput
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/_modules/graphnet/data/extractors/extractor.html b/_modules/graphnet/data/extractors/extractor.html
index 3d5037a63..22f43696f 100644
--- a/_modules/graphnet/data/extractors/extractor.html
+++ b/_modules/graphnet/data/extractors/extractor.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for g
super().__init__(name=__name__,class_name=self.__class__.__name__)@abstractmethod
- def__call__(self,data:Any)->dict:
+ def__call__(self,data:Any)->Union[dict,pd.DataFrame]:"""Extract information from data."""pass
@@ -395,7 +423,7 @@
Source code for g
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/data/extractors/icecube/i3extractor.html b/_modules/graphnet/data/extractors/icecube/i3extractor.html
index 4715a0a3c..92e23878c 100644
--- a/_modules/graphnet/data/extractors/icecube/i3extractor.html
+++ b/_modules/graphnet/data/extractors/icecube/i3extractor.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Sou
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/data/extractors/icecube/i3quesoextractor.html b/_modules/graphnet/data/extractors/icecube/i3quesoextractor.html
index a0d16868c..5231089b1 100644
--- a/_modules/graphnet/data/extractors/icecube/i3quesoextractor.html
+++ b/_modules/graphnet/data/extractors/icecube/i3quesoextractor.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
So
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/data/extractors/icecube/i3retroextractor.html b/_modules/graphnet/data/extractors/icecube/i3retroextractor.html
index 5299392c8..ccfdfa786 100644
--- a/_modules/graphnet/data/extractors/icecube/i3retroextractor.html
+++ b/_modules/graphnet/data/extractors/icecube/i3retroextractor.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
So
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/data/extractors/icecube/i3splinempeextractor.html b/_modules/graphnet/data/extractors/icecube/i3splinempeextractor.html
index d2c33d7d9..d794fc6b0 100644
--- a/_modules/graphnet/data/extractors/icecube/i3splinempeextractor.html
+++ b/_modules/graphnet/data/extractors/icecube/i3splinempeextractor.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
So
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/data/extractors/icecube/i3tumextractor.html b/_modules/graphnet/data/extractors/icecube/i3tumextractor.html
index 2bfb35fdd..ebafcaef6 100644
--- a/_modules/graphnet/data/extractors/icecube/i3tumextractor.html
+++ b/_modules/graphnet/data/extractors/icecube/i3tumextractor.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
So
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/data/extractors/icecube/utilities/i3_filters.html b/_modules/graphnet/data/extractors/icecube/utilities/i3_filters.html
index e02d4474e..7237fa78b 100644
--- a/_modules/graphnet/data/extractors/icecube/utilities/i3_filters.html
+++ b/_modules/graphnet/data/extractors/icecube/utilities/i3_filters.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for graphnet.data.extractors.internal.parquet_extractor
+"""Parquet Extractor for conversion from internal parquet format."""
+importpolarsaspol
+importpandasaspd
+
+fromgraphnet.data.extractorsimportExtractor
+
+
+
+[docs]
+classParquetExtractor(Extractor):
+"""Class for extracting information from internal GraphNeT parquet files.
+
+ Contains functionality required to extract data from internal parquet
+ files, i.e files saved using the ParquetWriter. This allows for conversion
+ between internal data formats.
+ """
+
+ def__init__(self,extractor_name:str):
+"""Construct ParquetExtractor.
+
+ Args:
+ extractor_name: Name of the `ParquetExtractor` instance.
+ Used to keep track of the provenance of different data,
+ and to name tables to which this data is saved.
+ """
+ # Member variable(s)
+ self._table=extractor_name
+ # Base class constructor
+ super().__init__(extractor_name=extractor_name)
+
+ def__call__(self,file_path:str)->pd.DataFrame:
+"""Extract information from parquet file."""
+ ifself._tableinfile_path:
+ returnpol.read_parquet(file_path).to_pandas()
+ else:
+ returnNone
Source code for graphnet.data.extractors.liquido.h5_extractor
+"""H5 Extractor for LiquidO data files."""
+fromtypingimportList
+importnumpyasnp
+importpandasaspd
+importh5py
+
+fromgraphnet.data.extractorsimportExtractor
+
+
+
+[docs]
+classH5Extractor(Extractor):
+"""Class for extracting information from LiquidO h5 files."""
+
+ def__init__(self,extractor_name:str,column_names:List[str]):
+"""Construct H5Extractor.
+
+ Args:
+ extractor_name: Name of the `ParquetExtractor` instance.
+ Used to keep track of the provenance of different data,
+ and to name tables to which this data is saved.
+ column_names: Name of the columns in `extractor_name`.
+ """
+ # Member variable(s)
+ self._table=extractor_name
+ self._column_names=column_names
+ # Base class constructor
+ super().__init__(extractor_name=extractor_name)
+
+ def__call__(self,file_path:str)->pd.DataFrame:
+"""Extract information from h5 file."""
+ withh5py.File(file_path,"r")asf:
+ available_tables=[fforfinf.keys()]
+ ifself._tableinavailable_tables:
+ array=f[self._table][:]
+ # Will throw error if the number of columns don't match
+ self._verify_columns(array)
+ df=pd.DataFrame(array,columns=self._column_names)
+ returndf
+ else:
+ returnNone
+
+ def_verify_columns(self,array:np.ndarray)->None:
+ try:
+ assertarray.shape[1]==len(self._column_names)
+ exceptAssertionErrorase:
+ self.error(
+ f"Got {len(self._column_names)} column names but "
+ f"{self._table} has {array.shape[1]}. Please make sure "
+ f"that the column names match. ({self._column_names})"
+ )
+ raisee
+
+
+
+
+[docs]
+classH5HitExtractor(H5Extractor):
+"""Extractor for `HitData` in LiquidO H5 files."""
+
+ def__init__(self)->None:
+"""Extractor for `HitData` in LiquidO H5 files."""
+ # Base class constructor
+ super().__init__(
+ extractor_name="HitData",
+ column_names=[
+ "event_no",
+ "sipmID",
+ "sipm_x",
+ "sipm_y",
+ "sipm_z",
+ "t",
+ "var",
+ ],
+ )
+
+
+
+
+[docs]
+classH5TruthExtractor(H5Extractor):
+"""Extractor for `TruthData` in LiquidO H5 files."""
+
+ def__init__(self)->None:
+"""Extractor for `TruthData` in LiquidO H5 files."""
+ # Base class constructor
+ super().__init__(
+ extractor_name="TruthData",
+ column_names=[
+ "event_no",
+ "vertex_x",
+ "vertex_y",
+ "vertex_z",
+ "zenith",
+ "azimuth",
+ "interaction_time",
+ "energy",
+ "pid",
+ ],
+ )
Source code for graphnet.data.extractors.prometheus.prometheus_extractor
+"""Parquet Extractor for conversion of simulation files from PROMETHEUS."""
+fromtypingimportList
+importpandasaspd
+importnumpyasnp
+
+fromgraphnet.data.extractorsimportExtractor
+
+
+
+[docs]
+classPrometheusExtractor(Extractor):
+"""Class for extracting information from PROMETHEUS parquet files.
+
+ Contains functionality required to extract data from PROMETHEUS parquet
+ files.
+ """
+
+ def__init__(self,extractor_name:str,columns:List[str]):
+"""Construct PrometheusExtractor.
+
+ Args:
+ extractor_name: Name of the `PrometheusExtractor` instance.
+ Used to keep track of the provenance of different data,
+ and to name tables to which this data is saved.
+ columns: List of column names to extract from the table.
+ """
+ # Member variable(s)
+ self._table=extractor_name
+ self._columns=columns
+ # Base class constructor
+ super().__init__(extractor_name=extractor_name)
+
+ def__call__(self,event:pd.DataFrame)->pd.DataFrame:
+"""Extract information from parquet file."""
+ output={key:[]forkeyinself._columns}# type: ignore
+ forkeyinself._columns:
+ ifkeyinevent.keys():
+ data=event[key]
+ ifisinstance(data,np.ndarray):
+ data=data.tolist()
+ ifisinstance(data,list):
+ output[key].extend(data)
+ else:
+ output[key].append(data)
+ else:
+ self.warning_once(f"{key} not found in {self._table}!")
+ returnoutput
+
+
+
+
+[docs]
+classPrometheusTruthExtractor(PrometheusExtractor):
+"""Class for extracting event level truth from Prometheus parquet files.
+
+ This Extractor will "initial_state" i.e. neutrino truth.
+ """
+
+ def__init__(self,table_name:str="mc_truth")->None:
+"""Construct PrometheusTruthExtractor.
+
+ Args:
+ table_name: Name of the table in the parquet files that contain
+ event-level truth. Defaults to "mc_truth".
+ """
+ columns=[
+ "interaction",
+ "initial_state_energy",
+ "initial_state_type",
+ "initial_state_zenith",
+ "initial_state_azimuth",
+ "initial_state_x",
+ "initial_state_y",
+ "initial_state_z",
+ ]
+ super().__init__(extractor_name=table_name,columns=columns)
+
+
+
+
+[docs]
+classPrometheusFeatureExtractor(PrometheusExtractor):
+"""Class for extracting pulses/photons from Prometheus parquet files."""
+
+ def__init__(self,table_name:str="photons"):
+"""Construct PrometheusFeatureExtractor.
+
+ Args:
+ table_name: Name of table in parquet files that contain the
+ photons/pulses. Defaults to "photons".
+ """
+ columns=[
+ "sensor_pos_x",
+ "sensor_pos_y",
+ "sensor_pos_z",
+ "string_id",
+ "sensor_id",
+ "t",
+ ]
+ super().__init__(extractor_name=table_name,columns=columns)
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/_modules/graphnet/data/parquet/deprecated_methods.html b/_modules/graphnet/data/parquet/deprecated_methods.html
index 238ca2429..ca1425ee8 100644
--- a/_modules/graphnet/data/parquet/deprecated_methods.html
+++ b/_modules/graphnet/data/parquet/deprecated_methods.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source c
fromtypingimportList,Union,Typefromgraphnet.dataimportDataConverter
-fromgraphnet.data.readersimportI3Reader
+fromgraphnet.data.readersimportI3Reader,ParquetReaderfromgraphnet.data.writersimportParquetWriter,SQLiteWriterfromgraphnet.data.extractors.icecubeimportI3Extractor
+fromgraphnet.data.extractors.internalimportParquetExtractorfromgraphnet.data.extractors.icecube.utilities.i3_filtersimportI3Filter
@@ -427,6 +455,45 @@
Source c
outdir=outdir,)
+
+
+
+[docs]
+classParquetToSQLiteConverter(DataConverter):
+"""Preconfigured DataConverter for converting Parquet to SQLite files.
+
+ This class converts Parquet files written by ParquetWriter to SQLite.
+ """
+
+ def__init__(
+ self,
+ extractors:List[ParquetExtractor],
+ outdir:str,
+ index_column:str="event_no",
+ num_workers:int=1,
+ ):
+"""Convert internal Parquet files to SQLite.
+
+ Args:
+ extractors: The `Extractor`(s) that will be applied to the input
+ files.
+ outdir: The directory to save the files in.
+ icetray_verbose: Set the level of verbosity of icetray.
+ Defaults to 0.
+ index_column: Name of the event id column added to the events.
+ Defaults to "event_no".
+ num_workers: The number of CPUs used for parallel processing.
+ Defaults to 1 (no multiprocessing).
+ """
+ super().__init__(
+ file_reader=ParquetReader(),
+ save_method=SQLiteWriter(),
+ extractors=extractors,
+ num_workers=num_workers,
+ index_column=index_column,
+ outdir=outdir,
+ )
+
@@ -451,7 +518,7 @@
Source c
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/data/readers/graphnet_file_reader.html b/_modules/graphnet/data/readers/graphnet_file_reader.html
index b14cda367..f7706ba94 100644
--- a/_modules/graphnet/data/readers/graphnet_file_reader.html
+++ b/_modules/graphnet/data/readers/graphnet_file_reader.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source co
file formats."""
-fromtypingimportList,Union,OrderedDict,Any
+fromtypingimportList,Union,OrderedDict,Any,Dictfromabcimportabstractmethod,ABCimportglobimportos
+importpandasaspdfromgraphnet.utilities.decoratorsimportfinalfromgraphnet.utilities.loggingimportLoggerfromgraphnet.data.dataclassesimportI3FileSetfromgraphnet.data.extractors.extractorimportExtractorfromgraphnet.data.extractors.icecubeimportI3Extractor
+fromgraphnet.data.extractors.internalimportParquetExtractor
+fromgraphnet.data.extractors.liquidoimportH5Extractor
+fromgraphnet.data.extractors.prometheusimportPrometheusExtractor
@@ -357,13 +388,21 @@
Source co
_accepted_extractors:List[Any]=[]@abstractmethod
- def__call__(self,file_path:Union[str,I3FileSet])->List[OrderedDict]:
+ def__call__(
+ self,file_path:Any
+ )->Union[List[OrderedDict[str,pd.DataFrame]],Dict[str,pd.DataFrame]]:"""Open and apply extractors to a single file.
- The `output` must be a list of dictionaries, where the number of events
- in the file `n_events` satisfies `len(output) = n_events`. I.e each
- element in the list is a dictionary, and each field in the dictionary
- is the output of a single extractor.
+ The `output` must be either
+ A) list of dictionaries, where the number of events
+ in the file `n_events` satisfies `len(output) = n_events`.
+ I.e each element in the list is a dictionary, and each field in
+ the dictionary is the output of a single extractor. If this is
+ provided, the `DataConverter` will automatically assign event ids.
+ B) A single dictionary where each field contains a single dataframe,
+ which holds the data from the `Extractor` for the entire file. In
+ this case, the `Reader` must itself assign event ids. This method
+ is faster if your files are not storing events serially. """@property
@@ -412,7 +451,14 @@
Source co
[docs]@finaldefset_extractors(
- self,extractors:Union[List[Extractor],List[I3Extractor]]
+ self,
+ extractors:Union[
+ List[Extractor],
+ List[I3Extractor],
+ List[ParquetExtractor],
+ List[H5Extractor],
+ List[PrometheusExtractor],
+ ],)->None:"""Set `Extractor`(s) as member variable.
@@ -427,7 +473,14 @@
Source code for graph
icetray_verbose: Set the level of verbosity of icetray. Defaults to 0. """
+ # checks
+ assertisinstance(gcd_rescue,str)# Set verbosityificetray_verbose==0:icetray.I3Logger.global_logger=icetray.I3NullLogger()
@@ -442,12 +471,16 @@
Source code for graphnet.data.readers.prometheus_reader
+"""Modules for reading data files from the Prometheus project."""
+
+fromtypingimportList,Union,OrderedDict
+importpandasaspd
+frompathlibimportPath
+
+fromgraphnet.data.extractors.prometheusimportPrometheusExtractor
+from.graphnet_file_readerimportGraphNeTFileReader
+
+
+
+[docs]
+classPrometheusReader(GraphNeTFileReader):
+"""A class for reading parquet files from Prometheus simulation."""
+
+ _accepted_file_extensions=[".parquet"]
+ _accepted_extractors=[PrometheusExtractor]
+
+ def__call__(self,file_path:str)->List[OrderedDict]:
+"""Extract data from single parquet file.
+
+ Args:
+ file_path: Path to parquet file.
+
+ Returns:
+ Extracted data.
+ """
+ # Open file
+ outputs=[]
+ file=pd.read_parquet(file_path)
+ forkinrange(len(file)):# Loop over events in file
+ extracted_event=OrderedDict()
+ forextractorinself._extractors:
+ assertisinstance(extractor,PrometheusExtractor)
+ ifextractor._tableinfile.columns:
+ output=extractor(file[extractor._table][k])
+ extracted_event[extractor._extractor_name]=output
+ outputs.append(extracted_event)
+ returnoutputs
+
+
+[docs]
+ deffind_files(self,path:Union[str,List[str]])->List[str]:
+"""Search folder(s) for parquet files.
+
+ Args:
+ path: directory to search for parquet files.
+
+ Returns:
+ List of parquet files in the folders.
+ """
+ files=[]
+ ifisinstance(path,str):
+ path=[path]
+
+ # List of files as Path objects
+ forpinpath:
+ files.extend(
+ list(Path(p).rglob(f"*{self.accepted_file_extensions}"))
+ )
+
+ # List of files as str's
+ paths_as_str:List[str]=[]
+ forfinfiles:
+ paths_as_str.append(f.absolute().as_posix())
+
+ returnpaths_as_str
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/_modules/graphnet/data/sqlite/deprecated_methods.html b/_modules/graphnet/data/sqlite/deprecated_methods.html
index d97e492fc..d431d8126 100644
--- a/_modules/graphnet/data/sqlite/deprecated_methods.html
+++ b/_modules/graphnet/data/sqlite/deprecated_methods.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for graphnet.data.utilities.parquet_to_sqlite
-"""Utilities for converting files from Parquet to SQLite."""
-
-importglob
-importos
-fromtypingimportList,Optional,Union
-
-importawkwardasak
-importnumpyasnp
-importpandasaspd
-fromtqdm.autoimporttrange
-
-fromgraphnet.data.utilities.sqlite_utilitiesimport(
- create_table_and_save_to_sql,
-)
-fromgraphnet.utilities.loggingimportLogger
-
-
-
-[docs]
-classParquetToSQLiteConverter(Logger):
-"""Convert Parquet files to a SQLite database.
-
- Each event in the parquet file(s) are assigned a unique event id. By
- default, every field in the parquet file(s) are extracted. One can choose
- to exclude certain fields by using the argument exclude_fields.
- """
-
- def__init__(
- self,
- parquet_path:Union[str,List[str]],
- mc_truth_table:str="mc_truth",
- excluded_fields:Optional[Union[str,List[str]]]=None,
- ):
-"""Construct `ParquetToSQLiteConverter`."""
- # checks
- ifisinstance(parquet_path,str):
- pass
- elifisinstance(parquet_path,list):
- assertisinstance(
- parquet_path[0],str
- ),"Argument `parquet_path` must be a string or list of strings"
- else:
- assertisinstance(
- parquet_path,str
- ),"Argument `parquet_path` must be a string or list of strings"
-
- assertisinstance(
- mc_truth_table,str
- ),"Argument `mc_truth_table` must be a string"
- self._parquet_files=self._find_parquet_files(parquet_path)
- ifexcluded_fieldsisnotNone:
- self._excluded_fields=excluded_fields
- else:
- self._excluded_fields=[]
- self._mc_truth_table=mc_truth_table
- self._event_counter=0
-
- # Base class constructor
- super().__init__(name=__name__,class_name=self.__class__.__name__)
-
- def_find_parquet_files(self,paths:Union[str,List[str]])->List[str]:
- ifisinstance(paths,str):
- ifpaths.endswith(".parquet"):
- files=[paths]
- else:
- files=glob.glob(f"{paths}/*.parquet")
- elifisinstance(paths,list):
- files=[]
- forpathinpaths:
- files.extend(self._find_parquet_files(path))
- assertlen(files)>0,f"No files found in {paths}"
- returnfiles
-
-
So
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/data/writers/graphnet_writer.html b/_modules/graphnet/data/writers/graphnet_writer.html
index 872aa88c1..81aeb36eb 100644
--- a/_modules/graphnet/data/writers/graphnet_writer.html
+++ b/_modules/graphnet/data/writers/graphnet_writer.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
[docs]
- defmerge_files(self,files:List[str],output_dir:str)->None:
-"""Merge parquet files.
+ defmerge_files(
+ self,
+ files:List[str],
+ output_dir:str,
+ events_per_batch:int=200000,
+ num_workers:int=1,
+ )->None:
+"""Convert files into shuffled batches.
- Args:
- files: input files for merging.
- output_dir: directory to store merged file(s) in.
+ Events will be shuffled, and the resulting batches will constitute
+ random subsamples of the full dataset.
- Raises:
- NotImplementedError
+ Args:
+ files: Files converted to parquet. Note this argument is ignored
+ by this method, as these files are automatically found
+ using the `output_dir`.
+ output_dir: The directory to store the batched data.
+ events_per_batch: Number of events in each batch.
+ Defaults to 200000.
+ num_workers: Number of workers to use for merging. Defaults to 1. """
- self.error(f"{self.__class__.__name__} does not have a merge method.")
- raiseNotImplementedError
Source code for
you have many events, as tables exceeding 400 million rows tend to be noticably slower to query. Defaults to None (All events are put into a single database).
+ index_column: Name of column that contains event id. """# Member Variablesself._file_extension=".db"self._merge_dataframes=Trueself._max_table_size=max_table_sizeself._database_name=merged_database_name
+ self._index_column=index_column# Add file extension to database name if forgottenifnotself._database_name.endswith(self._file_extension):
@@ -409,6 +439,7 @@
Source code for
output_file_path,default_type="FLOAT",integer_primary_key=len(df)<=n_events,
+ index_column=self._index_column,)saved_any=True
@@ -423,6 +454,7 @@
Source code for
self,files:List[str],output_dir:str,
+ primary_key_rescue:str="event_no",)->None:"""SQLite-specific method for merging output files/databases.
@@ -437,6 +469,9 @@
Source code for
you have many events, as tables exceeding 400 million rows tend to be noticably slower to query. Defaults to None (All events are put into a single database.)
+ primary_key_rescue: The name of the columns on which the primary
+ key is constructed. This will only be used if it is not
+ possible to infer the primary key name. """# Warningsifself._max_table_size:
@@ -448,10 +483,10 @@
Source code for
# Set variablesself._partition_count=1
+ self._primary_key_rescue=primary_key_rescue# Construct full database pathdatabase_path=os.path.join(output_dir,self._database_name)
- print(database_path)# Start merging if files are giveniflen(files)>0:os.makedirs(output_dir,exist_ok=True)
@@ -487,10 +522,11 @@
Source code for
# Merge temporary databases into newly created oneforfile_count,input_fileintqdm(enumerate(files),colour="green"):
-
# Extract table names and index column name in databasetry:tables,primary_key=get_primary_keys(database=input_file)
+ ifprimary_keyisNone:
+ primary_key=self._primary_key_rescueexceptAssertionErrorase:if"No tables found in database."instr(e):self.warning(f"Database {input_file} is empty. Skipping.")
@@ -585,7 +621,7 @@
+"""A CuratedDataset for unit tests."""
+fromtypingimportDict,Any,List,Tuple,Union
+importos
+
+fromgraphnet.dataimportERDAHostedDataset
+fromgraphnet.data.constantsimportFEATURES
+
+
+
+[docs]
+classTestDataset(ERDAHostedDataset):
+"""A CuratedDataset class for unit tests of ERDAHosted Datasets.
+
+ This dataset should not be used outside the context of unit tests.
+ """
+
+ # Static Member Variables:
+ _pulsemaps=["photons"]
+ _truth_table="mc_truth"
+ _event_truth=[
+ "interaction",
+ "initial_state_energy",
+ "initial_state_type",
+ "initial_state_zenith",
+ "initial_state_azimuth",
+ "initial_state_x",
+ "initial_state_y",
+ "initial_state_z",
+ ]
+ _pulse_truth=None
+ _features=FEATURES.PROMETHEUS
+ _experiment="ARCA Prometheus Simulation"
+ _creator="Rasmus F. Ørsøe"
+ _comments=(
+ "This Dataset should be used for unit tests only."
+ " Simulation produced by Stephan Meighen-Berger, "
+ "U. Melbourne."
+ )
+ _available_backends=["sqlite"]
+ _file_hashes={"sqlite":"EK3hSNgYr5"}
+ _citation=None
+
+ def_prepare_args(
+ self,backend:str,features:List[str],truth:List[str]
+ )->Tuple[Dict[str,Any],Union[List[int],None],Union[List[int],None]]:
+"""Prepare arguments for dataset.
+
+ Args:
+ backend: backend of dataset. Either "parquet" or "sqlite"
+ features: List of features from user to use as input.
+ truth: List of event-level truth form user.
+
+ Returns: Dataset arguments and selections
+ """
+ dataset_path=os.path.join(self.dataset_dir,"merged.db")
+
+ dataset_args={
+ "truth_table":self._truth_table,
+ "pulsemaps":self._pulsemaps,
+ "path":dataset_path,
+ "graph_definition":self._graph_definition,
+ "features":features,
+ "truth":truth,
+ }
+ selection=[0,1,2,3,4,5,6,7,8,9]# event 5 is empty
+ returndataset_args,selection,None
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/_modules/graphnet/deployment/deployer.html b/_modules/graphnet/deployment/deployer.html
index ea737ce89..a861b606f 100644
--- a/_modules/graphnet/deployment/deployer.html
+++ b/_modules/graphnet/deployment/deployer.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for
This module incorporates sinusoidal positional embeddings and auxiliary embeddings to process input sequences and produce meaningful
- representations.
+ representations. The module assumes that the input data is in the format of
+ (x, y, z, time, charge, auxiliary), being the first four features
+ mandatory. """def__init__(self,seq_length:int=128,
+ mlp_dim:Optional[int]=None,output_dim:int=384,scaled:bool=False,
+ n_features:int=6,):"""Construct `FourierEncoder`. Args: seq_length: Dimensionality of the base sinusoidal positional embeddings.
- output_dim: Output dimensionality of the final projection.
+ mlp_dim (Optional): Size of hidden, latent space of MLP. If not
+ given, `mlp_dim` is set automatically as multiples of
+ `seq_length` (in consistent with the 2nd place solution),
+ depending on `n_features`.
+ output_dim: Dimension of the output (I.e. number of columns). scaled: Whether or not to scale the embeddings.
+ n_features: The number of features in the input data. """super().__init__()
+
self.sin_emb=SinusoidalPosEmb(dim=seq_length,scaled=scaled)self.aux_emb=nn.Embedding(2,seq_length//2)self.sin_emb2=SinusoidalPosEmb(dim=seq_length//2,scaled=scaled)
- self.projection=nn.Sequential(
- nn.Linear(6*seq_length,6*seq_length),
- nn.LayerNorm(6*seq_length),
+
+ ifn_features<4:
+ raiseValueError(
+ f"At least x, y, z and time of the DOM are required. Got only "
+ f"{n_features} features."
+ )
+ elifn_features>=6:
+ hidden_dim=6*seq_length
+ else:
+ hidden_dim=int((n_features+0.5)*seq_length)
+
+ ifmlp_dimisNone:
+ mlp_dim=hidden_dim
+
+ self.mlp=nn.Sequential(
+ nn.Linear(hidden_dim,mlp_dim),
+ nn.LayerNorm(mlp_dim),nn.GELU(),
- nn.Linear(6*seq_length,output_dim),
+ nn.Linear(mlp_dim,output_dim),)
+ self.n_features=n_features
+
Source code for gr
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/models/components/pool.html b/_modules/graphnet/models/components/pool.html
index 11ce78f3d..ce921b2ce 100644
--- a/_modules/graphnet/models/components/pool.html
+++ b/_modules/graphnet/models/components/pool.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for gr
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/models/detector/icecube.html b/_modules/graphnet/models/detector/icecube.html
index 67974ea57..b43da6a7e 100644
--- a/_modules/graphnet/models/detector/icecube.html
+++ b/_modules/graphnet/models/detector/icecube.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
[docs]classORCA150(Detector):
-"""`Detector` class for Prometheus prototype."""
+"""`Detector` class for Prometheus ORCA150."""geometry_table_path=os.path.join(
- PROMETHEUS_GEOMETRY_TABLE_DIR,"orca_150.parquet"
+ PROMETHEUS_GEOMETRY_TABLE_DIR,"orca.parquet")xyz=["sensor_pos_x","sensor_pos_y","sensor_pos_z"]string_id_column="sensor_string_id"
@@ -357,11 +636,149 @@
Source code for gr
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html b/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html
index 61998b585..29744d2af 100644
--- a/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html
+++ b/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for graphnet.
scaled_emb:bool=False,include_dynedge:bool=False,dynedge_args:Dict[str,Any]=None,
+ n_features:int=6,):"""Construct `DeepIce`. Args: hidden_dim: The latent feature dimension.
+ mlp_ratio: Mlp expansion ratio of FourierEncoder and Transformer. seq_length: The base feature dimension. depth: The depth of the transformer. head_size: The size of the attention heads.
@@ -385,11 +415,16 @@
Source code for graphnet.
provided, DynEdge will be initialized with the original Kaggle Competition settings. If `include_dynedge` is False, this argument have no impact.
+ n_features: The number of features in the input data. """super().__init__(seq_length,hidden_dim)fourier_out_dim=hidden_dim//2ifinclude_dynedgeelsehidden_dimself.fourier_ext=FourierEncoder(
- seq_length,fourier_out_dim,scaled=scaled_emb
+ seq_length=seq_length,
+ mlp_dim=None,
+ output_dim=fourier_out_dim,
+ scaled=scaled_emb,
+ n_features=n_features,)self.rel_pos=SpacetimeEncoder(head_size)self.sandwich=nn.ModuleList(
@@ -406,7 +441,7 @@
Source code for g
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/models/graphs/edges/minkowski.html b/_modules/graphnet/models/graphs/edges/minkowski.html
index eb0721ff9..1c218cda2 100644
--- a/_modules/graphnet/models/graphs/edges/minkowski.html
+++ b/_modules/graphnet/models/graphs/edges/minkowski.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for g
max_pulses: Maximum number of pulses to keep in the event. z_name: Name of the z-coordinate column. hlc_name: Name of the `Hard Local Coincidence Check` column.
+ add_ice_properties: If True, scattering and absoption length of
+ ice in IceCube are added to the feature set based on z coordinate.
+ ice_args: Offset and scaling of the z coordinate in the Detector,
+ to be able to make similar conversion in the ice data. """
- super().__init__(input_feature_names=input_feature_names)
-
ifinput_feature_namesisNone:input_feature_names=["dom_x",
@@ -690,33 +724,39 @@
Source code for g
"rde",]
- ifz_namenotininput_feature_names:
- raiseValueError(
- f"z name {z_name} not found in "
- f"input_feature_names {input_feature_names}"
- )
+ ifadd_ice_properties:
+ ifz_namenotininput_feature_names:
+ raiseValueError(
+ f"z name '{z_name}' not found in "
+ f"input_feature_names {input_feature_names}"
+ )
+ self.all_features=input_feature_names+[
+ "scatt_lenght",
+ "abs_lenght",
+ ]
+ self.f_scattering,self.f_absoprtion=ice_transparency(**ice_args)
+ else:
+ self.all_features=input_feature_names
+
+ super().__init__(input_feature_names=input_feature_names)
+
ifhlc_namenotininput_feature_names:
- raiseValueError(
- f"hlc name {hlc_name} not found in "
- f"input_feature_names {input_feature_names}"
+ self.warning(
+ f"hlc name '{hlc_name}' not found in input_feature_names"
+ f" '{input_feature_names}', subsampling will be random.")
-
- self.all_features=input_feature_names+[
- "scatt_lenght",
- "abs_lenght",
- ]
+ hlc_name=Noneself.feature_indexes={feat:self.all_features.index(feat)forfeatininput_feature_names}
- self.f_scattering,self.f_absoprtion=ice_transparency()
-
self.input_feature_names=input_feature_namesself.n_features=len(self.all_features)self.max_length=max_pulsesself.z_name=z_nameself.hlc_name=hlc_name
+ self.add_ice_properties=add_ice_propertiesdef_define_output_feature_names(self,input_feature_names:List[str]
@@ -743,35 +783,47 @@
Source code for g
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/models/graphs/utils.html b/_modules/graphnet/models/graphs/utils.html
index 18ba17add..99d9f1c47 100644
--- a/_modules/graphnet/models/graphs/utils.html
+++ b/_modules/graphnet/models/graphs/utils.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
"""Utility functions for construction of graphs."""
-fromtypingimportList,Tuple
+fromtypingimportList,Tuple,Optional,Dict,Unionimportosimportnumpyasnpimportpandasaspd
@@ -503,13 +530,19 @@
Source code for graphne
[docs]
-defice_transparency()->Tuple[interp1d,interp1d]:
+defice_transparency(
+ z_offset:float=None,z_scaling:float=None
+)->Tuple[interp1d,interp1d]:"""Return interpolation functions for optical properties of IceCube. NOTE: The resulting interpolation functions assumes that the Z-coordinate of pulse are scaled as `z = z/500`. Any deviation from this scaling method results in inaccurate results.
+ Args:
+ z_offset: Offset to be added to the depth of the DOM.
+ z_scaling: Scaling factor to be applied to the depth of the DOM.
+
Returns: f_scattering: Function that takes a normalized depth and returns the corresponding normalized scattering length.
@@ -520,8 +553,11 @@
-[docs]
- deffit_weights(
- self,
- config_outdir:str,
- weight_name:str="",
- pisa_config_dict:Optional[Dict]=None,
- add_to_database:bool=False,
- )->pd.DataFrame:
-"""Fit flux weights to each neutrino event in `self._database_path`.
-
- If `statistical_fit=True`, only statistical effects are accounted for.
- If `True`, certain systematic effects are included, but not
- hypersurfaces.
-
- Args:
- config_outdir: The output directory in which to store the
- configuration.
- weight_name: The name of the weight. If `add_to_database=True`,
- this will be the name of the table.
- pisa_config_dict: The dictionary of PISA configurations. Can be
- used to change assumptions regarding the fit.
- add_to_database: If `True`, a table will be added to the database
- called `weight_name` with two columns:
- `[index_column, weight_name]`
-
- Returns:
- A dataframe with columns `[index_column, weight_name]`.
- """
- # If its a standard weight
- ifpisa_config_dictisNone:
- ifnotweight_name:
- print(weight_name)
- weight_name="pisa_weight_graphnet_standard"
-
- # If it is a custom weight without name
- elifpisa_config_dictisnotNone:
- ifnotweight_name:
- weight_name="pisa_custom_weight"
-
- pisa_config_path=self._make_config(
- config_outdir,weight_name,pisa_config_dict
- )
-
- model=Pipeline(pisa_config_path)
-
- ifself._statistical_fit=="True":
- # Only free parameters will be [aeff_scale] - corresponding to a statistical fit
- free_params=model.params.free.names
- forfree_paraminfree_params:
- iffree_paramnotin["aeff_scale"]:
- model.params[free_param].is_fixed=True
-
- # for stage in range(len(model.stages)):
- model.stages[-1].apply_mode="events"
- model.stages[-1].calc_mode="events"
- model.run()
-
- all_data=[]
- forcontainerinmodel.data:
- data=pd.DataFrame(container["event_no"],columns=["event_no"])
- data[weight_name]=container["weights"]
- all_data.append(data)
- results=pd.concat(all_data)
-
- ifadd_to_database:
- create_table_and_save_to_sql(
- results.columns,weight_name,self._database_path
- )
- returnresults.sort_values("event_no").reset_index(drop=True)
-[docs]
-defplot_2D_contour(
- contour_data:List[Dict],
- xlim:Tuple[float,float]=(0.4,0.6),
- ylim:Tuple[float,float]=(2.38*1e-3,2.55*1e-3),
- chi2_critical_value:float=4.605,
- width:float=3.176,
- height:float=2.388,
-)->Figure:
-"""Plot 2D contours from GraphNeT PISA fits.
-
- Args:
- contour_data: List of dictionaries with plotting information. Format is
- for each dictionary is:
- {'path': path_to_pisa_fit_result,
- 'model': 'name_of_my_model_in_fit'}.
- One can specify optional fields in the dictionary: "label" - the
- legend label, "color" - the color of the contour, "linestyle" - the
- style of the contour line.
- xlim: Lower and upper bound of x-axis.
- ylim: Lower and upper bound of y-axis.
- chi2_critical_value: The critical value of the chi2 fits. Defaults to
- 4.605 (90% CL). @NOTE: This, and the below, can't both be right.
- width: width of figure in inches.
- height: height of figure in inches.
-
- Returns:
- The figure with contours.
- """
- fig,ax=plt.subplots(figsize=(width,height),constrained_layout=True)
- proxy=[]
- labels=[]
- forentryincontour_data:
- entry_data,model_name,label,ls,color=read_entry(entry)
- model_idx=entry_data["model"]==model_name
- model_data=entry_data.loc[model_idx]
- x=pd.unique(model_data.sort_values("theta23_fixed")["theta23_fixed"])
- y=pd.unique(model_data.sort_values("dm31_fixed")["dm31_fixed"])
- z=np.zeros((len(y),len(x)))
- foriinrange(len(x)):
- forkinrange(len(y)):
- idx=(model_data["theta23_fixed"]==x[i])&(
- model_data["dm31_fixed"]==y[k]
- )
- match=model_data["mod_chi2"][idx]
- iflen(match)>0:
- ifmodel_data["converged"][idx].valuesisTrue:
- match=float(match)
- else:
- match=10000# Sets the z value very high to exclude it from contour
- else:
- match=10000# Sets the z value very high to exclude it from contour
- z[k,i]=match
-
- CS=ax.contour(
- np.sin(np.deg2rad(x))**2,
- y,
- z,
- levels=[chi2_critical_value],
- colors=color,
- label=label,
- linestyles=ls,
- linewidths=2,
- )
- # ax.clabel(CS, inline=1, fontsize=10)
- proxy.extend(
- [plt.Rectangle((0,0),1,1,fc=color)forpcinCS.collections]
- )
- ifchi2_critical_value==4.605:
- label=label+" 90 $\\%$ CL"
- labels.append(label)
- plt.legend(proxy,labels,frameon=False,loc="upper right")
- plt.xlim(xlim[0],xlim[1])
- plt.ylim(ylim[0],ylim[1])
- plt.xlabel("$\\sin^2(\\theta_{23})$",fontsize=12)
- plt.ylabel("$\\Delta m_{31}^2 [eV^2]$",fontsize=12)
- plt.ticklabel_format(axis="y",style="sci",scilimits=(0,0))
- plt.title("Sensitivity (Simplified Analysis)")
- returnfig
-
-
-
-
-[docs]
-defplot_1D_contour(
- contour_data:List[Dict],
- chi2_critical_value:float=2.706,
- width:float=2*3.176,
- height:float=2.388,
-)->Figure:
-"""Plot 1D contours from GraphNeT PISA fits.
-
- Args:
- contour_data: List of dictionaries with plotting information. Format is
- for each dictionary is:
- {'path': path_to_pisa_fit_result,
- 'model': 'name_of_my_model_in_fit'}.
- One can specify optional fields in the dictionary: "label" - the
- legend label, "color" - the color of the contour, "linestyle" - the
- style of the contour line.
- chi2_critical_value: The critical value of the chi2 fits. Defaults to
- 2.706 (90% CL). @NOTE: This, and the above, can't both be right.
- width: width of figure in inches.
- height: height of figure in inches.
-
- Returns:
- The figure with contours.
- """
- variables=["theta23_fixed","dm31_fixed"]
- fig,ax=plt.subplots(
- 1,2,figsize=(width,height),constrained_layout=True
- )
- ls=0
- forentryincontour_data:
- entry_data,model_name,label,ls,color=read_entry(entry)
- forvariableinvariables:
- model_idx=entry_data["model"]==model_name
- padding_idx=entry_data[variable]!=-1
- chi2_idx=entry_data["mod_chi2"]<chi2_critical_value
- model_data=entry_data.loc[
- (model_idx)&(padding_idx)&(chi2_idx),:
- ]
- x=model_data.sort_values(variable)
- ifvariable=="theta23_fixed":
- ax[0].set_ylabel("$\\chi^2$",fontsize=12)
- ax[0].plot(
- np.sin(np.deg2rad(x[variable]))**2,
- x["mod_chi2"],
- color=color,
- label=label,
- ls=ls,
- )
- ax[0].set_xlabel("$\\sin(\\theta_{23})^2$",fontsize=12)
- elifvariable=="dm31_fixed":
- ax[1].plot(
- x[variable],x["mod_chi2"],color=color,label=label,ls=ls
- )
- ax[1].ticklabel_format(axis="x",style="sci",scilimits=(0,0))
- ax[1].set_xlabel("$\\Delta m_{31}^2 [eV^2]$",fontsize=12)
- h=[item.get_text()foriteminax[1].get_yticklabels()]
- empty_string_labels=[""]*len(h)
- ax[1].set_yticklabels(empty_string_labels)
- ax[0].set_ylim(0,chi2_critical_value)
- ax[1].set_ylim(0,chi2_critical_value)
- ax[0].legend()
- returnfig
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/_modules/graphnet/training/callbacks.html b/_modules/graphnet/training/callbacks.html
index 7bea62b10..223ead861 100644
--- a/_modules/graphnet/training/callbacks.html
+++ b/_modules/graphnet/training/callbacks.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for graphnet.tr
z=torch.cos(graph[self._zenith_key]).reshape(-1,1)returntorch.cat((x,y,z),dim=1)
+
+
+
+[docs]
+classTrack(Label):
+"""Class for producing NuMuCC label.
+
+ Label is set to `1` if the event is a NuMu CC event, else `0`.
+ """
+
+ def__init__(
+ self,
+ key:str="track",
+ pid_key:str="pid",
+ interaction_key:str="interaction_type",
+ ):
+"""Construct `Track` label.
+
+ Args:
+ key: The name of the field in `Data` where the label will be
+ stored. That is, `graph[key] = label`.
+ pid_key: The name of the pre-existing key in `graph` that will
+ be used to access the pdg encoding, used when calculating
+ the direction.
+ interaction_key: The name of the pre-existing key in `graph` that
+ will be used to access the interaction type (1 denoting CC),
+ used when calculating the direction.
+ """
+ self._pid_key=pid_key
+ self._int_key=interaction_key
+
+ # Base class constructor
+ super().__init__(key=key)
+
+ def__call__(self,graph:Data)->torch.tensor:
+"""Compute label for `graph`."""
+ label=(graph[self._pid_key]==14)&(graph[self._int_key]==1)
+ returnlabel.type(torch.int)
+
@@ -422,7 +488,7 @@
Source code for graphnet.tr
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/training/loss_functions.html b/_modules/graphnet/training/loss_functions.html
index a5dab2a8f..2b907c806 100644
--- a/_modules/graphnet/training/loss_functions.html
+++ b/_modules/graphnet/training/loss_functions.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for gra
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/training/utils.html b/_modules/graphnet/training/utils.html
index b2e9890bd..249ad1eb7 100644
--- a/_modules/graphnet/training/utils.html
+++ b/_modules/graphnet/training/utils.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for gra
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/utilities/argparse.html b/_modules/graphnet/utilities/argparse.html
index f46e09075..05cfe51ff 100644
--- a/_modules/graphnet/utilities/argparse.html
+++ b/_modules/graphnet/utilities/argparse.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for gr
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/utilities/config/training_config.html b/_modules/graphnet/utilities/config/training_config.html
index dda012503..0677ef70d 100644
--- a/_modules/graphnet/utilities/config/training_config.html
+++ b/_modules/graphnet/utilities/config/training_config.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Source code for graphnet.
folder_i3_files=list(filter(is_i3_file,folder_files))folder_gcd_files=list(filter(is_gcd_file,folder_files))
- # Make sure that no more than one GCD file is found; and use rescue file of none is found.
+ # Make sure that no more than one GCD file is found;
+ # and use rescue file if none is found.assertlen(folder_gcd_files)<=1iflen(folder_gcd_files)==0:assertgcd_rescueisnotNone
@@ -453,7 +481,7 @@
Source code for graphnet.
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_modules/graphnet/utilities/imports.html b/_modules/graphnet/utilities/imports.html
index 67929efcd..6726c5d70 100644
--- a/_modules/graphnet/utilities/imports.html
+++ b/_modules/graphnet/utilities/imports.html
@@ -122,10 +122,9 @@
-
+
-
@@ -281,14 +280,42 @@
Created using
- Sphinx 7.2.6.
+ Sphinx 7.3.7.
and
Material for
Sphinx
diff --git a/_sources/about.md.txt b/_sources/about.md.txt
deleted file mode 100644
index 013075b7d..000000000
--- a/_sources/about.md.txt
+++ /dev/null
@@ -1,48 +0,0 @@
-# About
-
-`GraphNeT` is an open-source Python framework aimed at providing high quality, user friendly, end-to-end functionality to perform reconstruction tasks at neutrino telescopes using graph neural networks (GNNs). `GraphNeT` makes it fast and easy to train complex models that can provide event reconstruction with state-of-the-art performance, for arbitrary detector configurations, with inference times that are orders of magnitude faster than traditional reconstruction techniques.
-
-## Impact
-
-`GraphNeT` provides a common framework for ML developers and physicists that wish to use the state-of-the-art GNN tools in their research. By uniting both user groups, `GraphNeT` aims to increase the longevity and usability of individual code contributions from ML developers by building a general, reusable software package based on software engineering best practices, and lowers the technical threshold for physicists that wish to use the most performant tools for their scientific problems.
-
-The `GraphNeT` models can improve event classification and yield very accurate reconstruction, e.g., for low energy neutrinos observed in IceCube. Here, a GNN implemented in `GraphNeT` was applied to the problem of neutrino oscillations in IceCube, leading to significant improvements in both energy and angular reconstruction in the energy range relevant to oscillation studies. Furthermore, it was shown that the GNN could improve muon vs. neutrino classification and thereby the efficiency and purity of a neutrino sample for such an analysis.
-
-Similarly, improved angular reconstruction has a great impact on, e.g., neutrino point source analyses.
-
-Finally, the fast (order millisecond) reconstruction allows for a whole new type of cosmic alerts at lower energies, which were previously unfeasible. GNN-based reconstruction makes it possible to identify low energy (< 10 TeV) neutrinos and monitor their rate, direction, and energy in real-time. This will enable cosmic neutrino alerts based on such neutrinos for the first time ever, despite a large background of neutrinos that are not of cosmic origin.
-
-## Usage
-
-`GraphNeT` comprises a number of modules providing the necessary tools to build workflows from ingesting raw training data in domain-specific formats to deploying trained models in domain-specific reconstruction chains, as illustrated in [the Figure](flowchart).
-
-
-:::{figure-md} flowchart
-:class: figclass
-
-
-
-High-level overview of a typical workflow using `GraphNeT`: `graphnet.data` enables converting domain-specific data to industry-standard, intermediate file formats and reading this data; `graphnet.models` allows for configuring and building complex GNN models using simple, physics-oriented components; `graphnet.training` manages model training and experiment logging; and finally, `graphnet.deployment` allows for using trained models for inference in domain-specific reconstruction chains.
-:::
-
-`graphnet.models` provides modular components subclassing `torch.nn.Module`, meaning that users only need to import a few existing, purpose-built components and chain them together to form a complete GNN. ML developers can contribute to `GraphNeT` by extending this suite of model components — through new layer types, physics tasks, graph connectivities, etc. — and experiment with optimising these for different reconstruction tasks using experiment tracking.
-
-These models are trained using `graphnet.training` on data prepared using `graphnet.data`, to satisfy the high I/O loads required when training ML models on large batches of events, which domain-specific neutrino physics data formats typically do not allow.
-
-Trained models are deployed to a domain-specific reconstruction chain, yielding predictions, using the components in `graphnet.deployment`. This can either be through model files or container images, making deployment as portable and dependency-free as possible.
-
-By splitting up the GNN development as in [the Figure](flowchart), `GraphNeT` allows physics users to interface only with high-level building blocks or pre-trained models that can be used directly in their reconstruction chains, while allowing ML developers to continuously improve and expand the framework’s capabilities.
-
-## Acknowledgements
-
-
-:::{figure-md} eu-emblem
-
-
-
-
-:::
-
-This project has received funding from the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement No. 890778.
-
-The work of Rasmus Ørsøe was partly performed in the framework of the PUNCH4NFDI consortium supported by DFG fund "NFDI 39/1", Germany.
\ No newline at end of file
diff --git a/_sources/about/about.rst.txt b/_sources/about/about.rst.txt
new file mode 100644
index 000000000..dee9f2902
--- /dev/null
+++ b/_sources/about/about.rst.txt
@@ -0,0 +1,35 @@
+.. include:: ../substitutions.rst
+
+
+|graphnet|\ GraphNeT is an open-source Python framework aimed at providing high quality, user friendly, end-to-end functionality to perform reconstruction tasks at neutrino telescopes using deep learning. |graphnet|\ GraphNeT makes it fast and easy to train complex models that can provide event reconstruction with state-of-the-art performance, for arbitrary detector configurations, with inference times that are orders of magnitude faster than traditional reconstruction techniques.
+|graphnet|\ GraphNeT provides a common, detector agnostic framework for ML developers and physicists that wish to use the state-of-the-art tools in their research. By uniting both user groups, |graphnet|\ GraphNeT aims to increase the longevity and usability of individual code contributions from ML developers by building a general, reusable software package based on software engineering best practices, and lowers the technical threshold for physicists that wish to use the most performant tools for their scientific problems.
+
+Usage
+-----
+
+|graphnet|\ GraphNeT comprises a number of modules providing the necessary tools to build workflows from ingesting raw training data in domain-specific formats to deploying trained models in domain-specific reconstruction chains, as illustrated in [the Figure](flowchart).
+
+.. _flowchart:
+.. figure:: ../../../paper/flowchart.png
+
+ High-level overview of a typical workflow using |graphnet|\ GraphNeT: :code:`graphnet.data` enables converting domain-specific data to industry-standard, intermediate file formats and reading this data; :code:`graphnet.models` allows for configuring and building complex models using simple, physics-oriented components; :code:`graphnet.training` manages model training and experiment logging; and finally, :code:`graphnet.deployment` allows for using trained models for inference in domain-specific reconstruction chains.
+
+:code:`graphnet.models` provides modular components subclassing :code:`torch.nn.Module`, meaning that users only need to import a few existing, purpose-built components and chain them together to form a complete model. ML developers can contribute to |graphnet|\ GraphNeT by extending this suite of model components — through new layer types, physics tasks, graph connectivities, etc. — and experiment with optimising these for different reconstruction tasks using experiment tracking.
+
+These models are trained using :code:`graphnet.training` on data prepared using :code:`graphnet.data`, to satisfy the high I/O loads required when training ML models on large batches of events, which domain-specific neutrino physics data formats typically do not allow.
+
+Trained models are deployed to a domain-specific reconstruction chain, yielding predictions, using the components in :code:`graphnet.deployment`. This can either be through model files or container images, making deployment as portable and dependency-free as possible.
+
+By splitting up the model development as in :numref:`flowchart`, |graphnet|\ GraphNeT allows physics users to interface only with high-level building blocks or pre-trained models that can be used directly in their reconstruction chains, while allowing ML developers to continuously improve and expand the framework’s capabilities.
+
+
+Acknowledgements
+----------------
+
+
+.. image:: ../../../assets/images/eu-emblem.jpg
+ :width: 150
+
+This project has received funding from the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement No. 890778.
+
+The work of Rasmus Ørsøe was partly performed in the framework of the PUNCH4NFDI consortium supported by DFG fund "NFDI 39/1", Germany.
\ No newline at end of file
diff --git a/_sources/api/graphnet.data.curated_datamodule.rst.txt b/_sources/api/graphnet.data.curated_datamodule.rst.txt
new file mode 100644
index 000000000..e544b93e9
--- /dev/null
+++ b/_sources/api/graphnet.data.curated_datamodule.rst.txt
@@ -0,0 +1,8 @@
+
+curated\_datamodule
+===================
+
+.. automodule:: graphnet.data.curated_datamodule
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/_sources/api/graphnet.data.extractors.combine_extractors.rst.txt b/_sources/api/graphnet.data.extractors.combine_extractors.rst.txt
new file mode 100644
index 000000000..c03d4a24d
--- /dev/null
+++ b/_sources/api/graphnet.data.extractors.combine_extractors.rst.txt
@@ -0,0 +1,8 @@
+
+combine\_extractors
+===================
+
+.. automodule:: graphnet.data.extractors.combine_extractors
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/_sources/api/graphnet.data.extractors.internal.parquet_extractor.rst.txt b/_sources/api/graphnet.data.extractors.internal.parquet_extractor.rst.txt
new file mode 100644
index 000000000..8c1b55330
--- /dev/null
+++ b/_sources/api/graphnet.data.extractors.internal.parquet_extractor.rst.txt
@@ -0,0 +1,8 @@
+
+parquet\_extractor
+==================
+
+.. automodule:: graphnet.data.extractors.internal.parquet_extractor
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/_sources/api/graphnet.data.extractors.internal.rst.txt b/_sources/api/graphnet.data.extractors.internal.rst.txt
new file mode 100644
index 000000000..677d3169d
--- /dev/null
+++ b/_sources/api/graphnet.data.extractors.internal.rst.txt
@@ -0,0 +1,36 @@
+
+internal
+========
+
+
+.. automodule:: graphnet.data.extractors.internal
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+|start-h2| Submodules |end-h2|
+
+
+.. toctree::
+ :maxdepth: 2
+
+ graphnet.data.extractors.internal.parquet_extractor
+
+
+
+.. |start-h2| raw:: html
+
+
\ No newline at end of file
diff --git a/_sources/api/graphnet.pisa.plotting.rst.txt b/_sources/api/graphnet.pisa.plotting.rst.txt
deleted file mode 100644
index c3df95bbf..000000000
--- a/_sources/api/graphnet.pisa.plotting.rst.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-
-plotting
-========
-
-.. automodule:: graphnet.pisa.plotting
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/_sources/api/graphnet.rst.txt b/_sources/api/graphnet.rst.txt
index 867e8c81e..cee836074 100644
--- a/_sources/api/graphnet.rst.txt
+++ b/_sources/api/graphnet.rst.txt
@@ -15,9 +15,10 @@ API
:maxdepth: 2
graphnet.data
+ graphnet.datasets
graphnet.deployment
+ graphnet.exceptions
graphnet.models
- graphnet.pisa
graphnet.training
graphnet.utilities
diff --git a/_sources/contribute.md.txt b/_sources/contribute.md.txt
deleted file mode 100644
index 4a45b1d1b..000000000
--- a/_sources/contribute.md.txt
+++ /dev/null
@@ -1,37 +0,0 @@
-# Contribute
-
-To make sure that the process of contributing is as smooth and effective as possible, we provide a few guidelines in this contributing guide that we encourage contributors to follow.
-
-## GitHub issues
-
-Use [GitHub issues](https://github.com/graphnet-team/graphnet/issues) for tracking and discussing requests and bugs. If there is anything you'd wish to contribute, the best place to start is to create a new issues and describe what you would like to work on. Alternatively you can assign open issues to yourself, to indicate that you would like to take ownership of a particular task. Using issues actively in this way ensures transparency and agreement on priorities. This helps avoid situations with a lot of development effort going into a feature that e.g. turns out to be outside of scope for the project; or a specific solution to a problem that could have been better solved differently.
-
-## Pull requests
-
-Develop code in a fork of the [main repo](https://github.com/graphnet-team/graphnet). Make contributions in dedicated development/feature branches on your forked repositories, e.g. if you are implementing a specific `GraphDefinition` class you could create a branch named `add-euclidean-graph-definition` on your own fork.
-
-Create pull requests from your development branch into `graphnet-team/graphnet:main` to contribute to the project. **To be accepted,** pull requests must:
- * pass all automated checks,
- * be reviewed by at least one other contributor. These reviews should check for:
- * standard python coding conventions, e.g. [PEP8](https://www.python.org/dev/peps/pep-0008/)
- * docstring (Google-style) and type hinting as necessary,
- * unit tests as necessary,
- * clean coding practices, see e.g. [here](https://gist.github.com/wojteklu/73c6914cc446146b8b533c0988cf8d29).
-
-## Conventions
-
-This repository aims to support python 3 version that are actively supported (currently `>=3.8`). Standard python coding conventions should be followed:
-
-* Adhere to [PEP 8](https://www.python.org/dev/peps/pep-0008/)
-* Use [pylint](https://www.pylint.org/)/[flake8](https://flake8.pycqa.org/) and [black](https://black.readthedocs.io/) to ensure as clean and well-formatted code as possible
-* When relevant, adhere to [clean coding practices](https://gist.github.com/wojteklu/73c6914cc446146b8b533c0988cf8d29)
-
-## Code quality
-
-To ensure consistency in code style and adherence to select best practices, we recommend that all developers use `black`, `flake8`, `mypy`, `pydocstyle`, and `docformatter` for automatically formatting and checking their code. This can conveniently be done using pre-commit hooks. To set this up, first make sure that you have installed the `pre-commit` python package. It comes with included when installing `graphnet` with the `develop` tag, i.e., `pip install -e .[develop]`. Then, do
-```bash
-$ pre-commit install
-```
-Then, everytime you commit a change, your code and docstrings will automatically be formatted using `black` and `docformatter`, and `flake8`, `mypy`, and `pydocstyle` will check for errors and adherence to PEP8, PEP257, and static typing. See an illustration of the concept below:
-![pre-commit pipeline](../../assets/images/precommit_pipeline.png)
-Image source: https://ljvmiranda921.github.io/notebook/2018/06/21/precommits-using-black-and-flake8/
\ No newline at end of file
diff --git a/_sources/contribute/contribute.rst.txt b/_sources/contribute/contribute.rst.txt
new file mode 100644
index 000000000..70ac9492f
--- /dev/null
+++ b/_sources/contribute/contribute.rst.txt
@@ -0,0 +1,54 @@
+.. include:: ../substitutions.rst
+
+Contributing To GraphNeT\ |graphnet-header|
+===========================================
+To make sure that the process of contributing is as smooth and effective as possible, we provide a few guidelines in this contributing guide that we encourage contributors to follow.
+
+GitHub issues
+-------------
+
+Use `GitHub issues `_ for tracking and discussing requests and bugs. If there is anything you'd wish to contribute, the best place to start is to create a new issues and describe what you would like to work on. Alternatively you can assign open issues to yourself, to indicate that you would like to take ownership of a particular task. Using issues actively in this way ensures transparency and agreement on priorities. This helps avoid situations with a lot of development effort going into a feature that e.g. turns out to be outside of scope for the project; or a specific solution to a problem that could have been better solved differently.
+
+Pull requests
+-------------
+
+Develop code in a fork of the `main repo `_. Make contributions in dedicated development/feature branches on your forked repositories, e.g. if you are implementing a specific :code:`GraphDefinition` class you could create a branch named :code:`add-euclidean-graph-definition` on your own fork.
+
+Create pull requests from your development branch into :code:`graphnet-team/graphnet:main` to contribute to the project. **To be accepted,** pull requests must:
+
+* pass all automated checks,
+
+* be reviewed by at least one other contributor. These reviews should check for:
+
+ #. standard python coding conventions, e.g. `PEP8 `_
+
+ #. docstring (Google-style) and type hinting as necessary
+
+ #. unit tests as necessary
+
+ #. clean coding practices, see e.g. `here `_.
+
+Conventions
+-----------
+
+This repository aims to support python 3 version that are actively supported (currently :code:`>=3.8`). Standard python coding conventions should be followed:
+
+* Adhere to `PEP8 `_ `black `_
+* Use `pylint `_ / `flake8 `_ and `black `_ to ensure as clean and well-formatted code as possible
+* When relevant, adhere to `clean code practices `_
+
+Code quality
+------------
+
+To ensure consistency in code style and adherence to select best practices, we **require** that all developers use :code:`black`, :code:`flake8`, :code:`mypy`, :code:`pydocstyle`, and :code:`docformatter` for automatically formatting and checking their code. This can conveniently be done using pre-commit hooks. To set this up, first make sure that you have installed the :code:`pre-commit` python package. It is included when installing |graphnet|\ GraphNeT with the :code:`develop` tag, i.e., :code:`pip install -e .[develop]`. Then, do
+
+.. code-block:: bash
+
+ pre-commit install
+
+
+Then, everytime you commit a change, your code and docstrings will automatically be formatted using :code:`black` and :code:`docformatter`, while :code:`flake8`, :code:`mypy`, and :code:`pydocstyle` will check for errors and adherence to PEP8, PEP257, and static typing. See an illustration of the concept below:
+
+.. image:: ../../../assets/images/precommit_pipeline.png
+
+Image source: https://ljvmiranda921.github.io/notebook/2018/06/21/precommits-using-black-and-flake8/
\ No newline at end of file
diff --git a/_sources/data_conversion/data_conversion.rst.txt b/_sources/data_conversion/data_conversion.rst.txt
new file mode 100644
index 000000000..6658ce656
--- /dev/null
+++ b/_sources/data_conversion/data_conversion.rst.txt
@@ -0,0 +1,278 @@
+.. include:: ../substitutions.rst
+
+Data Conversion in GraphNeT\ |graphnet-header|
+==============================================
+
+GraphNeT comes with powerful data conversion code that can convert data formats specific to experiments to deep learning friendly data formats.
+
+Data conversion in GraphNeT follows a reader/writer scheme, where the :code:`DataConverter` does most of the heavy lifting.
+
+.. image:: ../../../assets/images/dataconverter.svg
+ :width: 500
+ :alt: Illustration of the reader/writer scheme of data conversion in GraphNeT.
+ :align: right
+ :class: with-shadow
+
+
+
+
+In the illustration, the "reader" module represents an experiment-specific implementation of :code:`GraphNeTFileReader`, able to parse your data files. The "Writer" module denotes a :code:`GraphNeTWriter` module that saves the interim data format from :code:`DataConverter` to disk.
+
+
+:code:`DataConverter`
+--------------------
+
+:code:`DataConverter` provides parallel processing of file conversion and
+extraction from experiment-specific file formats to graphnet-supported data formats out-of-the-box.
+:code:`DataConverter` can also assigns event ids to your events.
+
+Specifically, :code:`DataConverter` will manage multiprocessing calls to :code:`GraphNeTFileReader`\ s and passing their output to
+:code:`GraphNeTWriter` which will save the extracted data from your files as a specific file format.
+Below is an example of configuring :code:`DataConverter` to extract data from :code:`.h5` files from the LiquidO experiment,
+and to save the data as :code:`.parquet` files which are compatible with the :code:`ParquetDataset` in GraphNeT.
+
+
+.. code-block::
+
+ from graphnet.data.extractors.liquido import H5HitExtractor, H5TruthExtractor
+ from graphnet.data.dataconverter import DataConverter
+ from graphnet.data.readers import LiquidOReader
+ from graphnet.data.writers import ParquetWriter
+
+ # Your settings
+ dir_with_files = '/home/my_files'
+ outdir = '/home/my_outdir'
+ num_workers = 5
+
+ # Instantiate DataConverter - exports data from LiquidO to Parquet
+ converter = DataConverter(file_reader = LiquidOReader(),
+ save_method = ParquetWriter(),
+ extractors=[H5HitExtractor(), H5TruthExtractor()],
+ outdir=outdir,
+ num_workers=num_workers,
+ )
+ # Run Converter
+ converter(input_dir = dir_with_files)
+ # Merge files (Optional)
+ converter.merge_files()
+
+When :code:`converter(input_dir = dir_with_files)` is called, a `bijective` conversion is run, where every file
+is converted independently and in parallel. I.e. the parallelization is done over files.
+
+The :code:`converter.merge_files()` call merges these many smaller files into larger chunks of data, and the specific behavior of :code:`GraphNeTWriter.merge_files()`
+depends fully on the specific implementation of the :code:`GraphNeTWriter`.
+
+This modular structure means that extending GraphNeT conversion code to export experiment data to new file formats is as easy as implementing a new :code:`GraphNeTWriter`.
+Similarly, extending GraphNeT conversion code to work on data from a new experiment only requires implementing a new :code:`GraphNeTFileReader` and it's associated :code:`Extractors` .
+
+
+:code:`Readers`
+~~~~~~~~~~~~~~
+
+Readers are experiment-specific file readers, written to be able to read and parse data from specific experiments.
+
+Readers must subclass :code:`GraphNeTFileReader` and implement a :code:`__call__` method that opens a file, applies :code:`Extractor`\ s and return either it's output in one of two forms:
+
+- Serial Output: list of dictionaries, where the number of events in the file :code:`n_events` satisfies :code:`len(output) = n_events`. I.e each element in the list is a dictionary, and each field in the dictionary is the output of a single extractor. If this is provided, the :code:`DataConverter` will automatically assign event ids.
+
+- Vectorized Output: A single dictionary where each field contains a single dataframe, which holds the data from the :code:`Extractor` for the entire file. In this case, the :code:`Reader` must itself assign event ids. This method is faster if your files are not storing events serially.
+
+In addition, classes inheriting from :code:`GraphNeTFileReader` must set class properties :code:`accepted_file_extensions` and :code:`accepted_extractors`.
+
+.. raw:: html
+
+
+ Example of a Reader
+
+Implementing a :code:`GraphNeTFileReader` to read data from your experiment requires writing just a few lines of code.
+Below is an example of a reader meant to parse and extract data from :code:`.h5` files from LiquidO, which output data in the vectorized format.
+
+.. code-block:: python
+
+ from typing import List, Union, Dict
+ from glob import glob
+ import os
+ import pandas as pd
+
+ from graphnet.data.extractors.liquido import H5Extractor
+ from .graphnet_file_reader import GraphNeTFileReader
+
+ class LiquidOReader(GraphNeTFileReader):
+ """A class for reading h5 files from LiquidO."""
+
+ _accepted_file_extensions = [".h5"]
+ _accepted_extractors = [H5Extractor]
+
+ def __call__(self, file_path: str) -> Dict[str, pd.DataFrame]:
+ """Extract data from single parquet file.
+
+ Args:
+ file_path: Path to h5 file.
+
+ Returns:
+ Extracted data.
+ """
+ # Open file
+ outputs = {}
+ for extractor in self._extractors:
+ output = extractor(file_path)
+ if output is not None:
+ outputs[extractor._extractor_name] = output
+ return outputs
+
+ def find_files(self, path: Union[str, List[str]]) -> List[str]:
+ """Search folder(s) for h5 files.
+
+ Args:
+ path: directory to search for h5 files.
+
+ Returns:
+ List of h5 files in the folders.
+ """
+ files = []
+ if isinstance(path, str):
+ path = [path]
+ for p in path:
+ files.extend(glob(os.path.join(p, "*.h5")))
+ return files
+
+
+.. raw:: html
+
+
+
+
+:code:`Extractors`
+~~~~~~~~~~~~~~~~~
+
+Rarely `all` the data available in files from experiments are needed for training deep learning models, therefore GraphNeT uses :code:`Extractors` to extract only
+specific parts of the available data.
+
+:code:`Extractors` are written to work with a specific :code:`GraphNeTFileReader` and should subclass :code:`Extractor`.
+
+.. raw:: html
+
+
+ Example of an Extractor
+
+Implementing an :code:`Extractor` to retrieve specific parts of your data files is easy.
+Below is an example of an :code:`Extractor` that will retrieve tables :code:`.h5` files from LiquidO.
+
+.. code-block:: python
+
+ class H5Extractor(Extractor):
+ """Class for extracting information from LiquidO h5 files."""
+
+ def __init__(self, extractor_name: str, column_names: List[str]):
+ """Construct H5Extractor.
+
+ Args:
+ extractor_name: Name of the `H5Extractor` instance.
+ Used to keep track of the provenance of different data,
+ and to name tables to which this data is saved.
+ column_names: Name of the columns in `extractor_name`.
+ """
+ # Member variable(s)
+ self._table = extractor_name
+ self._column_names = column_names
+ # Base class constructor
+ super().__init__(extractor_name=extractor_name)
+
+ def __call__(self, file_path: str) -> pd.DataFrame:
+ """Extract information from h5 file."""
+ with h5py.File(file_path, "r") as f:
+ available_tables = [f for f in f.keys()]
+ if self._table in available_tables:
+ array = f[self._table][:]
+ # Will throw error if the number of columns don't match
+ self._verify_columns(array)
+ df = pd.DataFrame(array, columns=self._column_names)
+ return df
+ else:
+ return None
+
+ def _verify_columns(self, array: np.ndarray) -> None:
+ try:
+ assert array.shape[1] == len(self._column_names)
+ except AssertionError as e:
+ self.error(
+ f"Got {len(self._column_names)} column names but "
+ f"{self._table} has {array.shape[1]}. Please make sure "
+ f"that the column names match. ({self._column_names})"
+ )
+ raise e
+
+
+.. raw:: html
+
+
+
+
+
+
+:code:`Writers`
+~~~~~~~~~~~~~~
+
+Writers are methods used to save the interim data format from :code:`DataConverter` to disk. They are subclasses of the :code:`GraphNetWriter` and should
+implement the :code:`save_file method`, which recieves the interim data format from from a single file, and optionally the :code:`merge_files` method,
+which will be called by :code:`DataConverter` through :code:`DataConverter.merge_files()`.
+
+Below is a conceptual example of how easy it is to extend the data conversion API to save files in a different format.
+In this example, the writer will save the entire set of extractor outputs - a dictionary with pd.DataFrames - as a single pickle file.
+
+.. code:: python
+
+ from graphnet.data.writers import GraphNeTWriter
+ import pickle
+
+ class MyPickleWriter(GraphNeTWriter):
+
+ _file_extension = ".pickle"
+ _merge_dataframes = True # `data` will be Dict[str, pd.DataFrame]
+
+ def _save_file(
+ self,
+ data: Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]],
+ output_file_path: str,
+ n_events: int,
+ ) -> None:
+ """Save the interim data format from a single input file.
+
+ Args:
+ data: the interim data from a single input file.
+ output_file_path: output file path.
+ n_events: Number of events container in `data`.
+ """
+
+ # Save file contents as .pickle
+ with open(output_file_path, 'wb') as handle:
+ pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+
+ def merge_files(
+ self,
+ files: List[str],
+ output_dir: str,
+ ) -> None:
+ """Merge smaller files.
+
+ Args:
+ files: Files to be merged.
+ output_dir: The directory to store the merged files in.
+ """
+ raise NotImplementedError
+
+
+
+Two writers are implemented in GraphNeT; the :code:`SQLiteWriter` and :code:`ParquetWriter`, each of which output files that are directly used for
+training by :code:`ParquetDataset` and :code:`SQLiteDataset`.
+
+
+
+
+
+
+
+
+
diff --git a/_sources/datasets/datasets.rst.txt b/_sources/datasets/datasets.rst.txt
new file mode 100644
index 000000000..8716d6113
--- /dev/null
+++ b/_sources/datasets/datasets.rst.txt
@@ -0,0 +1,368 @@
+.. include:: ../substitutions.rst
+
+Datasets In GraphNeT\ |graphnet-header|
+=======================================
+
+
+
+:code:`Dataset`
+---------------
+
+The `Dataset `_ class in GraphNeT is a generic base class from which all Datasets in GraphNeT is expected to originate. :code:`Dataset` is based on `torch.utils.data.Dataset `_\ s, and is
+is responsible for reading data from a file and preparing user-specified data representations as `torch_geometric.data.Data `_ objects.
+`Dataset `_ provides structure and common functionality without ties to any specific file format.
+
+Subclasses of :code:`Dataset` inherits the ability to be exported as a `DatasetConfig `_ file:
+
+.. code-block:: python
+
+ dataset = Dataset(...)
+ dataset.config.dump("dataset.yml")
+
+This :code:`.yml` file will contain details about the path to the input data, the tables and columns that should be loaded, any selection that should be applied to data, etc.
+In another session, you can then recreate the same :code:`Dataset`:
+
+.. code-block:: python
+
+ from graphnet.data.dataset import Dataset
+
+ dataset = Dataset.from_config("dataset.yml")
+
+You also have the option to define multiple datasets from the same data file(s) using a single :code:`DatasetConfig` file but with multiple selections:
+
+.. code-block:: python
+
+ dataset = Dataset(...)
+ dataset.config.selection = {
+ "train": "event_no % 2 == 0",
+ "test": "event_no % 2 == 1",
+ }
+ dataset.config.dump("dataset.yml")
+
+When you then re-create your dataset, it will appear as a :code:`Dict` containing your datasets:
+
+.. code-block:: python
+
+ datasets = Dataset.from_config("dataset.yml")
+ >>> datasets
+ {"train": Dataset(...),
+ "test": Dataset(...),}
+
+You can also combine multiple selections into a single, named dataset:
+
+.. code-block:: python
+
+ dataset = Dataset(..)
+ dataset.config.selection = {
+ "train": [
+ "event_no % 2 == 0 & abs(injection_type) == 12",
+ "event_no % 2 == 0 & abs(injection_type) == 14",
+ "event_no % 2 == 0 & abs(injection_type) == 16",
+ ],
+ (...)
+ }
+ >>> dataset.config.dump("dataset.yml")
+ >>> datasets = Dataset.from_config("dataset.yml")
+ >>> datasets
+ {"train": EnsembleDataset(...),
+ (...)}
+
+You also have the option to select random subsets of your data using :code:`DatasetConfig` using the :code:`N random events ~ ...` syntax, e.g.:
+
+.. code-block:: python
+
+ dataset = Dataset(..)
+ dataset.config.selection = "1000 random events ~ abs(injection_type) == 14"
+
+Finally, you can also reference selections that you have stored as external CSV or JSON files on disk:
+
+.. code-block:: python
+
+ dataset.config.selection = {
+ "train": "50000 random events ~ train_selection.csv",
+ "test": "test_selection.csv",
+ }
+
+.. raw:: html
+
+
+ Example of DataConfig
+
+GraphNeT comes with a pre-defined :code:`DatasetConfig` file for the small open-source dataset which can be found at :code:`graphnet/configs/datasets/training_example_data_sqlite.yml`.
+It looks like so:
+
+.. code-block:: yaml
+
+ path: $GRAPHNET/data/examples/sqlite/prometheus/prometheus-events.db
+ graph_definition:
+ arguments:
+ columns: [0, 1, 2]
+ detector:
+ arguments: {}
+ class_name: Prometheus
+ dtype: null
+ nb_nearest_neighbours: 8
+ node_definition:
+ arguments: {}
+ class_name: NodesAsPulses
+ node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+ class_name: KNNGraph
+ pulsemaps:
+ - total
+ features:
+ - sensor_pos_x
+ - sensor_pos_y
+ - sensor_pos_z
+ - t
+ truth:
+ - injection_energy
+ - injection_type
+ - injection_interaction_type
+ - injection_zenith
+ - injection_azimuth
+ - injection_bjorkenx
+ - injection_bjorkeny
+ - injection_position_x
+ - injection_position_y
+ - injection_position_z
+ - injection_column_depth
+ - primary_lepton_1_type
+ - primary_hadron_1_type
+ - primary_lepton_1_position_x
+ - primary_lepton_1_position_y
+ - primary_lepton_1_position_z
+ - primary_hadron_1_position_x
+ - primary_hadron_1_position_y
+ - primary_hadron_1_position_z
+ - primary_lepton_1_direction_theta
+ - primary_lepton_1_direction_phi
+ - primary_hadron_1_direction_theta
+ - primary_hadron_1_direction_phi
+ - primary_lepton_1_energy
+ - primary_hadron_1_energy
+ - total_energy
+ - dummy_pid
+ index_column: event_no
+ truth_table: mc_truth
+ seed: 21
+ selection:
+ test: event_no % 5 == 0
+ validation: event_no % 5 == 1
+ train: event_no % 5 > 1
+
+.. raw:: html
+
+
+
+
+:code:`SQLiteDataset` & :code:`ParquetDataset`
+----------------------------------------------
+
+The two specific implementations of :code:`Dataset` exists :
+
+- `ParquetDataset `_ : Constructs :code:`Dataset` from files created by :code:`ParquetWriter`.
+- `SQLiteDataset `_ : Constructs :code:`Dataset` from files created by :code:`SQLiteWriter`.
+
+
+To instantiate a :code:`Dataset` from your files, you must specify at least the following:
+
+- :code:`pulsemaps`: These are named fields in your Parquet files, or tables in your SQLite databases, which store one or more pulse series from which you would like to create a dataset. A pulse series represents the detector response, in the form of a series of PMT hits or pulses, in some time window, usually triggered by a single neutrino or atmospheric muon interaction. This is the data that will be served as input to the `Model`.
+- :code:`truth_table`: The name of a table/array that contains the truth-level information associated with the pulse series, and should contain the truth labels that you would like to reconstruct or classify. Often this table will contain the true physical attributes of the primary particle — such as its true direction, energy, PID, etc. — and is therefore graph- or event-level (as opposed to the pulse series tables, which are node- or hit-level) truth information.
+- :code:`features`: The names of the columns in your pulse series table(s) that you would like to include for training; they typically constitute the per-node/-hit features such as xyz-position of sensors, charge, and photon arrival times.
+- :code:`truth`: The columns in your truth table/array that you would like to include in the dataset.
+- :code:`graph_definition`: A `GraphDefinition`that prepares the raw data from the `Dataset` into your choice in data representation.
+
+After that, you can construct your :code:`Dataset` from a SQLite database with just a few lines of code:
+
+.. code-block:: python
+
+ from graphnet.data.sqlite import SQLiteDataset
+ from graphnet.models.detector.prometheus import Prometheus
+ from graphnet.models.graphs import KNNGraph
+ from graphnet.models.graphs.nodes import NodesAsPulses
+
+ graph_definition = KNNGraph(
+ detector=Prometheus(),
+ node_definition=NodesAsPulses(),
+ nb_nearest_neighbours=8,
+ )
+
+ dataset = SQLiteDataset(
+ path="data/examples/sqlite/prometheus/prometheus-events.db",
+ pulsemaps="total",
+ truth_table="mc_truth",
+ features=["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t", ...],
+ truth=["injection_energy", "injection_zenith", ...],
+ graph_definiton = graph_definition,
+ )
+
+ graph = dataset[0] # torch_geometric.data.Data
+..
+
+Or similarly for Parquet files:
+
+.. code-block:: python
+
+ from graphnet.data.parquet import ParquetDataset
+ from graphnet.models.detector.prometheus import Prometheus
+ from graphnet.models.graphs import KNNGraph
+ from graphnet.models.graphs.nodes import NodesAsPulses
+
+ graph_definition = KNNGraph(
+ detector=Prometheus(),
+ node_definition=NodesAsPulses(),
+ nb_nearest_neighbours=8,
+ )
+
+ dataset = ParquetDataset(
+ path="data/examples/parquet/prometheus/prometheus-events.parquet",
+ pulsemaps="total",
+ truth_table="mc_truth",
+ features=["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t", ...],
+ truth=["injection_energy", "injection_zenith", ...],
+ graph_definiton = graph_definition,
+ )
+
+ graph = dataset[0] # torch_geometric.data.Data
+
+It's then straightforward to create a :code:`DataLoader` for training, which will take care of batching, shuffling, and such:
+
+.. code-block:: python
+
+ from graphnet.data.dataloader import DataLoader
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=128,
+ num_workers=10,
+ )
+
+ for batch in dataloader:
+ ...
+
+The :code:`Dataset`s in GraphNeT use :code:`torch_geometric.data.Data` objects to present data as graphs, and graphs in GraphNeT are therefore compatible with PyG and its handling of graph objects.
+By default, the following fields will be available in a graph built by :code:`Dataset` :
+
+- :code:`graph.x`: Node feature matrix with shape :code:`[num_nodes, num_features]`
+- :code:`graph.edge_index`: Graph connectivity in `COO format `_ with shape :code:`[2, num_edges]` and type :code:`torch.long` (by default this will be :code:`None`, i.e., the nodes will all be disconnected).
+- :code:`graph.edge_attr`: Edge feature matrix with shape :code:`[num_edges, num_edge_features]` (will be :code:`None` by default).
+- :code:`graph.features`: A copy of your :code:`features` argument to :code:`Dataset`, see above.
+- :code:`graph[truth_label] for truth_label in truth`: For each truth label in the :code:`truth` argument, the corresponding data is stored as a :code:`[num_rows, 1]` dimensional tensor. E.g., :code:`graph["energy"] = torch.tensor(26, dtype=torch.float)`
+- :code:`graph[feature] for feature in features`: For each feature given in the :code:`features` argument, the corresponding data is stored as a :code:`[num_rows, 1]` dimensional tensor. E.g., :code:`graph["sensor_x"] = torch.tensor([100, -200, -300, 200], dtype=torch.float)``
+
+:code:`SQLiteDataset` vs. :code:`ParquetDataset`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Besides working on different file formats, :code:`SQLiteDataset` and :code:`ParquetDataset` have significant differences,
+which may lead you to choose one over the other, depending on the problem at hand.
+
+:SQLiteDataset: SQLite provides fast random access to all events inside it. This makes plotting and subsampling your dataset particularly easy,
+ as you can use the :code:`selection` argument to :code:`SQLiteDataset` to pick out exactly which events you want to use. However, one clear downside of SQLite is that its **uncompressed**,
+ meaning that it is intractable to use for very large datasets. Converting raw files to SQLite also takes a while, and query speed scales roughly as :code:`log(n)` where n is the number of rows in the table being queried.
+
+:ParquetDataset: Parquet files produced by :code:`ParquetWriter` are compressed by ~x8 and in shuffled batches of 200.000 events (default) stored in seperate :code:`.parquet` files.
+ Unlike SQLite, the query speed remains constant regardless of dataset size, but does not offer fast random access. :code:`ParquetDataset` works on the merged files from :code:`ParquetWriter` and will read them serially file-by-file, row-by-row.
+ This means that the subsampling of your dataset needs to happen prior to the conversion to :code:`parquet`, unlike `SQLiteDataset` which allows for subsampling after conversion, due to it's fast random access.
+ Conversion of files to :code:`parquet` is significantly faster than its :code:`SQLite` counterpart.
+
+
+.. note::
+
+ :code:`ParquetDataset` is scalable to ultra large datasets, but is more difficult to work with and has a higher memory consumption.
+
+ :code:`SQLiteDataset` does not scale to very large datasets, but is easy to work with and has minimal memory consumption.
+
+
+Choosing a subset of events using `selection`
+----------------------------------------------
+
+You can choose to include only a subset of the events in your data file(s) in your :code:`Dataset` by providing a :code:`selection` and :code:`index_column` argument.
+`selection` is a list of integer event IDs that defines your subset and :code:`index_column` is the name of the column in your data that contains these IDs.
+
+Suppose you wanted to include only events with IDs :code:`[10, 5, 100, 21, 5001]` in your dataset, and that your index column was named :code:`"event_no"`, then
+
+.. code-block:: python
+
+ from graphnet.data.sqlite import SQLiteDataset
+
+ dataset = SQLiteDataset(
+ ...,
+ index_column="event_no",
+ selection=[10, 5, 100, 21, 5001],
+ )
+
+ assert len(dataset) == 5
+
+would produce a :code:`Dataset` with only those five events.
+
+.. note::
+
+ For :code:`SQLiteDatase`, the :code:`selection` argument specifies individual events chosen for the dataset,
+ whereas for :code:`ParquetDataset`, the :code:`selection` argument specifies which batches are used in the dataset.
+
+
+Adding custom :code:`Label`\ s
+--------------------------
+
+Some specific applications of :code:`Model`\ s in GraphNeT might require truth labels that are not included by default in your truth table.
+In these cases you can define a `Label `_ that calculates your label on the fly:
+
+.. code-block:: python
+
+ import torch
+ from torch_geometric.data import Data
+
+ from graphnet.training.labels import Label
+
+ class MyCustomLabel(Label):
+ """Class for producing my label."""
+ def __init__(self):
+ """Construct `MyCustomLabel`."""
+ # Base class constructor
+ super().__init__(key="my_custom_label")
+
+ def __call__(self, graph: Data) -> torch.tensor:
+ """Compute label for `graph`."""
+ label = ... # Your computations here.
+ return label
+
+You can then pass your :code:`Label` to your :code:`Dataset` as:
+
+.. code-block:: python
+
+ dataset.add_label(MyCustomLabel())
+
+ graph = dataset[0]
+ graph["my_custom_label"]
+ >>> ...
+
+
+Combining Multiple Datasets
+---------------------------
+
+You can combine multiple instances of :code:`Dataset` from GraphNeT into a single :code:`Dataset` by using the `EnsembleDataset `_ class:
+
+.. code-block:: python
+
+ from graphnet.data import EnsembleDataset
+ from graphnet.data.parquet import ParquetDataset
+ from graphnet.data.sqlite import SQLiteDataset
+
+ dataset_1 = SQLiteDataset(...)
+ dataset_2 = SQLiteDataset(...)
+ dataset_3 = ParquetDataset(...)
+
+ ensemble_dataset = EnsembleDataset([dataset_1, dataset_2, dataset_3])
+
+You can find a detailed example `here `_ .
+
+
+Implementing a new :code:`Dataset`
+----------------------------------
+
+You can extend GraphNeT to work on new file formats by implementing a subclass of :code:`Dataset` that works on the files you have.
+
+To do so, all you need to do is to implement the :code:`abstractmethod` `query_table `_
+which defines the logic of retrieving information from your files.
+
+The GraphNeT development team is willing to support such efforts, so please consider reaching out to us if you need help.
diff --git a/_sources/getting_started/getting_started.md.txt b/_sources/getting_started/getting_started.md.txt
new file mode 100644
index 000000000..35e1e7d16
--- /dev/null
+++ b/_sources/getting_started/getting_started.md.txt
@@ -0,0 +1,695 @@
+GraphNeT tutorial
+=================
+
+Contents
+--------
+
+1. `Introduction <#1-introduction>`__
+2. `Overview of GraphNeT <#2-overview-of-graphnet>`__
+3. `Data <#3-data>`__
+4. `The Dataset and DataLoader classes <#4-the-dataset-and-dataloader-classes>`__
+
+Appendix
+--------
+
+- A. `Interfacing your data with GraphNeT <#a-interfacing-your-data-with-graphnet>`__
+- B. `Converting your data to a supported format <#b-converting-your-data-to-a-supported-format>`__
+- C. `Basics for SQLite databases in GraphNeT <#c-basics-for-sqlite-databases-in-graphnet>`__
+
+Introduction
+------------
+
+GraphNeT is an open-source Python framework aimed at providing high quality, user friendly, end-to-end functionality to perform reconstruction tasks at neutrino telescopes using deep learning (DL). The framework builds on `PyTorch `__, `PyG `__, and `PyTorch-Lightning `__, but attempts to abstract away many of the lower-level implementation details and instead provide simple, high-level components that makes it easy and fast for physicists to use DL in their research.
+
+This tutorial aims to introduce the various elements of GraphNeT to new users. It will go through the main modules, explain some of the structure and design behind these, and show concrete code examples. Users should be able to follow along and run the code themselves, after having `installed `__ GraphNeT. After completing the tutorial, users should be able to continue running some of the provided `example scripts `__ and start modifying these to suit their own needs.
+
+However, this tutorial and the accompanying example scripts are not comprehensive: They are intended as simple starting point, showing just some of the things you can do with GraphNeT. If you have any question, run into any problems, or just need help, consider first joining the `GraphNeT team's Slack group `__ to talk to like-minded folks, or `open an issue `__ if you have a feature to suggest or are confident you have encountered a bug.
+
+If you want a quick lay of the land, you can start with `Section 2 - Overview of GraphNet <#2-overview-of-graphnet>`__. If you want to get your hands dirty right away, feel free to skip to `Section 3 - Data <#3-data>`__ and the subsequent sections.
+
+Overview of GraphNeT
+---------------------
+
+The main modules of GraphNeT are, in the order that you will likely use them:
+
+- `graphnet.data `__: For converting domain-specific data (i.e., I3 in the case of IceCube) to generic, intermediate file formats (e.g., SQLite or Parquet) using `DataConverter `__; and for reading data as graphs from these intermediate files when training using `Dataset `, and its format-specific subclasses and `DataLoader `__.
+- `graphnet.models `__: For building models to perform a variety of physics tasks. The base `Model `__ class provides common interfaces for training and inference, as well as for model management (saving, loading, configs, etc.). This can be subclassed to build and train any model using GraphNeT functionality. The more specialised `StandardModel `__ provides a simple way to create a standard type of `Model` with a fixed structure. This type of model is composed of the following components, in sequence:
+
+ - `GraphDefinition `__: A single, self-contained module that handles all processing from raw data to graph representation. It consists of the following sub-modules in sequence:
+ - `Detector `__: For handling detector-specific preprocessing of data. Currently, this module provides standardization of experiment specific input data.
+ - `NodeDefinition `__: A swapable module that defines what a node/row represents. In charge of transforming the collection of standardized Cherenkov pulses associated with a triggered event into a node/row representation of choice. It is the choice in this module that defines if nodes/rows represents single Cherenkov pulses, DOMs, entire strings or something completely different. **Note**: You can create `NodeDefinitions` that represents the data as sequences, images or whatever you fancy, making GraphNeT compatible with any deep learning paradigm, such as CNNs, Transformers etc.
+ - `EdgeDefinition `__ (Optional): A module that defines how edges are drawn between your nodes. This could be connecting the _N_ nearest neighbours of each node or connecting all nodes within a radius of _R_ meters of each other. For methods that does not directly use edges in their data representations, this module can be skipped.
+ - `backbone `__: For implementing the actual model architecture. These are the components of GraphNeT that are actually being trained, and the architecture and complexity of these are central to the performance and optimisation on the physics/learning task being performed. For now, we provide a few different example architectures, e.g., `DynEdge `__ and `ConvNet `__, but in principle any DL architecture could be implemented here — and we encourage you to contribute your favourite!
+ - `Task `__: For choosing a certain physics/learning task or tasks with respect to which the model should be trained. We provide a number of common `reconstruction `__ (`DirectionReconstructionWithKappa` and `EnergyReconstructionWithUncertainty`) and `classification `__ (e.g., `BinaryClassificationTask` and `MulticlassClassificationTask`) tasks, but we encourage you to expand on these with new, more specialised tasks appropriate to your physics use case. For now, `Task` instances also require an appropriate `LossFunction `__ to specify how the models should be trained (see below).
+
+ These components are packaged in a particularly simple way in `StandardModel`, but they are not specific to it.
+ That is, they can be used in any combination, and alongside more specialised PyTorch/PyG code, as part of a more generic `Model`.
+
+- `graphnet.training `__: For training GraphNeT models, including specifying a `LossFunction `__, defining
+
+
+Adding custom truth labels
+--------------------------
+
+Some specific applications of `Model`s in GraphNeT might require truth labels that are not included by default in your truth table. In these cases you can define a :doc:`Label ` that calculates your label on the fly:
+
+.. code-block:: python
+
+ import torch
+ from torch_geometric.data import Data
+
+ from graphnet.training.labels import Label
+
+ class MyCustomLabel(Label):
+ """Class for producing my label."""
+ def __init__(self):
+ """Construct `MyCustomLabel`."""
+ # Base class constructor
+ super().__init__(key="my_custom_label")
+
+ def __call__(self, graph: Data) -> torch.tensor:
+ """Compute label for `graph`."""
+ label = ... # Your computations here.
+ return label
+
+You can then pass your `Label` to your `Dataset` as:
+
+.. code-block:: python
+
+ dataset.add_label(MyCustomLabel())
+
+ graph = dataset[0]
+ graph["my_custom_label"]
+ >>> ...
+
+Combining Multiple Datasets
+---------------------------
+
+You can combine multiple instances of `Dataset` from GraphNeT into a single `Dataset` by using the :doc:`EnsembleDataset ` class:
+
+.. code-block:: python
+
+ from graphnet.data import EnsembleDataset
+ from graphnet.data.parquet import ParquetDataset
+ from graphnet.data.sqlite import SQLiteDataset
+
+ dataset_1 = SQLiteDataset(...)
+ dataset_2 = SQLiteDataset(...)
+ dataset_3 = ParquetDataset(...)
+
+ ensemble_dataset = EnsembleDataset([dataset_1, dataset_2, dataset_3])
+
+You can find a detailed example `here `_.
+
+Creating reproducible Datasets using DatasetConfig
+--------------------------------------------------
+
+You can summarise your `Dataset` and its configuration by exporting it as a :doc:`DatasetConfig ` file:
+
+.. code-block:: python
+
+ dataset = Dataset(...)
+ dataset.config.dump("dataset.yml")
+
+This YAML file will contain details about the path to the input data, the tables and columns that should be loaded, any selection that should be applied to data, etc.
+In another session, you can then recreate the same `Dataset`:
+
+.. code-block:: python
+
+ from graphnet.data.dataset import Dataset
+
+ dataset = Dataset.from_config("dataset.yml")
+
+You also have the option to define multiple datasets from the same data file(s) using a single `DatasetConfig` file but with multiple selections:
+
+.. code-block:: python
+
+ dataset = Dataset(...)
+ dataset.config.selection = {
+ "train": "event_no % 2 == 0",
+ "test": "event_no % 2 == 1",
+ }
+ dataset.config.dump("dataset.yml")
+
+When you then re-create your dataset, it will appear as a `Dict` containing your datasets:
+
+.. code-block:: python
+
+ datasets = Dataset.from_config("dataset.yml")
+ >>> datasets
+ {"train": Dataset(...),
+ "test": Dataset(...),}
+
+You can also combine multiple selections into a single, named dataset:
+
+.. code-block:: python
+
+ dataset = Dataset(..)
+ dataset.config.selection = {
+ "train": [
+ "event_no % 2 == 0 & abs(injection_type) == 12",
+ "event_no % 2 == 0 & abs(injection_type) == 14",
+ "event_no % 2 == 0 & abs(injection_type) == 16",
+ ],
+ (...)
+ }
+ >>> dataset.config.dump("dataset.yml")
+ >>> datasets = Dataset.from_config("dataset.yml")
+ >>> datasets
+ {"train": EnsembleDataset(...),
+ (...)}
+
+You also have the option to select random subsets of your data using `DatasetConfig` using the `N random events ~ ...` syntax, e.g.:
+
+.. code-block:: python
+
+ dataset = Dataset(..)
+ dataset.config.selection = "1000 random events ~ abs(injection_type) == 14"
+
+Finally, you can also reference selections that you have stored as external CSV or JSON files on disk:
+
+.. code-block:: python
+
+ dataset.config.selection = {
+ "train": "50000 random events ~ train_selection.csv",
+ "test": "test_selection.csv",
+ }
+
+Example `DataConfig`
+--------------------
+
+GraphNeT comes with a pre-defined `DatasetConfig` file for the small open-source dataset which can be found at ``graphnet/configs/datasets/training_example_data_sqlite.yml``.
+It looks like so:
+
+.. code-block:: yaml
+
+ path: $GRAPHNET/data/examples/sqlite/prometheus/prometheus-events.db
+ graph_definition:
+ arguments:
+ columns: [0, 1, 2]
+ detector:
+ arguments: {}
+ class_name: Prometheus
+ dtype: null
+ nb_nearest_neighbours: 8
+ node_definition:
+ arguments: {}
+ class_name: NodesAsPulses
+ node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+ class_name: KNNGraph
+ pulsemaps:
+ - total
+ features:
+ - sensor_pos_x
+ - sensor_pos_y
+ - sensor_pos_z
+ - t
+ truth:
+ - injection_energy
+ - injection_type
+ - injection_interaction_type
+ - injection_zenith
+ - injection_azimuth
+ - injection_bjorkenx
+ - injection_bjorkeny
+ - injection_position_x
+ - injection_position_y
+ - injection_position_z
+ - injection_column_depth
+ - primary_lepton_1_type
+ - primary_hadron_1_type
+ - primary_lepton_1_position_x
+ - primary_lepton_1_position_y
+ - primary_lepton_1_position_z
+ - primary_hadron_1_position_x
+ - primary_hadron_1_position_y
+ - primary_hadron_1_position_z
+ - primary_lepton_1_direction_theta
+ - primary_lepton_1_direction_phi
+ - primary_hadron_1_direction_theta
+ - primary_hadron_1_direction_phi
+ - primary_lepton_1_energy
+ - primary_hadron_1_energy
+ - total_energy
+ - dummy_pid
+ index_column: event_no
+ truth_table: mc_truth
+ seed: 21
+ selection:
+ test: event_no % 5 == 0
+ validation: event_no % 5 == 1
+ train: event_no % 5 > 1
+
+
+Advanced Functionality in SQLiteDataset
+---------------------------------------
+
+**@TODO**: node_truth_table, string selections ...
+
+The `Model` class
+-----------------
+
+One important part of the philosophy for :doc:`Model `s in GraphNeT is that they are self-contained.
+Functionality that a specific model requires (data pre-processing, transformation and other auxiliary calculations) should exist within the `Model` itself such that it is portable and deployable as a single package that only depends on data.
+That is, conceptually,
+
+> Data → `Model` → Predictions
+
+You can subclass the `Model` class to create any model implementation using GraphNeT components (such as instances of, e.g., the `GraphDefinition`, `Backbone`, and `Task` classes) along with PyTorch and PyG functionality.
+All `Model`s that are applicable to the same detector configuration, regardless of how the `Model`s themselves are implemented, should be able to act on the same graph (`torch_geometric.data.Data`) objects, thereby making them interchangeable and directly comparable.
+
+The `StandardModel` class
+----------------------------
+
+The simplest way to define a `Model` in GraphNeT is through the `StandardModel` subclass.
+This is uniquely defined based on one each of [`GraphDefinition` ](), [`Backbone` ], and one or more [`Task` ]s. Each of these components will be a problem-specific instance of these parent classes. This structure guarantees modularity and reuseability. For example, the only adaptation needed to run a `Model` made for IceCube on a different experiment — say, KM3NeT — would be to switch out the `Detector` component in `GraphDefinition` representing IceCube with one that represents KM3NeT. Similarly, a `Model` developed for [`EnergyReconstruction` ] can be put to work on a different problem, e.g., [`DirectionReconstructionWithKappa` ], by switching out just the [`Task` ] component.
+
+GraphNeT comes with many pre-defined components that you can simply import and use out-of-the-box.
+So to get started, all you need to do is to import your choices in these components and build the model.
+Below is a snippet that defines a `Model` that reconstructs the zenith angle with uncertainties using the `GNN published by IceCube `_ for the IceCube Upgrade detector:
+
+.. code-block:: python
+
+ # Choice of graph representation, GNN architecture, and physics task
+ from graphnet.models.detector.prometheus import Prometheus
+ from graphnet.models.graphs import KNNGraph
+ from graphnet.models.graphs.nodes import NodesAsPulses
+ from graphnet.models.gnn.dynedge import DynEdge
+ from graphnet.models.task.reconstruction import ZenithReconstructionWithKappa
+
+ # Choice of loss function and Model class
+ from graphnet.training.loss_functions import VonMisesFisher2DLoss
+ from graphnet.models import StandardModel
+
+ # Configuring the components
+
+ # Represents the data as a point-cloud graph where each
+ # node represents a pulse of Cherenkov radiation
+ # edges drawn to the 8 nearest neighbours
+
+ graph_definition = KNNGraph(
+ detector=Prometheus(),
+ node_definition=NodesAsPulses(),
+ nb_nearest_neighbours=8,
+ )
+ backbone = DynEdge(
+ nb_inputs=detector.nb_outputs,
+ global_pooling_schemes=["min", "max", "mean"],
+ )
+ task = ZenithReconstructionWithKappa(
+ hidden_size=backbone.nb_outputs,
+ target_labels="injection_zenith",
+ loss_function=VonMisesFisher2DLoss(),
+ )
+
+ # Construct the Model
+ model = StandardModel(
+ graph_definition=graph_definition,
+ backbone=backbone,
+ tasks=[task],
+ )
+
+**Note:** We're adding the argument ``global_pooling_schemes=["min", "max", "mean"],`` to the ``Backbone`` component,
+
+Creating reproducible `Model`s using `ModelConfig`
+--------------------------------------------------
+
+You can export your choices of `Model` components and their configuration to a `ModelConfig` file, and recreate your `Model` in a different session. That is,
+
+.. code-block:: python
+
+ model = Model(...)
+ model.save_config("model.yml")
+
+You can then reconstruct the same model architecture from the `.yml` file:
+
+.. code-block:: python
+
+ from graphnet.models import Model
+
+ # Indicate that you `trust` the config file after inspecting it, to allow for
+ # dynamically loading classes references in the file.
+ model = Model.from_config("model.yml", trust=True)
+
+**Please note**: Models built from a `ModelConfig` are initialised with random weights.
+The `ModelConfig` class is only meant for defining model _definitions_ in a portable, human-readable format.
+To save also trained model weights, you need to save the entire model, see below.
+
+Example `ModelConfig`
+-------------------------
+
+You can find several pre-defined `ModelConfig`'s under `graphnet/configs/models`. Below are the contents of `example_energy_reconstruction_model.yml`:
+
+```yml
+arguments:
+ architecture:
+ ModelConfig:
+ arguments:
+ add_global_variables_after_pooling: false
+ dynedge_layer_sizes: null
+ features_subset: null
+ global_pooling_schemes: [min, max, mean, sum]
+ nb_inputs: 4
+ nb_neighbours: 8
+ post_processing_layer_sizes: null
+ readout_layer_sizes: null
+ class_name: DynEdge
+ graph_definition:
+ ModelConfig:
+ arguments:
+ columns: [0, 1, 2]
+ detector:
+ ModelConfig:
+ arguments: {}
+ class_name: Prometheus
+ dtype: null
+ nb_nearest_neighbours: 8
+ node_definition:
+ ModelConfig:
+ arguments: {}
+ class_name: NodesAsPulses
+ node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+ class_name: KNNGraph
+ optimizer_class: '!class torch.optim.adam Adam'
+ optimizer_kwargs: {eps: 0.001, lr: 0.001}
+ scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR'
+ scheduler_config: {interval: step}
+ scheduler_kwargs:
+ factors: [0.01, 1, 0.01]
+ milestones: [0, 20.0, 80]
+ tasks:
+ - ModelConfig:
+ arguments:
+ hidden_size: 128
+ loss_function:
+ ModelConfig:
+ arguments: {}
+ class_name: LogCoshLoss
+ loss_weight: null
+ prediction_labels: null
+ target_labels: total_energy
+ transform_inference: '!lambda x: torch.pow(10,x)'
+ transform_prediction_and_target: '!lambda x: torch.log10(x)'
+ transform_support: null
+ transform_target: null
+ class_name: EnergyReconstruction
+class_name: StandardModel
+
+Building your own `Model` class
+--------------------------------
+
+**@TODO**
+
+
+Training `Model`s and tracking experiments
+------------------------------------------------
+
+`Model`s in GraphNeT comes with a powerful in-built :py:func:`~graphnet.models.model.Model.fit` method that reduces the training of models on neutrino telescopes to a syntax that is similar to that of `sklearn`:
+
+.. code-block:: python
+
+ model = Model(...)
+ train_dataloader = DataLoader(...)
+ model.fit(train_dataloader=train_dataloader, max_epochs=10)
+
+`Model`s in GraphNeT are PyTorch modules and fully compatible with PyTorch-Lightning.
+You can therefore choose to write your own custom training loops if needed, or use the regular PyTorch-Lightning training functionality.
+The snippet above is equivalent to:
+
+.. code-block:: python
+
+ from pytorch_lightning import Trainer
+
+ from graphnet.training.callbacks import ProgressBar
+
+ model = Model(...)
+ train_dataloader = DataLoader(...)
+
+ # Configure Trainer
+ trainer = Trainer(
+ gpus=None,
+ max_epochs=10,
+ callbacks=[ProgressBar()],
+ log_every_n_steps=1,
+ logger=None,
+ strategy="ddp",
+ )
+
+ # Train model
+ trainer.fit(model, train_dataloader)
+
+
+Experiment Tracking
+--------------------
+
+You can track your experiment using `Weights & Biases `_ by passing the `WandbLogger` to :py:func:`~graphnet.models.model.Model.fit`:
+
+.. code-block:: python
+
+ import os
+
+ from pytorch_lightning.loggers import WandbLogger
+
+ # Create wandb directory
+ wandb_dir = "./wandb/"
+ os.makedirs(wandb_dir, exist_ok=True)
+
+ # Initialise Weights & Biases (W&B) run
+ wandb_logger = WandbLogger(
+ project="example-script",
+ entity="graphnet-team",
+ save_dir=wandb_dir,
+ log_model=True,
+ )
+
+ # Fit Model
+ model = Model(...)
+ model.fit(
+ ...,
+ logger=wandb_logger,
+ )
+
+By using `WandbLogger`, your training and validation loss is logged and you have the full functionality of Weights & Biases available.
+This means, e.g., that you can log your :py:class:`~graphnet.utilities.config.model_config.ModelConfig`, :py:class:`~graphnet.utilities.config.dataset_config.DatasetConfig`, and :py:class:`~graphnet.utilities.config.training_config.TrainingConfig` as:
+
+.. code-block:: python
+
+ wandb_logger.experiment.config.update(training_config)
+ wandb_logger.experiment.config.update(model_config.as_dict())
+ wandb_logger.experiment.config.update(dataset_config.as_dict())
+
+Using an experiment tracking system like Weights & Biases to track training metrics as well as artifacts like configuration files greatly improves reproducibility, experiment transparency, and collaboration.
+This is because you can easily recreate an previous run from the saved artifacts, you can directly compare runs with diffierent model configurations and hyperparameter choices, and share and compare your results to other people on your team.
+Therefore, we strongly recommend using Weights & Biases or a similar system when training and optimising models meant for actual physics use.
+
+
+Saving, loading, and checkpointing `Model`s
+--------------------------------------------
+
+There are several methods for saving models in GraphNeT and each comes with its own pros and cons.
+
+Using `Model.save`
+~~~~~~~~~~~~~~~~~~
+
+You can pickle your entire model (including the `state_dict`) by calling the :py:meth:`~graphnet.models.model.Model.save` method:
+
+.. code-block:: python
+
+ model.save("model.pth")
+
+You can then load this model by calling :py:meth:`~graphnet.models.model.Model.load` classmethod:
+
+.. code-block:: python
+
+ from graphnet.models import Model
+
+ loaded_model = Model.load("model.pth")
+
+This method is rather convenient as it lets you store everything in a single file but it comes with a big caveat: **it's not version-proof**.
+That is, if you share a pickled model with a user who runs a different version of GraphNeT than what was used to train the model, you might experience compatibility issues.
+This is due to how pickle serialises `Model` objects.
+
+
+Using `ModelConfig` and `state_dict`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+You can summarise your `Model` components and their configurations by exporting it to a `.yml` file.
+This only captures the _definition_ of the model, not any trained weights, but by saving the `state_dict` too, you have effectively saved the entire model, both definition and weights.
+You can do so by:
+
+.. code-block:: python
+
+ model.save_config('model.yml')
+ model.save_state_dict('state_dict.pth')
+
+You can then reconstruct your model again by building the model from the `ModelConfig` file and loading in the `state_dict`:
+
+.. code-block:: python
+
+ from graphnet.models import Model
+ from graphnet.utilities.config import ModelConfig
+
+ model_config = ModelConfig.load("model.yml")
+ model = Model.from_config(model_config) # With randomly initialised weights.
+ model.load_state_dict("state_dict.pth") # Now with trained weight.
+
+This method is less prone to version incompatibility issues, such as those mentioned above, and is therefore our recommended way of storing and porting `Model`s.
+
+
+Using checkpoints
+~~~~~~~~~~~~~~~~~~
+
+Because `Model`s in GraphNeT inherit from are also PyTorch-Lightning's `LightningModule`, you have the option to use the `load_from_checkpoint` method:
+
+.. code-block:: python
+
+ model_config = ModelConfig.load("model.yml")
+ model = Model.from_config(model_config) # With randomly initialised weights.
+ model.load_from_checkpoint("checkpoint.ckpt") # Now with trained weight.
+
+You can find more information on checkpointing `here `_.
+
+Example: Energy Reconstruction using `ModelConfig`
+--------------------------------------------------
+
+Below is a minimal example for training a GNN in GraphNeT for energy reconstruction on the tiny data sample using configuration files:
+
+.. code-block:: python
+
+ # Import(s)
+ import os
+
+ from graphnet.constants import CONFIG_DIR # Local path to graphnet/configs
+ from graphnet.data.dataloader import DataLoader
+ from graphnet.models import Model
+ from graphnet.utilities.config import DatasetConfig, ModelConfig
+
+ # Configuration
+ dataset_config_path = f"{CONFIG_DIR}/datasets/training_example_data_sqlite.yml"
+ model_config_path = f"{CONFIG_DIR}/models/example_energy_reconstruction_model.yml"
+
+ # Build model
+ model_config = ModelConfig.load(model_config_path)
+ model = Model.from_config(model_config, trust=True)
+
+ # Construct dataloaders
+ dataset_config = DatasetConfig.load(dataset_config_path)
+ dataloaders = DataLoader.from_dataset_config(
+ dataset_config,
+ batch_size=16,
+ num_workers=1,
+ )
+
+ # Train model
+ model.fit(
+ dataloaders["train"],
+ dataloaders["validation"],
+ gpus=[0],
+ max_epochs=5,
+ )
+
+ # Predict on test set and return as pandas.DataFrame
+ results = model.predict_as_dataframe(
+ dataloaders["test"],
+ additional_attributes=model.target_labels + ["event_no"],
+ )
+
+ # Save predictions and model to file
+ outdir = "tutorial_output"
+ os.makedirs(outdir, exist_ok=True)
+ results.to_csv(f"{outdir}/results.csv")
+ model.save_state_dict(f"{outdir}/state_dict.pth")
+ model.save(f"{outdir}/model.pth")
+
+Because `ModelConfig` summarises a `Model` completely, including its `Task`(s), the only modifications required to change the example to reconstruct (or classify) a different attribute than energy, is to pass a `ModelConfig` that defines a model with the corresponding `Task`.
+Similarly, if you wanted to train on a different `Dataset`, you would just have to pass a `DatasetConfig` that defines *that* `Dataset` instead.
+
+
+Deploying `Model`s in physics analyses
+---------------------------------------
+
+**@TODO**
+
+
+Utilities
+---------
+
+The `Logger` class
+~~~~~~~~~~~~~~~~~~
+
+GraphNeT will automatically log prompts to the terminal from your training run (and in other instances) and write it to `logs` in the directory of your script (by default).
+You can add your own custom messages to the :class:`~graphnet.utilities.logging.Logger` by:
+
+.. code-block:: python
+
+ from graphnet.utilities.logging import Logger
+
+ logger = Logger()
+
+ logger.info("My very informative message")
+ logger.warning("My warning shown every time")
+ logger.warning_once("My warning shown once")
+ logger.debug("My debug call")
+ logger.error("My error")
+ logger.critical("My critical call")
+
+Similarly, every class inheriting from `Logger` can use the same methods as, e.g., `self.info("...")`.
+
+Appendix
+--------
+
+A. Interfacing your data with GraphNeT
+---------------------------------------
+
+GraphNeT currently supports two data formats — Parquet and SQLite — and you must therefore provide your data in either of these formats for training a `Model`.
+This is done using the `DataConverter` class.
+Performing this conversion into one of the two supported formats can be a somewhat time-consuming task, but it is only done once, and then you are free to perform all of the training and optimization you want.
+
+In addition, GraphNeT expects your data to contain at least:
+
+- `pulsemap`: A per-hit table, containing a series of sensor measurements that represents the detector response to some interaction in some time window, as described in [Section 4 - The `Dataset` and `DataLoader` classes](#4-the-dataset-and-dataloader-classes).
+- `truth_table`: A per-event table, containing the global truth of each event, as described in [Section 4 - The `Dataset` and `DataLoader` classes](#4-the-dataset-and-dataloader-classes).
+- (Optional) `node_truth_table`: A per-hit truth array, containing truth labels for each node in your graph. This could be labels indicating whether each reconstructed pulse/photon was a result of noise in the event, or a label indicating which particle in the simulation tree caused a specific pulse/photon. These are the node-level quantities that could be classification/reconstructing targets for certain physics/learning tasks.
+- `index_column`: A unique ID that maps each row in `pulsemap` with its corresponding row in `truth_table` and/or `node_truth_table`.
+
+Since `pulsemap`, `truth_table`, and `node_truth_table` are named fields in your Parquet files (or tables in SQLite) you may name these fields however you like.
+You can also freely name your `index_column`. For instance, the `truth_table` could be called `"mc_truth"` and the `index_column` could be called `"event_no"`, see the snippets above.
+However, the following constraints exist:
+
+1. The naming of fields/columns within `pulsemap`, `truth_table`, and `node_truth_table` must be unique. For instance, the _x_-coordinate of the PMTs and the _x_-coordinate of interaction vertex may not both be called *pos_x*.
+2. No field/column in `pulsemap`, `truth_table`, or `node_truth_table` may be called `x`, `features`, `edge_attr`, or `edge_index`, as this leads to naming conflicts with attributes of [`torch_geometric.data.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data).
+
+B. Converting your data to a supported format
+----------------------------------------------
+
+**@TODO**
+
+C. Basics for SQLite databases in GraphNeT
+------------------------------------------
+
+In SQLite databases, `pulsemap`, `truth_table`, and optionally `node_truth_table` exist as separate tables.
+Each table has a column `index_column` on which the tables are indexed, in addition to the data that it contains.
+The schemas are:
+
+- `truth_table`: The `index_column` is set to `INTEGER PRIMARY KEY NOT NULL` and other columns are `NOT NULL`.
+- `pulsemap` and `node_truth_table`: All columns are set to `NOT NULL` but a non-unique index is created on the table(s) using `index_column`. This is important for query times.
+
+Below is a snippet that extracts all the contents of `pulsemap` and `truth_table` for the event with `index_column == 120`:
+
+.. code-block:: python
+
+ import pandas as pd
+ import sqlite3
+
+ database = "data/examples/sqlite/prometheus/prometheus-events.db"
+ pulsemap = "total"
+ index_column = "event_no"
+ truth_table = "mc_truth"
+ event_no = 120
+
+ with sqlite3.connect(database) as conn:
+ query = f"SELECT * from {pulsemap} WHERE {index_column} == {event_no}"
+ detector_response = pd.read_sql(query, conn)
+
+ query = f"SELECT * from {truth_table} WHERE {index_column} == {event_no}"
+ truth = pd.read_sql(query, conn)
diff --git a/_sources/index.rst.txt b/_sources/index.rst.txt
index da9a39a0e..574357040 100644
--- a/_sources/index.rst.txt
+++ b/_sources/index.rst.txt
@@ -1,10 +1,18 @@
-.. include:: about.md
- :parser: myst_parser.sphinx_
+
+.. include:: substitutions.rst
+
+.. include:: intro/intro.rst
.. toctree::
- :maxdepth: 1
+ :maxdepth: 2
:hidden:
- install
- contribute
- api/graphnet.rst
\ No newline at end of file
+ installation/install.rst
+ models/models.rst
+ datasets/datasets.rst
+ data_conversion/data_conversion.rst
+ integration/integration.rst
+ contribute/contribute.rst
+ api/graphnet.rst
+
+
\ No newline at end of file
diff --git a/_sources/install.md.txt b/_sources/install.md.txt
deleted file mode 100644
index e4f95cd7b..000000000
--- a/_sources/install.md.txt
+++ /dev/null
@@ -1,82 +0,0 @@
-# Install
-
-We recommend installing `graphnet` in a separate environment, e.g. using a Python virtual environment or Anaconda (see details on installation [here](https://www.anaconda.com/products/individual)). Below we prove installation instructions for different setups.
-
-## Installing with IceTray
-
-You may want `graphnet` to be able to interface with IceTray, e.g., when converting I3 files to an intermediate file format for training GNN models (e.g., SQLite or parquet),[^1] or when running GNN inference as part of an IceTray chain. In these cases, you need to install `graphnet` in a Python runtime that has IceTray installed.
-
-To achieve this, we recommend running the following commands in a clean bash shell:
-```bash
-$ eval `/cvmfs/icecube.opensciencegrid.org/py3-v4.2.1/setup.sh`
-$ /cvmfs/icecube.opensciencegrid.org/py3-v4.2.1/RHEL_7_x86_64/metaprojects/icetray/v1.5.1/env-shell.sh
-```
-Optionally, you can alias these commands or save them as a bash script for convenience, as you will have to run these commands every time you want to use IceTray (with `graphnet`) in a clean shell.
-
-With the IceTray environment active, you can now install `graphnet`, either at a user level or in a Python virtual environment. You can either install a light-weight version of `graphnet` without the `torch` extras, i.e., without the machine learning packages (pytorch and pytorch-geometric); this is useful when you just want to convert data from I3 files to, e.g., SQLite, and won't be running inference on I3 files later on. In this case, you don't need to specify a requirements file. If you want torch, you do.
-
-
-Install without torch
-
-```bash
-$ pip install --user -e .[develop] # Without torch, i.e. only for file conversion
-```
-
-
-
-
-Install with torch
-
-```bash
-$ pip install --user -r requirements/torch_cpu.txt -e .[develop,torch] # CPU-only torch
-$ pip install --user -r requirements/torch_gpu.txt -e .[develop,torch] # GPU support
-```
-
-
-
-This should allow you to run the I3 conversion scripts in [examples/](./examples/) with your preferred I3 files.
-
-## Installing stand-alone
-
-If you don't need to interface with [IceTray](https://github.com/icecube/icetray/) (e.g., for reading data from I3 files or running inference on these), the following commands should provide a fast way to get up and running on most UNIX systems:
-```bash
-$ git clone git@github.com:/graphnet.git
-$ cd graphnet
-$ conda create --name graphnet python=3.8 gcc_linux-64 gxx_linux-64 libgcc cudatoolkit=11.5 -c conda-forge -y # Optional
-$ conda activate graphnet # Optional
-(graphnet) $ pip install -r requirements/torch_cpu.txt -e .[develop,torch] # CPU-only torch
-(graphnet) $ pip install -r requirements/torch_gpu.txt -e .[develop,torch] # GPU support
-```
-This should allow you to e.g. run the scripts in [examples/](./examples/) out of the box.
-
-A stand-alone installation requires specifying a supported Python version (see above), ensuring that the C++ compilers (gcc) are up to date, and possibly installing the CUDA Toolkit. Here, we have installed recent C++ compilers using conda (`gcc_linux-64 gxx_linux-64 libgcc`), but if your system already has a recent version (`$gcc --version` should be > 5, at least) you should be able to omit these from the setup.
-If you install the CUDA Toolkit and/or newer compilers using the above command, you should add **one of**:
-```bash
-$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/anaconda3/lib/
-$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/miniconda3/lib/
-$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/anaconda3/envs/graphnet/lib/
-$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/miniconda3/envs/graphnet/lib/
-```
-depending on your setup to your `.bashrc` script or similar to make sure that the corresponding library files are accessible. Check which one of the above paths contains the `.so`-files you're looking to use, and add that path.
-
-## Running in Docker
-
-If you want to run GraphNeT (with IceTray), and don't intend to contribute to the package, consider using the provided [Docker image](https://hub.docker.com/repository/docker/asogaard/graphnet). With Docker, you can then run GraphNeT as:
-```bash
-$ docker run --rm -it asogaard/graphnet:latest
-🐳 graphnet@dc423315742c ❯ ~/graphnet $ python examples/01_icetray/01_convert_i3_files.py sqlite icecube-upgrade
-graphnet: INFO 2023-01-24 13:41:27 - Logger.__init__ - Writing log to logs/graphnet_20230124-134127.log
-(...)
-graphnet: INFO 2023-01-24 13:41:46 - SQLiteDataConverter.info - Saving results to /root/graphnet/data/examples/outputs/convert_i3_files/ic86
-graphnet: INFO 2023-01-24 13:41:46 - SQLiteDataConverter.info - Processing 1 I3 file(s) in main thread (not multiprocessing)
-100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:39<00:00, 39.79s/file(s)]
-graphnet: INFO 2023-01-24 13:42:26 - SQLiteDataConverter.info - Merging files output by current instance.
-graphnet: INFO 2023-01-24 13:42:26 - SQLiteDataConverter.info - Merging 1 database files
-100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 413.88it/s]
-```
-This should allow you to run all provided examples (excluding the specialised ones requiring [PISA](https://github.com/icecube/pisa)) out of the box, and to start working on your own analysis scripts.
-
-You can use any of the following Docker image tags:
-* `main`: Image corresponding to the latest push to the `main` branch.
-* `latest`: Image corresponding to the latest named tagged version of `graphnet`.
-* `vX.Y.Z`: Image corresponding to the specific named tagged version of `graphnet`.
\ No newline at end of file
diff --git a/_sources/installation/install.rst.txt b/_sources/installation/install.rst.txt
new file mode 100644
index 000000000..e21000426
--- /dev/null
+++ b/_sources/installation/install.rst.txt
@@ -0,0 +1,46 @@
+.. include:: ../substitutions.rst
+
+Installation
+============
+
+|graphnet|\ GraphNeT is available for Python 3.8 to Python 3.11.
+
+.. note::
+ We recommend installing |graphnet|\ GraphNeT in a separate environment, e.g. using a Python virtual environment or Anaconda (see details on installation `here `_).
+
+Quick Start
+-----------
+
+.. raw:: html
+ :file: quick-start.html
+
+
+When installation is completed, you should be able to run `the examples `_.
+
+Installation in CVMFS (IceCube)
+-------------------------------
+
+You may want |graphnet|\ GraphNeT to be able to interface with IceTray, e.g., when converting I3 files to a deep learning friendly file format, or when deploying models as part of an IceTray chain. In these cases, you need to install |graphnet|\ GraphNeT in a Python runtime that has IceTray installed.
+
+To achieve this, we recommend installing |graphnet|\ GraphNeT into a CVMFS with IceTray installed, like so:
+
+.. code-block:: bash
+
+ # Download GraphNeT
+ git clone https://github.com/graphnet-team/graphnet.git
+ cd graphnet
+ # Open your favorite CVMFS distribution
+ eval `/cvmfs/icecube.opensciencegrid.org/py3-v4.2.1/setup.sh`
+ /cvmfs/icecube.opensciencegrid.org/py3-v4.2.1/RHEL_7_x86_64/metaprojects/icetray/v1.5.1/env-shell.sh
+ # Update central utils
+ pip install --upgrade pip>=20
+ pip install wheel setuptools==59.5.0
+ # Install graphnet into the CVMFS as a user
+ pip install --user -r requirements/torch_cpu.txt -e .[torch,develop]
+
+
+Once installed, |graphnet|\ GraphNeT is available whenever you open the CVMFS locally.
+
+.. note::
+ We recommend installing |graphnet|\ GraphNeT without GPU in clean metaprojects.
+
diff --git a/_sources/integration/integration.rst.txt b/_sources/integration/integration.rst.txt
new file mode 100644
index 000000000..bbde4640c
--- /dev/null
+++ b/_sources/integration/integration.rst.txt
@@ -0,0 +1,261 @@
+.. include:: ../substitutions.rst
+
+Integrating New Experiments into GraphNeT\ |graphnet-header|
+============================================================
+
+GraphNeT is built to host data conversion, model and deployment code from different neutrino telescopes and related experiments.
+Part of the design is to minimize the technical overhead of implementing support for an experiment, and can typically be done with 200 - 300 lines of code.
+
+.. note::
+ The GraphNeT development team is willing to support the integration efforts of new experiments, so please consider reaching out to us if you need help.
+
+A general outline of the steps to integrate an experiment into GraphNeT is outlined below.
+
+**1) Adding Support for Your Data**
+----------------------------
+
+The most critical element of implementing support for an experiment into graphnet is an interface to data from your experiment.
+This can be done in two ways:
+
+- **a)** Adding a :code:`Dataset` class that are able to read your data directly during training
+- **b)** Adding a :code:`GraphNeTFileReader` and associated :code:`Extractors` to convert your data to a supported data format.
+
+Option **a)** is only viable if the data format from your experiment is suitable for deep learning.
+
+Option **b)** requires adding your own reader and defining extractors. Below is a step-by-step example
+
+
+
+
+Writing your own :code:`Extractor` and :code:`GraphNeTFileReader`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. note::
+ We recommend visiting the :code:`DataConverter` documentation before reading this example
+
+Suppose you want to add a reader and extractor for MyExperiment and that the data
+is stored as pickled dictionaries with two kinds of entries:
+
+**hits**: a single pandas.DataFrame with dimensions [n_pulses,d] where `n_pulses` is the number of observed Cherenkov pulses across all events in the file, and `d` is the number of features we know about each measurement.
+These features would normally include the sensor position, time of measurement, event id etc.
+
+**truth**: a single pandas.DataFrame with dimensions [n_events, t] where `n_events` denotes the total number of events in the file and `t` is the number of event-level truth variables we have available for this simulation.
+Ordinarily, these truth variables would include particle id, energy, direction etc.
+
+
+To convert these pickle files to a supported backend in GraphNeT, we first have to define our reader. This reader should open a pickle file and apply the :code:`Extractor` that we must also implement. Lets start with the reader:
+
+.. code-block:: python
+
+ from typing import List, Union, Dict
+ import pandas as pd
+ import pickle
+
+ # Import the generic file reader
+ from .graphnet_file_reader import GraphNeTFileReader
+
+ # Import your own extractor
+ from graphnet.data.extractors.myexperiment import MyExtractor
+
+ class MyReader(GraphNeTFileReader):
+ """A class for reading my pickle files from MyExperiment."""
+
+ _accepted_file_extensions = [".pickle"]
+ _accepted_extractors = [MyExperiment]
+
+ def __call__(self, file_path: str) -> Dict[str, pd.DataFrame]:
+ """Extract data from single pickle file.
+
+ Args:
+ file_path: Path to pickle file.
+
+ Returns:
+ Extracted data.
+ """
+
+ # Open file
+ file = open(file_path,'r')
+ data = pickle.load(file)
+
+ # Apply extractors
+ outputs = {}
+ for extractor in self._extractors:
+ output = extractor(data)
+ if output is not None:
+ outputs[extractor._extractor_name] = output
+ return outputs
+
+When the :code:`DataConverter` is instantiated, it will set the :code:`Extractors` that it was instantiated with as member
+variables of our :code:`GraphNeTFileReader`, making them available to us under :code:`self._extractors`. When the conversion is running,
+the :code:`DataConverter` will pass a :code:`file_path` to our :code:`__call__` function, and it is the job of our reader to open this
+file and apply extractors to it. These calls will happen in parallel automatically.
+
+So - the reader above first opens the `.pickle` file, and then applies the extractors. Job done! Let's now define the extractor:
+
+
+The purpose of an :code:`Extractor` is to extract only part of the information available in files. When an :code:`Extractor` is instantiated, it is given a name:
+
+.. code-block:: python
+
+ extractor = MyExtractor(extractor_name = "hits")
+
+and this name is used to select specific tables in the file.
+
+.. code-block:: python
+
+ from typing import Dict
+
+ from graphnet.data.extractors import Extractor
+
+ class MyExtractor(Extractor):
+ """
+ Class for extracting information from pickles files in MyExperiment
+ """
+
+ def __call__(self, dictionary: Dict[str, pd.DataFrame]) -> pd.DataFrame:
+ """Extract information from pickle file."""
+
+ # Check if the table is in the dict
+ if self._extractor_name in dictionary.keys():
+ return dictionary[self._extractor_name]
+ else:
+ return None
+
+We defined our reader in such a way that our extractor recieves a :code:`dictionary: Dict[str, pd.DataFrame]` argument. Our extractor therefore only have to find the field it needs to extract and return it.
+
+
+With both a reader and extractor defined, we're ready to convert data to a supported backend in GraphNeT! Below is an example of using the code above in conversion:
+
+.. code-block::
+
+ from graphnet.data.extractors.myexperiment import MyExtractor
+ from graphnet.data.dataconverter import DataConverter
+ from graphnet.data.readers import MyReader
+ from graphnet.data.writers import ParquetWriter
+
+ # Your settings
+ dir_with_files = '/home/my_files'
+ outdir = '/home/my_outdir'
+ num_workers = 5
+
+ # Instantiate DataConverter - exports data from MyExperiment to Parquet
+ converter = DataConverter(file_reader = MyReader(),
+ save_method = ParquetWriter(),
+ extractors=[MyExtractor('hits'), MyExtractor('truth')],
+ outdir=outdir,
+ num_workers=num_workers,
+ )
+ # Run Converter
+ converter(input_dir = dir_with_files)
+ # Merge files (Optional)
+ converter.merge_files()
+
+
+
+
+
+**2) Implementing a Detector Class**
+-----------------------------
+
+GraphNeT requires a :code:`Detector` class to represent details that are specific to your experiment.
+
+a :code:`Detector` holds a geometry table, standardization functions that maps your raw data into a numerical range suitable for deep learning, and names of important columns in your data.
+
+Below is an example of a :code:`Detector` class:
+
+.. code-block:: python
+
+ from graphnet.models.detector import Detector
+
+ class MyDetector(Detector):
+ """`Detector` class for my experiment."""
+
+ geometry_table_path = "path_to_geometry_table.parquet"
+ xyz = ["sensor_x", "sensor_y", "sensor_z"]
+ string_id_column = "string_id"
+ sensor_id_column = "sensor_id"
+
+ def feature_map(self) -> Dict[str, Callable]:
+ """Map standardization functions to each dimension of input data."""
+ feature_map = {
+ "sensor_x": self._sensor_xyz,
+ "sensor_y": self._sensor_xyz,
+ "sensor_z": self._sensor_xyz,
+ "sensor_time": self._sensor_time,
+ }
+ return feature_map
+
+ def _sensor_xyz(self, x: torch.tensor) -> torch.tensor:
+ return x / 500.0
+
+ def _sensor_time(self, x: torch.tensor) -> torch.tensor:
+ return (x - 1.0e04) / 3.0e4
+
+:code:`feature_map` is a function that maps a standardization function to each possible feature from your experiment.
+The class variable :code:`xyz` contains the names of the xyz-position of sensors in your detector.
+:code:`string_id_column` and :code:`sensor_id_column` holds the names of the columns in your input data that contain the string and sensor ids, respectively.
+
+Lastly, :code:`geometry_table_path` points to a file that you should add to :code:`graphnet/data/geometry_tables/name-of-your-experiment/name-of-detector.parquet`.
+A geometry table is an array containing all sensors in your experiment and has dimensions [n, d] where `n` denotes the number of sensors in your detector and `d` is the number of available features.
+
+Suppose the detector represented by the Detector class above had 5 sensors in total on one string, then the corresponding geometry table would be:
+
+.. list-table:: Example of geometry table before applying multi-index
+ :widths: 20 20 20 20 20 20
+ :header-rows: 1
+
+ * - sensor_x
+ - sensor_y
+ - sensor_z
+ - sensor_time
+ - string_id
+ - sensor_id
+ * - 10
+ - 10
+ - 10
+ - 1
+ - 0
+ - 0
+ * - 20
+ - 20
+ - 20
+ - 1
+ - 0
+ - 1
+ * - 30
+ - 30
+ - 30
+ - 1
+ - 0
+ - 2
+ * - 40
+ - 40
+ - 40
+ - 1
+ - 0
+ - 3
+ * - 50
+ - 50
+ - 50
+ - 1
+ - 0
+ - 4
+
+Here, every row represents a unique sensor identified by :code:`sensor_id`.
+GraphNeT will use this id to add/remove/filter sensors from your training examples, if you specify so in your data representations.
+
+To convert the table above into a geometry table, you must set a :code:`MultiIndex` on the xyz position variables, and save it as :code:`.parquet`:
+
+.. code-block:: python
+
+ import pandas as pd
+
+ path_to_array = 'my_table.csv'
+
+ table_without_index = pd.read_csv(path_to_array)
+ geometry_table = table_without_index.set_index(['sensor_x','sensor_y','sensor_z'])
+ geometry_table.to_parquet('my_geometry_table.parquet')
+
+here :code:`'my_table.csv'` refers to the table above, and the resulting :code:`'my_geometry_table.parquet'` would be the file to include under :code:`graphnet/data/geometry_tables/name-of-your-experiment/`
+
+
diff --git a/_sources/intro/intro.rst.txt b/_sources/intro/intro.rst.txt
new file mode 100644
index 000000000..52d8701d4
--- /dev/null
+++ b/_sources/intro/intro.rst.txt
@@ -0,0 +1,40 @@
+.. |graphnet| image:: ../../assets/identity/favicon.svg
+ :width: 25px
+ :height: 25px
+ :alt: graphnet
+ :align: bottom
+
+.. |graphnet-header| image:: ../../assets/identity/favicon.svg
+ :width: 50px
+ :height: 50px
+ :alt: graphnet
+ :align: bottom
+
+GraphNeT\ |graphnet-header|
+########
+
+|graphnet|\ GraphNeT is an open-source Python framework aimed at providing high quality, user friendly, end-to-end functionality to perform reconstruction tasks at neutrino telescopes using deep learning. |graphnet|\ GraphNeT makes it fast and easy to train complex models that can provide event reconstruction with state-of-the-art performance, for arbitrary detector configurations, with inference times that are orders of magnitude faster than traditional reconstruction techniques.
+|graphnet|\ GraphNeT provides a common, detector agnostic framework for ML developers and physicists that wish to use the state-of-the-art tools in their research. By uniting both user groups, |graphnet|\ GraphNeT aims to increase the longevity and usability of individual code contributions from ML developers by building a general, reusable software package based on software engineering best practices, and lowers the technical threshold for physicists that wish to use the most performant tools for their scientific problems.
+
+
+|graphnet|\ GraphNeT comprises a number of modules providing the necessary tools to build workflows. These workflows range from ingesting raw training data in domain-specific formats to deploying trained models in domain-specific reconstruction chains, as illustrated in the flowchart below.
+
+.. figure:: ../../paper/flowchart.png
+
+ High-level overview of a typical workflow using |graphnet|\ GraphNeT: :code:`graphnet.data` enables converting domain-specific data to industry-standard, intermediate file formats and reading this data; :code:`graphnet.models` allows for configuring and building complex models using simple, physics-oriented components; :code:`graphnet.training` manages model training and experiment logging; and finally, :code:`graphnet.deployment` allows for using trained models for inference in domain-specific reconstruction chains.
+
+:code:`graphnet.models` provides modular components subclassing :code:`torch.nn.Module`, meaning that users only need to import a few existing, purpose-built components and chain them together to form a complete model. ML developers can contribute to |graphnet|\ GraphNeT by extending this suite of model components — through new layer types, physics tasks, graph connectivities, etc. — and experiment with optimising these for different reconstruction tasks using experiment tracking.
+
+These models are trained using :code:`graphnet.training` on data prepared using :code:`graphnet.data`, to satisfy the high I/O loads required when training ML models on large batches of events, which domain-specific neutrino physics data formats typically do not allow.
+
+Trained models are deployed to a domain-specific reconstruction chain, yielding predictions, using the components in :code:`graphnet.deployment`. This can either be through model files or container images, making deployment as portable and dependency-free as possible.
+
+By splitting up the model development as in the flowchart, |graphnet|\ GraphNeT allows physics users to interface only with high-level building blocks or pre-trained models that can be used directly in their reconstruction chains, while allowing ML developers to continuously improve and expand the framework’s capabilities.
+
+
+.. image:: ../../assets/images/eu-emblem.jpg
+ :width: 150
+
+This project has received funding from the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement No. 890778.
+
+The work of Rasmus Ørsøe was partly performed in the framework of the PUNCH4NFDI consortium supported by DFG fund "NFDI 39/1", Germany.
diff --git a/_sources/models/models.rst.txt b/_sources/models/models.rst.txt
new file mode 100644
index 000000000..4ec16ac76
--- /dev/null
+++ b/_sources/models/models.rst.txt
@@ -0,0 +1,515 @@
+.. include:: ../substitutions.rst
+
+Models In GraphNeT\ |graphnet-header|
+=====================================
+
+Three ideals form the philosophy behind `Model `_\ s in GraphNeT:
+
+:Self-containment: Functionality that a specific model requires (data pre-processing, transformation and other auxiliary calculations) should exist within the :code:`Model` itself such that it is portable and deployable as a single package that only depends on data.
+ I.e
+ Data → :code:`Model` → Predictions
+
+:Summarizeable: Trained `Model `_\ s should be fully summarizeable to configuration files, allowing you to easily distribute the results of your experimentation to other |graphnet|\ GraphNeT users.
+
+:Reuseable: It should be easy and intuitive to repurpose existing `Model `_\ s to tackle new problems or work on new physics experiments.
+
+To help developers adhere to these ideals, |graphnet|\ GraphNeT provides structure and functionality through class inheritance of :code:`Model`.
+
+.. note::
+ A `Model `_ in |graphnet|\ GraphNeT is a :code:`LightningModule` configured to recieve as input :code:`torch_geometric.data.Data` objects.
+ If you're unfamiliar with these terms, we recommend that you consult the `Lightning `_ and `PyG `_ documentation for details.
+
+
+The :code:`Model` class
+-----------------
+
+:code:`Model` is the generic base class for `Model `_\ s in GraphNeT, and forms the basis from which all models are expected to originate.
+It comes with very few restrictions and you can therefore implement nearly any deep learning technique using |graphnet|\ GraphNeT by subclassing :code:`Model`.
+
+.. image:: ../../../assets/images/model.svg
+ :width: 250
+ :align: center
+ :class: with-shadow
+
+Any subclass of :code:`Model` will inherit the methods :code:`Model.to_config` and :code:`Model.from_config`, which allows the model to be exported as :code:`.yml` files and reloaded in a different session. E.g.
+
+.. code-block:: python
+
+ model = Model(...)
+ model.save_config("model.yml")
+
+You can then reconstruct the same model architecture from the :code:`.yml` file:
+
+.. code-block:: python
+
+ from graphnet.models import Model
+
+ # Indicate that you `trust` the config file after inspecting it, to allow for
+ # dynamically loading classes references in the file.
+ model = Model.from_config("model.yml", trust=True)
+
+
+.. raw:: html
+
+
+ Example of ModelConfig
+
+You can find several pre-defined :code:`ModelConfig`\ s under :code:`graphnet/configs/models`. Below are the contents of :code:`example_energy_reconstruction_model.yml`:
+
+.. code-block:: yaml
+
+ arguments:
+ architecture:
+ ModelConfig:
+ arguments:
+ add_global_variables_after_pooling: false
+ dynedge_layer_sizes: null
+ features_subset: null
+ global_pooling_schemes: [min, max, mean, sum]
+ nb_inputs: 4
+ nb_neighbours: 8
+ post_processing_layer_sizes: null
+ readout_layer_sizes: null
+ class_name: DynEdge
+ graph_definition:
+ ModelConfig:
+ arguments:
+ columns: [0, 1, 2]
+ detector:
+ ModelConfig:
+ arguments: {}
+ class_name: Prometheus
+ dtype: null
+ nb_nearest_neighbours: 8
+ node_definition:
+ ModelConfig:
+ arguments: {}
+ class_name: NodesAsPulses
+ node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+ class_name: KNNGraph
+ optimizer_class: '!class torch.optim.adam Adam'
+ optimizer_kwargs: {eps: 0.001, lr: 0.001}
+ scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR'
+ scheduler_config: {interval: step}
+ scheduler_kwargs:
+ factors: [0.01, 1, 0.01]
+ milestones: [0, 20.0, 80]
+ tasks:
+ - ModelConfig:
+ arguments:
+ hidden_size: 128
+ loss_function:
+ ModelConfig:
+ arguments: {}
+ class_name: LogCoshLoss
+ loss_weight: null
+ prediction_labels: null
+ target_labels: total_energy
+ transform_inference: '!lambda x: torch.pow(10,x)'
+ transform_prediction_and_target: '!lambda x: torch.log10(x)'
+ transform_support: null
+ transform_target: null
+ class_name: EnergyReconstruction
+ class_name: StandardModel
+
+and thus the 2. ideal outlined for models is directly addressed by simply subclassing :code:`Model`. In addition, :code:`Model` comes with extra functionality for saving, loading and checkpointing:
+
+
+Saving, loading, and checkpointing :code:`Model`\ s
+~~~~~~~~~~~~~~~~~~
+
+There are several methods for saving models in GraphNeT and each comes with its own pros and cons.
+
+:code:`Model.save`
+~~~~~~~~~~~~~~~~~~
+
+You can pickle your entire model (including the :code:`state_dict`) by calling the :py:meth:`~graphnet.models.model.Model.save` method:
+
+.. code-block:: python
+
+ model.save("model.pth")
+
+You can then load this model by calling :py:meth:`~graphnet.models.model.Model.load` classmethod:
+
+.. code-block:: python
+
+ from graphnet.models import Model
+
+ loaded_model = Model.load("model.pth")
+
+.. warning::
+ This method is rather convenient as it lets you store everything in a single file but it comes with a big caveat: **it's not version-proof**.
+ That is, if you share a pickled model with a user who runs a different version of GraphNeT than what was used to train the model, you might experience compatibility issues.
+ This is due to how pickle serialises Python objects.
+
+
+:code:`ModelConfig` and :code:`state_dict`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+You can summarise your :code:`Model` components and their configurations by exporting it to a :code:`.yml` file.
+This only captures the `definition` of the model, not any trained weights, but by saving the :code:`state_dict` too, you have effectively saved the entire model, both definition and weights.
+You can do so by:
+
+.. code-block:: python
+
+ model.save_config('model.yml')
+ model.save_state_dict('state_dict.pth')
+
+You can then reconstruct your model again by building the model from the `ModelConfig` file and loading in the `state_dict`:
+
+.. code-block:: python
+
+ from graphnet.models import Model
+ from graphnet.utilities.config import ModelConfig
+
+ model_config = ModelConfig.load("model.yml")
+ model = Model.from_config(model_config) # With randomly initialised weights.
+ model.load_state_dict("state_dict.pth") # Now with trained weight.
+
+.. note::
+ This method is the recommended way of storing and sharing :code:`Model`\ s.
+
+
+Using checkpoints
+~~~~~~~~~~~~~~~~~~
+
+Because :code:`Model`\ s in GraphNeT are PyTorch-Lightning's :code:`LightningModule`\ s, you have the option to use the :code:`load_from_checkpoint` method:
+
+.. code-block:: python
+
+ model_config = ModelConfig.load("model.yml")
+ model = Model.from_config(model_config) # With randomly initialised weights.
+ model.load_from_checkpoint("checkpoint.ckpt") # Now with trained weight.
+
+You can find more information on checkpointing `here `_.
+
+
+The :code:`StandardModel` class
+-------------------------------
+The simplest way to define a :code:`Model` in GraphNeT is through the :code:`StandardModel` subclass, which provides additional functionality on top of :code:`Model`.
+
+The :code:`StandardModel` consists of a series of modules - a `GraphDefinition `_ , which defines the representation of the raw data, `Backbone `_, which defines the actual achitecture, and one or more `Task `_ which defines the problem(s) that the model needs to solve.
+
+.. image:: ../../../assets/images/standardmodel.svg
+ :width: 350
+ :align: right
+ :class: with-shadow
+
+This structure guarantees modularity and reuseability of models in |graphnet|\ GraphNeT, as these modules are interchangeable. The role of each of these model components is outlined below.
+
+For example, the only adaptation needed to run a :code:`StandardModel` made for IceCube on a different experiment — say, KM3NeT — would be to switch out the :code:`Detector` component in :code:`GraphDefinition`
+representing IceCube with one that represents KM3NeT. Similarly, a :code:`Model` developed for `EnergyReconstruction `_
+can be put to work on a different problem, e.g., `DirectionReconstructionWithKappa `_ ,
+by switching out just the `Task `_ component.
+
+
+components (such as instances of, e.g., the :code:`GraphDefinition`, :code:`Backbone`, and :code:`Task` classes) along with PyTorch and PyG functionality.
+All :code:`Model`\ s that are applicable to the same detector configuration, regardless of how the :code:`Model`\ s themselves are implemented, should be able to act on the same graph (:code:`torch_geometric.data.Data`) objects, thereby making them interchangeable and directly comparable.
+
+
+:code:`GraphDefinition`, :code:`backbone` & :code:`Task`
+~~~~~~~~~~
+These components are packaged in a particularly simple way in `StandardModel`, but they are not specific to it.
+That is, they can be used in any combination, and alongside more specialised PyTorch/PyG code.
+
+.. image:: ../../../assets/images/datarepresentation.svg
+ :width: 600
+ :align: right
+ :class: with-shadow
+
+
+:GraphDefinition: A :code:`GraphDefinition` in GraphNeT is a data representation class, that uniquely defines how the raw data is processed and presented to the model architecture. `graphs` are a flexibile data representation format, which allows raw neutrino telescope data to be presented as point cloud graphs, images, sequences and more, making GraphNeT compatible with nearly all deep learning paradigms.
+
+ A :code:`GraphDefinition` is itself composed of interchangeable modules, namely :code:`Detector`, :code:`NodeDefinition` and :code:`EdgeDefinition`.
+
+ :Detector: The :code:`Detector` class holds experiment-specific details such as sensor geometry and index column names and defines standardization functions for each variable in the raw data.
+
+ :NodeDefinition: :code:`NodeDefinition` defines what a node (i.e. a row and its columns) represents. In charge of transforming the collection of standardized Cherenkov pulses associated with a triggered event into a node/row representation of choice. It is the choice in this module that defines if nodes/rows represents single Cherenkov pulses, DOMs, entire strings or something completely different
+
+ :EdgeDefinition: The optional :code:`EdgeDefinition` defines how `edges` are drawn between nodes in a graph, which for graph neural networks can define how information may flow in the network. Methods not requiring edges, such as CNNs or transformers can simply omit this module in their :code:`GraphDefinition`.
+
+
+ .. note::
+
+ The modularity of :code:`GraphDefinition` ensures that the only adaptation needed to run a :code:`StandardModel` made for IceCube on a different experiment — say, KM3NeT — would be to switch out the :code:`Detector` component in :code:`GraphDefinition`
+ representing IceCube with one that represents KM3NeT.
+
+
+:Backbone: The :code:`Backbone` defines the actual model architecture that will be used to process the data representation, and its output is directly passed to :code:`Task`.
+ The model architecture could be based on CNNs, GNNs, transformers or any of the other established deep learning paradigms. :code:`Backbone` should be a subclass of :code:`Model`.
+
+:Task:
+
+ Different applications of deep learning in neutrino telescopes (i.e. the problems we want to solve using DL) are represented as individual detector agnostic :code:`Task`\ s.
+
+ A :code:`Task` fully defines the physics problem that the model is trained to solve, and is in charge of scaling/unscaling truth values and calculation of loss.
+ Multiple subclasses of :code:`Task` exists, most popular of which is :code:`StandardLearnedTask` that acts as a learnable prediction head that maps
+ the latent output of :code:`backbone` to the target value(s). Many instances of :code:`StandardLearnedTask` has been implemented in GraphNeT to tackle a wide range of supervised learning tasks, such as binary classification and energy reconstruction.
+
+ Below is an example of a :code:`StandardLearnedTask` that defines binary classification in general:
+
+ .. code-block:: python
+
+ from torch import Tensor
+ from graphnet.models.task import StandardLearnedTask
+
+ class BinaryClassificationTask(StandardLearnedTask):
+ """Performs binary classification."""
+
+ # Requires one feature, logit for being signal class.
+ nb_inputs = 1
+ default_target_labels = ["target"]
+ default_prediction_labels = ["target_pred"]
+
+ def _forward(self, x: Tensor) -> Tensor:
+ # transform probability of being muon
+ return torch.sigmoid(x)
+
+ The class variable **nb_inputs** specifies the dimensions that this specific :code:`Task` expects it's input :code:`x` to have. In the case of :code:`StandardLearnedTask`, a simple MLP is used to adjust the dimensions of the latent prediction from :code:`backbone` to **nb_inputs**.
+ :code:`_forward(self, x: Tensor)` , defines what the :code:`Task` does to the latent predictions.
+ In this task :code:`x` will be a [batch_size,1]-dimensional latent vector and the :code:`Task` simply returns the sigmoid of it.
+
+ As such, the code in under :code:`_forward(self, x: Tensor)` defines the last steps of the :code:`Model`.
+
+Instantiating a :code:`StandardModel`
+~~~~~~~~~~~~~~~~~~~~~~~
+
+GraphNeT comes with many pre-defined :code:`GraphDefinition`\ s, :code:`Backbone`\ s, and :code:`Task`\ s components that you can simply import and use out-of-the-box.
+So to get started, all you need to do is to import your choices in these components and build the model.
+Below is a snippet that defines a :code:`Model` that reconstructs the zenith angle with uncertainties using the `GNN published by IceCube `_ for ORCA150:
+
+.. code-block:: python
+
+ # Choice of graph representation, architecture, and physics task
+ from graphnet.models.detector.prometheus import ORCA150
+ from graphnet.models.graphs import KNNGraph
+ from graphnet.models.graphs.nodes import NodesAsPulses
+ from graphnet.models.gnn.dynedge import DynEdge
+ from graphnet.models.task.reconstruction import ZenithReconstructionWithKappa
+
+ # Choice of loss function and Model class
+ from graphnet.training.loss_functions import VonMisesFisher2DLoss
+ from graphnet.models import StandardModel
+
+ # Configuring the components
+
+ # Represents the data as a point-cloud graph where each
+ # node represents a pulse of Cherenkov radiation
+ # edges drawn to the 8 nearest neighbours
+
+ graph_definition = KNNGraph(
+ detector=ORCA150(),
+ node_definition=NodesAsPulses(),
+ nb_nearest_neighbours=8,
+ )
+ backbone = DynEdge(
+ nb_inputs=detector.nb_outputs,
+ global_pooling_schemes=["min", "max", "mean"],
+ )
+ task = ZenithReconstructionWithKappa(
+ hidden_size=backbone.nb_outputs,
+ target_labels="injection_zenith",
+ loss_function=VonMisesFisher2DLoss(),
+ )
+
+ # Construct the Model
+ model = StandardModel(
+ graph_definition=graph_definition,
+ backbone=backbone,
+ tasks=[task],
+ )
+
+The only change required to get this :code:`Model` to work on a different integrated experiment in GraphNeT,
+is to switch out the :code:`Detector`-component. Similarly, the model can be repurposed to solve a completely different problem,
+by switching out the :code:`Task` component.
+
+Training Syntax for :code:`StandardModel`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Models inheriting from :code:`StandardModel` in GraphNeT comes with a powerful in-built :code:`model.fit` method that reduces the training of models on neutrino telescopes to a syntax that is similar to that of :code:`sklearn`:
+
+.. code-block:: python
+
+ model = Model(...)
+ train_dataloader = DataLoader(...)
+ model.fit(train_dataloader=train_dataloader, max_epochs=10)
+
+:code:`model.fit` is built upon `torch_lightning.Trainer.fit `_ , and therefore accepts the same arguments,
+allowing GraphNeT users to train :code:`Model`\ s with the exact same functionality but with less boilerplate code.
+
+But because :code:`Model` s in GraphNeT are PyTorch modules and fully compatible with PyTorch-Lightning, you can also choose to write your own custom training loops if needed, or use the regular PyTorch-Lightning training functionality.
+The snippet above is equivalent to:
+
+.. code-block:: python
+
+ from torch_lightning import Trainer
+ from graphnet.training.callbacks import ProgressBar
+
+ model = Model(...)
+ train_dataloader = DataLoader(...)
+
+ # Configure Trainer
+ trainer = Trainer(
+ gpus=None,
+ max_epochs=10,
+ callbacks=[ProgressBar()],
+ log_every_n_steps=1,
+ logger=None,
+ strategy="ddp",
+ )
+
+ # Train model
+ trainer.fit(model, train_dataloader)
+
+
+
+Adding Your Own Model
+-------------------------------------
+
+Model architectures in GraphNeT are ordinary pytorch :code:`torch.nn.Module`\ 's that inherits from the generic :code:`graphnet.models.Model` class,
+and they are configured to receive :code:`torch_geometric.data.Data` objects as input to their :code:`forward` pass.
+Therefore, adding your pytorch models to GraphNeT is as easy as changing the inheritance and adjusting the input to expect :code:`Data` objects.
+
+Below is an example of a simple pytorch model
+
+.. code-block:: python
+
+ import torch
+
+
+ class MyModel(torch.nn.Module):
+
+ def __init__(self,
+ input_dim : int = 5,
+ output_dim : int = 10):
+
+ super().__init__()
+ self._layer = torch.nn.Linear(input_dim, output_dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self._layer(x)
+
+Modifying this model to work in GraphNeT requires changing the inheritance to model, and configuring the input to be :code:`Data` objects:
+
+.. code-block:: python
+
+ import torch
+ from graphnet.models import Model
+ from torch_geometric.data import Data
+
+
+ class MyGraphNeTModel(Model):
+
+ def __init__(self,
+ input_dim : int = 5,
+ output_dim : int = 10):
+
+ super().__init__()
+ self._layer = torch.nn.Linear(input_dim, output_dim)
+
+ def forward(self, data: Data) -> torch.Tensor:
+ x = data.x
+ return self._layer(x)
+
+The model is then ready to be used as a :code:`backbone` in :code:`StandardModel`, or to be included in your own implementation.
+
+Experiment Tracking
+--------------------
+
+You can track your experiment using `Weights & Biases `_ by passing the :code:`WandbLogger` to :py:func:`~graphnet.models.model.Model.fit`:
+
+.. code-block:: python
+
+ import os
+
+ from pytorch_lightning.loggers import WandbLogger
+
+ # Create wandb directory
+ wandb_dir = "./wandb/"
+ os.makedirs(wandb_dir, exist_ok=True)
+
+ # Initialise Weights & Biases (W&B) run
+ wandb_logger = WandbLogger(
+ project="example-script",
+ entity="graphnet-team",
+ save_dir=wandb_dir,
+ log_model=True,
+ )
+
+ # Fit Model
+ model = Model(...)
+ model.fit(
+ ...,
+ logger=wandb_logger,
+ )
+
+By using :code:`WandbLogger`, your training and validation loss is logged and you have the full functionality of Weights & Biases available.
+This means, e.g., that you can log your :py:class:`~graphnet.utilities.config.model_config.ModelConfig`, :py:class:`~graphnet.utilities.config.dataset_config.DatasetConfig`, and :py:class:`~graphnet.utilities.config.training_config.TrainingConfig` as:
+
+.. code-block:: python
+
+ wandb_logger.experiment.config.update(training_config)
+ wandb_logger.experiment.config.update(model_config.as_dict())
+ wandb_logger.experiment.config.update(dataset_config.as_dict())
+
+Using an experiment tracking system like Weights & Biases to track training metrics as well as artifacts like configuration files greatly improves reproducibility, experiment transparency, and collaboration.
+This is because you can easily recreate an previous run from the saved artifacts, you can directly compare runs with diffierent model configurations and hyperparameter choices, and share and compare your results to other people on your team.
+Therefore, we strongly recommend using Weights & Biases or a similar system when training and optimising models meant for actual physics use.
+
+
+
+Example: Energy Reconstruction using :code:`ModelConfig`
+--------------------------------------------------
+
+Below is a minimal example for training a GNN in GraphNeT for energy reconstruction on a small data sample in GraphNeT, using configuration files:
+
+.. code-block:: python
+
+ # Import(s)
+ import os
+
+ from graphnet.constants import CONFIG_DIR # Local path to graphnet/configs
+ from graphnet.data.dataloader import DataLoader
+ from graphnet.models import Model
+ from graphnet.utilities.config import DatasetConfig, ModelConfig
+
+ # Configuration
+ dataset_config_path = f"{CONFIG_DIR}/datasets/training_example_data_sqlite.yml"
+ model_config_path = f"{CONFIG_DIR}/models/example_energy_reconstruction_model.yml"
+
+ # Build model
+ model_config = ModelConfig.load(model_config_path)
+ model = Model.from_config(model_config, trust=True)
+
+ # Construct dataloaders
+ dataset_config = DatasetConfig.load(dataset_config_path)
+ dataloaders = DataLoader.from_dataset_config(
+ dataset_config,
+ batch_size=16,
+ num_workers=1,
+ )
+
+ # Train model
+ model.fit(
+ dataloaders["train"],
+ dataloaders["validation"],
+ gpus=[0],
+ max_epochs=5,
+ )
+
+ # Predict on test set and return as pandas.DataFrame
+ results = model.predict_as_dataframe(
+ dataloaders["test"],
+ additional_attributes=model.target_labels + ["event_no"],
+ )
+
+ # Save predictions and model to file
+ outdir = "tutorial_output"
+ os.makedirs(outdir, exist_ok=True)
+ results.to_csv(f"{outdir}/results.csv")
+ model.save_state_dict(f"{outdir}/state_dict.pth")
+ model.save(f"{outdir}/model.pth")
+
+Because :code:`ModelConfig` summarises a :code:`Model` completely, including its :code:`Task`\ (s),
+the only modifications required to change the example to reconstruct (or classify) a different attribute than energy, is to pass a :code:`ModelConfig` that defines a model with the corresponding :code:`Task`.
+Similarly, if you wanted to train on a different :code:`Dataset`, you would just have to pass a :code:`DatasetConfig` that defines *that* :code:`Dataset` instead.
diff --git a/_sources/substitutions.rst.txt b/_sources/substitutions.rst.txt
new file mode 100644
index 000000000..d53e41319
--- /dev/null
+++ b/_sources/substitutions.rst.txt
@@ -0,0 +1,15 @@
+.. |graphnet| image:: ../../../assets/identity/favicon.svg
+ :width: 25px
+ :height: 25px
+ :alt: graphnet
+ :align: bottom
+
+.. |graphnet-header| image:: ../../../assets/identity/favicon.svg
+ :width: 50px
+ :height: 50px
+ :alt: graphnet
+ :align: bottom
+
+.. |clearfloat| raw:: html
+
+
\ No newline at end of file
diff --git a/_static/basic.css b/_static/basic.css
index 30fee9d0f..f316efcb4 100644
--- a/_static/basic.css
+++ b/_static/basic.css
@@ -4,7 +4,7 @@
*
* Sphinx stylesheet -- basic theme.
*
- * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS.
+ * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
diff --git a/_static/doctools.js b/_static/doctools.js
index d06a71d75..4d67807d1 100644
--- a/_static/doctools.js
+++ b/_static/doctools.js
@@ -4,7 +4,7 @@
*
* Base JavaScript utilities for all Sphinx HTML documentation.
*
- * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS.
+ * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
diff --git a/_static/language_data.js b/_static/language_data.js
index 250f5665f..367b8ed81 100644
--- a/_static/language_data.js
+++ b/_static/language_data.js
@@ -5,7 +5,7 @@
* This script contains the language-specific data used by searchtools.js,
* namely the list of stopwords, stemmer, scorer and splitter.
*
- * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS.
+ * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
@@ -13,7 +13,7 @@
var stopwords = ["a", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it", "near", "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"];
-/* Non-minified version is copied as a separate JS file, is available */
+/* Non-minified version is copied as a separate JS file, if available */
/**
* Porter Stemmer
diff --git a/_static/searchtools.js b/_static/searchtools.js
index 7918c3fab..92da3f8b2 100644
--- a/_static/searchtools.js
+++ b/_static/searchtools.js
@@ -4,7 +4,7 @@
*
* Sphinx JavaScript utilities for the full-text search.
*
- * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS.
+ * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
@@ -99,7 +99,7 @@ const _displayItem = (item, searchTerms, highlightTerms) => {
.then((data) => {
if (data)
listItem.appendChild(
- Search.makeSearchSummary(data, searchTerms)
+ Search.makeSearchSummary(data, searchTerms, anchor)
);
// highlight search terms in the summary
if (SPHINX_HIGHLIGHT_ENABLED) // set in sphinx_highlight.js
@@ -116,8 +116,8 @@ const _finishSearch = (resultCount) => {
);
else
Search.status.innerText = _(
- `Search finished, found ${resultCount} page(s) matching the search query.`
- );
+ "Search finished, found ${resultCount} page(s) matching the search query."
+ ).replace('${resultCount}', resultCount);
};
const _displayNextItem = (
results,
@@ -137,6 +137,22 @@ const _displayNextItem = (
// search finished, update title and status message
else _finishSearch(resultCount);
};
+// Helper function used by query() to order search results.
+// Each input is an array of [docname, title, anchor, descr, score, filename].
+// Order the results by score (in opposite order of appearance, since the
+// `_displayNextItem` function uses pop() to retrieve items) and then alphabetically.
+const _orderResultsByScoreThenName = (a, b) => {
+ const leftScore = a[4];
+ const rightScore = b[4];
+ if (leftScore === rightScore) {
+ // same score: sort alphabetically
+ const leftTitle = a[1].toLowerCase();
+ const rightTitle = b[1].toLowerCase();
+ if (leftTitle === rightTitle) return 0;
+ return leftTitle > rightTitle ? -1 : 1; // inverted is intentional
+ }
+ return leftScore > rightScore ? 1 : -1;
+};
/**
* Default splitQuery function. Can be overridden in ``sphinx.search`` with a
@@ -160,13 +176,26 @@ const Search = {
_queued_query: null,
_pulse_status: -1,
- htmlToText: (htmlString) => {
+ htmlToText: (htmlString, anchor) => {
const htmlElement = new DOMParser().parseFromString(htmlString, 'text/html');
- htmlElement.querySelectorAll(".headerlink").forEach((el) => { el.remove() });
+ for (const removalQuery of [".headerlinks", "script", "style"]) {
+ htmlElement.querySelectorAll(removalQuery).forEach((el) => { el.remove() });
+ }
+ if (anchor) {
+ const anchorContent = htmlElement.querySelector(`[role="main"] ${anchor}`);
+ if (anchorContent) return anchorContent.textContent;
+
+ console.warn(
+ `Anchored content block not found. Sphinx search tries to obtain it via DOM query '[role=main] ${anchor}'. Check your theme or template.`
+ );
+ }
+
+ // if anchor not specified or not found, fall back to main content
const docContent = htmlElement.querySelector('[role="main"]');
- if (docContent !== undefined) return docContent.textContent;
+ if (docContent) return docContent.textContent;
+
console.warn(
- "Content block not found. Sphinx search tries to obtain it via '[role=main]'. Could you check your theme or template."
+ "Content block not found. Sphinx search tries to obtain it via DOM query '[role=main]'. Check your theme or template."
);
return "";
},
@@ -239,16 +268,7 @@ const Search = {
else Search.deferQuery(query);
},
- /**
- * execute search (requires search index to be loaded)
- */
- query: (query) => {
- const filenames = Search._index.filenames;
- const docNames = Search._index.docnames;
- const titles = Search._index.titles;
- const allTitles = Search._index.alltitles;
- const indexEntries = Search._index.indexentries;
-
+ _parseQuery: (query) => {
// stem the search terms and add them to the correct list
const stemmer = new Stemmer();
const searchTerms = new Set();
@@ -284,16 +304,32 @@ const Search = {
// console.info("required: ", [...searchTerms]);
// console.info("excluded: ", [...excludedTerms]);
- // array of [docname, title, anchor, descr, score, filename]
- let results = [];
+ return [query, searchTerms, excludedTerms, highlightTerms, objectTerms];
+ },
+
+ /**
+ * execute search (requires search index to be loaded)
+ */
+ _performSearch: (query, searchTerms, excludedTerms, highlightTerms, objectTerms) => {
+ const filenames = Search._index.filenames;
+ const docNames = Search._index.docnames;
+ const titles = Search._index.titles;
+ const allTitles = Search._index.alltitles;
+ const indexEntries = Search._index.indexentries;
+
+ // Collect multiple result groups to be sorted separately and then ordered.
+ // Each is an array of [docname, title, anchor, descr, score, filename].
+ const normalResults = [];
+ const nonMainIndexResults = [];
+
_removeChildren(document.getElementById("search-progress"));
- const queryLower = query.toLowerCase();
+ const queryLower = query.toLowerCase().trim();
for (const [title, foundTitles] of Object.entries(allTitles)) {
- if (title.toLowerCase().includes(queryLower) && (queryLower.length >= title.length/2)) {
+ if (title.toLowerCase().trim().includes(queryLower) && (queryLower.length >= title.length/2)) {
for (const [file, id] of foundTitles) {
let score = Math.round(100 * queryLower.length / title.length)
- results.push([
+ normalResults.push([
docNames[file],
titles[file] !== title ? `${titles[file]} > ${title}` : title,
id !== null ? "#" + id : "",
@@ -308,46 +344,47 @@ const Search = {
// search for explicit entries in index directives
for (const [entry, foundEntries] of Object.entries(indexEntries)) {
if (entry.includes(queryLower) && (queryLower.length >= entry.length/2)) {
- for (const [file, id] of foundEntries) {
- let score = Math.round(100 * queryLower.length / entry.length)
- results.push([
+ for (const [file, id, isMain] of foundEntries) {
+ const score = Math.round(100 * queryLower.length / entry.length);
+ const result = [
docNames[file],
titles[file],
id ? "#" + id : "",
null,
score,
filenames[file],
- ]);
+ ];
+ if (isMain) {
+ normalResults.push(result);
+ } else {
+ nonMainIndexResults.push(result);
+ }
}
}
}
// lookup as object
objectTerms.forEach((term) =>
- results.push(...Search.performObjectSearch(term, objectTerms))
+ normalResults.push(...Search.performObjectSearch(term, objectTerms))
);
// lookup as search terms in fulltext
- results.push(...Search.performTermsSearch(searchTerms, excludedTerms));
+ normalResults.push(...Search.performTermsSearch(searchTerms, excludedTerms));
// let the scorer override scores with a custom scoring function
- if (Scorer.score) results.forEach((item) => (item[4] = Scorer.score(item)));
-
- // now sort the results by score (in opposite order of appearance, since the
- // display function below uses pop() to retrieve items) and then
- // alphabetically
- results.sort((a, b) => {
- const leftScore = a[4];
- const rightScore = b[4];
- if (leftScore === rightScore) {
- // same score: sort alphabetically
- const leftTitle = a[1].toLowerCase();
- const rightTitle = b[1].toLowerCase();
- if (leftTitle === rightTitle) return 0;
- return leftTitle > rightTitle ? -1 : 1; // inverted is intentional
- }
- return leftScore > rightScore ? 1 : -1;
- });
+ if (Scorer.score) {
+ normalResults.forEach((item) => (item[4] = Scorer.score(item)));
+ nonMainIndexResults.forEach((item) => (item[4] = Scorer.score(item)));
+ }
+
+ // Sort each group of results by score and then alphabetically by name.
+ normalResults.sort(_orderResultsByScoreThenName);
+ nonMainIndexResults.sort(_orderResultsByScoreThenName);
+
+ // Combine the result groups in (reverse) order.
+ // Non-main index entries are typically arbitrary cross-references,
+ // so display them after other results.
+ let results = [...nonMainIndexResults, ...normalResults];
// remove duplicate search results
// note the reversing of results, so that in the case of duplicates, the highest-scoring entry is kept
@@ -361,7 +398,12 @@ const Search = {
return acc;
}, []);
- results = results.reverse();
+ return results.reverse();
+ },
+
+ query: (query) => {
+ const [searchQuery, searchTerms, excludedTerms, highlightTerms, objectTerms] = Search._parseQuery(query);
+ const results = Search._performSearch(searchQuery, searchTerms, excludedTerms, highlightTerms, objectTerms);
// for debugging
//Search.lastresults = results.slice(); // a copy
@@ -466,14 +508,18 @@ const Search = {
// add support for partial matches
if (word.length > 2) {
const escapedWord = _escapeRegExp(word);
- Object.keys(terms).forEach((term) => {
- if (term.match(escapedWord) && !terms[word])
- arr.push({ files: terms[term], score: Scorer.partialTerm });
- });
- Object.keys(titleTerms).forEach((term) => {
- if (term.match(escapedWord) && !titleTerms[word])
- arr.push({ files: titleTerms[word], score: Scorer.partialTitle });
- });
+ if (!terms.hasOwnProperty(word)) {
+ Object.keys(terms).forEach((term) => {
+ if (term.match(escapedWord))
+ arr.push({ files: terms[term], score: Scorer.partialTerm });
+ });
+ }
+ if (!titleTerms.hasOwnProperty(word)) {
+ Object.keys(titleTerms).forEach((term) => {
+ if (term.match(escapedWord))
+ arr.push({ files: titleTerms[term], score: Scorer.partialTitle });
+ });
+ }
}
// no match but word was a required one
@@ -496,9 +542,8 @@ const Search = {
// create the mapping
files.forEach((file) => {
- if (fileMap.has(file) && fileMap.get(file).indexOf(word) === -1)
- fileMap.get(file).push(word);
- else fileMap.set(file, [word]);
+ if (!fileMap.has(file)) fileMap.set(file, [word]);
+ else if (fileMap.get(file).indexOf(word) === -1) fileMap.get(file).push(word);
});
});
@@ -549,8 +594,8 @@ const Search = {
* search summary for a given text. keywords is a list
* of stemmed words.
*/
- makeSearchSummary: (htmlText, keywords) => {
- const text = Search.htmlToText(htmlText);
+ makeSearchSummary: (htmlText, keywords, anchor) => {
+ const text = Search.htmlToText(htmlText, anchor);
if (text === "") return null;
const textLower = text.toLowerCase();
diff --git a/about/about.html b/about/about.html
new file mode 100644
index 000000000..42b3b8f24
--- /dev/null
+++ b/about/about.html
@@ -0,0 +1,414 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Usage — graphnet documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to content
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
GraphNeT is an open-source Python framework aimed at providing high quality, user friendly, end-to-end functionality to perform reconstruction tasks at neutrino telescopes using deep learning. GraphNeT makes it fast and easy to train complex models that can provide event reconstruction with state-of-the-art performance, for arbitrary detector configurations, with inference times that are orders of magnitude faster than traditional reconstruction techniques.
+GraphNeT provides a common, detector agnostic framework for ML developers and physicists that wish to use the state-of-the-art tools in their research. By uniting both user groups, GraphNeT aims to increase the longevity and usability of individual code contributions from ML developers by building a general, reusable software package based on software engineering best practices, and lowers the technical threshold for physicists that wish to use the most performant tools for their scientific problems.
GraphNeT comprises a number of modules providing the necessary tools to build workflows from ingesting raw training data in domain-specific formats to deploying trained models in domain-specific reconstruction chains, as illustrated in [the Figure](flowchart).
+
+
graphnet.models provides modular components subclassing torch.nn.Module, meaning that users only need to import a few existing, purpose-built components and chain them together to form a complete model. ML developers can contribute to GraphNeT by extending this suite of model components — through new layer types, physics tasks, graph connectivities, etc. — and experiment with optimising these for different reconstruction tasks using experiment tracking.
+
These models are trained using graphnet.training on data prepared using graphnet.data, to satisfy the high I/O loads required when training ML models on large batches of events, which domain-specific neutrino physics data formats typically do not allow.
+
Trained models are deployed to a domain-specific reconstruction chain, yielding predictions, using the components in graphnet.deployment. This can either be through model files or container images, making deployment as portable and dependency-free as possible.
+
By splitting up the model development as in flowchart, GraphNeT allows physics users to interface only with high-level building blocks or pre-trained models that can be used directly in their reconstruction chains, while allowing ML developers to continuously improve and expand the framework’s capabilities.
This project has received funding from the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement No. 890778.
+
The work of Rasmus Ørsøe was partly performed in the framework of the PUNCH4NFDI consortium supported by DFG fund “NFDI 39/1”, Germany.
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/api/graphnet.constants.html b/api/graphnet.constants.html
index 8b00f9333..b9ecca20c 100644
--- a/api/graphnet.constants.html
+++ b/api/graphnet.constants.html
@@ -123,10 +123,9 @@
-
+
-
@@ -283,14 +282,42 @@
Contains a Generic class for curated DataModules/Datasets.
+
Inheriting subclasses are data-specific implementations that allow the user to
+import and download pre-converteddatasets for training of deep learning based
+methods in GraphNeT.
Curated Datasets in GraphNeT are pre-converted datasets that have been
+prepared for training and evaluation of deep learning models. On these
+Datasets, graphnet users can train and benchmark their models against SOTA
+methods.
+
Construct CuratedDataset.
+
+
Parameters:
+
+
graph_definition (GraphDefinition) – Method that defines the data representation.
+
download_dir (str) – Directory to download dataset to.
+
truth (Optional) – List of event-level truth to include. Will
+include all available information if not given.
+
features (Optional) – List of input features from pulsemap to use.
+If not given, all available features will be
+used.
+
backend (Optional) – data backend to use. Either “parquet” or
+“sqlite”. Defaults to “parquet”.
+
train_dataloader_kwargs (Optional) – Arguments for the training
+DataLoader. Default None.
+
validation_dataloader_kwargs (Optional) – Arguments for the
+validation DataLoader, Default None.
+
test_dataloader_kwargs (Optional) – Arguments for the test
+DataLoader. Default None.
A base class for dataset/datamodule hosted at ERDA.
+
Inheriting subclasses will just need to fill out the _file_hashes
+attribute, which points to the file-id of a ERDA-hosted sharelink. It
+is assumed that sharelinks point to a single compressed file that has
+been compressed using tar with extension “.tar.gz”.
Dataset class for Parquet-files converted with ParquetWriter.
Construct Dataset.
+
+
NOTE: DataLoaders using this Dataset should have
+“multiprocessing_context = ‘spawn’” set to avoid thread locking.
+
Parameters:
-
path (Union[str, List[str]]) – Path to the file(s) from which this Dataset should read.
+
path (str) – Path to the file(s) from which this Dataset should read.
pulsemaps (Union[str, List[str]]) – Name(s) of the pulse map series that should be used to
construct the nodes on the individual graph objects, and their
features. Multiple pulse series maps can be used, e.g., when
@@ -579,11 +617,8 @@
string_selection (Optional[List[int]], default: None) – Subset of strings for which data should be read
and used to construct graph objects. Defaults to None, meaning
all strings for which data exists are used.
-
selection (Union[str, List[int], List[List[int]], None], default: None) – The events that should be read. This can be given either
-as list of indicies (in index_column); or a string-based
-selection used to query the Dataset for events passing the
-selection. Defaults to None, meaning that all events in the
-input files are read.
+
selection (Union[str, List[int], List[List[int]], None], default: None) – The batch ids to include in the dataset.
+Defaults to None, meaning that batches are read.
dtype (dtype, default: torch.float32) – Type of the feature tensor on the graph objects returned.
loss_weight_table (Optional[str], default: None) – Name of the table containing per-event loss
weights.
@@ -601,9 +636,10 @@
“10000 random events ~ event_no % 5 > 0” or “20% random
events ~ event_no % 5 > 0”).
graph_definition (GraphDefinition) – Method that defines the graph representation.
-
labels (Optional[Dict[str, Any]], default: None) – Dictionary of labels to be added to the dataset.
-
args (Any) –
-
kwargs (Any) –
+
cache_size (int, default: 1) – Number of batches to cache in memory.
+Must be at least 1. Defaults to 1.
Query table at a specific index, optionally with some selection.
+
Query a table at a specific index, optionally with some selection.
-
Return type:
-
List[Tuple[Any, ...]]
-
-
Parameters:
-
-
table (str) –
-
columns (List[str] | str) –
-
sequential_index (int | None) –
-
selection (str | None) –
+
Parameters:
+
+
table (str) – Table to be queried.
+
columns (Union[List[str], str]) – Columns to read out.
+
sequential_index (Optional[int], default: None) – Sequentially numbered index
+(i.e. in [0,len(self))) of the event to query. This _may_
+differ from the indexation used in self._indices. If no value
+is provided, the entire column is returned.
+
selection (Optional[str], default: None) – Selection to be imposed before reading out data.
+Defaults to None.
+
Return type:
+
ndarray
+
+
Returns:
+
+
List of tuples containing the values in columns. If the table
contains only scalar data for columns, a list of length 1 is
+returned