Skip to content

Commit

Permalink
[Core][Channel] Refactor BroadcastChannel
Browse files Browse the repository at this point in the history
Add unit tests for broadcast channel and migrate channel
ref husky-team#228
  • Loading branch information
ClydeZhao committed Feb 7, 2017
1 parent 881d944 commit 0dce7ef
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 132 deletions.
17 changes: 7 additions & 10 deletions core/accessor_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,15 @@ class AccessorStore {
template <typename CollectT>
static std::vector<Accessor<CollectT>>* create_accessor(size_t channel_id, size_t local_id,
size_t num_local_threads) {
// double-checked locking
std::lock_guard<std::mutex> lock(accessors_map_mutex);
if (accessors_map.find(channel_id) == accessors_map.end()) {
std::lock_guard<std::mutex> lock(accessors_map_mutex);
if (accessors_map.find(channel_id) == accessors_map.end()) {
AccessorSet<CollectT>* accessor_set = new AccessorSet<CollectT>();
accessor_set->data.resize(num_local_threads);
for (auto& i : accessor_set->data) {
i.init(num_local_threads);
}
AccessorStore::num_local_threads.insert(std::make_pair(channel_id, num_local_threads));
accessors_map.insert(std::make_pair(channel_id, accessor_set));
AccessorSet<CollectT>* accessor_set = new AccessorSet<CollectT>();
accessor_set->data.resize(num_local_threads);
for (auto& i : accessor_set->data) {
i.init(num_local_threads);
}
AccessorStore::num_local_threads.insert({channel_id, num_local_threads});
accessors_map.insert({channel_id, accessor_set});
}
auto& data = dynamic_cast<AccessorSet<CollectT>*>(accessors_map[channel_id])->data;
// data[local_id].init();
Expand Down
39 changes: 8 additions & 31 deletions core/channel/broadcast_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,13 @@ using base::BinStream;
template <typename KeyT, typename ValueT>
class BroadcastChannel : public ChannelBase {
public:
explicit BroadcastChannel(ChannelSource* src) : src_ptr_(src) {
// TODO(yuzhen): Should be careful, maybe need to deregister every time?
src_ptr_->register_outchannel(channel_id_, this);
}
BroadcastChannel() = default;

~BroadcastChannel() override {
// Make sure to invoke inc_progress_ before destructor
if (need_leave_accessor_)
leave_accessor();
AccessorStore::remove_accessor(channel_id_);
src_ptr_->deregister_outchannel(channel_id_);
}

BroadcastChannel(const BroadcastChannel&) = delete;
Expand All @@ -53,7 +49,7 @@ class BroadcastChannel : public ChannelBase {
BroadcastChannel(BroadcastChannel&&) = default;
BroadcastChannel& operator=(BroadcastChannel&&) = default;

void customized_setup() override {
void buffer_accessor_setup() {
broadcast_buffer_.resize(worker_info_->get_largest_tid() + 1);
accessor_ = AccessorStore::create_accessor<std::unordered_map<KeyT, ValueT>>(
channel_id_, local_id_, worker_info_->get_num_local_workers());
Expand Down Expand Up @@ -92,17 +88,9 @@ class BroadcastChannel : public ChannelBase {

void set_clear_dict(bool clear) { clear_dict_each_progress_ = clear; }

void prepare() override {}

void in(BinStream& bin) override {}

void out() override {
flush();
prepare_broadcast();
}
std::unordered_map<KeyT, ValueT>& get_local_dict() { return (*accessor_)[local_id_].storage(); }

/// This method is only useful without list_execute
void flush() {
void send() override {
this->inc_progress();
int start = global_id_;
for (int i = 0; i < broadcast_buffer_.size(); ++i) {
Expand All @@ -116,17 +104,17 @@ class BroadcastChannel : public ChannelBase {
this->worker_info_->get_pids());
}

/// This method is only useful without list_execute
void prepare_broadcast() {
void recv() override {
// Check whether need to leave accessor_ (last round's accessor_)
if (need_leave_accessor_)
leave_accessor();
need_leave_accessor_ = true;

auto& local_dict = (*accessor_)[local_id_].storage();
while (mailbox_->poll(channel_id_, progress_)) {
auto bin = mailbox_->recv(channel_id_, progress_);
process_bin(bin, local_dict);
if (bin_stream_processor_ != nullptr) {
bin_stream_processor_(&bin);
}
}
(*accessor_)[local_id_].commit();
}
Expand All @@ -144,17 +132,6 @@ class BroadcastChannel : public ChannelBase {
}
}

void process_bin(BinStream& bin, std::unordered_map<KeyT, ValueT>& local_dict) {
while (bin.size() != 0) {
KeyT key;
ValueT value;
bin >> key >> value;
local_dict[key] = value;
}
}

ChannelSource* src_ptr_;

bool clear_dict_each_progress_ = false;
bool need_leave_accessor_ = false;
std::vector<BinStream> broadcast_buffer_;
Expand Down
71 changes: 44 additions & 27 deletions core/channel/broadcast_channel_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,19 @@ class Obj {
};

// Create broadcast without setting
template <typename KeyT, typename ValueT>
BroadcastChannel<KeyT, ValueT> create_broadcast_channel(ChannelSource& src_list) {
BroadcastChannel<KeyT, ValueT> broadcast_channel(&src_list);
return broadcast_channel;
template <typename KeyT, typename MsgT>
BroadcastChannel<KeyT, MsgT> create_broadcast_channel() {
auto ch = BroadcastChannel<KeyT, MsgT>();
ch.set_bin_stream_processor([&](base::BinStream* bin_stream) {
auto& local_dict = ch.get_local_dict();
while (bin_stream->size() != 0) {
KeyT key;
MsgT value;
*bin_stream >> key >> value;
local_dict[key] = value;
}
});
return ch;
}

TEST_F(TestBroadcastChannel, Create) {
Expand All @@ -64,8 +73,10 @@ TEST_F(TestBroadcastChannel, Create) {
ObjList<Obj> src_list;

// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, int>(src_list);
auto broadcast_channel = create_broadcast_channel<int, int>();

broadcast_channel.setup(0, 0, workerinfo, &mailbox);
broadcast_channel.buffer_accessor_setup();
}

TEST_F(TestBroadcastChannel, Broadcast) {
Expand All @@ -91,34 +102,36 @@ TEST_F(TestBroadcastChannel, Broadcast) {
ObjList<Obj> src_list;

// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>(src_list);
auto broadcast_channel = create_broadcast_channel<int, std::string>();

broadcast_channel.setup(0, 0, workerinfo, &mailbox);
broadcast_channel.buffer_accessor_setup();

// broadcast
// Round 1
broadcast_channel.broadcast(23, "abc");
broadcast_channel.broadcast(45, "bbb");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc");
EXPECT_EQ(broadcast_channel.get(45), "bbb");

// Round 2
broadcast_channel.broadcast(23, "a");
broadcast_channel.broadcast(45, "b");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "a");
EXPECT_EQ(broadcast_channel.get(45), "b");

// Round 3
broadcast_channel.broadcast(23, "c");
broadcast_channel.broadcast(45, "d");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "c");
EXPECT_EQ(broadcast_channel.get(45), "d");
}
Expand Down Expand Up @@ -146,28 +159,30 @@ TEST_F(TestBroadcastChannel, BroadcastClearDict) {
ObjList<Obj> src_list;

// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>(src_list);
auto broadcast_channel = create_broadcast_channel<int, std::string>();

broadcast_channel.setup(0, 0, workerinfo, &mailbox);
broadcast_channel.buffer_accessor_setup();

// broadcast
// Round 1
broadcast_channel.broadcast(23, "abc");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc");

// Round 2
broadcast_channel.flush();
broadcast_channel.prepare_broadcast();
broadcast_channel.out();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc"); // Last round result remain valid

// set clear dict
broadcast_channel.set_clear_dict(true);

// Round 3
broadcast_channel.flush();
broadcast_channel.prepare_broadcast();
broadcast_channel.out();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.find(23), false); // Last round result is invalid
}

Expand Down Expand Up @@ -201,33 +216,35 @@ TEST_F(TestBroadcastChannel, MultiThread) {
// ObjList Setup
ObjList<Obj> src_list;

// BroacastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>(src_list);
// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>();
broadcast_channel.setup(0, 0, workerinfo, &mailbox_0);
broadcast_channel.buffer_accessor_setup();

// broadcast
// Round 1
broadcast_channel.broadcast(23, "abc");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc");
EXPECT_EQ(broadcast_channel.get(12), "ddd");
});
std::thread th2 = std::thread([&]() {
// ObjList Setup
ObjList<Obj> src_list;

// BroacastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>(src_list);
// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>();
broadcast_channel.setup(1, 1, workerinfo, &mailbox_1);
broadcast_channel.buffer_accessor_setup();

// broadcast
// Round 1
broadcast_channel.broadcast(12, "ddd");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc");
EXPECT_EQ(broadcast_channel.get(12), "ddd");
});
Expand Down
13 changes: 8 additions & 5 deletions core/channel/channel_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ namespace husky {

thread_local int ChannelBase::max_channel_id_ = 0;

ChannelBase::ChannelBase() : channel_id_(max_channel_id_), progress_(0) {
max_channel_id_ += 1;
}
ChannelBase::ChannelBase() : channel_id_(max_channel_id_), progress_(0) { max_channel_id_ += 1; }

void ChannelBase::inc_progress() {
progress_ += 1;
void ChannelBase::setup(size_t local_id, size_t global_id, const WorkerInfo& worker_info, LocalMailbox* mailbox) {
set_local_id(local_id);
set_global_id(global_id);
set_worker_info(worker_info);
set_mailbox(mailbox);
}

void ChannelBase::inc_progress() { progress_ += 1; }

} // namespace husky
21 changes: 11 additions & 10 deletions core/channel/channel_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class ChannelBase {
void set_worker_info(const WorkerInfo& worker_info) { worker_info_.reset(new WorkerInfo(worker_info)); }
void set_mailbox(LocalMailbox* mailbox) { mailbox_ = mailbox; }

// Setup API for unit test
void setup(size_t local_id, size_t global_id, const WorkerInfo& worker_info, LocalMailbox* mailbox);

// Top-level APIs

virtual void in() {
Expand All @@ -60,30 +63,28 @@ class ChannelBase {

virtual void recv() {
// A simple default synchronous implementation
if(mailbox_ == nullptr)
if (mailbox_ == nullptr)
throw base::HuskyException("Local mailbox not set, and thus cannot use the recv() method.");

while(mailbox_->poll(channel_id_, progress_)) {
while (mailbox_->poll(channel_id_, progress_)) {
base::BinStream bin_stream = mailbox_->recv(channel_id_, progress_);
if(bin_stream_processor_ != nullptr)
if (bin_stream_processor_ != nullptr)
bin_stream_processor_(&bin_stream);
}
};

virtual void post_recv() {};
virtual void pre_send() {};
virtual void send() {};
virtual void post_send() {};
virtual void post_recv(){};
virtual void pre_send(){};
virtual void send(){};
virtual void post_send(){};

// Third-level APIs (invoked by its upper level)

void set_bin_stream_processor(std::function<void(base::BinStream*)> bin_stream_processor) {
bin_stream_processor_ = bin_stream_processor;
}

std::function<void(base::BinStream*)> get_bin_stream_processor() {
return bin_stream_processor_;
}
std::function<void(base::BinStream*)> get_bin_stream_processor() { return bin_stream_processor_; }

void inc_progress();

Expand Down
Loading

0 comments on commit 0dce7ef

Please sign in to comment.