Skip to content

Commit

Permalink
JPipelineArrow uses PlaceRef's
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanwbrei committed Dec 27, 2023
1 parent 4447bf6 commit 3836c18
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 111 deletions.
55 changes: 52 additions & 3 deletions src/libraries/JANA/Engine/JArrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ class JArrow {
virtual size_t get_pending();

// TODO: Get rid of me
virtual size_t get_threshold() { return 0; }
virtual size_t get_threshold();

virtual void set_threshold(size_t /* threshold */) {}
virtual void set_threshold(size_t /* threshold */);

void attach(JArrow* downstream) {
m_listeners.push_back(downstream);
Expand Down Expand Up @@ -128,6 +128,8 @@ struct PlaceRefBase {
size_t max_item_count = 1;

virtual size_t get_pending() { return 0; }
virtual size_t get_threshold() { return 0; }
virtual void set_threshold(size_t) {}
};

template <typename T>
Expand All @@ -138,6 +140,14 @@ struct PlaceRef : public PlaceRefBase {
parent->attach(this);
}

PlaceRef(JArrow* parent, bool is_input, size_t min_item_count, size_t max_item_count) {
assert(parent != nullptr);
parent->attach(this);
this->is_input = is_input;
this->min_item_count = min_item_count;
this->max_item_count = max_item_count;
}

PlaceRef(JArrow* parent, JMailbox<T*>* queue, bool is_input, size_t min_item_count, size_t max_item_count) {
assert(parent != nullptr);
parent->attach(this);
Expand All @@ -158,7 +168,16 @@ struct PlaceRef : public PlaceRefBase {
this->max_item_count = max_item_count;
}

// TODO: We can get de-virtualize this if we go the parameter pack route
void set_queue(JMailbox<T*>* queue) {
this->place_ref = queue;
this->is_queue = true;
}

void set_pool(JPool<T>* pool) {
this->place_ref = pool;
this->is_queue = false;
}

size_t get_pending() override {
if (is_input && is_queue) {
auto queue = static_cast<JMailbox<T*>*>(place_ref);
Expand All @@ -167,6 +186,21 @@ struct PlaceRef : public PlaceRefBase {
return 0;
}

size_t get_threshold() override {
if (is_input && is_queue) {
auto queue = static_cast<JMailbox<T*>*>(place_ref);
return queue->get_threshold();
}
return -1;
}

void set_threshold(size_t threshold) override {
if (is_input && is_queue) {
auto queue = static_cast<JMailbox<T*>*>(place_ref);
queue->set_threshold(threshold);
}
}

bool pull(Data<T>& data) {
if (is_input) { // Actually pull the data
if (is_queue) {
Expand Down Expand Up @@ -238,5 +272,20 @@ inline size_t JArrow::get_pending() {
return sum;
}

inline size_t JArrow::get_threshold() {
size_t result = -1;
for (PlaceRefBase* place : m_places) {
result = std::min(result, place->get_threshold());
}
return result;

}

inline void JArrow::set_threshold(size_t threshold) {
for (PlaceRefBase* place : m_places) {
place->set_threshold(threshold);
}
}


#endif // GREENFIELD_ARROW_H
27 changes: 4 additions & 23 deletions src/libraries/JANA/Engine/JJunctionArrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,6 @@ class JJunctionArrow : public JArrow {
{
}

size_t get_pending() final {
// This is actually used by JScheduler for better or for worse
size_t first_pending = first_input.is_queue ? (static_cast<JMailbox<FirstT*>*>(first_input.place_ref))->size() : 0;
size_t second_pending = second_input.is_queue ? (static_cast<JMailbox<SecondT*>*>(second_input.place_ref))->size() : 0;
return first_pending + second_pending;
};

size_t get_threshold() final {
// TODO: Is this even meaningful? Only used in JArrowSummary I think -- Maybe get rid of this eventually?
return 0;
}

void set_threshold(size_t) final { }


bool try_pull_all(Data<FirstT>& fi, Data<FirstT>& fo, Data<SecondT>& si, Data<SecondT>& so) {

bool success;
Expand Down Expand Up @@ -86,14 +71,10 @@ class JJunctionArrow : public JArrow {

auto start_total_time = std::chrono::steady_clock::now();

Data<FirstT> first_input_data;
Data<FirstT> first_output_data;
Data<SecondT> second_input_data;
Data<SecondT> second_output_data;
first_input_data.location_id = location_id;
first_output_data.location_id = location_id;
second_input_data.location_id = location_id;
second_output_data.location_id = location_id;
Data<FirstT> first_input_data {location_id};
Data<FirstT> first_output_data {location_id};
Data<SecondT> second_input_data {location_id};
Data<SecondT> second_output_data {location_id};

bool success = try_pull_all(first_input_data, first_output_data, second_input_data, second_output_data);
if (success) {
Expand Down
120 changes: 35 additions & 85 deletions src/libraries/JANA/Engine/JPipelineArrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
template <typename DerivedT, typename MessageT>
class JPipelineArrow : public JArrow {
private:
JMailbox<MessageT*>* m_input_queue;
JMailbox<MessageT*>* m_output_queue;
JPool<MessageT>* m_pool;
PlaceRef<MessageT> m_input {this, true, 1, 1};
PlaceRef<MessageT> m_output {this, false, 1, 1};

public:

JPipelineArrow(std::string name,
bool is_parallel,
bool is_source,
Expand All @@ -25,106 +23,58 @@ class JPipelineArrow : public JArrow {
JMailbox<MessageT*>* output_queue,
JPool<MessageT>* pool
)
: JArrow(std::move(name), is_parallel, is_source, is_sink),
m_input_queue(input_queue),
m_output_queue(output_queue),
m_pool(pool)
{
}

size_t get_pending() final { return (m_input_queue == nullptr) ? 0 : m_input_queue->size(); };

size_t get_threshold() final { return (m_input_queue == nullptr) ? 0 : m_input_queue->get_threshold(); }
: JArrow(std::move(name), is_parallel, is_source, is_sink) {

void set_threshold(size_t threshold) final { if (m_input_queue != nullptr) m_input_queue->set_threshold(threshold); }
if (input_queue == nullptr) {
assert(pool != nullptr);
m_input.set_pool(pool);
}
else {
m_input.set_queue(input_queue);
}
if (output_queue == nullptr) {
assert(pool != nullptr);
m_output.set_pool(pool);
}
else {
m_output.set_queue(output_queue);
}
}

void execute(JArrowMetrics& result, size_t location_id) final {

auto start_total_time = std::chrono::steady_clock::now();

// ===================================
// Reserve output before popping input
// ===================================
bool reserve_succeeded = true;
if (m_output_queue != nullptr) {
auto reserved_count = m_output_queue->reserve(1, location_id);
reserve_succeeded = (reserved_count != 0);
}
if (!reserve_succeeded) {
// Exit early!
auto end_total_time = std::chrono::steady_clock::now();
result.update(JArrowMetrics::Status::ComeBackLater, 0, 1, std::chrono::milliseconds(0), end_total_time - start_total_time);
return;
}

// =========
// Pop input
// =========
bool pop_succeeded = false;
MessageT* event;
if (m_input_queue == nullptr) {
// Obtain from pool
event = m_pool->get(location_id);
pop_succeeded = (event != nullptr);
}
else {
// Obtain from queue
size_t popped_count = m_input_queue->pop_and_reserve(&event, 1, 1, location_id);
pop_succeeded = (popped_count == 1);;
Data<MessageT> in_data {location_id};
Data<MessageT> out_data {location_id};

}
if (!pop_succeeded) {
// Exit early!
bool success = m_input.pull(in_data) && m_output.pull(out_data);
if (!success) {
m_input.revert(in_data);
m_output.revert(out_data);
// TODO: Test that revert works properly

auto end_total_time = std::chrono::steady_clock::now();
result.update(JArrowMetrics::Status::ComeBackLater, 0, 1, std::chrono::milliseconds(0), end_total_time - start_total_time);
return;
}


// ========================
// Process individual event
// ========================

auto start_processing_time = std::chrono::steady_clock::now();

bool process_succeeded = true;
JArrowMetrics::Status process_status = JArrowMetrics::Status::KeepGoing;
static_cast<DerivedT*>(this)->process(event, process_succeeded, process_status);
assert(in_data.item_count == 1);
MessageT* event = in_data.items[0];

auto start_processing_time = std::chrono::steady_clock::now();
static_cast<DerivedT*>(this)->process(event, process_succeeded, process_status);
auto end_processing_time = std::chrono::steady_clock::now();


// ==========
// Push event
// ==========
if (process_succeeded) {
// process() succeeded, so we push our event to the output queue/pool
if (m_output_queue != nullptr) {
// Push event to the output queue. This always succeeds due to reserve().
m_output_queue->push_and_unreserve(&event, 1, 1, location_id);
}
else {
// Push event to the output pool. This always succeeds.
m_pool->put(event, location_id);
}
if (m_input_queue != nullptr) {
m_input_queue->unreserve(1, location_id);
}
}
else {
// process() failed, so we return the event to the input queue/pool
if (m_input_queue != nullptr) {
// Return event to input queue. This always succeeds due to pop_and_reserve().
m_input_queue->push_and_unreserve(&event, 1, 1, location_id);
}
else {
// Return event to input pool. This always succeeds.
m_pool->put(event, location_id);
}
if (m_output_queue != nullptr) {
m_output_queue->unreserve(1, location_id);
}
in_data.item_count = 0;
out_data.item_count = 1;
out_data.items[0] = event;
}
m_input.push(in_data);
m_output.push(out_data);

// Publish metrics
auto end_total_time = std::chrono::steady_clock::now();
Expand Down

0 comments on commit 3836c18

Please sign in to comment.