diff --git a/test/cpp/BUILD b/test/cpp/BUILD index 2e796f516ee..93bf02895ff 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -75,10 +75,10 @@ ptxla_cc_test( srcs = ["test_replication.cpp"], deps = [ ":cpp_test_util", + ":thread_pool", ":torch_xla_test", "//torch_xla/csrc/runtime:runtime", "//torch_xla/csrc/runtime:debug_macros", - "//torch_xla/csrc/runtime:thread_pool", "//torch_xla/csrc:tensor", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 3f88df6e3f5..af38ecec1c0 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -10,7 +10,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" -#include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" #include "xla/client/xla_builder.h" @@ -70,7 +70,7 @@ void TestSingleReplication( device_strings[i], exec_options); mwait.DecrementCount(); }; - torch_xla::runtime::Schedule(std::move(executor)); + torch_xla::thread::Schedule(std::move(executor)); } mwait.Wait(); diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 635da87fc9b..b18014ab2df 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -271,7 +271,6 @@ ptxla_cc_library( "//torch_xla/csrc/runtime:metrics_reader", "//torch_xla/csrc/runtime:profiler", "//torch_xla/csrc/runtime:sys_util", - "//torch_xla/csrc/runtime:thread_pool", "//torch_xla/csrc/runtime:util", "//torch_xla/csrc/runtime:xla_coordinator", "//torch_xla/csrc/runtime:xla_util", @@ -320,6 +319,16 @@ cc_library( ], ) +cc_library( + name = "thread_pool", + srcs = ["thread_pool.cc"], + hdrs = ["thread_pool.h"], + deps = [ + "//torch_xla/csrc/runtime:sys_util", + "@tsl//tsl/platform:env" + ], +) + ptxla_cc_library( name = "unwrap_data", srcs = ["unwrap_data.cpp"], diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 0ba4fd1b297..a1ae74bdef0 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -47,7 +47,6 @@ #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/xla_util.h" diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 395dd0433c5..d705ea0bdc5 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -86,8 +86,8 @@ cc_library( ":stablehlo_helper", ":tensor_source", ":tf_logging", - ":thread_pool", ":xla_coordinator", + "//torch_xla/csrc:thread_pool", "@xla//xla:literal", "@xla//xla:shape_util", "@xla//xla/client:xla_computation", @@ -270,17 +270,6 @@ cc_library( ], ) -cc_library( - name = "thread_pool", - srcs = ["thread_pool.cc"], - hdrs = ["thread_pool.h"], - deps = [ - ":metrics", - ":tf_logging", - "@tsl//tsl/platform:env" - ], -) - cc_library( name = "tensor_source", hdrs = ["tensor_source.h"], diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 6712394ccfb..a189fff44f4 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -8,6 +8,7 @@ #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" #include "pjrt_computation_client.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" @@ -15,7 +16,6 @@ #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" #include "tsl/profiler/lib/traceme.h" #include "xla/client/xla_builder.h" @@ -619,7 +619,7 @@ PjRtComputationClient::ExecuteComputation( } CreateDataHandlesCounter()->AddValue(datas.size()); - Schedule(std::move([&, this, device, + thread::Schedule(std::move([&, this, device, returned_future = std::move(*returned_future), timed]() mutable { TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " @@ -689,7 +689,7 @@ PjRtComputationClient::ExecuteReplicated( argument_handles[i] = std::move(buffers); mwait.DecrementCount(); }; - Schedule(std::move(buffer_converter)); + thread::Schedule(std::move(buffer_converter)); } mwait.Wait(); } @@ -746,7 +746,7 @@ PjRtComputationClient::ExecuteReplicated( } } - Schedule(std::move([&, this, returned_futures = std::move(*returned_futures), + thread::Schedule(std::move([&, this, returned_futures = std::move(*returned_futures), timed]() mutable { // Grab the shared lock and block the `WaitDeviceOps` until buffer is // ready. Since this is the SPMD code path. There is no points to grab diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 410b9eace5b..aa10417bf28 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -38,7 +38,6 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index e598e4c0b3e..f59c4579672 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -22,7 +22,7 @@ #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_backend_impl.h" @@ -373,7 +373,7 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape, dest_data, dest_strides, iter_dims, parts[i]); mwait.DecrementCount(); }; - runtime::Schedule(std::move(copy_fn)); + thread::Schedule(std::move(copy_fn)); } mwait.Wait(); } diff --git a/torch_xla/csrc/runtime/thread_pool.cc b/torch_xla/csrc/thread_pool.cc similarity index 54% rename from torch_xla/csrc/runtime/thread_pool.cc rename to torch_xla/csrc/thread_pool.cc index a51916dd2a9..e440afce7bd 100644 --- a/torch_xla/csrc/runtime/thread_pool.cc +++ b/torch_xla/csrc/thread_pool.cc @@ -1,25 +1,21 @@ -#include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/thread_pool.h" -#include -#include -#include -#include +#include -#include "torch_xla/csrc/runtime/metrics.h" -#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" namespace torch_xla { -namespace runtime { +namespace thread { void Schedule(std::function fn) { - static size_t num_threads = sys_util::GetEnvInt( + static size_t num_threads = torch_xla::runtime::sys_util::GetEnvInt( "XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", num_threads); pool.Schedule(std::move(fn)); } -} // namespace runtime +} // namespace thread } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/thread_pool.h b/torch_xla/csrc/thread_pool.h similarity index 87% rename from torch_xla/csrc/runtime/thread_pool.h rename to torch_xla/csrc/thread_pool.h index 16262edbd1c..22074e6886f 100644 --- a/torch_xla/csrc/runtime/thread_pool.h +++ b/torch_xla/csrc/thread_pool.h @@ -4,13 +4,13 @@ #include namespace torch_xla { -namespace runtime { +namespace thread { // Schedules a closure to be run. The closure should not block waiting for other // events. void Schedule(std::function fn); -} // namespace runtime +} // namespace thread } // namespace torch_xla #endif // XLA_CLIENT_THREAD_POOL_H_ diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 9f29f36ef6b..1ab8a577130 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -48,7 +48,7 @@ #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" @@ -757,7 +757,7 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( } }; - runtime::Schedule(async->mwait.Completer(std::move(syncfn))); + thread::Schedule(async->mwait.Completer(std::move(syncfn))); return placeholders; } @@ -1029,7 +1029,7 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( } }; - runtime::Schedule(async->mwait.Completer(std::move(syncfn))); + thread::Schedule(async->mwait.Completer(std::move(syncfn))); return async; } diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 6d1ec6853a8..b3317b51df3 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -15,7 +15,7 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/runtime.h" -#include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" @@ -341,7 +341,7 @@ ShardingUtil::InputHandler( } mwait.DecrementCount(); }; - runtime::Schedule(std::move(argument_setter)); + thread::Schedule(std::move(argument_setter)); } mwait.Wait(); return arguments_by_device;