Skip to content

Commit

Permalink
split names with with _
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewfeickert committed Oct 10, 2023
1 parent 8e1c3da commit c1a5f08
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions src/pyhf/experimental/modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def func(d: Sequence[float]) -> Any:


def make_builder(
funcname: str, deps: list[str], newparams: dict[str, dict[str, Sequence[float]]]
func_name: str, deps: list[str], new_params: dict[str, dict[str, Sequence[float]]]
) -> BaseBuilder:
class _builder(BaseBuilder):
is_shared = False
Expand All @@ -83,13 +83,13 @@ def append(self, key, channel, sample, thismod, defined_samp):
moddata = self.collect(thismod, nom)
self.builder_data[key][sample]["data"]["mask"] += moddata["mask"]
if thismod:
if thismod["name"] != funcname:
if thismod["name"] != func_name:
print(thismod)
self.builder_data["funcs"].setdefault(
thismod["name"], thismod["data"]["expr"]
)
self.required_parsets = {
k: [_allocate_new_param(v)] for k, v in newparams.items()
k: [_allocate_new_param(v)] for k, v in new_params.items()
}

def finalize(self):
Expand All @@ -99,10 +99,10 @@ def finalize(self):


def make_applier(
funcname: str, deps: list[str], newparams: dict[str, dict[str, Sequence[float]]]
func_name: str, deps: list[str], new_params: dict[str, dict[str, Sequence[float]]]
) -> BaseApplier:
class _applier(BaseApplier):
name = funcname
name = func_name
op_code = "multiplication"

def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
Expand All @@ -120,7 +120,7 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
self.param_viewer = ParamViewer(
parfield_shape, pdfconfig.par_map, pars_for_applier
)
self._custommod_mask = [
self._custom_mod_mask = [
[[builder_data[modname][s]["data"]["mask"]] for s in pdfconfig.samples]
for modname in _modnames
]
Expand All @@ -131,14 +131,14 @@ def _precompute(self):
tensorlib, _ = get_backend()
if not self.param_viewer.index_selection:
return

Check warning on line 133 in src/pyhf/experimental/modifiers.py

View check run for this annotation

Codecov / codecov/patch

src/pyhf/experimental/modifiers.py#L133

Added line #L133 was not covered by tests
self.custommod_mask = tensorlib.tile(
tensorlib.astensor(self._custommod_mask),
self.custom_mod_mask = tensorlib.tile(
tensorlib.astensor(self._custom_mod_mask),
(1, 1, self.batch_size or 1, 1),
)
self.custommod_mask_bool = tensorlib.astensor(
self.custommod_mask, dtype="bool"
self.custom_mod_mask_bool = tensorlib.astensor(
self.custom_mod_mask, dtype="bool"
)
self.custommod_default = tensorlib.ones(self.custommod_mask.shape)
self.custom_mod_default = tensorlib.ones(self.custom_mod_mask.shape)

def apply(self, pars):
"""
Expand All @@ -152,27 +152,29 @@ def apply(self, pars):
deps = self.param_viewer.get(pars)
print("deps", deps.shape)

Check warning on line 153 in src/pyhf/experimental/modifiers.py

View check run for this annotation

Codecov / codecov/patch

src/pyhf/experimental/modifiers.py#L152-L153

Added lines #L152 - L153 were not covered by tests
results = tensorlib.astensor([f(deps) for f in self.funcs])
results = tensorlib.einsum("msab,m->msab", self.custommod_mask, results)
results = tensorlib.einsum(

Check warning on line 155 in src/pyhf/experimental/modifiers.py

View check run for this annotation

Codecov / codecov/patch

src/pyhf/experimental/modifiers.py#L155

Added line #L155 was not covered by tests
"msab,m->msab", self.custom_mod_mask, results
)
else:
deps = self.param_viewer.get(pars)
print("deps", deps.shape)
results = tensorlib.astensor([f(deps) for f in self.funcs])
results = tensorlib.einsum(
"msab,ma->msab", self.custommod_mask, results
"msab,ma->msab", self.custom_mod_mask, results
)
results = tensorlib.where(
self.custommod_mask_bool, results, self.custommod_default
self.custom_mod_mask_bool, results, self.custom_mod_default
)
return results

return _applier


def add_custom_modifier(
funcname: str, deps: list[str], newparams: dict[str, dict[str, Sequence[float]]]
func_name: str, deps: list[str], new_params: dict[str, dict[str, Sequence[float]]]
) -> dict[str, tuple[BaseBuilder, BaseApplier]]:
_builder = make_builder(funcname, deps, newparams)
_applier = make_applier(funcname, deps, newparams)
_builder = make_builder(func_name, deps, new_params)
_applier = make_applier(func_name, deps, new_params)

modifier_set = {_applier.name: (_builder, _applier)}
modifier_set.update(**pyhf.modifiers.histfactory_set)
Expand Down

0 comments on commit c1a5f08

Please sign in to comment.