Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add event builder filter algorithm #9

Merged
merged 8 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ project(
'iguana',
'cpp',
version: '0.0.0',
license: 'LGPLv3'
license: 'LGPLv3',
default_options: [ 'cpp_std=c++20' ],
)

project_inc = include_directories('src')
Expand All @@ -20,6 +21,6 @@ fmt_dep = dependency('fmt')
hipo_dep = dependency('hipo4', method: 'cmake', cmake_args: '-DCMAKE_PREFIX_PATH=' + get_option('hipo'))

subdir('src/services')
subdir('src/algorithms/clas12/fiducial_cuts')
subdir('src/algorithms')
subdir('src/iguana')
subdir('src/tests')
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "EventBuilderFilter.h"

namespace iguana::clas12 {

void EventBuilderFilter::Start() {
m_log->Debug("START {}", m_name);

// set configuration
m_log->SetLevel(Logger::Level::trace);
m_opt.mode = EventBuilderFilterOptions::Modes::blank;
m_opt.pids = {11, 211, -211};
}


Algorithm::BankMap EventBuilderFilter::Run(Algorithm::BankMap inBanks) {
m_log->Debug("RUN {}", m_name);

// check the input banks existence
if(MissingInputBanks(inBanks, {"particles"}))
Throw("missing input banks");

// define the output schemata and banks
BankMap outBanks = {
{ "particles", hipo::bank(inBanks.at("particles").getSchema()) }
};

// filter the input bank for requested PDG code(s)
std::set<int> acceptedRows;
for(int row = 0; row < inBanks.at("particles").getRows(); row++) {
auto pid = inBanks.at("particles").get("pid", row);
auto accept = m_opt.pids.contains(pid);
if(accept) acceptedRows.insert(row);
m_log->Debug("input PID {} -- accept = {}", pid, accept);
}

// fill the output bank
switch(m_opt.mode) {

case EventBuilderFilterOptions::Modes::blank:
outBanks.at("particles").setRows(inBanks.at("particles").getRows());
for(int row = 0; row < inBanks.at("particles").getRows(); row++) {
if(acceptedRows.contains(row))
CopyBankRow(inBanks.at("particles"), row, outBanks.at("particles"), row);
else
BlankRow(outBanks.at("particles"), row);
}
break;

case EventBuilderFilterOptions::Modes::compact:
outBanks.at("particles").setRows(acceptedRows.size());
for(int row = 0; auto acceptedRow : acceptedRows)
CopyBankRow(inBanks.at("particles"), acceptedRow, outBanks.at("particles"), row++);
break;

default:
Throw("unknown 'mode' option");

}

// dump the banks and return the output
ShowBanks(inBanks, outBanks);
return outBanks;
}


void EventBuilderFilter::Stop() {
m_log->Debug("STOP {}", m_name);
}

}
30 changes: 30 additions & 0 deletions src/algorithms/clas12/event_builder_filter/EventBuilderFilter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include "services/Algorithm.h"

namespace iguana::clas12 {

class EventBuilderFilterOptions {
public:
enum Modes { blank, compact };
Modes mode = blank;
std::set<int> pids = {11, 211};
};


class EventBuilderFilter : public Algorithm {

public:
EventBuilderFilter() : Algorithm("event_builder_filter") {}
~EventBuilderFilter() {}

void Start() override;
Algorithm::BankMap Run(Algorithm::BankMap inBanks) override;
void Stop() override;

private:
EventBuilderFilterOptions m_opt;

};

}
3 changes: 3 additions & 0 deletions src/algorithms/clas12/event_builder_filter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Event Builder Filter

Filters a particle bank for specific Event Builder PDGs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
algo_headers = [
'FiducialCuts.h',
'clas12/event_builder_filter/EventBuilderFilter.h',
]

algo_sources = [
'FiducialCuts.cc',
'clas12/event_builder_filter/EventBuilderFilter.cc',
]

algo_lib = shared_library(
Expand Down
2 changes: 1 addition & 1 deletion src/iguana/Arbiter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace iguana {

Arbiter::Arbiter() {
algo_map.insert({clas12_FiducialCuts, std::make_shared<clas12::FiducialCuts>()});
algo_map.insert({clas12_EventBuilderFilter, std::make_shared<clas12::EventBuilderFilter>()});
}

}
4 changes: 2 additions & 2 deletions src/iguana/Arbiter.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <memory>

// TODO: avoid listing the algos
#include "algorithms/clas12/fiducial_cuts/FiducialCuts.h"
#include "algorithms/clas12/event_builder_filter/EventBuilderFilter.h"

namespace iguana {

Expand All @@ -18,7 +18,7 @@ namespace iguana {
// TODO: avoid listing the algos
// TODO: who should own the algorithm instances: Arbiter or the user?
enum algo {
clas12_FiducialCuts
clas12_EventBuilderFilter
};

// TODO: make private
Expand Down
1 change: 1 addition & 0 deletions src/iguana/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ iguana_lib = shared_library(
'Iguana',
iguana_sources,
include_directories: project_inc,
dependencies: [ fmt_dep, hipo_dep ],
link_with: [ algo_lib, services_lib ],
install: true,
install_dir: project_lib_install_dir,
Expand Down
50 changes: 48 additions & 2 deletions src/services/Algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,54 @@

namespace iguana {

Algorithm::Algorithm(std::string name) {
m_log = std::make_shared<Logger>(name);
Algorithm::Algorithm(std::string name) : m_name(name) {
m_log = std::make_shared<Logger>(m_name);
}

bool Algorithm::MissingInputBanks(BankMap banks, std::set<std::string> keys) {
for(auto key : keys) {
if(!banks.contains(key)) {
m_log->Error("Algorithm '{}' is missing the input bank '{}'", m_name, key);
m_log->Error(" => the following input banks are required by '{}':", m_name);
for(auto k : keys)
m_log->Error(" - {}", k);
return true;
}
}
return false;
}

void Algorithm::CopyBankRow(hipo::bank srcBank, int srcRow, hipo::bank destBank, int destRow) {
// TODO: check srcBank.getSchema() == destBank.getSchema()
for(int item = 0; item < srcBank.getSchema().getEntries(); item++) {
auto val = srcBank.get(item, srcRow);
destBank.put(item, destRow, val);
}
}

void Algorithm::BlankRow(hipo::bank bank, int row) {
for(int item = 0; item < bank.getSchema().getEntries(); item++) {
bank.put(item, row, 0);
}
}

void Algorithm::ShowBanks(BankMap banks, std::string message, Logger::Level level) {
if(m_log->GetLevel() <= level) {
m_log->Print(level, message);
for(auto [key,bank] : banks) {
m_log->Print(level, "BANK: '{}'", key);
bank.show();
}
}
}

void Algorithm::ShowBanks(BankMap inBanks, BankMap outBanks, Logger::Level level) {
ShowBanks(inBanks, "===== INPUT BANKS =====", level);
ShowBanks(outBanks, "===== OUTPUT BANKS =====", level);
}

void Algorithm::Throw(std::string message) {
throw std::runtime_error(fmt::format("CRITICAL ERROR: {}; Algorithm '{}' stopped!", message, m_name));
}

}
60 changes: 58 additions & 2 deletions src/services/Algorithm.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,75 @@
#pragma once

#include "Logger.h"
#include <hipo4/bank.h>
#include <set>

namespace iguana {

class Algorithm {

public:

using BankMap = std::unordered_map<std::string, hipo::bank>;

/// Algorithm base class constructor
/// @param name the unique name for a derived class instance
Algorithm(std::string name);

/// Algorithm base class destructor
virtual ~Algorithm() {}

/// Initialize an algorithm before any events are processed
virtual void Start() = 0;
virtual int Run(int a, int b) = 0;

/// Run an algorithm
/// @param inBanks the set of input banks
/// @return a set of output banks
virtual BankMap Run(BankMap inBanks) = 0;

/// Finalize an algorithm after all events are processed
virtual void Stop() = 0;
virtual ~Algorithm() {}

protected:

/// Check if `banks` contains all keys `keys`; this is useful for checking algorithm inputs are complete.
/// @param banks the set of (key,bank) pairs to check
/// @keys the required keys
/// @return true if `banks` is missing any keys in `keys`
bool MissingInputBanks(BankMap banks, std::set<std::string> keys);

/// Copy a row from one bank to another, assuming their schemata are equivalent
/// @param srcBank the source bank
/// @param srcRow the row in `srcBank` to copy from
/// @param destBank the destination bank
/// @param destRow the row in `destBank` to copy to
void CopyBankRow(hipo::bank srcBank, int srcRow, hipo::bank destBank, int destRow);

/// Blank a row, setting all items to zero
/// @param bank the bank to modify
/// @param row the row to blank
void BlankRow(hipo::bank bank, int row);

/// Dump all banks in a BankMap
/// @param banks the banks to show
/// @param message optionally print a header message
/// @param level the log level
void ShowBanks(BankMap banks, std::string message="", Logger::Level level=Logger::trace);

/// Dump all input and output banks
/// @param inBanks the input banks
/// @param outBanks the output banks
/// @param level the log level
void ShowBanks(BankMap inBanks, BankMap outBanks, Logger::Level level=Logger::trace);

/// Stop the algorithm and throw a runtime exception
/// @param message the error message
void Throw(std::string message);

/// algorithm name
std::string m_name;

/// Logger
std::shared_ptr<Logger> m_log;
};
}
4 changes: 4 additions & 0 deletions src/services/Logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ namespace iguana {
Debug("Logger '{}' set to '{}'", m_name, m_level_names.at(m_level));
}

Logger::Level Logger::GetLevel() {
return m_level;
}

}
31 changes: 14 additions & 17 deletions src/services/Logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,27 @@ namespace iguana {
~Logger() {}

void SetLevel(Level lev);
Level GetLevel();

template <typename... VALUES> void Trace(std::string msg, VALUES... vals) { Print(trace, msg, vals...); }
template <typename... VALUES> void Debug(std::string msg, VALUES... vals) { Print(debug, msg, vals...); }
template <typename... VALUES> void Info(std::string msg, VALUES... vals) { Print(info, msg, vals...); }
template <typename... VALUES> void Warn(std::string msg, VALUES... vals) { Print(warn, msg, vals...); }
template <typename... VALUES> void Error(std::string msg, VALUES... vals) { Print(error, msg, vals...); }
template <typename... VALUES> void Trace(std::string message, VALUES... vals) { Print(trace, message, vals...); }
template <typename... VALUES> void Debug(std::string message, VALUES... vals) { Print(debug, message, vals...); }
template <typename... VALUES> void Info(std::string message, VALUES... vals) { Print(info, message, vals...); }
template <typename... VALUES> void Warn(std::string message, VALUES... vals) { Print(warn, message, vals...); }
template <typename... VALUES> void Error(std::string message, VALUES... vals) { Print(error, message, vals...); }

template <typename... VALUES>
void Print(Level lev, std::string msg, VALUES... vals) {
void Print(Level lev, std::string message, VALUES... vals) {
if(lev >= m_level) {
auto level_name_it = m_level_names.find(lev);
if(level_name_it == m_level_names.end()) {
Warn("Logger::Print called with unknown log level '{}'; printing as error instead", static_cast<int>(lev)); // FIXME: static_cast -> fmt::underlying, but needs new version of fmt
Error(msg, vals...);
} else {
if(m_level_names.contains(lev)) {
auto prefix = fmt::format("[{}] [{}] ", m_level_names.at(lev), m_name);
fmt::print(
lev >= warn ? stderr : stdout,
fmt::format(
"[{}] [{}] {}\n",
level_name_it->second,
m_name,
fmt::format(msg, vals...)
)
fmt::runtime(prefix + message + "\n"),
vals...
);
} else {
Warn("Logger::Print called with unknown log level '{}'; printing as error instead", static_cast<int>(lev)); // FIXME: static_cast -> fmt::underlying, but needs new version of fmt
Error(message, vals...);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/services/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ services_lib = shared_library(
'IguanaServices',
services_sources,
include_directories: project_inc,
dependencies: fmt_dep,
dependencies: [ fmt_dep, hipo_dep ],
install: true,
install_dir: project_lib_install_dir,
install_rpath: project_lib_rpath,
)

services_dep = declare_dependency(
dependencies: fmt_dep
dependencies: [ fmt_dep, hipo_dep ]
)

install_headers(services_headers, subdir : meson.project_name())
Loading