diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index d4eed1ec83f..8a38c87cd00 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -118,10 +118,8 @@ ptxla_cc_library( ":layout_manager", ":shape_builder", ":shape_helper", - "//torch_xla/csrc/runtime:async_task", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", - "//torch_xla/csrc/runtime:unique", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 7723d6d95d9..9959d46f8a2 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -1,6 +1,7 @@ #include "torch_xla/csrc/debug_util.h" #include +#include #include #include @@ -17,7 +18,6 @@ #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -61,7 +61,7 @@ std::string DebugUtil::GetTensorsGraphHlo( absl::Span tensors, const std::vector* indices, bool dump_stablehlo) { std::vector root_values; - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo( std::vector root_nodes; std::vector root_values; std::vector root_hashes; - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index b19dc0e717d..85e1e1557a6 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -273,15 +273,6 @@ cc_library( ], ) -cc_library( - name = "unique", - hdrs = ["unique.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/types:optional", - ], -) - cc_library( name = "util", hdrs = ["util.h"], diff --git a/torch_xla/csrc/runtime/unique.h b/torch_xla/csrc/runtime/unique.h deleted file mode 100644 index f50e24320d9..00000000000 --- a/torch_xla/csrc/runtime/unique.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef XLA_CLIENT_UNIQUE_H_ -#define XLA_CLIENT_UNIQUE_H_ - -#include -#include - -#include "absl/types/optional.h" -#include "torch_xla/csrc/runtime/debug_macros.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -// Helper class to allow tracking zero or more things, which should be forcibly -// be one only thing. -template > -class Unique { - public: - std::pair set(const T& value) { - if (value_) { - XLA_CHECK(C()(*value_, value)) - << "'" << *value_ << "' vs '" << value << "'"; - return std::pair(false, *value_); - } - value_ = value; - return std::pair(true, *value_); - } - - operator bool() const { return value_.has_value(); } - operator const T&() const { return *value_; } - const T& operator*() const { return *value_; } - const T* operator->() const { return value_.operator->(); } - - std::set AsSet() const { - std::set vset; - if (value_.has_value()) { - vset.insert(*value_); - } - return vset; - } - - private: - absl::optional value_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_UNIQUE_H_ diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index e14b11882a7..a5dc91d27ce 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -38,7 +38,6 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 39d866358ac..9337f779b4f 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -47,7 +48,6 @@ #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" @@ -534,7 +534,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { tsl::profiler::TraceMe activity("CollectSyncTensors", tsl::profiler::TraceMeLevel::kInfo); - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; for (size_t i = 0; i < tensors.size(); ++i) { unique_device.set(tensors[i]->GetDevice()); }