Skip to content

Commit

Permalink
Implement serialization methods for DarkNewsCrossSection
Browse files Browse the repository at this point in the history
  • Loading branch information
nickkamp1 committed Feb 15, 2024
1 parent 309047d commit b8384fd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 40 deletions.
16 changes: 2 additions & 14 deletions projects/interactions/private/pybindings/DarkNewsCrossSection.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,20 +498,8 @@ void register_DarkNewsCrossSection(pybind11::module_ & m) {
.def("SampleFinalState",&DarkNewsCrossSection::SampleFinalState)
.def("get_representation", &DarkNewsCrossSection::get_representation)
.def(pybind11::pickle(
[](const LI::interactions::DarkNewsCrossSection & cpp_obj) {
pybind11::object self;
if(dynamic_cast<LI::interactions::pyDarkNewsCrossSection const *>(&cpp_obj) != nullptr and dynamic_cast<LI::interactions::pyDarkNewsCrossSection const *>(&cpp_obj)->self) {
self = pybind11::reinterpret_borrow<pybind11::object>(dynamic_cast<LI::interactions::pyDarkNewsCrossSection const *>(&cpp_obj)->self);
} else {
auto *tinfo = pybind11::detail::get_type_info(typeid(LI::interactions::DarkNewsCrossSection));
pybind11::handle self_handle = get_object_handle(static_cast<const LI::interactions::DarkNewsCrossSection *>(&cpp_obj), tinfo);
self = pybind11::reinterpret_borrow<pybind11::object>(self_handle);
}
pybind11::dict d;
if (pybind11::hasattr(self, "__dict__")) {
d = self.attr("__dict__");
}
return pybind11::make_tuple(d);
[](LI::interactions::DarkNewsCrossSection & cpp_obj) {
return pybind11::make_tuple(cpp_obj.get_representation());
},
[](const pybind11::tuple &t) {
if (t.size() != 1) {
Expand Down
52 changes: 30 additions & 22 deletions python/LIDarkNews.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,18 @@ def __init__(
self.table_dir = table_dir
self.interpolate_differential = interpolate_differential

# Define the target particle
# make sure protons are stored as H nuclei
self.target_type = Particle.ParticleType(self.ups_case.nuclear_target.pdgid)
if self.target_type==Particle.ParticleType.PPlus:
self.target_type = Particle.ParticleType.HNucleus

# 2D table in E, sigma
self.total_cross_section_table = np.empty((0, 2), dtype=float)
self.total_cross_section_interpolator = None
# 3D table in E, z, dsigma/dQ2 where z = (Q2 - Q2min) / (Q2max - Q2min)
self.differential_cross_section_table = np.empty((0, 3), dtype=float)
self.differential_cross_section_interpolator = None

if table_dir is None:
print(
"No table_dir specified; disabling interpolation\nWARNING: this will siginficantly slow down event generation"
)
return

# Make the table directory where will we store cross section integrators
# Make the table directory where will we store cross section tables
table_dir_exists = False
if os.path.exists(self.table_dir):
# print("Directory '%s' already exists"%self.table_dir)
Expand All @@ -190,7 +182,35 @@ def __init__(
if os.path.exists(diff_xsec_file):
self.differential_cross_section_table = np.load(diff_xsec_file)

self._redefine_interpolation_objects(total=True, diff=True)
self.configure()

# serialization method
def get_representation(self):
return {"total_cross_section_table":self.total_cross_section_table,
"differential_cross_section_table":self.differential_cross_section_table,
"ups_case":self.ups_case,
"tolerance":self.tolerance,
"interp_tolerance":self.interp_tolerance,
"table_dir":self.table_dir,
"interpolate_differential":self.interpolate_differential
}

# Configure function to set up member variables
# assumes we have defined the following:
# ups_case, total_cross_section_table, differential_cross_section_table,
# tolerance, interp_tolerance, table_dir, interpolate_differential
def configure(self):

# Define the target particle
# make sure protons are stored as H nuclei
self.target_type = Particle.ParticleType(self.ups_case.nuclear_target.pdgid)
if self.target_type==Particle.ParticleType.PPlus:
self.target_type = Particle.ParticleType.HNucleus

# Initialize interpolation objects
self.total_cross_section_interpolator = None
self.differential_cross_section_interpolator = None
self._redefine_interpolation_objects(total=True, diff=True)

# Sorts and redefines scipy interpolation objects
def _redefine_interpolation_objects(self, total=False, diff=False):
Expand Down Expand Up @@ -380,18 +400,6 @@ def SaveInterpolationTables(self, total=True, diff=True):
) as f:
np.save(f, self.differential_cross_section_table)

##### START METHODS FOR SERIALIZATION #########
# def get_initialized_dict(config):
# # do the intitialization step
# pddn = PyDerivedDarkNews(config)
# return pddn.__dict__
# # return the conent of __dict__ for PyDerivedDarkNews

# @staticmethod
# def get_config(self):
# return self.config
##### END METHODS FOR SERIALIZATION #########

def GetPossiblePrimaries(self):
return [Particle.ParticleType(self.ups_case.nu_projectile.pdgid)]

Expand Down
4 changes: 0 additions & 4 deletions resources/Examples/Example1/DIS_ATLAS.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@
primary_injection_distributions["energy"] = edist
primary_physical_distributions["energy"] = edist

# we need this conversion to make sure the flux is in units of 1/m2
flux_unit_conv = LI.distributions.NormalizationConstant((100 / 1)**2)
primary_physical_distributions["flux_norm"] = flux_unit_conv

# direction distribution
# let's just inject upwards
injection_dir = LI.math.Vector3D(0, 0, 1)
Expand Down

0 comments on commit b8384fd

Please sign in to comment.