Skip to content

Commit

Permalink
feat: EB filter with compact mode
Browse files Browse the repository at this point in the history
  • Loading branch information
c-dilks committed Nov 21, 2023
1 parent 8189668 commit b9d2aef
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 53 deletions.
47 changes: 24 additions & 23 deletions src/algorithms/clas12/event_builder_filter/EventBuilderFilter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,43 @@ namespace iguana::clas12 {

// check the input banks existence
if(MissingInputBanks(inBanks, {"particles"}))
ThrowRun();
Throw("missing input banks");

// define the output schemata and banks
BankMap outBanks = {
{ "particles", hipo::bank(inBanks.at("particles").getSchema()) }
};

// set number of output rows
// filter the input bank for requested PDG code(s)
std::set<int> acceptedRows;
for(int row = 0; row < inBanks.at("particles").getRows(); row++) {
auto pid = inBanks.at("particles").get("pid", row);
auto accept = m_opt.pids.contains(pid);
if(accept) acceptedRows.insert(row);
m_log->Debug("input PID {} -- accept = {}", pid, accept);
}

// fill the output bank
switch(m_opt.mode) {

case EventBuilderFilterOptions::Modes::blank:
outBanks.at("particles").setRows(inBanks.at("particles").getRows());
for(int row = 0; row < inBanks.at("particles").getRows(); row++) {
if(acceptedRows.contains(row))
CopyBankRow(inBanks.at("particles"), row, outBanks.at("particles"), row);
else
BlankRow(outBanks.at("particles"), row);
}
break;

case EventBuilderFilterOptions::Modes::compact:
outBanks.at("particles").setRows(inBanks.at("particles").getRows()); // FIXME
outBanks.at("particles").setRows(acceptedRows.size());
for(int row = 0; auto acceptedRow : acceptedRows)
CopyBankRow(inBanks.at("particles"), acceptedRow, outBanks.at("particles"), row++);
break;
}

// filter the input bank for requested PDG code(s)
int outRow = -1;
for(int inRow = 0; inRow < inBanks.at("particles").getRows(); inRow++) {
auto inPid = inBanks.at("particles").get("pid",inRow);

if(m_opt.pids.contains(inPid)) {
m_log->Debug("input PID {} -- accept", inPid);
CopyBankRow(
inBanks.at("particles"),
outBanks.at("particles"),
m_opt.mode == EventBuilderFilterOptions::Modes::blank ? inRow : outRow++
);
}

else {
m_log->Debug("input PID {} -- reject", inPid);
if(m_opt.mode == EventBuilderFilterOptions::blank)
BlankRow(outBanks.at("particles"), inRow);
}
default:
Throw("unknown 'mode' option");

}

Expand Down
16 changes: 8 additions & 8 deletions src/services/Algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace iguana {

bool Algorithm::MissingInputBanks(BankMap banks, std::set<std::string> keys) {
for(auto key : keys) {
if(banks.find(key) == banks.end()) {
if(!banks.contains(key)) {
m_log->Error("Algorithm '{}' is missing the input bank '{}'", m_name, key);
m_log->Error(" => the following input banks are required by '{}':", m_name);
for(auto k : keys)
Expand All @@ -19,11 +19,11 @@ namespace iguana {
return false;
}

void Algorithm::CopyBankRow(hipo::bank srcBank, hipo::bank destBank, int row) {
void Algorithm::CopyBankRow(hipo::bank srcBank, int srcRow, hipo::bank destBank, int destRow) {
// TODO: check srcBank.getSchema() == destBank.getSchema()
for(int item = 0; item < srcBank.getSchema().getEntries(); item++) {
auto val = srcBank.get(item, row);
destBank.put(item, row, val);
auto val = srcBank.get(item, srcRow);
destBank.put(item, destRow, val);
}
}

Expand All @@ -33,10 +33,6 @@ namespace iguana {
}
}

void Algorithm::ThrowRun() {
throw std::runtime_error(fmt::format("Algorithm '{}' cannot `Run`", m_name));
}

void Algorithm::ShowBanks(BankMap banks, std::string message, Logger::Level level) {
if(m_log->GetLevel() <= level) {
m_log->Print(level, message);
Expand All @@ -52,4 +48,8 @@ namespace iguana {
ShowBanks(outBanks, "===== OUTPUT BANKS =====", level);
}

void Algorithm::Throw(std::string message) {
throw std::runtime_error(fmt::format("CRITICAL ERROR: {}; Algorithm '{}' stopped!", message, m_name));
}

}
12 changes: 7 additions & 5 deletions src/services/Algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,16 @@ namespace iguana {

/// Copy a row from one bank to another, assuming their schemata are equivalent
/// @param srcBank the source bank
/// @param srcRow the row in `srcBank` to copy from
/// @param destBank the destination bank
/// @param row the row to copy from `srcBank` to `destBank`
void CopyBankRow(hipo::bank srcBank, hipo::bank destBank, int row);
/// @param destRow the row in `destBank` to copy to
void CopyBankRow(hipo::bank srcBank, int srcRow, hipo::bank destBank, int destRow);

/// Blank a row, setting all items to zero
/// @param bank the bank to modify
/// @param row the row to blank
void BlankRow(hipo::bank bank, int row);

/// Throw a runtime exception when calling `Run`
void ThrowRun();

/// Dump all banks in a BankMap
/// @param banks the banks to show
/// @param message optionally print a header message
Expand All @@ -64,6 +62,10 @@ namespace iguana {
/// @param level the log level
void ShowBanks(BankMap inBanks, BankMap outBanks, Logger::Level level=Logger::trace);

/// Stop the algorithm and throw a runtime exception
/// @param message the error message
void Throw(std::string message);

/// algorithm name
std::string m_name;

Expand Down
30 changes: 13 additions & 17 deletions src/services/Logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,25 @@ namespace iguana {
void SetLevel(Level lev);
Level GetLevel();

template <typename... VALUES> void Trace(std::string msg, VALUES... vals) { Print(trace, msg, vals...); }
template <typename... VALUES> void Debug(std::string msg, VALUES... vals) { Print(debug, msg, vals...); }
template <typename... VALUES> void Info(std::string msg, VALUES... vals) { Print(info, msg, vals...); }
template <typename... VALUES> void Warn(std::string msg, VALUES... vals) { Print(warn, msg, vals...); }
template <typename... VALUES> void Error(std::string msg, VALUES... vals) { Print(error, msg, vals...); }
template <typename... VALUES> void Trace(std::string message, VALUES... vals) { Print(trace, message, vals...); }
template <typename... VALUES> void Debug(std::string message, VALUES... vals) { Print(debug, message, vals...); }
template <typename... VALUES> void Info(std::string message, VALUES... vals) { Print(info, message, vals...); }
template <typename... VALUES> void Warn(std::string message, VALUES... vals) { Print(warn, message, vals...); }
template <typename... VALUES> void Error(std::string message, VALUES... vals) { Print(error, message, vals...); }

template <typename... VALUES>
void Print(Level lev, std::string msg, VALUES... vals) {
void Print(Level lev, std::string message, VALUES... vals) {
if(lev >= m_level) {
auto level_name_it = m_level_names.find(lev);
if(level_name_it == m_level_names.end()) {
Warn("Logger::Print called with unknown log level '{}'; printing as error instead", static_cast<int>(lev)); // FIXME: static_cast -> fmt::underlying, but needs new version of fmt
Error(msg, vals...);
} else {
if(m_level_names.contains(lev)) {
auto prefix = fmt::format("[{}] [{}] ", m_level_names.at(lev), m_name);
fmt::print(
lev >= warn ? stderr : stdout,
fmt::format(
"[{}] [{}] {}\n",
level_name_it->second,
m_name,
fmt::format(msg, vals...)
)
fmt::runtime(prefix + message + "\n"),
vals...
);
} else {
Warn("Logger::Print called with unknown log level '{}'; printing as error instead", static_cast<int>(lev)); // FIXME: static_cast -> fmt::underlying, but needs new version of fmt
Error(message, vals...);
}
}
}
Expand Down

0 comments on commit b9d2aef

Please sign in to comment.