From 31f769d60e6e23fce6209de6f236dc32be7dc6d3 Mon Sep 17 00:00:00 2001 From: Austin Schneider Date: Thu, 12 Sep 2024 12:51:09 -0600 Subject: [PATCH] Remove requirement for target_types --- .../primary/vertex/ColumnDepthPositionDistribution.cxx | 8 ++++---- .../primary/vertex/PointSourcePositionDistribution.cxx | 9 ++++----- .../distributions/private/pybindings/distributions.cxx | 4 ++-- .../primary/vertex/ColumnDepthPositionDistribution.h | 7 ++----- .../primary/vertex/PointSourcePositionDistribution.h | 7 ++----- 5 files changed, 14 insertions(+), 21 deletions(-) diff --git a/projects/distributions/private/primary/vertex/ColumnDepthPositionDistribution.cxx b/projects/distributions/private/primary/vertex/ColumnDepthPositionDistribution.cxx index a4eed9b3..fe901fd8 100644 --- a/projects/distributions/private/primary/vertex/ColumnDepthPositionDistribution.cxx +++ b/projects/distributions/private/primary/vertex/ColumnDepthPositionDistribution.cxx @@ -173,7 +173,7 @@ double ColumnDepthPositionDistribution::GenerationProbability(std::shared_ptr depth_function, std::set target_types) : radius(radius), endcap_length(endcap_length), depth_function(depth_function), target_types(target_types) {} +ColumnDepthPositionDistribution::ColumnDepthPositionDistribution(double radius, double endcap_length, std::shared_ptr depth_function) : radius(radius), endcap_length(endcap_length), depth_function(depth_function) {} std::string ColumnDepthPositionDistribution::Name() const { return "ColumnDepthPositionDistribution"; @@ -215,7 +215,7 @@ bool ColumnDepthPositionDistribution::equal(WeightableDistribution const & other (depth_function and x->depth_function and *depth_function == *x->depth_function) or (!depth_function and !x->depth_function) ) - and target_types == x->target_types); + ); } bool ColumnDepthPositionDistribution::less(WeightableDistribution const & other) const { @@ -226,9 +226,9 @@ bool ColumnDepthPositionDistribution::less(WeightableDistribution const & other) and *depth_function < *x->depth_function); // Less than bool f = false; return - std::tie(radius, endcap_length, f, target_types) + std::tie(radius, endcap_length, f) < - std::tie(radius, x->endcap_length, depth_less, x->target_types); + std::tie(radius, x->endcap_length, depth_less); } } // namespace distributions diff --git a/projects/distributions/private/primary/vertex/PointSourcePositionDistribution.cxx b/projects/distributions/private/primary/vertex/PointSourcePositionDistribution.cxx index 977ae58d..db194e97 100644 --- a/projects/distributions/private/primary/vertex/PointSourcePositionDistribution.cxx +++ b/projects/distributions/private/primary/vertex/PointSourcePositionDistribution.cxx @@ -144,7 +144,7 @@ double PointSourcePositionDistribution::GenerationProbability(std::shared_ptr target_types) : origin(origin), max_distance(max_distance), target_types(target_types) {} +PointSourcePositionDistribution::PointSourcePositionDistribution(siren::math::Vector3D origin, double max_distance) : origin(origin), max_distance(max_distance) {} std::string PointSourcePositionDistribution::Name() const { return "PointSourcePositionDistribution"; @@ -177,16 +177,15 @@ bool PointSourcePositionDistribution::equal(WeightableDistribution const & other return false; else return (origin == x->origin - and max_distance == x->max_distance - and target_types == x->target_types); + and max_distance == x->max_distance); } bool PointSourcePositionDistribution::less(WeightableDistribution const & other) const { const PointSourcePositionDistribution* x = dynamic_cast(&other); return - std::tie(origin, max_distance, target_types) + std::tie(origin, max_distance) < - std::tie(origin, x->max_distance, x->target_types); + std::tie(origin, x->max_distance); } } // namespace distributions diff --git a/projects/distributions/private/pybindings/distributions.cxx b/projects/distributions/private/pybindings/distributions.cxx index e698f5a1..f70348f0 100644 --- a/projects/distributions/private/pybindings/distributions.cxx +++ b/projects/distributions/private/pybindings/distributions.cxx @@ -187,7 +187,7 @@ PYBIND11_MODULE(distributions,m) { .def("Name",&CylinderVolumePositionDistribution::Name); class_, VertexPositionDistribution>(m, "ColumnDepthPositionDistribution") - .def(init, std::set>()) + .def(init>()) .def("GenerationProbability",&ColumnDepthPositionDistribution::GenerationProbability) .def("InjectionBounds",&ColumnDepthPositionDistribution::InjectionBounds) .def("Name",&ColumnDepthPositionDistribution::Name) @@ -202,7 +202,7 @@ PYBIND11_MODULE(distributions,m) { class_, VertexPositionDistribution>(m, "PointSourcePositionDistribution") .def(init<>()) - .def(init>()) + .def(init()) .def("GenerationProbability",&PointSourcePositionDistribution::GenerationProbability) .def("InjectionBounds",&PointSourcePositionDistribution::InjectionBounds) .def("Name",&PointSourcePositionDistribution::Name); diff --git a/projects/distributions/public/SIREN/distributions/primary/vertex/ColumnDepthPositionDistribution.h b/projects/distributions/public/SIREN/distributions/primary/vertex/ColumnDepthPositionDistribution.h index d0669daf..ab067713 100644 --- a/projects/distributions/public/SIREN/distributions/primary/vertex/ColumnDepthPositionDistribution.h +++ b/projects/distributions/public/SIREN/distributions/primary/vertex/ColumnDepthPositionDistribution.h @@ -38,7 +38,6 @@ friend cereal::access; double radius; double endcap_length; std::shared_ptr depth_function; - std::set target_types; siren::math::Vector3D SampleFromDisk(std::shared_ptr rand, siren::math::Vector3D const & dir) const; @@ -46,7 +45,7 @@ friend cereal::access; public: std::tuple GetSamplePosition(std::shared_ptr rand, std::shared_ptr detector_model, std::shared_ptr interactions, siren::dataclasses::PrimaryDistributionRecord & record); virtual double GenerationProbability(std::shared_ptr detector_model, std::shared_ptr interactions, siren::dataclasses::InteractionRecord const & record) const override; - ColumnDepthPositionDistribution(double radius, double endcap_length, std::shared_ptr depth_function, std::set target_types); + ColumnDepthPositionDistribution(double radius, double endcap_length, std::shared_ptr depth_function); std::string Name() const override; virtual std::shared_ptr clone() const override; virtual std::tuple InjectionBounds(std::shared_ptr detector_model, std::shared_ptr interactions, siren::dataclasses::InteractionRecord const & interaction) const override; @@ -56,7 +55,6 @@ friend cereal::access; archive(::cereal::make_nvp("Radius", radius)); archive(::cereal::make_nvp("EndcapLength", endcap_length)); archive(::cereal::make_nvp("DepthFunction", depth_function)); - archive(::cereal::make_nvp("TargetTypes", target_types)); archive(cereal::virtual_base_class(this)); } else { throw std::runtime_error("ColumnDepthPositionDistribution only supports version <= 0!"); @@ -72,8 +70,7 @@ friend cereal::access; archive(::cereal::make_nvp("Radius", r)); archive(::cereal::make_nvp("EndcapLength", l)); archive(::cereal::make_nvp("DepthFunction", f)); - archive(::cereal::make_nvp("TargetTypes", t)); - construct(r, l, f, t); + construct(r, l, f); archive(cereal::virtual_base_class(construct.ptr())); } else { throw std::runtime_error("ColumnDepthPositionDistribution only supports version <= 0!"); diff --git a/projects/distributions/public/SIREN/distributions/primary/vertex/PointSourcePositionDistribution.h b/projects/distributions/public/SIREN/distributions/primary/vertex/PointSourcePositionDistribution.h index d14279ac..dd05add3 100644 --- a/projects/distributions/public/SIREN/distributions/primary/vertex/PointSourcePositionDistribution.h +++ b/projects/distributions/public/SIREN/distributions/primary/vertex/PointSourcePositionDistribution.h @@ -34,7 +34,6 @@ friend cereal::access; private: siren::math::Vector3D origin; double max_distance; - std::set target_types; siren::math::Vector3D SampleFromDisk(std::shared_ptr rand, siren::math::Vector3D const & dir) const; @@ -43,7 +42,7 @@ friend cereal::access; virtual double GenerationProbability(std::shared_ptr detector_model, std::shared_ptr interactions, siren::dataclasses::InteractionRecord const & record) const override; PointSourcePositionDistribution(); PointSourcePositionDistribution(const PointSourcePositionDistribution &) = default; - PointSourcePositionDistribution(siren::math::Vector3D origin, double max_distance, std::set target_types); + PointSourcePositionDistribution(siren::math::Vector3D origin, double max_distance); std::string Name() const override; virtual std::tuple InjectionBounds(std::shared_ptr detector_model, std::shared_ptr interactions, siren::dataclasses::InteractionRecord const & interaction) const override; virtual std::shared_ptr clone() const override; @@ -52,7 +51,6 @@ friend cereal::access; if(version == 0) { archive(::cereal::make_nvp("Origin", origin)); archive(::cereal::make_nvp("MaxDistance", max_distance)); - archive(::cereal::make_nvp("TargetTypes", target_types)); archive(cereal::virtual_base_class(this)); } else { throw std::runtime_error("PointSourcePositionDistribution only supports version <= 0!"); @@ -66,8 +64,7 @@ friend cereal::access; std::set t; archive(::cereal::make_nvp("Origin", r)); archive(::cereal::make_nvp("MaxDistance", l)); - archive(::cereal::make_nvp("TargetTypes", t)); - construct(r, l, t); + construct(r, l); archive(cereal::virtual_base_class(construct.ptr())); } else { throw std::runtime_error("PointSourcePositionDistribution only supports version <= 0!");