Skip to content

Commit

Permalink
Remove requirement for target_types
Browse files Browse the repository at this point in the history
  • Loading branch information
austinschneider committed Sep 12, 2024
1 parent 04a1e70 commit 31f769d
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ double ColumnDepthPositionDistribution::GenerationProbability(std::shared_ptr<si
return prob_density;
}

ColumnDepthPositionDistribution::ColumnDepthPositionDistribution(double radius, double endcap_length, std::shared_ptr<DepthFunction> depth_function, std::set<siren::dataclasses::ParticleType> target_types) : radius(radius), endcap_length(endcap_length), depth_function(depth_function), target_types(target_types) {}
ColumnDepthPositionDistribution::ColumnDepthPositionDistribution(double radius, double endcap_length, std::shared_ptr<DepthFunction> depth_function) : radius(radius), endcap_length(endcap_length), depth_function(depth_function) {}

std::string ColumnDepthPositionDistribution::Name() const {
return "ColumnDepthPositionDistribution";
Expand Down Expand Up @@ -215,7 +215,7 @@ bool ColumnDepthPositionDistribution::equal(WeightableDistribution const & other
(depth_function and x->depth_function and *depth_function == *x->depth_function)
or (!depth_function and !x->depth_function)
)
and target_types == x->target_types);
);
}

bool ColumnDepthPositionDistribution::less(WeightableDistribution const & other) const {
Expand All @@ -226,9 +226,9 @@ bool ColumnDepthPositionDistribution::less(WeightableDistribution const & other)
and *depth_function < *x->depth_function); // Less than
bool f = false;
return
std::tie(radius, endcap_length, f, target_types)
std::tie(radius, endcap_length, f)
<
std::tie(radius, x->endcap_length, depth_less, x->target_types);
std::tie(radius, x->endcap_length, depth_less);
}

} // namespace distributions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ double PointSourcePositionDistribution::GenerationProbability(std::shared_ptr<si

PointSourcePositionDistribution::PointSourcePositionDistribution() {}

PointSourcePositionDistribution::PointSourcePositionDistribution(siren::math::Vector3D origin, double max_distance, std::set<siren::dataclasses::ParticleType> target_types) : origin(origin), max_distance(max_distance), target_types(target_types) {}
PointSourcePositionDistribution::PointSourcePositionDistribution(siren::math::Vector3D origin, double max_distance) : origin(origin), max_distance(max_distance) {}

std::string PointSourcePositionDistribution::Name() const {
return "PointSourcePositionDistribution";
Expand Down Expand Up @@ -177,16 +177,15 @@ bool PointSourcePositionDistribution::equal(WeightableDistribution const & other
return false;
else
return (origin == x->origin
and max_distance == x->max_distance
and target_types == x->target_types);
and max_distance == x->max_distance);
}

bool PointSourcePositionDistribution::less(WeightableDistribution const & other) const {
const PointSourcePositionDistribution* x = dynamic_cast<const PointSourcePositionDistribution*>(&other);
return
std::tie(origin, max_distance, target_types)
std::tie(origin, max_distance)
<
std::tie(origin, x->max_distance, x->target_types);
std::tie(origin, x->max_distance);
}

} // namespace distributions
Expand Down
4 changes: 2 additions & 2 deletions projects/distributions/private/pybindings/distributions.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ PYBIND11_MODULE(distributions,m) {
.def("Name",&CylinderVolumePositionDistribution::Name);

class_<ColumnDepthPositionDistribution, std::shared_ptr<ColumnDepthPositionDistribution>, VertexPositionDistribution>(m, "ColumnDepthPositionDistribution")
.def(init<double, double, std::shared_ptr<DepthFunction>, std::set<siren::dataclasses::ParticleType>>())
.def(init<double, double, std::shared_ptr<DepthFunction>>())
.def("GenerationProbability",&ColumnDepthPositionDistribution::GenerationProbability)
.def("InjectionBounds",&ColumnDepthPositionDistribution::InjectionBounds)
.def("Name",&ColumnDepthPositionDistribution::Name)
Expand All @@ -202,7 +202,7 @@ PYBIND11_MODULE(distributions,m) {

class_<PointSourcePositionDistribution, std::shared_ptr<PointSourcePositionDistribution>, VertexPositionDistribution>(m, "PointSourcePositionDistribution")
.def(init<>())
.def(init<siren::math::Vector3D, double, std::set<siren::dataclasses::ParticleType>>())
.def(init<siren::math::Vector3D, double>())
.def("GenerationProbability",&PointSourcePositionDistribution::GenerationProbability)
.def("InjectionBounds",&PointSourcePositionDistribution::InjectionBounds)
.def("Name",&PointSourcePositionDistribution::Name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@ friend cereal::access;
double radius;
double endcap_length;
std::shared_ptr<DepthFunction> depth_function;
std::set<siren::dataclasses::ParticleType> target_types;

siren::math::Vector3D SampleFromDisk(std::shared_ptr<siren::utilities::SIREN_random> rand, siren::math::Vector3D const & dir) const;

std::tuple<siren::math::Vector3D, siren::math::Vector3D> SamplePosition(std::shared_ptr<siren::utilities::SIREN_random> rand, std::shared_ptr<siren::detector::DetectorModel const> detector_model, std::shared_ptr<siren::interactions::InteractionCollection const> interactions, siren::dataclasses::PrimaryDistributionRecord & record) const override;
public:
std::tuple<siren::math::Vector3D, siren::math::Vector3D> GetSamplePosition(std::shared_ptr<siren::utilities::SIREN_random> rand, std::shared_ptr<siren::detector::DetectorModel const> detector_model, std::shared_ptr<siren::interactions::InteractionCollection const> interactions, siren::dataclasses::PrimaryDistributionRecord & record);
virtual double GenerationProbability(std::shared_ptr<siren::detector::DetectorModel const> detector_model, std::shared_ptr<siren::interactions::InteractionCollection const> interactions, siren::dataclasses::InteractionRecord const & record) const override;
ColumnDepthPositionDistribution(double radius, double endcap_length, std::shared_ptr<DepthFunction> depth_function, std::set<siren::dataclasses::ParticleType> target_types);
ColumnDepthPositionDistribution(double radius, double endcap_length, std::shared_ptr<DepthFunction> depth_function);
std::string Name() const override;
virtual std::shared_ptr<PrimaryInjectionDistribution> clone() const override;
virtual std::tuple<siren::math::Vector3D, siren::math::Vector3D> InjectionBounds(std::shared_ptr<siren::detector::DetectorModel const> detector_model, std::shared_ptr<siren::interactions::InteractionCollection const> interactions, siren::dataclasses::InteractionRecord const & interaction) const override;
Expand All @@ -56,7 +55,6 @@ friend cereal::access;
archive(::cereal::make_nvp("Radius", radius));
archive(::cereal::make_nvp("EndcapLength", endcap_length));
archive(::cereal::make_nvp("DepthFunction", depth_function));
archive(::cereal::make_nvp("TargetTypes", target_types));
archive(cereal::virtual_base_class<VertexPositionDistribution>(this));
} else {
throw std::runtime_error("ColumnDepthPositionDistribution only supports version <= 0!");
Expand All @@ -72,8 +70,7 @@ friend cereal::access;
archive(::cereal::make_nvp("Radius", r));
archive(::cereal::make_nvp("EndcapLength", l));
archive(::cereal::make_nvp("DepthFunction", f));
archive(::cereal::make_nvp("TargetTypes", t));
construct(r, l, f, t);
construct(r, l, f);
archive(cereal::virtual_base_class<VertexPositionDistribution>(construct.ptr()));
} else {
throw std::runtime_error("ColumnDepthPositionDistribution only supports version <= 0!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ friend cereal::access;
private:
siren::math::Vector3D origin;
double max_distance;
std::set<siren::dataclasses::ParticleType> target_types;

siren::math::Vector3D SampleFromDisk(std::shared_ptr<siren::utilities::SIREN_random> rand, siren::math::Vector3D const & dir) const;

Expand All @@ -43,7 +42,7 @@ friend cereal::access;
virtual double GenerationProbability(std::shared_ptr<siren::detector::DetectorModel const> detector_model, std::shared_ptr<siren::interactions::InteractionCollection const> interactions, siren::dataclasses::InteractionRecord const & record) const override;
PointSourcePositionDistribution();
PointSourcePositionDistribution(const PointSourcePositionDistribution &) = default;
PointSourcePositionDistribution(siren::math::Vector3D origin, double max_distance, std::set<siren::dataclasses::ParticleType> target_types);
PointSourcePositionDistribution(siren::math::Vector3D origin, double max_distance);
std::string Name() const override;
virtual std::tuple<siren::math::Vector3D, siren::math::Vector3D> InjectionBounds(std::shared_ptr<siren::detector::DetectorModel const> detector_model, std::shared_ptr<siren::interactions::InteractionCollection const> interactions, siren::dataclasses::InteractionRecord const & interaction) const override;
virtual std::shared_ptr<PrimaryInjectionDistribution> clone() const override;
Expand All @@ -52,7 +51,6 @@ friend cereal::access;
if(version == 0) {
archive(::cereal::make_nvp("Origin", origin));
archive(::cereal::make_nvp("MaxDistance", max_distance));
archive(::cereal::make_nvp("TargetTypes", target_types));
archive(cereal::virtual_base_class<VertexPositionDistribution>(this));
} else {
throw std::runtime_error("PointSourcePositionDistribution only supports version <= 0!");
Expand All @@ -66,8 +64,7 @@ friend cereal::access;
std::set<siren::dataclasses::ParticleType> t;
archive(::cereal::make_nvp("Origin", r));
archive(::cereal::make_nvp("MaxDistance", l));
archive(::cereal::make_nvp("TargetTypes", t));
construct(r, l, t);
construct(r, l);
archive(cereal::virtual_base_class<VertexPositionDistribution>(construct.ptr()));
} else {
throw std::runtime_error("PointSourcePositionDistribution only supports version <= 0!");
Expand Down

0 comments on commit 31f769d

Please sign in to comment.