Skip to content

Commit

Permalink
fix: add check for keep and drop branches (#70)
Browse files Browse the repository at this point in the history
* CLI commands

* formatting...
  • Loading branch information
zbilodea authored Feb 12, 2024
1 parent 7d9d2ea commit c3be4c1
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 2 deletions.
18 changes: 18 additions & 0 deletions src/hepconvert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def parquet_to_root(
@click.argument("destination", type=click.Path())
@click.argument("file")
@click.option("--drop-branches", default=None, type=list or dict or str, required=False)
@click.option("--keep-branches", default=None, type=list or dict or str, required=False)
@click.option("--drop-trees", default=None, type=list or str, required=False)
@click.option("--keep-trees", default=None, type=list or str, required=False)
@click.option("--title", required=False, default="")
@click.option(
"--initial-basket-capacity",
Expand All @@ -108,7 +110,9 @@ def copy_root(
file,
*,
drop_branches=None,
keep_branches=None,
drop_trees=None,
keep_trees=None,
force=False,
title="",
field_name=lambda outer, inner: inner if outer == "" else outer + "_" + inner,
Expand All @@ -128,7 +132,9 @@ def copy_root(
destination,
file,
drop_branches=drop_branches,
keep_branches=keep_branches,
drop_trees=drop_trees,
keep_trees=keep_trees,
force=force,
title=title,
field_name=field_name,
Expand Down Expand Up @@ -225,6 +231,10 @@ def add(
default=100,
help="If an integer, the maximum number of entries to include in each iteration step; if a string, the maximum memory size to include. The string must be a number followed by a memory unit, such as “100 MB”.",
)
@click.option("--drop-branches", default=None, type=list or dict or str, required=False)
@click.option("--keep-branches", default=None, type=list or dict or str, required=False)
@click.option("--drop-trees", default=None, type=list or str, required=False)
@click.option("--keep-trees", default=None, type=list or str, required=False)
@click.option(
"--force", default=True, help="Overwrite destination file if it already exists"
)
Expand All @@ -251,6 +261,10 @@ def merge_root(
fieldname_separator="_",
title="",
field_name=lambda outer, inner: inner if outer == "" else outer + "_" + inner,
drop_branches=None,
keep_branches=None,
drop_trees=None,
keep_trees=None,
initial_basket_capacity=10,
resize_factor=10.0,
counter_name=lambda counted: "n" + counted,
Expand All @@ -272,6 +286,10 @@ def merge_root(
fieldname_separator=fieldname_separator,
title=title,
field_name=field_name,
drop_branches=drop_branches,
keep_branches=keep_branches,
drop_trees=drop_trees,
keep_trees=keep_trees,
initial_basket_capacity=initial_basket_capacity,
resize_factor=resize_factor,
counter_name=counter_name,
Expand Down
6 changes: 4 additions & 2 deletions src/hepconvert/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# def get_counter_branches(tree):
# counter_branches =
from __future__ import annotations

import numpy as np
Expand Down Expand Up @@ -53,6 +51,10 @@ def filter_branches(tree, keep_branches, drop_branches, count_branches):
"""
Creates lambda function for filtering branches based on keep_branches or drop_branches.
"""
if drop_branches and keep_branches:
msg = "Can specify either drop_branches or keep_branches, not both."
raise ValueError(msg) from None

if drop_branches:
if isinstance(drop_branches, dict): # noqa: SIM102
if (
Expand Down
4 changes: 4 additions & 0 deletions src/hepconvert/copy_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def copy_root(
trees = f.keys(filter_classname="TTree", cycle=False, recursive=False)

# Check that drop_trees keys are valid/refer to a tree:
if drop_trees and keep_trees:
msg = "Can specify either drop_trees or keep_trees, not both."
raise ValueError(msg) from None
if keep_trees:
if isinstance(keep_trees, list):
for key in keep_trees:
Expand All @@ -162,6 +165,7 @@ def copy_root(
drop_trees = [tree for tree in trees if tree not in keep_trees]
else:
drop_trees = [tree for tree in trees if tree != keep_trees[0]]

if drop_trees:
if isinstance(drop_trees, list):
for key in drop_trees:
Expand Down
4 changes: 4 additions & 0 deletions src/hepconvert/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def merge_root(
trees = f.keys(filter_classname="TTree", cycle=False, recursive=False)

# Check that drop_trees keys are valid/refer to a tree:
if drop_trees and keep_trees:
msg = "Can specify either drop_trees or keep_trees, not both."
raise ValueError(msg) from None

if keep_trees:
if isinstance(keep_trees, list):
for key in keep_trees:
Expand Down
4 changes: 4 additions & 0 deletions src/hepconvert/root_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ def _filter_branches(tree, keep_branches, drop_branches):
"""
Creates lambda function for filtering branches based on keep_branches or drop_branches.
"""
if drop_branches and keep_branches:
msg = "Can specify either drop_branches or keep_branches, not both."
raise ValueError(msg) from None

if drop_branches:
if isinstance(drop_branches, str):
drop_branches = tree.keys(filter_name=drop_branches)
Expand Down

0 comments on commit c3be4c1

Please sign in to comment.