Skip to content

Commit

Permalink
[Core][Channel] Refactor migrate and broadcast channel (#3)
Browse files Browse the repository at this point in the history
* [Core][Channel] refactor migrate channel

* [Core][Channel] Refactor BroadcastChannel
Add unit tests for broadcast channel and migrate channel
ref husky-team#228
  • Loading branch information
ClydeZhao authored and ddmbr committed Feb 8, 2017
1 parent 0b24cf2 commit 16a1f82
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 170 deletions.
17 changes: 7 additions & 10 deletions core/accessor_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,15 @@ class AccessorStore {
template <typename CollectT>
static std::vector<Accessor<CollectT>>* create_accessor(size_t channel_id, size_t local_id,
size_t num_local_threads) {
// double-checked locking
std::lock_guard<std::mutex> lock(accessors_map_mutex);
if (accessors_map.find(channel_id) == accessors_map.end()) {
std::lock_guard<std::mutex> lock(accessors_map_mutex);
if (accessors_map.find(channel_id) == accessors_map.end()) {
AccessorSet<CollectT>* accessor_set = new AccessorSet<CollectT>();
accessor_set->data.resize(num_local_threads);
for (auto& i : accessor_set->data) {
i.init(num_local_threads);
}
AccessorStore::num_local_threads.insert(std::make_pair(channel_id, num_local_threads));
accessors_map.insert(std::make_pair(channel_id, accessor_set));
AccessorSet<CollectT>* accessor_set = new AccessorSet<CollectT>();
accessor_set->data.resize(num_local_threads);
for (auto& i : accessor_set->data) {
i.init(num_local_threads);
}
AccessorStore::num_local_threads.insert({channel_id, num_local_threads});
accessors_map.insert({channel_id, accessor_set});
}
auto& data = dynamic_cast<AccessorSet<CollectT>*>(accessors_map[channel_id])->data;
// data[local_id].init();
Expand Down
39 changes: 8 additions & 31 deletions core/channel/broadcast_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,13 @@ using base::BinStream;
template <typename KeyT, typename ValueT>
class BroadcastChannel : public ChannelBase {
public:
explicit BroadcastChannel(ChannelSource* src) : src_ptr_(src) {
// TODO(yuzhen): Should be careful, maybe need to deregister every time?
src_ptr_->register_outchannel(channel_id_, this);
}
BroadcastChannel() = default;

~BroadcastChannel() override {
// Make sure to invoke inc_progress_ before destructor
if (need_leave_accessor_)
leave_accessor();
AccessorStore::remove_accessor(channel_id_);
src_ptr_->deregister_outchannel(channel_id_);
}

BroadcastChannel(const BroadcastChannel&) = delete;
Expand All @@ -53,7 +49,7 @@ class BroadcastChannel : public ChannelBase {
BroadcastChannel(BroadcastChannel&&) = default;
BroadcastChannel& operator=(BroadcastChannel&&) = default;

void customized_setup() override {
void buffer_accessor_setup() {
broadcast_buffer_.resize(worker_info_->get_largest_tid() + 1);
accessor_ = AccessorStore::create_accessor<std::unordered_map<KeyT, ValueT>>(
channel_id_, local_id_, worker_info_->get_num_local_workers());
Expand Down Expand Up @@ -92,17 +88,9 @@ class BroadcastChannel : public ChannelBase {

void set_clear_dict(bool clear) { clear_dict_each_progress_ = clear; }

void prepare() override {}

void in(BinStream& bin) override {}

void out() override {
flush();
prepare_broadcast();
}
std::unordered_map<KeyT, ValueT>& get_local_dict() { return (*accessor_)[local_id_].storage(); }

/// This method is only useful without list_execute
void flush() {
void send() override {
this->inc_progress();
int start = global_id_;
for (int i = 0; i < broadcast_buffer_.size(); ++i) {
Expand All @@ -116,17 +104,17 @@ class BroadcastChannel : public ChannelBase {
this->worker_info_->get_pids());
}

/// This method is only useful without list_execute
void prepare_broadcast() {
void recv() override {
// Check whether need to leave accessor_ (last round's accessor_)
if (need_leave_accessor_)
leave_accessor();
need_leave_accessor_ = true;

auto& local_dict = (*accessor_)[local_id_].storage();
while (mailbox_->poll(channel_id_, progress_)) {
auto bin = mailbox_->recv(channel_id_, progress_);
process_bin(bin, local_dict);
if (bin_stream_processor_ != nullptr) {
bin_stream_processor_(&bin);
}
}
(*accessor_)[local_id_].commit();
}
Expand All @@ -144,17 +132,6 @@ class BroadcastChannel : public ChannelBase {
}
}

void process_bin(BinStream& bin, std::unordered_map<KeyT, ValueT>& local_dict) {
while (bin.size() != 0) {
KeyT key;
ValueT value;
bin >> key >> value;
local_dict[key] = value;
}
}

ChannelSource* src_ptr_;

bool clear_dict_each_progress_ = false;
bool need_leave_accessor_ = false;
std::vector<BinStream> broadcast_buffer_;
Expand Down
71 changes: 44 additions & 27 deletions core/channel/broadcast_channel_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,19 @@ class Obj {
};

// Create broadcast without setting
template <typename KeyT, typename ValueT>
BroadcastChannel<KeyT, ValueT> create_broadcast_channel(ChannelSource& src_list) {
BroadcastChannel<KeyT, ValueT> broadcast_channel(&src_list);
return broadcast_channel;
template <typename KeyT, typename MsgT>
BroadcastChannel<KeyT, MsgT> create_broadcast_channel() {
auto ch = BroadcastChannel<KeyT, MsgT>();
ch.set_bin_stream_processor([&](base::BinStream* bin_stream) {
auto& local_dict = ch.get_local_dict();
while (bin_stream->size() != 0) {
KeyT key;
MsgT value;
*bin_stream >> key >> value;
local_dict[key] = value;
}
});
return ch;
}

TEST_F(TestBroadcastChannel, Create) {
Expand All @@ -64,8 +73,10 @@ TEST_F(TestBroadcastChannel, Create) {
ObjList<Obj> src_list;

// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, int>(src_list);
auto broadcast_channel = create_broadcast_channel<int, int>();

broadcast_channel.setup(0, 0, workerinfo, &mailbox);
broadcast_channel.buffer_accessor_setup();
}

TEST_F(TestBroadcastChannel, Broadcast) {
Expand All @@ -91,34 +102,36 @@ TEST_F(TestBroadcastChannel, Broadcast) {
ObjList<Obj> src_list;

// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>(src_list);
auto broadcast_channel = create_broadcast_channel<int, std::string>();

broadcast_channel.setup(0, 0, workerinfo, &mailbox);
broadcast_channel.buffer_accessor_setup();

// broadcast
// Round 1
broadcast_channel.broadcast(23, "abc");
broadcast_channel.broadcast(45, "bbb");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc");
EXPECT_EQ(broadcast_channel.get(45), "bbb");

// Round 2
broadcast_channel.broadcast(23, "a");
broadcast_channel.broadcast(45, "b");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "a");
EXPECT_EQ(broadcast_channel.get(45), "b");

// Round 3
broadcast_channel.broadcast(23, "c");
broadcast_channel.broadcast(45, "d");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "c");
EXPECT_EQ(broadcast_channel.get(45), "d");
}
Expand Down Expand Up @@ -146,28 +159,30 @@ TEST_F(TestBroadcastChannel, BroadcastClearDict) {
ObjList<Obj> src_list;

// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>(src_list);
auto broadcast_channel = create_broadcast_channel<int, std::string>();

broadcast_channel.setup(0, 0, workerinfo, &mailbox);
broadcast_channel.buffer_accessor_setup();

// broadcast
// Round 1
broadcast_channel.broadcast(23, "abc");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc");

// Round 2
broadcast_channel.flush();
broadcast_channel.prepare_broadcast();
broadcast_channel.out();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc"); // Last round result remain valid

// set clear dict
broadcast_channel.set_clear_dict(true);

// Round 3
broadcast_channel.flush();
broadcast_channel.prepare_broadcast();
broadcast_channel.out();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.find(23), false); // Last round result is invalid
}

Expand Down Expand Up @@ -201,33 +216,35 @@ TEST_F(TestBroadcastChannel, MultiThread) {
// ObjList Setup
ObjList<Obj> src_list;

// BroacastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>(src_list);
// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>();
broadcast_channel.setup(0, 0, workerinfo, &mailbox_0);
broadcast_channel.buffer_accessor_setup();

// broadcast
// Round 1
broadcast_channel.broadcast(23, "abc");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc");
EXPECT_EQ(broadcast_channel.get(12), "ddd");
});
std::thread th2 = std::thread([&]() {
// ObjList Setup
ObjList<Obj> src_list;

// BroacastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>(src_list);
// BroadcastChannel
auto broadcast_channel = create_broadcast_channel<int, std::string>();
broadcast_channel.setup(1, 1, workerinfo, &mailbox_1);
broadcast_channel.buffer_accessor_setup();

// broadcast
// Round 1
broadcast_channel.broadcast(12, "ddd");
broadcast_channel.flush();
broadcast_channel.out();

broadcast_channel.prepare_broadcast();
broadcast_channel.in();
EXPECT_EQ(broadcast_channel.get(23), "abc");
EXPECT_EQ(broadcast_channel.get(12), "ddd");
});
Expand Down
13 changes: 8 additions & 5 deletions core/channel/channel_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ namespace husky {

thread_local int ChannelBase::max_channel_id_ = 0;

ChannelBase::ChannelBase() : channel_id_(max_channel_id_), progress_(0) {
max_channel_id_ += 1;
}
ChannelBase::ChannelBase() : channel_id_(max_channel_id_), progress_(0) { max_channel_id_ += 1; }

void ChannelBase::inc_progress() {
progress_ += 1;
void ChannelBase::setup(size_t local_id, size_t global_id, const WorkerInfo& worker_info, LocalMailbox* mailbox) {
set_local_id(local_id);
set_global_id(global_id);
set_worker_info(worker_info);
set_mailbox(mailbox);
}

void ChannelBase::inc_progress() { progress_ += 1; }

} // namespace husky
3 changes: 3 additions & 0 deletions core/channel/channel_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class ChannelBase {
virtual void set_worker_info(const WorkerInfo& worker_info) { worker_info_.reset(new WorkerInfo(worker_info)); }
void set_mailbox(LocalMailbox* mailbox) { mailbox_ = mailbox; }

// Setup API for unit test
void setup(size_t local_id, size_t global_id, const WorkerInfo& worker_info, LocalMailbox* mailbox);

// Top-level APIs

virtual void in() {
Expand Down
38 changes: 30 additions & 8 deletions core/channel/channel_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,41 @@ class ChannelStore : public ChannelStoreBase {

// Create MigrateChannel
template <typename ObjT>
static MigrateChannel<ObjT>& create_migrate_channel(ObjList<ObjT>& src_list, ObjList<ObjT>& dst_list,
const std::string& name = "") {
auto& ch = ChannelStoreBase::create_migrate_channel<ObjT>(src_list, dst_list, name);
setup(ch);
static auto* create_migrate_channel(ObjList<ObjT>* src_list, ObjList<ObjT>* dst_list,
const std::string& name = "") {
auto* ch = ChannelStoreBase::create_migrate_channel<ObjT>(*src_list, *dst_list, name);
common_setup(ch);
ch->set_obj_list(src_list);
ch->buffer_setup();
ch->set_bin_stream_processor([=](base::BinStream* bin_stream) {
while (bin_stream->size() != 0) {
ObjT obj;
*bin_stream >> obj;
auto idx = dst_list->add_object(std::move(obj));
dst_list->process_attribute(*bin_stream, idx);
}
if (dst_list->get_num_del() * 2 > dst_list->get_vector_size())
dst_list->deletion_finalize();
});
return ch;
}

// Create BroadcastChannel
template <typename KeyT, typename MsgT>
static BroadcastChannel<KeyT, MsgT>& create_broadcast_channel(ChannelSource& src_list,
const std::string& name = "") {
auto& ch = ChannelStoreBase::create_broadcast_channel<KeyT, MsgT>(src_list, name);
setup(ch);
static BroadcastChannel<KeyT, MsgT>& create_broadcast_channel(const std::string& name = "") {
auto* ch = ChannelStoreBase::create_broadcast_channel<KeyT, MsgT>(name);
common_setup(ch);
ch->buffer_accessor_setup();
auto& local_dict = ch->get_local_dict();
ch->set_bin_stream_processor([=](base::BinStream* bin_stream) {
auto& local_dict = ch->get_local_dict();
while (bin_stream->size() != 0) {
KeyT key;
MsgT value;
*bin_stream >> key >> value;
local_dict[key] = value;
}
});
return ch;
}

Expand Down
Loading

0 comments on commit 16a1f82

Please sign in to comment.