From bdf47557fe284f0d661b7b9a75226602e5453b74 Mon Sep 17 00:00:00 2001 From: Nicholas Kamp Date: Thu, 15 Feb 2024 14:22:20 -0500 Subject: [PATCH] Add serialization methods for decays --- .../interactions/private/DarkNewsDecay.cxx | 2 +- .../private/pybindings/DarkNewsDecay.h | 74 ++++++++++--------- .../interactions/DarkNewsDecay.h | 2 +- python/LIDarkNews.py | 12 +++ 4 files changed, 53 insertions(+), 37 deletions(-) diff --git a/projects/interactions/private/DarkNewsDecay.cxx b/projects/interactions/private/DarkNewsDecay.cxx index fdf99bad..1a15da46 100644 --- a/projects/interactions/private/DarkNewsDecay.cxx +++ b/projects/interactions/private/DarkNewsDecay.cxx @@ -22,7 +22,7 @@ namespace interactions { DarkNewsDecay::DarkNewsDecay() {} -pybind11::object DarkNewsDecay::get_self() { +pybind11::object DarkNewsDecay::get_representation() { return pybind11::cast(Py_None); } diff --git a/projects/interactions/private/pybindings/DarkNewsDecay.h b/projects/interactions/private/pybindings/DarkNewsDecay.h index 87981a83..f93428a3 100644 --- a/projects/interactions/private/pybindings/DarkNewsDecay.h +++ b/projects/interactions/private/pybindings/DarkNewsDecay.h @@ -7,7 +7,6 @@ #include #include -#include "../../public/LeptonInjector/interactions/CrossSection.h" #include "../../public/LeptonInjector/interactions/DarkNewsDecay.h" #include "../../../dataclasses/public/LeptonInjector/dataclasses/Particle.h" #include "../../../dataclasses/public/LeptonInjector/dataclasses/InteractionRecord.h" @@ -68,7 +67,7 @@ namespace LI { namespace interactions { -// Trampoline class for CrossSection +// Trampoline class for Decay class pyDecay : public Decay { public: using Decay::Decay; @@ -195,7 +194,7 @@ class pyDecay : public Decay { ) } - pybind11::object get_self() { + pybind11::object get_representation() { return self; } }; @@ -316,8 +315,37 @@ class pyDarkNewsDecay : public DarkNewsDecay { ) } - pybind11::object get_self() override { - return self; + pybind11::object get_representation() override { + const DarkNewsDecay * ref; + if(self) { + ref = self.cast(); + } else { + ref = this; + } + auto *tinfo = pybind11::detail::get_type_info(typeid(DarkNewsDecay)); + pybind11::function override_func = + tinfo ? pybind11::detail::get_type_override(static_cast(ref), tinfo, "get_representation") : pybind11::function(); + if (override_func) { + pybind11::object o = override_func(); + if(not pybind11::isinstance(o)) { + throw std::runtime_error("get_representation must return a dict"); + } + return o; + } + + pybind11::object _self; + if(this->self) { + self = pybind11::reinterpret_borrow(this->self); + } else { + auto *tinfo = pybind11::detail::get_type_info(typeid(DarkNewsDecay)); + pybind11::handle self_handle = get_object_handle(static_cast(this), tinfo); + _self = pybind11::reinterpret_borrow(self_handle); + } + pybind11::dict d; + if (pybind11::hasattr(self, "__dict__")) { + d = _self.attr("__dict__"); + } + return d; } }; } // end interactions namespace @@ -344,22 +372,10 @@ void register_DarkNewsDecay(pybind11::module_ & m) { .def("FinalStateProbability",&DarkNewsDecay::FinalStateProbability) .def("SampleFinalState",&DarkNewsDecay::SampleFinalState) .def("SampleRecordFromDarkNews",&DarkNewsDecay::SampleRecordFromDarkNews) - .def("get_self", &pyDarkNewsDecay::get_self) + .def("get_representation", &pyDarkNewsDecay::get_representation) .def(pybind11::pickle( - [](const LI::interactions::pyDarkNewsDecay & cpp_obj) { - pybind11::object self; - if(cpp_obj.self) { - self = pybind11::reinterpret_borrow(cpp_obj.self); - } else { - auto *tinfo = pybind11::detail::get_type_info(typeid(DarkNewsDecay)); - pybind11::handle self_handle = get_object_handle(static_cast(&cpp_obj), tinfo); - self = pybind11::reinterpret_borrow(self_handle); - } - pybind11::dict d; - if (pybind11::hasattr(self, "__dict__")) { - d = self.attr("__dict__"); - } - return pybind11::make_tuple(d); + [](LI::interactions::pyDarkNewsDecay & cpp_obj) { + return pybind11::make_tuple(cpp_obj.get_representation()); }, [](const pybind11::tuple &t) { if (t.size() != 1) { @@ -389,22 +405,10 @@ void register_DarkNewsDecay(pybind11::module_ & m) { .def("FinalStateProbability",&DarkNewsDecay::FinalStateProbability) .def("SampleFinalState",&DarkNewsDecay::SampleFinalState) .def("SampleRecordFromDarkNews",&DarkNewsDecay::SampleRecordFromDarkNews) - .def("get_self", &DarkNewsDecay::get_self) + .def("get_representation", &DarkNewsDecay::get_representation) .def(pybind11::pickle( - [](const LI::interactions::DarkNewsDecay & cpp_obj) { - pybind11::object self; - if(dynamic_cast(&cpp_obj) != nullptr and dynamic_cast(&cpp_obj)->self) { - self = pybind11::reinterpret_borrow(dynamic_cast(&cpp_obj)->self); - } else { - auto *tinfo = pybind11::detail::get_type_info(typeid(LI::interactions::DarkNewsDecay)); - pybind11::handle self_handle = get_object_handle(static_cast(&cpp_obj), tinfo); - self = pybind11::reinterpret_borrow(self_handle); - } - pybind11::dict d; - if (pybind11::hasattr(self, "__dict__")) { - d = self.attr("__dict__"); - } - return pybind11::make_tuple(d); + [](LI::interactions::DarkNewsDecay & cpp_obj) { + return pybind11::make_tuple(cpp_obj.get_representation()); }, [](const pybind11::tuple &t) { if (t.size() != 1) { diff --git a/projects/interactions/public/LeptonInjector/interactions/DarkNewsDecay.h b/projects/interactions/public/LeptonInjector/interactions/DarkNewsDecay.h index 47417ff6..7bfec6c5 100644 --- a/projects/interactions/public/LeptonInjector/interactions/DarkNewsDecay.h +++ b/projects/interactions/public/LeptonInjector/interactions/DarkNewsDecay.h @@ -36,7 +36,7 @@ friend cereal::access; DarkNewsDecay(); - virtual pybind11::object get_self(); + virtual pybind11::object get_representation(); virtual bool equal(Decay const & other) const override; diff --git a/python/LIDarkNews.py b/python/LIDarkNews.py index 18d98f6f..fad7ba83 100644 --- a/python/LIDarkNews.py +++ b/python/LIDarkNews.py @@ -614,6 +614,18 @@ def __init__(self, dec_case, table_dir=None): if table_dir_exists: self.SetIntegratorAndNorm() + # serialization method + def get_representation(self): + return {"decay_integrator":self.decay_integrator, + "decay_norm":self.decay_norm, + "dec_case":self.dec_case, + "PS_samples":self.PS_samples, + "PS_weights":self.PS_weights, + "PS_weights_CDF":self.PS_weights_CDF, + "total_width":self.total_width, + "table_dir":self.table_dir + } + def SetIntegratorAndNorm(self): # Try to find the decay integrator int_file = os.path.join(self.table_dir, "decay_integrator.pkl")