diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index 640b6f3e7..ccefbca8d 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -120,6 +120,8 @@ SET( LCIO_UTIL_SRCS ./src/UTIL/PIDHandler.cc ./src/UTIL/ILDConf.cc ./src/UTIL/ProcessFlag.cc + ./src/UTIL/LCCollectionTools.cc + ./src/UTIL/ReconstructedParticleTools.cc ) SET( LCIO_MT_SRCS diff --git a/src/cpp/include/UTIL/LCCollectionTools.h b/src/cpp/include/UTIL/LCCollectionTools.h new file mode 100644 index 000000000..ab54a7984 --- /dev/null +++ b/src/cpp/include/UTIL/LCCollectionTools.h @@ -0,0 +1,15 @@ +#ifndef UTIL_LCCollectionTools_H +#define UTIL_LCCollectionTools_H 1 + +#include "EVENT/LCObject.h" +#include "EVENT/LCCollection.h" + +namespace UTIL{ + /** Extract object index inside a given LCCollection. If object is not found, return -1. + * @author Bohdan Dudar + * @version August 2022 + */ + int getElementIndex(const EVENT::LCObject* item, EVENT::LCCollection* collection); +} + +#endif diff --git a/src/cpp/include/UTIL/LCRelationNavigator.h b/src/cpp/include/UTIL/LCRelationNavigator.h index 022c2bfa2..7c2020821 100644 --- a/src/cpp/include/UTIL/LCRelationNavigator.h +++ b/src/cpp/include/UTIL/LCRelationNavigator.h @@ -38,55 +38,75 @@ namespace UTIL { LCRelationNavigator( const EVENT::LCCollection* col ) ; /// Destructor. - virtual ~LCRelationNavigator() { /* nop */; } + ~LCRelationNavigator() { /* nop */; } /**The type of the 'from' objects in this relation. */ - virtual const std::string & getFromType() const ; + const std::string & getFromType() const ; /**The type of the 'to' objects in this relation. */ - virtual const std::string & getToType() const ; + const std::string & getToType() const ; /** All objects that the given from-object is related to. * LCObjects are of type getToType(). */ - virtual const EVENT::LCObjectVec & getRelatedToObjects(EVENT::LCObject * from) const ; + const EVENT::LCObjectVec & getRelatedToObjects(EVENT::LCObject * from) const ; /** All from-objects related to the given object ( the inverse relationship). * LCObjects are of type getFromType(). */ - virtual const EVENT::LCObjectVec & getRelatedFromObjects(EVENT::LCObject * to) const ; + const EVENT::LCObjectVec & getRelatedFromObjects(EVENT::LCObject * to) const ; /** The weights of the relations returned by a call to getRelatedToObjects(from). * @see getRelatedToObjects */ - virtual const EVENT::FloatVec & getRelatedToWeights(EVENT::LCObject * from) const ; + const EVENT::FloatVec & getRelatedToWeights(EVENT::LCObject * from) const ; /** The weights of the relations returned by a call to getRelatedFromObjects(to). * @see getRelatedFromObjects */ - virtual const EVENT::FloatVec & getRelatedFromWeights(EVENT::LCObject * to) const ; + const EVENT::FloatVec & getRelatedFromWeights(EVENT::LCObject * to) const ; + + /** Object with a highest weight that the given from-object is related to. + * LCObject is of type getToType(). + */ + const EVENT::LCObject* getRelatedToMaxWeightObject(EVENT::LCObject* from, const std::string& weightType) const ; + + /** From-object related to the given object with a highest weight (the inverse relationship). + * LCObject is of type getFromType(). + */ + const EVENT::LCObject* getRelatedFromMaxWeightObject(EVENT::LCObject* to, const std::string& weightType) const ; + + /** The highest weight of the relations returned by a call to getRelatedToObjects(from). + * @see getRelatedToObjects + */ + float getRelatedToMaxWeight(EVENT::LCObject* from, const std::string& weightType) const ; + + /** The highest weight of the relations returned by a call to getRelatedFromObjects(to). + * @see getRelatedFromObjects + */ + float getRelatedFromMaxWeight(EVENT::LCObject* to, const std::string& weightType) const ; /** Adds a relation. If there is already an existing relation between the two given objects * the weight (or default weight 1.0) is added to that relationship's weight. */ - virtual void addRelation(EVENT::LCObject * from, EVENT::LCObject * to, float weight = 1.0) ; + void addRelation(EVENT::LCObject * from, EVENT::LCObject * to, float weight = 1.0) ; /** Remove a given relation. */ - virtual void removeRelation(EVENT::LCObject * from, EVENT::LCObject * to) ; + void removeRelation(EVENT::LCObject * from, EVENT::LCObject * to) ; /** Remove a given relation. To reduce the weight of the relationship, call * addRelation( from, to, weight ) with weight<0. */ - virtual EVENT::LCCollection * createLCCollection() ; + EVENT::LCCollection * createLCCollection() ; protected: LCRelationNavigator() ; - virtual void initialize( const EVENT::LCCollection* col ) ; + void initialize( const EVENT::LCCollection* col ) ; void removeRelation(EVENT::LCObject * from, EVENT::LCObject * to, RelMap& map ) ; void addRelation(EVENT::LCObject * from, EVENT::LCObject * to, float weight, RelMap& map) ; diff --git a/src/cpp/include/UTIL/ReconstructedParticleTools.h b/src/cpp/include/UTIL/ReconstructedParticleTools.h new file mode 100644 index 000000000..b703decb2 --- /dev/null +++ b/src/cpp/include/UTIL/ReconstructedParticleTools.h @@ -0,0 +1,15 @@ +#ifndef UTIL_ReconstructedParticleTools_H +#define UTIL_ReconstructedParticleTools_H 1 + +#include "EVENT/ReconstructedParticle.h" +#include "EVENT/Track.h" + +namespace UTIL{ + /** Extract the leading track (sorted by momentum) in case of multiple tracks attached to a single ReconstructedParticle. + * @author Bohdan Dudar + * @version August 2022 + */ + EVENT::Track* getLeadingTrack(const EVENT::ReconstructedParticle* particle); +} + +#endif diff --git a/src/cpp/include/UTIL/TrackTools.h b/src/cpp/include/UTIL/TrackTools.h new file mode 100644 index 000000000..952675735 --- /dev/null +++ b/src/cpp/include/UTIL/TrackTools.h @@ -0,0 +1,25 @@ +#ifndef UTIL_TrackTools_H +#define UTIL_TrackTools_H 1 + +#include "EVENT/Track.h" +#include +#include + +namespace UTIL{ + /** Extract track momentum from its track parameters and magnetic field + * @author Bohdan Dudar + * @version August 2022 + */ + template + std::array getTrackMomentum(const TrackLikeT* track, double bz){ + double omega = track->getOmega(); + if (omega == 0.) return {0., 0., 0.}; + double phi = track->getPhi(); + double tanL = track->getTanLambda(); + double c_light = 299.792458; // mm/ns + double pt = (1e-6 * c_light * bz) / std::abs(omega); + return {pt*std::cos(phi), pt*std::sin(phi), pt*tanL}; + } +} + +#endif diff --git a/src/cpp/src/TESTS/test_tracks.cc b/src/cpp/src/TESTS/test_tracks.cc index c81d71a73..60c5e1d18 100644 --- a/src/cpp/src/TESTS/test_tracks.cc +++ b/src/cpp/src/TESTS/test_tracks.cc @@ -15,6 +15,7 @@ #include "UTIL/Operators.h" #include "UTIL/LCIterator.h" +#include "UTIL/TrackTools.h" #include #include @@ -218,6 +219,19 @@ int main(int /*argc*/, char** /*argv*/ ){ ss << " ref[" << k << "] " ; MYTEST( ref[k] , float(k+1) , ss.str() ) ; } + + //Test of the getTrackMomentum + std::array trkRecoMomentum = UTIL::getTrackMomentum(trk, 3.5); + std::array trkTrueMomentum = {0., 0., 0.}; + if (trk->getOmega() != 0.){ + double trkPx = (1e-6 * 299.792458 * 3.5) / std::abs((i+j) * .1)*std::cos((i+j) * .3); + double trkPy = (1e-6 * 299.792458 * 3.5) / std::abs((i+j) * .1)*std::sin((i+j) * .3); + double trkPz = (1e-6 * 299.792458 * 3.5) / std::abs((i+j) * .1)*(i+j) * .2; + trkTrueMomentum = {trkPx, trkPy, trkPz}; + } + std::stringstream failMsg; + failMsg << " getTrackMomentum : [" << trkRecoMomentum << "] != [" << trkTrueMomentum << "]"; + MYTEST(approxEqArray(trkRecoMomentum, trkTrueMomentum), failMsg.str()); ++j ; } @@ -438,9 +452,8 @@ int main(int /*argc*/, char** /*argv*/ ){ } catch( Exception &e ){ MYTEST.FAILED( e.what() ); } - return 0; + } //============================================================================= - diff --git a/src/cpp/src/TESTS/test_trackstate.cc b/src/cpp/src/TESTS/test_trackstate.cc index 47a41fc72..314294755 100644 --- a/src/cpp/src/TESTS/test_trackstate.cc +++ b/src/cpp/src/TESTS/test_trackstate.cc @@ -7,6 +7,7 @@ #include "EVENT/TrackState.h" #include "IMPL/TrackStateImpl.h" +#include "UTIL/TrackTools.h" //#include "UTIL/Operators.h" @@ -41,7 +42,8 @@ int main(int /*argc*/, char** /*argv*/ ){ MYTEST( a.getPhi(), float( .0 ), "getPhi" ) ; MYTEST( a.getOmega(), float( .0 ), "getOmega" ) ; - + // Omega == 0, will yield zero momentum as there is no reasonable value in that case + MYTEST( getTrackMomentum(&a, 3.5), std::array{0, 0, 0}, "getTrackMomentum" ); MYTEST.LOG( "test constructor with arguments" ); @@ -65,6 +67,17 @@ int main(int /*argc*/, char** /*argv*/ ){ MYTEST( b.getPhi(), float( .2 ), "getPhi" ) ; MYTEST( b.getOmega(), float( .3 ), "getOmega" ) ; + const std::array trueMomentum = { + 1e-6 * 299.7922458 * 3.5 / .3 * std::cos(.2), // pt * cos(phi), with pt = c * B / omega + 1e-6 * 299.7922458 * 3.5 / .3 * std::sin(.2), // pt * sin(phi) + 1e-6 * 299.7922458 * 3.5 / .3 * .5 // pt * tanL + }; + + std::stringstream failMsg; + const auto tsMomentum = getTrackMomentum(&b, 3.5); + failMsg << " getTrackMomentum : [" << tsMomentum << "] != << [" << trueMomentum << "]"; + MYTEST(approxEqArray(tsMomentum, trueMomentum), failMsg.str()); + MYTEST.LOG( "test default copy constructor" ); diff --git a/src/cpp/src/TESTS/tutil.h b/src/cpp/src/TESTS/tutil.h index 513792be7..bdc2bb093 100644 --- a/src/cpp/src/TESTS/tutil.h +++ b/src/cpp/src/TESTS/tutil.h @@ -3,6 +3,22 @@ #include #include #include +#include +#include +#include + +template +std::ostream& operator<<(std::ostream& os, const std::array& arr) { + if constexpr (N == 0) { + return os << "[]"; + } + + os << "[" << arr[0]; + for (size_t i = 1; i < N; ++i) { + os << ", " << arr[i]; + } + return os << "]"; +} class TEST{ @@ -34,10 +50,10 @@ class TEST{ return ; } -// void operator()(bool cond, const std::string msg) { -// if ( ! cond ) FAILED( msg ) ; -// return ; -// } + void operator()(bool cond, const std::string msg) { + if ( ! cond ) FAILED( msg ) ; + return ; + } void FAILED( const std::string& msg ){ @@ -61,3 +77,21 @@ class TEST{ std::string _testname; std::ostream& _out; }; + +bool approxEqual(double lhs, double rhs) { + // Following a similar, but slightly simplified approach as Catch2::Approx here + constexpr double epsilon = std::numeric_limits::epsilon() * 100; + const double margin = std::fabs(lhs) * epsilon; + return (lhs + margin >= rhs) && (rhs + margin >= lhs); +} + +template +bool approxEqArray(const std::array& arr1, const std::array& arr2, ApproxComp&& comp=approxEqual) { + for (size_t i = 0; i < N; ++i) { + if (!comp(arr1[i], arr2[i])) { + return false; + } + } + return true; +} + diff --git a/src/cpp/src/UTIL/LCCollectionTools.cc b/src/cpp/src/UTIL/LCCollectionTools.cc new file mode 100644 index 000000000..9294e7859 --- /dev/null +++ b/src/cpp/src/UTIL/LCCollectionTools.cc @@ -0,0 +1,10 @@ +#include "UTIL/LCCollectionTools.h" + +namespace UTIL{ + int getElementIndex(const EVENT::LCObject* item, EVENT::LCCollection* collection){ + for(int i=0; i < collection->getNumberOfElements(); ++i){ + if ( item == collection->getElementAt(i) ) return i; + } + return -1; + } +} diff --git a/src/cpp/src/UTIL/LCRelationNavigator.cc b/src/cpp/src/UTIL/LCRelationNavigator.cc index 2af07b5a6..67e57817f 100644 --- a/src/cpp/src/UTIL/LCRelationNavigator.cc +++ b/src/cpp/src/UTIL/LCRelationNavigator.cc @@ -68,6 +68,56 @@ namespace UTIL{ return _rMap[ to ].second ; } + auto getMaxWeightIt(const EVENT::FloatVec& weights, const std::string& weightType) { + if (weightType == "track") { + std::max_element(weights.begin(), weights.end(), [](float a, float b) { + return (int(a) % 10000) / 1000. < (int(b) % 10000) / 1000.; + }); + } else if (weightType == "cluster") { + std::max_element(weights.begin(), weights.end(), [](float a, float b) { + return (int(a) / 10000) / 1000. < (int(b) / 10000) /1000.; + }); + } + return std::max_element(weights.begin(), weights.end()); + } + + const EVENT::LCObject* LCRelationNavigator::getRelatedToMaxWeightObject(EVENT::LCObject* from, const std::string& weightType) const { + const auto& objects = getRelatedToObjects(from); + if ( objects.empty() ) return nullptr; + + const auto& weights = getRelatedToWeights(from); + const auto maxWeightIt = getMaxWeightIt(weights, weightType); + int i = std::distance(weights.begin(), maxWeightIt); + return objects[i]; + } + + const EVENT::LCObject* LCRelationNavigator::getRelatedFromMaxWeightObject(EVENT::LCObject* to, const std::string& weightType) const { + const auto& objects = getRelatedToObjects(to); + if ( objects.empty() ) return nullptr; + + const auto& weights = getRelatedToWeights(to); + const auto maxWeightIt = getMaxWeightIt(weights, weightType); + + int i = std::distance(weights.begin(), maxWeightIt); + return objects[i]; + } + + float LCRelationNavigator::getRelatedToMaxWeight(EVENT::LCObject* from, const std::string& weightType) const { + const auto& objects = getRelatedToObjects(from); + if ( objects.empty() ) return 0.; + + const auto& weights = getRelatedToWeights(from); + return *getMaxWeightIt(weights, weightType); + } + + float LCRelationNavigator::getRelatedFromMaxWeight(EVENT::LCObject* to, const std::string& weightType) const { + const auto& objects = getRelatedToObjects(to); + if ( objects.empty() ) return 0.; + + const auto& weights = getRelatedToWeights(to); + return *getMaxWeightIt(weights, weightType); + } + void LCRelationNavigator::addRelation(EVENT::LCObject * from, EVENT::LCObject * to, float weight) { diff --git a/src/cpp/src/UTIL/ReconstructedParticleTools.cc b/src/cpp/src/UTIL/ReconstructedParticleTools.cc new file mode 100644 index 000000000..019bfc688 --- /dev/null +++ b/src/cpp/src/UTIL/ReconstructedParticleTools.cc @@ -0,0 +1,16 @@ +#include "UTIL/ReconstructedParticleTools.h" +#include +#include "UTIL/TrackTools.h" +#include + +namespace UTIL{ + EVENT::Track* getLeadingTrack(const EVENT::ReconstructedParticle* particle){ + const EVENT::TrackVec& tracks = particle->getTracks(); + if ( tracks.empty() ) return nullptr; + // compare momentum w/o using its common scale factor (which includes magnetic field) + auto sortByMomentum = [](const EVENT::Track* a, const EVENT::Track* b) { return std::hypot(1., a->getTanLambda())/std::abs(a->getOmega()) < std::hypot(1., b->getTanLambda())/std::abs(b->getOmega()); }; + auto* leadingTrack = *(std::max_element(tracks.begin(), tracks.end(), sortByMomentum)); + + return leadingTrack; + } +}