Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add various utility functions for LCIO objects #150

Merged
merged 11 commits into from
Oct 19, 2022
2 changes: 2 additions & 0 deletions src/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ SET( LCIO_UTIL_SRCS
./src/UTIL/PIDHandler.cc
./src/UTIL/ILDConf.cc
./src/UTIL/ProcessFlag.cc
./src/UTIL/LCCollectionTools.cc
./src/UTIL/ReconstructedParticleTools.cc
)

SET( LCIO_MT_SRCS
Expand Down
15 changes: 15 additions & 0 deletions src/cpp/include/UTIL/LCCollectionTools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef UTIL_LCCollectionTools_H
#define UTIL_LCCollectionTools_H 1

#include "EVENT/LCObject.h"
#include "EVENT/LCCollection.h"

namespace UTIL{
/** Extract object index inside a given LCCollection. If object is not found, return -1.
* @author Bohdan Dudar
* @version August 2022
*/
int getElementIndex(const EVENT::LCObject* item, EVENT::LCCollection* collection);
}

#endif
42 changes: 31 additions & 11 deletions src/cpp/include/UTIL/LCRelationNavigator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,55 +38,75 @@ namespace UTIL {
LCRelationNavigator( const EVENT::LCCollection* col ) ;

/// Destructor.
virtual ~LCRelationNavigator() { /* nop */; }
~LCRelationNavigator() { /* nop */; }

/**The type of the 'from' objects in this relation.
*/
virtual const std::string & getFromType() const ;
const std::string & getFromType() const ;

/**The type of the 'to' objects in this relation.
*/
virtual const std::string & getToType() const ;
const std::string & getToType() const ;

/** All objects that the given from-object is related to.
* LCObjects are of type getToType().
*/
virtual const EVENT::LCObjectVec & getRelatedToObjects(EVENT::LCObject * from) const ;
const EVENT::LCObjectVec & getRelatedToObjects(EVENT::LCObject * from) const ;

/** All from-objects related to the given object ( the inverse relationship).
* LCObjects are of type getFromType().
*/
virtual const EVENT::LCObjectVec & getRelatedFromObjects(EVENT::LCObject * to) const ;
const EVENT::LCObjectVec & getRelatedFromObjects(EVENT::LCObject * to) const ;

/** The weights of the relations returned by a call to getRelatedToObjects(from).
* @see getRelatedToObjects
*/
virtual const EVENT::FloatVec & getRelatedToWeights(EVENT::LCObject * from) const ;
const EVENT::FloatVec & getRelatedToWeights(EVENT::LCObject * from) const ;

/** The weights of the relations returned by a call to getRelatedFromObjects(to).
* @see getRelatedFromObjects
*/
virtual const EVENT::FloatVec & getRelatedFromWeights(EVENT::LCObject * to) const ;
const EVENT::FloatVec & getRelatedFromWeights(EVENT::LCObject * to) const ;

/** Object with a highest weight that the given from-object is related to.
* LCObject is of type getToType().
*/
const EVENT::LCObject* getRelatedToMaxWeightObject(EVENT::LCObject* from, const std::string& weightType) const ;

/** From-object related to the given object with a highest weight (the inverse relationship).
* LCObject is of type getFromType().
*/
const EVENT::LCObject* getRelatedFromMaxWeightObject(EVENT::LCObject* to, const std::string& weightType) const ;

/** The highest weight of the relations returned by a call to getRelatedToObjects(from).
* @see getRelatedToObjects
*/
float getRelatedToMaxWeight(EVENT::LCObject* from, const std::string& weightType) const ;

/** The highest weight of the relations returned by a call to getRelatedFromObjects(to).
* @see getRelatedFromObjects
*/
float getRelatedFromMaxWeight(EVENT::LCObject* to, const std::string& weightType) const ;

/** Adds a relation. If there is already an existing relation between the two given objects
* the weight (or default weight 1.0) is added to that relationship's weight.
*/
virtual void addRelation(EVENT::LCObject * from, EVENT::LCObject * to, float weight = 1.0) ;
void addRelation(EVENT::LCObject * from, EVENT::LCObject * to, float weight = 1.0) ;

/** Remove a given relation.
*/
virtual void removeRelation(EVENT::LCObject * from, EVENT::LCObject * to) ;
void removeRelation(EVENT::LCObject * from, EVENT::LCObject * to) ;

/** Remove a given relation. To reduce the weight of the relationship, call
* addRelation( from, to, weight ) with weight<0.
*/
virtual EVENT::LCCollection * createLCCollection() ;
EVENT::LCCollection * createLCCollection() ;

protected:

LCRelationNavigator() ;

virtual void initialize( const EVENT::LCCollection* col ) ;
void initialize( const EVENT::LCCollection* col ) ;

void removeRelation(EVENT::LCObject * from, EVENT::LCObject * to, RelMap& map ) ;
void addRelation(EVENT::LCObject * from, EVENT::LCObject * to, float weight, RelMap& map) ;
Expand Down
15 changes: 15 additions & 0 deletions src/cpp/include/UTIL/ReconstructedParticleTools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef UTIL_ReconstructedParticleTools_H
#define UTIL_ReconstructedParticleTools_H 1

#include "EVENT/ReconstructedParticle.h"
#include "EVENT/Track.h"

namespace UTIL{
/** Extract the leading track (sorted by momentum) in case of multiple tracks attached to a single ReconstructedParticle.
* @author Bohdan Dudar
* @version August 2022
*/
EVENT::Track* getLeadingTrack(const EVENT::ReconstructedParticle* particle);
}

#endif
25 changes: 25 additions & 0 deletions src/cpp/include/UTIL/TrackTools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef UTIL_TrackTools_H
#define UTIL_TrackTools_H 1

#include "EVENT/Track.h"
#include <array>
#include <cmath>

namespace UTIL{
/** Extract track momentum from its track parameters and magnetic field
* @author Bohdan Dudar
* @version August 2022
*/
template <typename TrackLikeT>
std::array<double, 3> getTrackMomentum(const TrackLikeT* track, double bz){
double omega = track->getOmega();
if (omega == 0.) return {0., 0., 0.};
double phi = track->getPhi();
double tanL = track->getTanLambda();
double c_light = 299.792458; // mm/ns
double pt = (1e-6 * c_light * bz) / std::abs(omega);
return {pt*std::cos(phi), pt*std::sin(phi), pt*tanL};
}
}

#endif
17 changes: 15 additions & 2 deletions src/cpp/src/TESTS/test_tracks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "UTIL/Operators.h"
#include "UTIL/LCIterator.h"
#include "UTIL/TrackTools.h"

#include <sstream>
#include <assert.h>
Expand Down Expand Up @@ -218,6 +219,19 @@ int main(int /*argc*/, char** /*argv*/ ){
ss << " ref[" << k << "] " ;
MYTEST( ref[k] , float(k+1) , ss.str() ) ;
}

//Test of the getTrackMomentum
std::array<double, 3> trkRecoMomentum = UTIL::getTrackMomentum(trk, 3.5);
std::array<double, 3> trkTrueMomentum = {0., 0., 0.};
if (trk->getOmega() != 0.){
double trkPx = (1e-6 * 299.792458 * 3.5) / std::abs((i+j) * .1)*std::cos((i+j) * .3);
double trkPy = (1e-6 * 299.792458 * 3.5) / std::abs((i+j) * .1)*std::sin((i+j) * .3);
double trkPz = (1e-6 * 299.792458 * 3.5) / std::abs((i+j) * .1)*(i+j) * .2;
trkTrueMomentum = {trkPx, trkPy, trkPz};
}
std::stringstream failMsg;
failMsg << " getTrackMomentum : [" << trkRecoMomentum << "] != [" << trkTrueMomentum << "]";
MYTEST(approxEqArray(trkRecoMomentum, trkTrueMomentum), failMsg.str());

++j ;
}
Expand Down Expand Up @@ -438,9 +452,8 @@ int main(int /*argc*/, char** /*argv*/ ){
} catch( Exception &e ){
MYTEST.FAILED( e.what() );
}

return 0;

}

//=============================================================================

15 changes: 14 additions & 1 deletion src/cpp/src/TESTS/test_trackstate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "EVENT/TrackState.h"
#include "IMPL/TrackStateImpl.h"
#include "UTIL/TrackTools.h"

//#include "UTIL/Operators.h"

Expand Down Expand Up @@ -41,7 +42,8 @@ int main(int /*argc*/, char** /*argv*/ ){
MYTEST( a.getPhi(), float( .0 ), "getPhi" ) ;
MYTEST( a.getOmega(), float( .0 ), "getOmega" ) ;


// Omega == 0, will yield zero momentum as there is no reasonable value in that case
MYTEST( getTrackMomentum(&a, 3.5), std::array<double, 3>{0, 0, 0}, "getTrackMomentum" );

MYTEST.LOG( "test constructor with arguments" );

Expand All @@ -65,6 +67,17 @@ int main(int /*argc*/, char** /*argv*/ ){
MYTEST( b.getPhi(), float( .2 ), "getPhi" ) ;
MYTEST( b.getOmega(), float( .3 ), "getOmega" ) ;

const std::array<double, 3> trueMomentum = {
1e-6 * 299.7922458 * 3.5 / .3 * std::cos(.2), // pt * cos(phi), with pt = c * B / omega
1e-6 * 299.7922458 * 3.5 / .3 * std::sin(.2), // pt * sin(phi)
1e-6 * 299.7922458 * 3.5 / .3 * .5 // pt * tanL
};

std::stringstream failMsg;
const auto tsMomentum = getTrackMomentum(&b, 3.5);
failMsg << " getTrackMomentum : [" << tsMomentum << "] != << [" << trueMomentum << "]";
MYTEST(approxEqArray(tsMomentum, trueMomentum), failMsg.str());



MYTEST.LOG( "test default copy constructor" );
Expand Down
42 changes: 38 additions & 4 deletions src/cpp/src/TESTS/tutil.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,22 @@
#include <sstream>
#include <stdio.h>
#include <stdlib.h>
#include <array>
#include <limits>
#include <cmath>

template<typename T, size_t N>
std::ostream& operator<<(std::ostream& os, const std::array<T, N>& arr) {
if constexpr (N == 0) {
return os << "[]";
}

os << "[" << arr[0];
for (size_t i = 1; i < N; ++i) {
os << ", " << arr[i];
}
return os << "]";
}

class TEST{

Expand Down Expand Up @@ -34,10 +50,10 @@ class TEST{
return ;
}

// void operator()(bool cond, const std::string msg) {
// if ( ! cond ) FAILED( msg ) ;
// return ;
// }
void operator()(bool cond, const std::string msg) {
if ( ! cond ) FAILED( msg ) ;
return ;
}

void FAILED( const std::string& msg ){

Expand All @@ -61,3 +77,21 @@ class TEST{
std::string _testname;
std::ostream& _out;
};

bool approxEqual(double lhs, double rhs) {
// Following a similar, but slightly simplified approach as Catch2::Approx here
constexpr double epsilon = std::numeric_limits<float>::epsilon() * 100;
const double margin = std::fabs(lhs) * epsilon;
return (lhs + margin >= rhs) && (rhs + margin >= lhs);
}

template<typename T, size_t N, typename ApproxComp=decltype(approxEqual)>
bool approxEqArray(const std::array<T, N>& arr1, const std::array<T, N>& arr2, ApproxComp&& comp=approxEqual) {
for (size_t i = 0; i < N; ++i) {
if (!comp(arr1[i], arr2[i])) {
return false;
}
}
return true;
}

10 changes: 10 additions & 0 deletions src/cpp/src/UTIL/LCCollectionTools.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include "UTIL/LCCollectionTools.h"

namespace UTIL{
int getElementIndex(const EVENT::LCObject* item, EVENT::LCCollection* collection){
for(int i=0; i < collection->getNumberOfElements(); ++i){
if ( item == collection->getElementAt(i) ) return i;
}
return -1;
}
}
50 changes: 50 additions & 0 deletions src/cpp/src/UTIL/LCRelationNavigator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,56 @@ namespace UTIL{
return _rMap[ to ].second ;
}

auto getMaxWeightIt(const EVENT::FloatVec& weights, const std::string& weightType) {
if (weightType == "track") {
std::max_element(weights.begin(), weights.end(), [](float a, float b) {
return (int(a) % 10000) / 1000. < (int(b) % 10000) / 1000.;
});
} else if (weightType == "cluster") {
std::max_element(weights.begin(), weights.end(), [](float a, float b) {
return (int(a) / 10000) / 1000. < (int(b) / 10000) /1000.;
});
}
return std::max_element(weights.begin(), weights.end());
}

const EVENT::LCObject* LCRelationNavigator::getRelatedToMaxWeightObject(EVENT::LCObject* from, const std::string& weightType) const {
const auto& objects = getRelatedToObjects(from);
if ( objects.empty() ) return nullptr;

const auto& weights = getRelatedToWeights(from);
const auto maxWeightIt = getMaxWeightIt(weights, weightType);
int i = std::distance(weights.begin(), maxWeightIt);
return objects[i];
}

const EVENT::LCObject* LCRelationNavigator::getRelatedFromMaxWeightObject(EVENT::LCObject* to, const std::string& weightType) const {
const auto& objects = getRelatedToObjects(to);
if ( objects.empty() ) return nullptr;

const auto& weights = getRelatedToWeights(to);
const auto maxWeightIt = getMaxWeightIt(weights, weightType);

int i = std::distance(weights.begin(), maxWeightIt);
return objects[i];
}

float LCRelationNavigator::getRelatedToMaxWeight(EVENT::LCObject* from, const std::string& weightType) const {
const auto& objects = getRelatedToObjects(from);
if ( objects.empty() ) return 0.;

const auto& weights = getRelatedToWeights(from);
return *getMaxWeightIt(weights, weightType);
}

float LCRelationNavigator::getRelatedFromMaxWeight(EVENT::LCObject* to, const std::string& weightType) const {
const auto& objects = getRelatedToObjects(to);
if ( objects.empty() ) return 0.;

const auto& weights = getRelatedToWeights(to);
return *getMaxWeightIt(weights, weightType);
}

void LCRelationNavigator::addRelation(EVENT::LCObject * from,
EVENT::LCObject * to,
float weight) {
Expand Down
16 changes: 16 additions & 0 deletions src/cpp/src/UTIL/ReconstructedParticleTools.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "UTIL/ReconstructedParticleTools.h"
#include <algorithm>
#include "UTIL/TrackTools.h"
#include <cmath>

namespace UTIL{
EVENT::Track* getLeadingTrack(const EVENT::ReconstructedParticle* particle){
const EVENT::TrackVec& tracks = particle->getTracks();
if ( tracks.empty() ) return nullptr;
// compare momentum w/o using its common scale factor (which includes magnetic field)
auto sortByMomentum = [](const EVENT::Track* a, const EVENT::Track* b) { return std::hypot(1., a->getTanLambda())/std::abs(a->getOmega()) < std::hypot(1., b->getTanLambda())/std::abs(b->getOmega()); };
auto* leadingTrack = *(std::max_element(tracks.begin(), tracks.end(), sortByMomentum));

return leadingTrack;
}
}