Skip to content

Commit

Permalink
Move threadpool namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 14, 2023
1 parent 48e6333 commit b688515
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 41 deletions.
2 changes: 1 addition & 1 deletion test/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ ptxla_cc_test(
srcs = ["test_replication.cpp"],
deps = [
":cpp_test_util",
":thread_pool",
":torch_xla_test",
"//torch_xla/csrc/runtime:runtime",
"//torch_xla/csrc/runtime:debug_macros",
"//torch_xla/csrc/runtime:thread_pool",
"//torch_xla/csrc:tensor",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"
#include "xla/client/xla_builder.h"
Expand Down Expand Up @@ -70,7 +70,7 @@ void TestSingleReplication(
device_strings[i], exec_options);
mwait.DecrementCount();
};
torch_xla::runtime::Schedule(std::move(executor));
torch_xla::thread::Schedule(std::move(executor));
}
mwait.Wait();

Expand Down
11 changes: 10 additions & 1 deletion torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ ptxla_cc_library(
"//torch_xla/csrc/runtime:metrics_reader",
"//torch_xla/csrc/runtime:profiler",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc/runtime:thread_pool",
"//torch_xla/csrc/runtime:util",
"//torch_xla/csrc/runtime:xla_coordinator",
"//torch_xla/csrc/runtime:xla_util",
Expand Down Expand Up @@ -320,6 +319,16 @@ cc_library(
],
)

cc_library(
name = "thread_pool",
srcs = ["thread_pool.cc"],
hdrs = ["thread_pool.h"],
deps = [
"//torch_xla/csrc/runtime:sys_util",
"@tsl//tsl/platform:env"
],
)

ptxla_cc_library(
name = "unwrap_data",
srcs = ["unwrap_data.cpp"],
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/runtime/xla_util.h"
Expand Down
13 changes: 1 addition & 12 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ cc_library(
":stablehlo_helper",
":tensor_source",
":tf_logging",
":thread_pool",
":xla_coordinator",
"//torch_xla/csrc:thread_pool",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
Expand Down Expand Up @@ -270,17 +270,6 @@ cc_library(
],
)

cc_library(
name = "thread_pool",
srcs = ["thread_pool.cc"],
hdrs = ["thread_pool.h"],
deps = [
":metrics",
":tf_logging",
"@tsl//tsl/platform:env"
],
)

cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/client/xla_builder.h"
Expand Down Expand Up @@ -619,7 +619,7 @@ PjRtComputationClient::ExecuteComputation(
}
CreateDataHandlesCounter()->AddValue(datas.size());

Schedule(std::move([&, this, device,
thread::Schedule(std::move([&, this, device,
returned_future = std::move(*returned_future),
timed]() mutable {
TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for "
Expand Down Expand Up @@ -689,7 +689,7 @@ PjRtComputationClient::ExecuteReplicated(
argument_handles[i] = std::move(buffers);
mwait.DecrementCount();
};
Schedule(std::move(buffer_converter));
thread::Schedule(std::move(buffer_converter));
}
mwait.Wait();
}
Expand Down Expand Up @@ -746,7 +746,7 @@ PjRtComputationClient::ExecuteReplicated(
}
}

Schedule(std::move([&, this, returned_futures = std::move(*returned_futures),
thread::Schedule(std::move([&, this, returned_futures = std::move(*returned_futures),
timed]() mutable {
// Grab the shared lock and block the `WaitDeviceOps` until buffer is
// ready. Since this is the SPMD code path. There is no points to grab
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla/csrc/xla_backend_impl.h"
Expand Down Expand Up @@ -373,7 +373,7 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape,
dest_data, dest_strides, iter_dims, parts[i]);
mwait.DecrementCount();
};
runtime::Schedule(std::move(copy_fn));
thread::Schedule(std::move(copy_fn));
}
mwait.Wait();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/thread_pool.h"

#include <condition_variable>
#include <deque>
#include <exception>
#include <mutex>
#include <functional>

#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/threadpool.h"

namespace torch_xla {
namespace runtime {
namespace thread {

void Schedule(std::function<void()> fn) {
static size_t num_threads = sys_util::GetEnvInt(
static size_t num_threads = torch_xla::runtime::sys_util::GetEnvInt(
"XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency());
static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla",
num_threads);
pool.Schedule(std::move(fn));
}

} // namespace runtime
} // namespace thread
} // namespace torch_xla
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
#include <functional>

namespace torch_xla {
namespace runtime {
namespace thread {

// Schedules a closure to be run. The closure should not block waiting for other
// events.
void Schedule(std::function<void()> fn);

} // namespace runtime
} // namespace thread
} // namespace torch_xla

#endif // XLA_CLIENT_THREAD_POOL_H_
6 changes: 3 additions & 3 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
Expand Down Expand Up @@ -757,7 +757,7 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
}
};

runtime::Schedule(async->mwait.Completer(std::move(syncfn)));
thread::Schedule(async->mwait.Completer(std::move(syncfn)));

return placeholders;
}
Expand Down Expand Up @@ -1029,7 +1029,7 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
}
};

runtime::Schedule(async->mwait.Completer(std::move(syncfn)));
thread::Schedule(async->mwait.Completer(std::move(syncfn)));
return async;
}

Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/tensor_methods.h"
#include "torch_xla/csrc/tensor_util.h"
Expand Down Expand Up @@ -341,7 +341,7 @@ ShardingUtil::InputHandler(
}
mwait.DecrementCount();
};
runtime::Schedule(std::move(argument_setter));
thread::Schedule(std::move(argument_setter));
}
mwait.Wait();
return arguments_by_device;
Expand Down

0 comments on commit b688515

Please sign in to comment.