Skip to content

Commit

Permalink
Add serialization methods for decays
Browse files Browse the repository at this point in the history
  • Loading branch information
nickkamp1 committed Feb 15, 2024
1 parent b8384fd commit bdf4755
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 37 deletions.
2 changes: 1 addition & 1 deletion projects/interactions/private/DarkNewsDecay.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace interactions {

DarkNewsDecay::DarkNewsDecay() {}

pybind11::object DarkNewsDecay::get_self() {
pybind11::object DarkNewsDecay::get_representation() {
return pybind11::cast<pybind11::none>(Py_None);
}

Expand Down
74 changes: 39 additions & 35 deletions projects/interactions/private/pybindings/DarkNewsDecay.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <pybind11/stl.h>
#include <pybind11/embed.h>

#include "../../public/LeptonInjector/interactions/CrossSection.h"
#include "../../public/LeptonInjector/interactions/DarkNewsDecay.h"
#include "../../../dataclasses/public/LeptonInjector/dataclasses/Particle.h"
#include "../../../dataclasses/public/LeptonInjector/dataclasses/InteractionRecord.h"
Expand Down Expand Up @@ -68,7 +67,7 @@

namespace LI {
namespace interactions {
// Trampoline class for CrossSection
// Trampoline class for Decay
class pyDecay : public Decay {
public:
using Decay::Decay;
Expand Down Expand Up @@ -195,7 +194,7 @@ class pyDecay : public Decay {
)
}

pybind11::object get_self() {
pybind11::object get_representation() {
return self;
}
};
Expand Down Expand Up @@ -316,8 +315,37 @@ class pyDarkNewsDecay : public DarkNewsDecay {
)
}

pybind11::object get_self() override {
return self;
pybind11::object get_representation() override {
const DarkNewsDecay * ref;
if(self) {
ref = self.cast<DarkNewsDecay *>();
} else {
ref = this;
}
auto *tinfo = pybind11::detail::get_type_info(typeid(DarkNewsDecay));
pybind11::function override_func =
tinfo ? pybind11::detail::get_type_override(static_cast<const DarkNewsDecay *>(ref), tinfo, "get_representation") : pybind11::function();
if (override_func) {
pybind11::object o = override_func();
if(not pybind11::isinstance<pybind11::dict>(o)) {
throw std::runtime_error("get_representation must return a dict");
}
return o;
}

pybind11::object _self;
if(this->self) {
self = pybind11::reinterpret_borrow<pybind11::object>(this->self);
} else {
auto *tinfo = pybind11::detail::get_type_info(typeid(DarkNewsDecay));
pybind11::handle self_handle = get_object_handle(static_cast<const DarkNewsDecay *>(this), tinfo);
_self = pybind11::reinterpret_borrow<pybind11::object>(self_handle);
}
pybind11::dict d;
if (pybind11::hasattr(self, "__dict__")) {
d = _self.attr("__dict__");
}
return d;
}
};
} // end interactions namespace
Expand All @@ -344,22 +372,10 @@ void register_DarkNewsDecay(pybind11::module_ & m) {
.def("FinalStateProbability",&DarkNewsDecay::FinalStateProbability)
.def("SampleFinalState",&DarkNewsDecay::SampleFinalState)
.def("SampleRecordFromDarkNews",&DarkNewsDecay::SampleRecordFromDarkNews)
.def("get_self", &pyDarkNewsDecay::get_self)
.def("get_representation", &pyDarkNewsDecay::get_representation)
.def(pybind11::pickle(
[](const LI::interactions::pyDarkNewsDecay & cpp_obj) {
pybind11::object self;
if(cpp_obj.self) {
self = pybind11::reinterpret_borrow<pybind11::object>(cpp_obj.self);
} else {
auto *tinfo = pybind11::detail::get_type_info(typeid(DarkNewsDecay));
pybind11::handle self_handle = get_object_handle(static_cast<const DarkNewsDecay *>(&cpp_obj), tinfo);
self = pybind11::reinterpret_borrow<pybind11::object>(self_handle);
}
pybind11::dict d;
if (pybind11::hasattr(self, "__dict__")) {
d = self.attr("__dict__");
}
return pybind11::make_tuple(d);
[](LI::interactions::pyDarkNewsDecay & cpp_obj) {
return pybind11::make_tuple(cpp_obj.get_representation());
},
[](const pybind11::tuple &t) {
if (t.size() != 1) {
Expand Down Expand Up @@ -389,22 +405,10 @@ void register_DarkNewsDecay(pybind11::module_ & m) {
.def("FinalStateProbability",&DarkNewsDecay::FinalStateProbability)
.def("SampleFinalState",&DarkNewsDecay::SampleFinalState)
.def("SampleRecordFromDarkNews",&DarkNewsDecay::SampleRecordFromDarkNews)
.def("get_self", &DarkNewsDecay::get_self)
.def("get_representation", &DarkNewsDecay::get_representation)
.def(pybind11::pickle(
[](const LI::interactions::DarkNewsDecay & cpp_obj) {
pybind11::object self;
if(dynamic_cast<LI::interactions::pyDarkNewsDecay const *>(&cpp_obj) != nullptr and dynamic_cast<LI::interactions::pyDarkNewsDecay const *>(&cpp_obj)->self) {
self = pybind11::reinterpret_borrow<pybind11::object>(dynamic_cast<LI::interactions::pyDarkNewsDecay const *>(&cpp_obj)->self);
} else {
auto *tinfo = pybind11::detail::get_type_info(typeid(LI::interactions::DarkNewsDecay));
pybind11::handle self_handle = get_object_handle(static_cast<const LI::interactions::DarkNewsDecay *>(&cpp_obj), tinfo);
self = pybind11::reinterpret_borrow<pybind11::object>(self_handle);
}
pybind11::dict d;
if (pybind11::hasattr(self, "__dict__")) {
d = self.attr("__dict__");
}
return pybind11::make_tuple(d);
[](LI::interactions::DarkNewsDecay & cpp_obj) {
return pybind11::make_tuple(cpp_obj.get_representation());
},
[](const pybind11::tuple &t) {
if (t.size() != 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ friend cereal::access;

DarkNewsDecay();

virtual pybind11::object get_self();
virtual pybind11::object get_representation();

virtual bool equal(Decay const & other) const override;

Expand Down
12 changes: 12 additions & 0 deletions python/LIDarkNews.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,18 @@ def __init__(self, dec_case, table_dir=None):
if table_dir_exists:
self.SetIntegratorAndNorm()

# serialization method
def get_representation(self):
return {"decay_integrator":self.decay_integrator,
"decay_norm":self.decay_norm,
"dec_case":self.dec_case,
"PS_samples":self.PS_samples,
"PS_weights":self.PS_weights,
"PS_weights_CDF":self.PS_weights_CDF,
"total_width":self.total_width,
"table_dir":self.table_dir
}

def SetIntegratorAndNorm(self):
# Try to find the decay integrator
int_file = os.path.join(self.table_dir, "decay_integrator.pkl")
Expand Down

0 comments on commit bdf4755

Please sign in to comment.