Skip to content

Commit

Permalink
Use TSL threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 14, 2023
1 parent 3b34cb2 commit bf4ff6f
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 196 deletions.
2 changes: 1 addition & 1 deletion test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void TestSingleReplication(
tensors_data[i])},
device_strings[i], exec_options);
};
torch_xla::runtime::env::ScheduleIoClosure(
torch_xla::runtime::Schedule(
mwait.Completer(std::move(executor)));
}
mwait.Wait();
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ cc_library(
deps = [
":metrics",
":tf_logging",
"@tsl//tsl/platform:env"
],
)

Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ PjRtComputationClient::ExecuteComputation(
});
};

env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn)));
Schedule(util::MultiWait::Completer(mwait, std::move(lockfn)));

TF_VLOG(1) << "Returning " << datas.size() << " results";
return datas;
Expand Down Expand Up @@ -690,7 +690,7 @@ PjRtComputationClient::ExecuteReplicated(
}
argument_handles[i] = std::move(buffers);
};
env::ScheduleIoClosure(util::MultiWait::Completer(
Schedule(util::MultiWait::Completer(
mwait_argument, std::move(buffer_converter)));
}
mwait_argument->Wait();
Expand Down Expand Up @@ -772,7 +772,7 @@ PjRtComputationClient::ExecuteReplicated(
TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished";
});
};
env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn)));
Schedule(util::MultiWait::Completer(mwait, std::move(lockfn)));

TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results "
<< "with dimensions [" << absl::StrJoin(dims, ",") << "].";
Expand Down
169 changes: 5 additions & 164 deletions torch_xla/csrc/runtime/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,177 +7,18 @@

#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "tsl/platform/env.h"
#include "tsl/platform/threadpool.h"

namespace torch_xla {
namespace runtime {
namespace env {
namespace {

class ThreadPool {
public:
explicit ThreadPool(size_t num_threads) {
threads_.reserve(num_threads);
for (size_t i = 0; i < num_threads; ++i) {
threads_.emplace_back([this]() { Worker(); });
}
}

~ThreadPool() {
{
std::lock_guard<std::mutex> lock(mutex_);
exiting_ = true;
cv_.notify_all();
}
for (auto& thread : threads_) {
thread.join();
}
}

void Schedule(std::function<void()> closure) {
// If we have more work scheduled than waiting worker threads, just schedule
// it on a separate thread. This prevents tricky thread-pool-size-deadlocks
// caused by an undersized thread pool and closures that end up doing sync
// waits on the pool threads.
bool scheduled = false;
{
std::lock_guard<std::mutex> lock(mutex_);
if (work_.size() < waiting_) {
work_.emplace_back(std::move(closure));
scheduled = true;
}
}
if (scheduled) {
cv_.notify_one();
} else {
ScheduleOnThread(std::move(closure));
}
}

private:
void Worker() {
while (true) {
std::function<void()> closure = GetWork();
if (closure == nullptr) {
break;
}
try {
closure();
} catch (const std::exception& ex) {
XLA_COUNTER("ThreadPoolException", 1);
TF_LOG(ERROR) << "Exception from running thread pool closure: "
<< ex.what();
}
}
}

void ScheduleOnThread(std::function<void()> closure) {
std::thread thread(std::move(closure));
thread.detach();
}

std::function<void()> GetWork() {
std::unique_lock<std::mutex> lock(mutex_);
++waiting_;
cv_.wait(lock, [this] { return exiting_ || !work_.empty(); });
--waiting_;
if (work_.empty()) {
return nullptr;
}
std::function<void()> closure(std::move(work_.front()));
work_.pop_front();
return closure;
}

std::vector<std::thread> threads_;
std::mutex mutex_;
std::condition_variable cv_;
bool exiting_ = false;
std::deque<std::function<void()>> work_;
size_t waiting_ = 0;
};

ThreadPool* GetThreadPool() {
void Schedule(std::function<void()> fn) {
static size_t num_threads = sys_util::GetEnvInt(
"XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency());
static ThreadPool* pool = new ThreadPool(num_threads);
return pool;
}

ThreadPool* GetIoThreadPool() {
static size_t num_threads = sys_util::GetEnvInt(
"XLA_IO_THREAD_POOL_SIZE", std::thread::hardware_concurrency());
static ThreadPool* pool = new ThreadPool(num_threads);
return pool;
}

} // namespace

class Completion::Data {
public:
void Wait() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return completed_; });
if (exptr_ != nullptr) {
std::rethrow_exception(exptr_);
}
}

static std::function<void()> GetCompleter(std::shared_ptr<Data> data,
std::function<void()> closure) {
auto closure_wrapper = [closure = std::move(closure), data]() {
std::exception_ptr exptr;
try {
closure();
} catch (...) {
exptr = std::current_exception();
}
data->Complete(exptr);
};
return closure_wrapper;
}

private:
void Complete(std::exception_ptr exptr) {
std::lock_guard<std::mutex> lock(mutex_);
exptr_ = std::move(exptr);
completed_ = true;
cv_.notify_all();
}

std::mutex mutex_;
std::condition_variable cv_;
bool completed_ = false;
std::exception_ptr exptr_;
};

Completion::Completion(std::shared_ptr<Data> data) : data_(std::move(data)) {}

Completion::~Completion() {}

void Completion::Wait() { data_->Wait(); }

void ScheduleClosure(std::function<void()> closure) {
GetThreadPool()->Schedule(std::move(closure));
}

void ScheduleIoClosure(std::function<void()> closure) {
GetIoThreadPool()->Schedule(std::move(closure));
}

Completion ScheduleClosureWithCompletion(std::function<void()> closure) {
auto data = std::make_shared<Completion::Data>();
GetThreadPool()->Schedule(
Completion::Data::GetCompleter(data, std::move(closure)));
return Completion(std::move(data));
}

Completion ScheduleIoClosureWithCompletion(std::function<void()> closure) {
auto data = std::make_shared<Completion::Data>();
GetIoThreadPool()->Schedule(
Completion::Data::GetCompleter(data, std::move(closure)));
return Completion(std::move(data));
static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", num_threads);
pool.Schedule(std::move(fn));
}

} // namespace env
} // namespace runtime
} // namespace torch_xla
25 changes: 1 addition & 24 deletions torch_xla/csrc/runtime/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,14 @@
#define XLA_CLIENT_THREAD_POOL_H_

#include <functional>
#include <memory>
#include <thread>

namespace torch_xla {
namespace runtime {
namespace env {

class Completion {
public:
class Data;

explicit Completion(std::shared_ptr<Data> data);

~Completion();

void Wait();

private:
std::shared_ptr<Data> data_;
};

// Schedules a closure to be run. The closure should not block waiting for other
// events.
void ScheduleClosure(std::function<void()> closure);
Completion ScheduleClosureWithCompletion(std::function<void()> closure);

// Schedules a closure which might wait for IO or other events/conditions.
void ScheduleIoClosure(std::function<void()> closure);
Completion ScheduleIoClosureWithCompletion(std::function<void()> closure);
void Schedule(std::function<void()> fn);

} // namespace env
} // namespace runtime
} // namespace torch_xla

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape,
SlicedCopy<SType, DType>(dest_shape.dimensions(), src_data, src_strides,
dest_data, dest_strides, iter_dims, parts[i]);
};
runtime::env::ScheduleClosure(
runtime::Schedule(
runtime::util::MultiWait::Completer(mwait, std::move(copy_fn)));
}
mwait->Wait();
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
}
};

runtime::env::ScheduleIoClosure(async->mwait.Completer(std::move(syncfn)));
runtime::Schedule(async->mwait.Completer(std::move(syncfn)));

return placeholders;
}
Expand Down Expand Up @@ -1029,7 +1029,7 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
}
};

runtime::env::ScheduleIoClosure(async->mwait.Completer(std::move(syncfn)));
runtime::Schedule(async->mwait.Completer(std::move(syncfn)));
return async;
}

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ ShardingUtil::InputHandler(
arguments_by_device[device_i][argument_i] = shard;
}
};
runtime::env::ScheduleIoClosure(
runtime::Schedule(
runtime::util::MultiWait::Completer(mwait, std::move(argument_setter)));
}
mwait->Wait();
Expand Down

0 comments on commit bf4ff6f

Please sign in to comment.