diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 8a38c87cd00..d4eed1ec83f 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -118,8 +118,10 @@ ptxla_cc_library( ":layout_manager", ":shape_builder", ":shape_helper", + "//torch_xla/csrc/runtime:async_task", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", + "//torch_xla/csrc/runtime:unique", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index e601034790b..7723d6d95d9 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include @@ -18,6 +17,7 @@ #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -61,14 +61,14 @@ std::string DebugUtil::GetTensorsGraphHlo( absl::Span tensors, const std::vector* indices, bool dump_stablehlo) { std::vector root_values; - std::optional device; + runtime::util::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; torch::lazy::Value ir_value = tensor->CurrentIrValue(); if (ir_value) { root_values.push_back(std::move(ir_value)); - device = tensor->GetDevice(); + unique_device.set(tensor->GetDevice()); } } } else { @@ -76,13 +76,13 @@ std::string DebugUtil::GetTensorsGraphHlo( torch::lazy::Value ir_value = tensor->CurrentIrValue(); if (ir_value) { root_values.push_back(std::move(ir_value)); - device = tensor->GetDevice(); + unique_device.set(tensor->GetDevice()); } } } - return DumpUtil::ToHlo(root_values, - device.value_or(bridge::GetCurrentDevice()), - EmitMode::kStableHloReadable); + return DumpUtil::ToHlo( + root_values, unique_device ? *unique_device : bridge::GetCurrentDevice(), + EmitMode::kStableHloReadable); } std::string DebugUtil::GetTensorsGraphInfo( @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo( std::vector root_nodes; std::vector root_values; std::vector root_hashes; - std::optional device; + runtime::util::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; @@ -100,7 +100,7 @@ std::string DebugUtil::GetTensorsGraphInfo( root_nodes.push_back(ir_value.node.get()); root_hashes.push_back(ir_value.hash()); root_values.push_back(std::move(ir_value)); - device = tensor->GetDevice(); + unique_device.set(tensor->GetDevice()); } } } else { @@ -110,7 +110,7 @@ std::string DebugUtil::GetTensorsGraphInfo( root_nodes.push_back(ir_value.node.get()); root_hashes.push_back(ir_value.hash()); root_values.push_back(std::move(ir_value)); - device = tensor->GetDevice(); + unique_device.set(tensor->GetDevice()); } } } @@ -137,12 +137,14 @@ std::string DebugUtil::GetTensorsGraphInfo( } else if (format == GraphFormat::kDot) { graph_str = DumpUtil::ToDot(root_nodes); } else if (format == GraphFormat::kHlo) { - graph_str = DumpUtil::ToHlo(root_values, - device.value_or(bridge::GetCurrentDevice())); + graph_str = DumpUtil::ToHlo(root_values, unique_device + ? *unique_device + : bridge::GetCurrentDevice()); } else if (format == GraphFormat::kStableHlo) { - graph_str = DumpUtil::ToHlo(root_values, - device.value_or(bridge::GetCurrentDevice()), - EmitMode::kStableHloReadable); + graph_str = DumpUtil::ToHlo( + root_values, + unique_device ? *unique_device : bridge::GetCurrentDevice(), + EmitMode::kStableHloReadable); } else { XLA_ERROR() << "Invalid graph format: " << format; } diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 85e1e1557a6..b19dc0e717d 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -273,6 +273,15 @@ cc_library( ], ) +cc_library( + name = "unique", + hdrs = ["unique.h"], + deps = [ + ":debug_macros", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "util", hdrs = ["util.h"], diff --git a/torch_xla/csrc/runtime/unique.h b/torch_xla/csrc/runtime/unique.h new file mode 100644 index 00000000000..f50e24320d9 --- /dev/null +++ b/torch_xla/csrc/runtime/unique.h @@ -0,0 +1,50 @@ +#ifndef XLA_CLIENT_UNIQUE_H_ +#define XLA_CLIENT_UNIQUE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "torch_xla/csrc/runtime/debug_macros.h" + +namespace torch_xla { +namespace runtime { +namespace util { + +// Helper class to allow tracking zero or more things, which should be forcibly +// be one only thing. +template > +class Unique { + public: + std::pair set(const T& value) { + if (value_) { + XLA_CHECK(C()(*value_, value)) + << "'" << *value_ << "' vs '" << value << "'"; + return std::pair(false, *value_); + } + value_ = value; + return std::pair(true, *value_); + } + + operator bool() const { return value_.has_value(); } + operator const T&() const { return *value_; } + const T& operator*() const { return *value_; } + const T* operator->() const { return value_.operator->(); } + + std::set AsSet() const { + std::set vset; + if (value_.has_value()) { + vset.insert(*value_); + } + return vset; + } + + private: + absl::optional value_; +}; + +} // namespace util +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_CLIENT_UNIQUE_H_ diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 85378de0f8b..e14b11882a7 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -15,7 +15,6 @@ #include #include #include -#include #include #include #include @@ -39,6 +38,7 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 659dbfa8834..39d866358ac 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -48,6 +47,7 @@ #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" @@ -534,12 +534,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { tsl::profiler::TraceMe activity("CollectSyncTensors", tsl::profiler::TraceMeLevel::kInfo); - std::optional device; + runtime::util::Unique unique_device; for (size_t i = 0; i < tensors.size(); ++i) { - device = tensors[i]->GetDevice(); + unique_device.set(tensors[i]->GetDevice()); } SyncTensorCollection coll; - if (!device) { + if (!unique_device) { return coll; } @@ -552,7 +552,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( // graph with on/off force_ltc_data should not match, hash wise. coll.hash = torch::lazy::MHash(config.force_ltc_data); coll.config = config; - coll.device = *device; + coll.device = *unique_device; coll.indices.reserve(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&