Skip to content

Commit

Permalink
update extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed Aug 6, 2024
1 parent 0c8071d commit a7ec4c0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
24 changes: 23 additions & 1 deletion src/graphnet/data/extractors/prometheus/prometheus_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,19 @@ class PrometheusTruthExtractor(PrometheusExtractor):
This Extractor will "initial_state" i.e. neutrino truth.
"""

def __init__(self, table_name: str = "mc_truth") -> None:
def __init__(
self,
table_name: str = "mc_truth",
transform_azimuth: bool = True,
) -> None:
"""Construct PrometheusTruthExtractor.
Args:
table_name: Name of the table in the parquet files that contain
event-level truth. Defaults to "mc_truth".
transform_azimuth: Some simulation has the azimuthal angle
written in a [-pi, pi] projection instead of [0, 2pi].
If True, the azimuthal angle will be transformed to [0, 2pi].
"""
columns = [
"interaction",
Expand All @@ -67,9 +74,24 @@ def __init__(self, table_name: str = "mc_truth") -> None:
"initial_state_x",
"initial_state_y",
"initial_state_z",
"bjorken_x",
"bjorken_y",
]
self._transform_az = transform_azimuth
super().__init__(extractor_name=table_name, columns=columns)

def __call__(self, event: pd.DataFrame) -> pd.DataFrame:
"""Extract event-level truth information."""
# Extract data
res = super().__call__(event=event)
# transform azimuth from [-pi, pi] to [0, 2pi] if wanted
if self._transform_az:
if len(res["initial_state_azimuth"]) > 0:
azimuth = np.asarray(res["initial_state_azimuth"]) + np.pi
azimuth = azimuth.tolist() # back to list
res["initial_state_azimuth"] = azimuth
return res


class PrometheusFeatureExtractor(PrometheusExtractor):
"""Class for extracting pulses/photons from Prometheus parquet files."""
Expand Down
2 changes: 1 addition & 1 deletion src/graphnet/data/readers/prometheus_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _keep_event(self, extracted_event: OrderedDict) -> bool:
)
else:
self.warning_once(
f"{filter._filter_on} not in file." " Filter skipped."
f"{filter._filter_on} not in file. Filter skipped."
)
filter_counter += True
if filter_counter < len(self._filters):
Expand Down

0 comments on commit a7ec4c0

Please sign in to comment.