Skip to content

Commit

Permalink
Revert "remove unique"
Browse files Browse the repository at this point in the history
This reverts commit ebe4567.
  • Loading branch information
will-cromar committed Nov 10, 2023
1 parent ab212cf commit 0146c1b
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 21 deletions.
2 changes: 2 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@ ptxla_cc_library(
":layout_manager",
":shape_builder",
":shape_helper",
"//torch_xla/csrc/runtime:async_task",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:stablehlo_helper",
"//torch_xla/csrc/runtime:unique",
"//torch_xla/csrc/runtime:xla_util",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
Expand Down
32 changes: 17 additions & 15 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <fstream>
#include <iostream>
#include <mutex>
#include <optional>
#include <sstream>
#include <unordered_set>

Expand All @@ -18,6 +17,7 @@
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/unique.h"
#include "torch_xla/csrc/xla_graph_executor.h"

namespace torch_xla {
Expand Down Expand Up @@ -61,28 +61,28 @@ std::string DebugUtil::GetTensorsGraphHlo(
absl::Span<const XLATensorPtr> tensors, const std::vector<size_t>* indices,
bool dump_stablehlo) {
std::vector<torch::lazy::Value> root_values;
std::optional<torch::lazy::BackendDevice> device;
runtime::util::Unique<torch::lazy::BackendDevice> unique_device;
if (indices != nullptr) {
for (auto index : *indices) {
const XLATensorPtr& tensor = tensors[index];
torch::lazy::Value ir_value = tensor->CurrentIrValue();
if (ir_value) {
root_values.push_back(std::move(ir_value));
device = tensor->GetDevice();
unique_device.set(tensor->GetDevice());
}
}
} else {
for (auto& tensor : tensors) {
torch::lazy::Value ir_value = tensor->CurrentIrValue();
if (ir_value) {
root_values.push_back(std::move(ir_value));
device = tensor->GetDevice();
unique_device.set(tensor->GetDevice());
}
}
}
return DumpUtil::ToHlo(root_values,
device.value_or(bridge::GetCurrentDevice()),
EmitMode::kStableHloReadable);
return DumpUtil::ToHlo(
root_values, unique_device ? *unique_device : bridge::GetCurrentDevice(),
EmitMode::kStableHloReadable);
}

std::string DebugUtil::GetTensorsGraphInfo(
Expand All @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo(
std::vector<const torch::lazy::Node*> root_nodes;
std::vector<torch::lazy::Value> root_values;
std::vector<torch::lazy::hash_t> root_hashes;
std::optional<torch::lazy::BackendDevice> device;
runtime::util::Unique<torch::lazy::BackendDevice> unique_device;
if (indices != nullptr) {
for (auto index : *indices) {
const XLATensorPtr& tensor = tensors[index];
Expand All @@ -100,7 +100,7 @@ std::string DebugUtil::GetTensorsGraphInfo(
root_nodes.push_back(ir_value.node.get());
root_hashes.push_back(ir_value.hash());
root_values.push_back(std::move(ir_value));
device = tensor->GetDevice();
unique_device.set(tensor->GetDevice());
}
}
} else {
Expand All @@ -110,7 +110,7 @@ std::string DebugUtil::GetTensorsGraphInfo(
root_nodes.push_back(ir_value.node.get());
root_hashes.push_back(ir_value.hash());
root_values.push_back(std::move(ir_value));
device = tensor->GetDevice();
unique_device.set(tensor->GetDevice());
}
}
}
Expand All @@ -137,12 +137,14 @@ std::string DebugUtil::GetTensorsGraphInfo(
} else if (format == GraphFormat::kDot) {
graph_str = DumpUtil::ToDot(root_nodes);
} else if (format == GraphFormat::kHlo) {
graph_str = DumpUtil::ToHlo(root_values,
device.value_or(bridge::GetCurrentDevice()));
graph_str = DumpUtil::ToHlo(root_values, unique_device
? *unique_device
: bridge::GetCurrentDevice());
} else if (format == GraphFormat::kStableHlo) {
graph_str = DumpUtil::ToHlo(root_values,
device.value_or(bridge::GetCurrentDevice()),
EmitMode::kStableHloReadable);
graph_str = DumpUtil::ToHlo(
root_values,
unique_device ? *unique_device : bridge::GetCurrentDevice(),
EmitMode::kStableHloReadable);
} else {
XLA_ERROR() << "Invalid graph format: " << format;
}
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@ cc_library(
],
)

cc_library(
name = "unique",
hdrs = ["unique.h"],
deps = [
":debug_macros",
"@com_google_absl//absl/types:optional",
],
)

cc_library(
name = "util",
hdrs = ["util.h"],
Expand Down
50 changes: 50 additions & 0 deletions torch_xla/csrc/runtime/unique.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef XLA_CLIENT_UNIQUE_H_
#define XLA_CLIENT_UNIQUE_H_

#include <functional>
#include <set>

#include "absl/types/optional.h"
#include "torch_xla/csrc/runtime/debug_macros.h"

namespace torch_xla {
namespace runtime {
namespace util {

// Helper class to allow tracking zero or more things, which should be forcibly
// be one only thing.
template <typename T, typename C = std::equal_to<T>>
class Unique {
public:
std::pair<bool, const T&> set(const T& value) {
if (value_) {
XLA_CHECK(C()(*value_, value))
<< "'" << *value_ << "' vs '" << value << "'";
return std::pair<bool, const T&>(false, *value_);
}
value_ = value;
return std::pair<bool, const T&>(true, *value_);
}

operator bool() const { return value_.has_value(); }
operator const T&() const { return *value_; }
const T& operator*() const { return *value_; }
const T* operator->() const { return value_.operator->(); }

std::set<T> AsSet() const {
std::set<T> vset;
if (value_.has_value()) {
vset.insert(*value_);
}
return vset;
}

private:
absl::optional<T> value_;
};

} // namespace util
} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_UNIQUE_H_
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include <exception>
#include <functional>
#include <mutex>
#include <optional>
#include <set>
#include <stdexcept>
#include <unordered_set>
Expand All @@ -39,6 +38,7 @@
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/unique.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"
Expand Down
10 changes: 5 additions & 5 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <fstream>
#include <functional>
#include <mutex>
#include <optional>
#include <set>
#include <stdexcept>
#include <unordered_map>
Expand Down Expand Up @@ -48,6 +47,7 @@
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/unique.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
Expand Down Expand Up @@ -534,12 +534,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
const std::vector<XLATensorPtr>& tensors, const SyncTensorsConfig& config) {
tsl::profiler::TraceMe activity("CollectSyncTensors",
tsl::profiler::TraceMeLevel::kInfo);
std::optional<torch::lazy::BackendDevice> device;
runtime::util::Unique<torch::lazy::BackendDevice> unique_device;
for (size_t i = 0; i < tensors.size(); ++i) {
device = tensors[i]->GetDevice();
unique_device.set(tensors[i]->GetDevice());
}
SyncTensorCollection coll;
if (!device) {
if (!unique_device) {
return coll;
}

Expand All @@ -552,7 +552,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
// graph with on/off force_ltc_data should not match, hash wise.
coll.hash = torch::lazy::MHash(config.force_ltc_data);
coll.config = config;
coll.device = *device;
coll.device = *unique_device;
coll.indices.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&
Expand Down

0 comments on commit 0146c1b

Please sign in to comment.