Skip to content

Commit

Permalink
refactor: use hipo::banklist for mutation (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-dilks authored Dec 6, 2023
1 parent 4e6c314 commit e230917
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 81 deletions.
12 changes: 6 additions & 6 deletions src/algorithms/clas12/event_builder_filter/EventBuilderFilter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace iguana::clas12 {

}

void EventBuilderFilter::Start(bank_index_cache_t &index_cache) {
void EventBuilderFilter::Start(bank_index_cache_t& index_cache) {

// define options, their default values, and cache them
CacheOption("pids", std::set<int>{11, 211}, o_pids);
Expand All @@ -23,18 +23,18 @@ namespace iguana::clas12 {
}


void EventBuilderFilter::Run(bank_vec_t banks) {
void EventBuilderFilter::Run(hipo::banklist& banks) {

// get the banks
auto particleBank = GetBank(banks, b_particle, "REC::Particle");
auto caloBank = GetBank(banks, b_calo, "REC::Calorimeter"); // TODO: remove
auto& particleBank = GetBank(banks, b_particle, "REC::Particle");
// auto& caloBank = GetBank(banks, b_calo, "REC::Calorimeter"); // TODO: remove

// dump the bank
ShowBank(particleBank, Logger::Header("INPUT PARTICLES"));

// filter the input bank for requested PDG code(s)
for(int row = 0; row < particleBank->getRows(); row++) {
auto pid = particleBank->getInt("pid", row);
for(int row = 0; row < particleBank.getRows(); row++) {
auto pid = particleBank.getInt("pid", row);
auto accept = Filter(pid);
if(!accept)
MaskRow(particleBank, row);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ namespace iguana::clas12 {
~EventBuilderFilter() {}

void Start() override { Algorithm::Start(); }
void Start(bank_index_cache_t &index_cache) override;
void Run(bank_vec_t banks) override;
void Start(bank_index_cache_t& index_cache) override;
void Run(hipo::banklist& banks) override;
void Stop() override;

bool Filter(int pid);

private:

/// `bank_vec_t` indices
/// `hipo::banklist` indices
int b_particle, b_calo; // TODO: remove calorimeter

/// configuration options
Expand Down
2 changes: 1 addition & 1 deletion src/iguana/Iguana.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace iguana {

Iguana::Iguana() {
algo_map.insert({clas12_EventBuilderFilter, std::make_shared<clas12::EventBuilderFilter>()});
algo_map.insert({clas12_EventBuilderFilter, std::move(std::make_unique<clas12::EventBuilderFilter>())});
}

}
2 changes: 1 addition & 1 deletion src/iguana/Iguana.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace iguana {
};

// TODO: make private
std::unordered_map<Iguana::algo, std::shared_ptr<Algorithm>> algo_map;
std::unordered_map<Iguana::algo, std::unique_ptr<Algorithm>> algo_map;

};
}
38 changes: 19 additions & 19 deletions src/services/Algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace iguana {

Algorithm::Algorithm(std::string name) : m_name(name) {
m_log = std::make_shared<Logger>(m_name);
m_log = std::make_unique<Logger>(m_name);
}

void Algorithm::Start() {
Expand All @@ -19,14 +19,14 @@ namespace iguana {
m_log->Debug("User set option '{}' = {}", key, PrintOptionValue(key));
}

std::shared_ptr<Logger> Algorithm::Log() {
std::unique_ptr<Logger>& Algorithm::Log() {
return m_log;
}

void Algorithm::CacheBankIndex(bank_index_cache_t index_cache, int &idx, std::string bankName) {
void Algorithm::CacheBankIndex(bank_index_cache_t index_cache, int& idx, std::string bankName) {
try {
idx = index_cache.at(bankName);
} catch(const std::out_of_range &o) {
} catch(const std::out_of_range& o) {
Throw(fmt::format("required input bank '{}' not found; cannot `Start` algorithm '{}'", bankName, m_name));
}
m_log->Debug("cached index of bank '{}' is {}", bankName, idx);
Expand All @@ -50,39 +50,39 @@ namespace iguana {
return "UNKNOWN";
}

bank_ptr Algorithm::GetBank(bank_vec_t banks, int idx, std::string expectedBankName) {
bank_ptr result;
hipo::bank& Algorithm::GetBank(hipo::banklist& banks, int idx, std::string expectedBankName) {
try {
result = banks.at(idx);
} catch(const std::out_of_range &o) {
auto& result = banks.at(idx);
if(expectedBankName != "" && result.getSchema().getName() != expectedBankName) {
Throw(fmt::format("expected input bank '{}' at index={}; got bank named '{}'", expectedBankName, idx, result.getSchema().getName()));
}
return result;
} catch(const std::out_of_range& o) {
Throw(fmt::format("required input bank '{}' not found; cannot `Run` algorithm '{}'", expectedBankName, m_name));
}
if(expectedBankName != "" && result->getSchema().getName() != expectedBankName) {
Throw(fmt::format("expected input bank '{}' at index={}; got bank named '{}'", expectedBankName, idx, result->getSchema().getName()));
}
return result;
throw std::runtime_error("GetBank failed"); // avoid `-Wreturn-type` warning
}

void Algorithm::MaskRow(bank_ptr bank, int row) {
void Algorithm::MaskRow(hipo::bank& bank, int row) {
// TODO: need https://github.com/gavalian/hipo/issues/35
// until then, just set the PID to -1
bank->putInt("pid", row, -1);
bank.putInt("pid", row, -1);
}

void Algorithm::ShowBanks(bank_vec_t banks, std::string message, Logger::Level level) {
void Algorithm::ShowBanks(hipo::banklist& banks, std::string message, Logger::Level level) {
if(m_log->GetLevel() <= level) {
if(message != "")
m_log->Print(level, message);
for(auto bank : banks)
bank->show();
for(auto& bank : banks)
bank.show();
}
}

void Algorithm::ShowBank(bank_ptr bank, std::string message, Logger::Level level) {
void Algorithm::ShowBank(hipo::bank& bank, std::string message, Logger::Level level) {
if(m_log->GetLevel() <= level) {
if(message != "")
m_log->Print(level, message);
bank->show();
bank.show();
}
}

Expand Down
38 changes: 19 additions & 19 deletions src/services/Algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ namespace iguana {

/// Initialize an algorithm before any events are processed
/// @param index_cache The `Run` method will use these indices to access banks
virtual void Start(bank_index_cache_t &index_cache) = 0;
virtual void Start(bank_index_cache_t& index_cache) = 0;

/// Run an algorithm
/// @param banks the set of banks to process
virtual void Run(bank_vec_t banks) = 0;
virtual void Run(hipo::banklist& banks) = 0;

/// Finalize an algorithm after all events are processed
virtual void Stop() = 0;
Expand All @@ -38,27 +38,27 @@ namespace iguana {

/// Get the logger
/// @return the logger used by this algorithm
std::shared_ptr<Logger> Log();
std::unique_ptr<Logger>& Log();

protected:

/// Cache the index of a bank in a `bank_vec_t`; throws an exception if the bank is not found
/// @param index_cache the relation between bank name and `bank_vec_t` index
/// @param idx a reference to the `bank_vec_t` index of the bank
/// Cache the index of a bank in a `hipo::banklist`; throws an exception if the bank is not found
/// @param index_cache the relation between bank name and `hipo::banklist` index
/// @param idx a reference to the `hipo::banklist` index of the bank
/// @param bankName the name of the bank
void CacheBankIndex(bank_index_cache_t index_cache, int &idx, std::string bankName);
void CacheBankIndex(bank_index_cache_t index_cache, int& idx, std::string bankName) noexcept(false);

/// Cache an option specified by the user, and define its default value
/// @param key the name of the option
/// @param def the default value
/// @param val reference to the value of the option, to be cached by `Start`
template <typename OPTION_TYPE>
void CacheOption(std::string key, OPTION_TYPE def, OPTION_TYPE &val) {
void CacheOption(std::string key, OPTION_TYPE def, OPTION_TYPE& val) {
bool get_error = false;
if(auto it{m_opt.find(key)}; it != m_opt.end()) { // cache the user's option value
try { // get the expected type
val = std::get<OPTION_TYPE>(it->second);
} catch(const std::bad_variant_access &ex1) {
} catch(const std::bad_variant_access& ex1) {
m_log->Error("user option '{}' set to '{}', which is the wrong type...", key, PrintOptionValue(key));
get_error = true;
val = def;
Expand All @@ -79,33 +79,33 @@ namespace iguana {
/// @return the string value and its type
std::string PrintOptionValue(std::string key);

/// Get the pointer to a bank from a `bank_vec_t`; optionally checks if the bank name matches the expectation
/// @param banks the `bank_vec_t` from which to get the specified bank
/// Get the pointer to a bank from a `hipo::banklist`; optionally checks if the bank name matches the expectation
/// @param banks the `hipo::banklist` from which to get the specified bank
/// @param idx the index of `banks` of the specified bank
/// @param expectedBankName if specified, checks that the specified bank has this name
/// @return the modified `bank_vec_t`
bank_ptr GetBank(bank_vec_t banks, int idx, std::string expectedBankName="");
/// @return the modified `hipo::banklist`
hipo::bank& GetBank(hipo::banklist& banks, int idx, std::string expectedBankName="") noexcept(false);

/// Mask a row, setting all items to zero
/// @param bank the bank to modify
/// @param row the row to blank
void MaskRow(bank_ptr bank, int row);
void MaskRow(hipo::bank& bank, int row);

/// Dump all banks in a `bank_vec_t`
/// Dump all banks in a `hipo::banklist`
/// @param banks the banks to show
/// @param message optionally print a header message
/// @param level the log level
void ShowBanks(bank_vec_t banks, std::string message="", Logger::Level level=Logger::trace);
void ShowBanks(hipo::banklist& banks, std::string message="", Logger::Level level=Logger::trace);

/// Dump a single bank
/// @param bank the bank to show
/// @param message optionally print a header message
/// @param level the log level
void ShowBank(bank_ptr bank, std::string message="", Logger::Level level=Logger::trace);
void ShowBank(hipo::bank& bank, std::string message="", Logger::Level level=Logger::trace);

/// Stop the algorithm and throw a runtime exception
/// @param message the error message
void Throw(std::string message);
void Throw(std::string message) noexcept(false);

/// algorithm name
std::string m_name;
Expand All @@ -114,7 +114,7 @@ namespace iguana {
std::vector<std::string> m_requiredBanks;

/// Logger
std::shared_ptr<Logger> m_log;
std::unique_ptr<Logger> m_log;

/// Configuration options
options_t m_opt;
Expand Down
2 changes: 1 addition & 1 deletion src/services/Logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace iguana {
}

void Logger::SetLevel(std::string lev) {
for(auto &[lev_i, lev_n] : m_level_names) {
for(auto& [lev_i, lev_n] : m_level_names) {
if(lev == lev_n) {
SetLevel(lev_i);
return;
Expand Down
8 changes: 1 addition & 7 deletions src/services/TypeDefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@

namespace iguana {

/// pointer to a HIPO bank
using bank_ptr = std::shared_ptr<hipo::bank>;

/// ordered list of HIPO bank pointers
using bank_vec_t = std::vector<bank_ptr>;

/// association between HIPO bank name and its index in a `bank_vec_t`
/// association between HIPO bank name and its index in a `hipo::banklist`
using bank_index_cache_t = std::unordered_map<std::string, int>;

/// option value variant type
Expand Down
40 changes: 19 additions & 21 deletions src/tests/run_banks.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "iguana/Iguana.h"
#include <hipo4/reader.h>

void printParticles(std::string prefix, iguana::bank_ptr b) {
void printParticles(std::string prefix, hipo::bank& b) {
std::vector<int> pids;
for(int row=0; row<b->getRows(); row++)
pids.push_back(b->getInt("pid", row));
for(int row=0; row<b.getRows(); row++)
pids.push_back(b.getInt("pid", row));
fmt::print("{}: {}\n", prefix, fmt::join(pids, ", "));
}

Expand All @@ -20,7 +20,7 @@ int main(int argc, char **argv) {
* use the test algorithm directly
*/
iguana::Iguana I;
auto algo = I.algo_map.at(iguana::Iguana::clas12_EventBuilderFilter);
auto& algo = I.algo_map.at(iguana::Iguana::clas12_EventBuilderFilter);
algo->Log()->SetLevel("trace");
// algo->Log()->DisableStyle();
algo->SetOption("pids", std::set<int>{11, 211, -211});
Expand All @@ -31,26 +31,24 @@ int main(int argc, char **argv) {
/////////////////////////////////////////////////////

// read input file
hipo::reader reader;
reader.open(inFileName.c_str());

// get bank schema
/* TODO: users should not have to do this; this is a workaround until
* the pattern `hipo::event::getBank("REC::Particle")` is possible
*/
hipo::dictionary factory;
reader.readDictionary(factory);
auto particleBank = std::make_shared<hipo::bank>(factory.getSchema("REC::Particle"));
auto caloBank = std::make_shared<hipo::bank>(factory.getSchema("REC::Calorimeter")); // TODO: remove when not needed (this is for testing)
hipo::reader reader(inFileName.c_str());

// set banks
hipo::banklist banks = reader.getBanks({
"REC::Particle",
"REC::Calorimeter"
});
enum banks_enum { // TODO: make this nicer
b_particle,
b_calo
};

// event loop
hipo::event event;
int iEvent = 0;
while(reader.next(event) && (iEvent++ < numEvents || numEvents == 0)) {
event.getStructure(*particleBank);
printParticles("PIDS BEFORE algo->Run() ", particleBank);
algo->Run({particleBank, caloBank});
printParticles("PIDS AFTER algo->Run() ", particleBank);
while(reader.next(banks) && (iEvent++ < numEvents || numEvents == 0)) {
printParticles("PIDS BEFORE algo->Run() ", banks.at(b_particle));
algo->Run(banks);
printParticles("PIDS AFTER algo->Run() ", banks.at(b_particle));
}

/////////////////////////////////////////////////////
Expand Down
7 changes: 4 additions & 3 deletions src/tests/run_rows.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

int main(int argc, char **argv) {

/* DISABLED until `run_banks` is more stable
// parse arguments
int argi = 1;
std::string inFileName = argc > argi ? std::string(argv[argi++]) : "data.hipo";
int numEvents = argc > argi ? std::stoi(argv[argi++]) : 3;
// start the algorithm
auto algo = std::make_shared<iguana::clas12::EventBuilderFilter>();
algo->SetOption("pids", std::set<int>{11, 211, -211});
Expand All @@ -20,9 +23,6 @@ int main(int argc, char **argv) {
reader.open(inFileName.c_str());
// get bank schema
/* TODO: users should not have to do this; this is a workaround until
* the pattern `hipo::event::getBank("REC::Particle")` is possible
*/
hipo::dictionary factory;
reader.readDictionary(factory);
auto particleBank = std::make_shared<hipo::bank>(factory.getSchema("REC::Particle"));
Expand All @@ -43,5 +43,6 @@ int main(int argc, char **argv) {
/////////////////////////////////////////////////////
algo->Stop();
*/
return 0;
}

0 comments on commit e230917

Please sign in to comment.