Skip to content

Commit

Permalink
fix rebase issues
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 28, 2023
1 parent 02f13f9 commit 40ab61d
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 217 deletions.
31 changes: 21 additions & 10 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,15 @@ cc_library(
":computation_client",
":debug_macros",
":env_vars",
":multi_wait",
":initialize_pjrt",
":operation_manager",
":stablehlo_helper",
":tf_logging",
":thread_pool",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
"@xla//xla/pjrt/distributed",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
"@xla//xla/service:gpu_plugin",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/pjrt:pjrt_c_api_client",
"@xla//xla/python/ifrt",
"@xla//xla/python/pjrt_ifrt",
"@tsl//tsl/profiler/lib:traceme",
Expand All @@ -117,6 +113,7 @@ cc_library(
":computation_client",
":debug_macros",
":env_vars",
":initialize_pjrt",
":operation_manager",
":profiler",
":stablehlo_helper",
Expand All @@ -128,11 +125,7 @@ cc_library(
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
"@xla//xla/pjrt/distributed",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
"@xla//xla/service:gpu_plugin",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/pjrt:pjrt_c_api_client",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
Expand Down Expand Up @@ -174,6 +167,24 @@ cc_library(
hdrs = ["env_vars.h"],
)

cc_library(
name = "initialize_pjrt",
srcs = ["initialize_pjrt.cc"],
hdrs = ["initialize_pjrt.h"],
deps = [
":debug_macros",
":env_vars",
":profiler",
":sys_util",
":tf_logging",
":xla_coordinator",
"@xla//xla/service:gpu_plugin",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/pjrt:pjrt_c_api_client",
],
)

cc_library(
name = "metrics_analysis",
srcs = ["metrics_analysis.cc"],
Expand Down
138 changes: 55 additions & 83 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,19 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/initialize_pjrt.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/sharding.h"
Expand Down Expand Up @@ -98,26 +95,7 @@ std::vector<std::string> IfrtComputationClient::PjRtDevicesToString(

IfrtComputationClient::IfrtComputationClient() {
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, "");
if (device_type == "CPU") {
TF_VLOG(1) << "Initializing PjRt CPU client...";
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true);
int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1);
client_ = xla::ifrt::PjRtClient::Create(
std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()));
} else if (device_type == "TPU" || device_type == "TPU_C_API") {
TF_VLOG(1) << "Initializing TFRT TPU client...";
XLA_CHECK_OK(pjrt::LoadPjrtPlugin(
"tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so")));
tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
XLA_CHECK(tpu_status.ok());
client_ = xla::ifrt::PjRtClient::Create(
std::move(xla::GetCApiClient("TPU").value()));
} else {
XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice,
device_type);
}

XLA_CHECK(client_.get() != nullptr);
client_ = xla::ifrt::PjRtClient::Create(std::move(InitializePjRt(device_type)));

// PjRtDevice IDs are not guaranteed to be dense, so we need to track
// a device's global ordinal separately from its device ID. Order the
Expand All @@ -130,10 +108,38 @@ IfrtComputationClient::IfrtComputationClient() {
global_ordinals_[device->id()] = global_ordinals_.size();
std::string device_str = PjRtDeviceToString(device);
string_to_device_.emplace(device_str, device);
device_locks_.emplace(device_str, std::make_unique<std::shared_mutex>());
}
// manually create the device_locks for SPMD device
device_locks_.emplace(spmd_device_str, std::make_unique<std::shared_mutex>());

auto tracked_devices = GetLocalDevices();
tracked_devices.emplace_back(spmd_device_str);
operation_manager_ = std::move(OperationManager(std::move(tracked_devices)));
}

IfrtComputationClient::~IfrtComputationClient() {
// In the GPU case, the PjRtClient depends on the DistributedRuntimeClient
// tracked in XlaCoordinator, so the PjRtClient must be destroyed first.
client_ = nullptr;
coordinator_ = nullptr;
}

bool IfrtComputationClient::CoordinatorInitialized() const {
return coordinator_ != nullptr;
}

void IfrtComputationClient::InitializeCoordinator(int global_rank,
int world_size,
std::string master_addr,
std::string port) {
XLA_CHECK(coordinator_ == nullptr)
<< "Can only initialize the XlaCoordinator once.";
coordinator_ = std::make_unique<XlaCoordinator>(global_rank, world_size,
master_addr, port);
}

XlaCoordinator& IfrtComputationClient::GetCoordinator() {
XLA_CHECK(coordinator_ != nullptr)
<< "XlaCoordinator has not been initialized";
return *coordinator_;
}

void IfrtComputationClient::IfrtData::Assign(
Expand Down Expand Up @@ -222,43 +228,35 @@ std::optional<xla::OpSharding> IfrtComputationClient::GetDataSharding(
}

std::vector<ComputationClient::DataPtr> IfrtComputationClient::TransferToServer(
absl::Span<const TensorSource> tensors) {
absl::Span<const std::shared_ptr<const TensorSource>> tensors) {
metrics::TimedSection timed(TransferToServerMetric());
tsl::profiler::TraceMe activity("IfrtComputationClient::TransferToServer",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::DataPtr> datas;
datas.reserve(tensors.size());
int64_t total_size = 0;
for (auto& tensor : tensors) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device);
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device());

auto literal = std::make_shared<xla::Literal>(tensor.shape);
tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes());
std::vector<int64_t> byte_strides(literal->shape().dimensions_size());
XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(),
absl::MakeSpan(byte_strides)));
total_size += literal->size_bytes();
total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape());

// Avoid use-after-free on `literal` due to unsequenced move and use.
xla::Literal* literal_pointer = literal.get();
tsl::RCReference<xla::ifrt::Array> buffer =
client_
->MakeArrayFromHostBuffer(
literal_pointer->untyped_data(),
xla::ifrt::ToDType(literal_pointer->shape().element_type())
.value(),
xla::ifrt::Shape(literal_pointer->shape().dimensions()),
byte_strides,
tensor->data(),
xla::ifrt::ToDType(tensor->primitive_type()).value(),
xla::ifrt::Shape(tensor->dimensions()),
tensor->byte_strides(),
// TODO: what is MemoryKind?
xla::ifrt::SingleDeviceSharding::Create(
pjrt_device, xla::ifrt::MemoryKind()),
xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes,
[literal{std::move(literal)}]() { /* frees literal */ })
[tensor]() { /* frees tensor */ })
.value();

ComputationClient::DataPtr data =
std::make_shared<IfrtData>(tensor.device, tensor.shape, buffer);
std::make_shared<IfrtData>(tensor->device(), tensor->shape(), buffer);
datas.push_back(data);
}
OutboundDataMetric()->AddSample(total_size);
Expand All @@ -268,8 +266,8 @@ std::vector<ComputationClient::DataPtr> IfrtComputationClient::TransferToServer(
}

ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer(
absl::Span<const TensorSource> tensor_shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) {
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) {
tsl::profiler::TraceMe activity(
"IfrtComputationClient::TransferShardsToServer",
tsl::profiler::TraceMeLevel::kInfo);
Expand Down Expand Up @@ -367,7 +365,7 @@ tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
auto sharded_results =
ExecuteReplicated(*computations.front(), {{handle_but_not_const}},
GetLocalDevices(), execute_options);
auto replicated_output = std::dynamic_pointer_cast<IfrtData>(sharded_results[0][0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy);
auto replicated_output = std::dynamic_pointer_cast<IfrtData>(sharded_results[0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy);
// TODO: sanity check outputs
return *replicated_output;
}
Expand Down Expand Up @@ -575,10 +573,10 @@ IfrtComputationClient::ExecuteComputation(
// return datas;
}

std::vector<std::vector<ComputationClient::DataPtr>>
std::vector<ComputationClient::DataPtr>
IfrtComputationClient::ExecuteReplicated(
const ComputationClient::Computation& computation,
const std::vector<std::vector<ComputationClient::DataPtr>>& arguments,
const absl::Span<const ComputationClient::DataPtr> arguments,
// TODO: devices isn't doing anything helpful here
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) {
Expand All @@ -597,9 +595,9 @@ IfrtComputationClient::ExecuteReplicated(
// << "ExecuteReplicated over " << devices.size() << " devices, but "
// << arguments.size() << " arguments devices.";
// TODO: parallelize again if necessary
std::vector<tsl::RCReference<xla::ifrt::Array>> argument_handles(arguments[0].size());
for (int32_t i = 0; i < arguments[0].size(); ++i) {
auto ifrt_data = std::dynamic_pointer_cast<IfrtData>(arguments[0][i]);
std::vector<tsl::RCReference<xla::ifrt::Array>> argument_handles(arguments.size());
for (int32_t i = 0; i < arguments.size(); ++i) {
auto ifrt_data = std::dynamic_pointer_cast<IfrtData>(arguments[i]);
argument_handles[i] = ifrt_data->buffer;
}

Expand Down Expand Up @@ -683,37 +681,11 @@ xla::PjRtDevice* IfrtComputationClient::StringToPjRtDevice(
return pjrt_device;
}

std::shared_lock<std::shared_mutex> IfrtComputationClient::lock_device_shared(
const std::string& device) {
std::shared_lock lock(*device_locks_[device]);
return lock;
}

std::unique_lock<std::shared_mutex> IfrtComputationClient::lock_device(
const std::string& device) {
std::unique_lock lock(*device_locks_[device]);
return lock;
}

void IfrtComputationClient::WaitDeviceOps(
const std::vector<std::string>& devices) {
std::unordered_set<std::string> wait_devices;
if (!devices.empty()) {
for (auto& device_str : devices) {
wait_devices.insert(device_str);
}
} else {
for (auto& device_str : GetLocalDevices()) {
wait_devices.insert(device_str);
}
}
for (const std::string& device_str : wait_devices) {
TF_VLOG(3) << "Waiting for device execution for " << device_str
<< " to finish";
lock_device(device_str);
TF_VLOG(3) << "Waiting for device execution for " << device_str
<< " to finish.. Done";
}
absl::Span<const std::string> devices) {
TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", ");
operation_manager_.WaitForDevices(devices.empty() ? GetLocalDevices()
: devices);
}

std::map<std::string, Metric> IfrtComputationClient::GetMetrics() const {
Expand Down
35 changes: 21 additions & 14 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/operation_manager.h"
#include "torch_xla/csrc/runtime/util.h"
#include "xla/client/xla_computation.h"
#include "xla/literal.h"
Expand All @@ -26,6 +27,7 @@ namespace runtime {
class IfrtComputationClient : public ComputationClient {
public:
IfrtComputationClient();
~IfrtComputationClient();

DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override;

Expand All @@ -39,17 +41,17 @@ class IfrtComputationClient : public ComputationClient {
std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) override;

std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) override;
absl::Span<const std::shared_ptr<const TensorSource>> tensors) override;

// Use XLA replication to re-assemble the sharded data.
// DataPtr ReplicateShardedData(const DataPtr& handle);

std::vector<xla::Literal> TransferFromServer(
absl::Span<const DataPtr> handles) override;

DataPtr TransferShardsToServer(absl::Span<const TensorSource> tensor_shards,
std::string device, xla::Shape shape,
xla::OpSharding sharding) override;
DataPtr TransferShardsToServer(
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) override;

DataPtr CopyToDevice(DataPtr data, std::string dst) override;

Expand All @@ -61,9 +63,9 @@ class IfrtComputationClient : public ComputationClient {
const std::string& device,
const ExecuteComputationOptions& options) override;

std::vector<std::vector<DataPtr>> ExecuteReplicated(
std::vector<DataPtr> ExecuteReplicated(
const Computation& computation,
const std::vector<std::vector<DataPtr>>& arguments,
const absl::Span<const DataPtr> arguments,
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) override;

Expand All @@ -88,12 +90,18 @@ class IfrtComputationClient : public ComputationClient {

std::shared_ptr<std::vector<std::string>> GetReplicationDevices() override;

void PrepareToExit() override { return; };

void WaitDeviceOps(const std::vector<std::string>& devices) override;
void WaitDeviceOps(absl::Span<const std::string> devices) override;

std::map<std::string, Metric> GetMetrics() const override;

void InitializeCoordinator(int global_rank, int world_size,
std::string master_addr,
std::string port) override;

XlaCoordinator& GetCoordinator() override;

bool CoordinatorInitialized() const override;

// NOT IMPLEMENTED

MemoryInfo GetMemoryInfo(const std::string& device) override {
Expand All @@ -102,18 +110,17 @@ class IfrtComputationClient : public ComputationClient {

private:
std::shared_ptr<xla::ifrt::PjRtClient> client_;
std::unique_ptr<XlaCoordinator> coordinator_;
// global_ordinals_ tracks a map from PjRtDeviceId to the device's
// dense global ordinal.
std::unordered_map<int, int> global_ordinals_;
std::unordered_map<std::string, xla::PjRtDevice* const> string_to_device_;
std::shared_ptr<std::vector<std::string>> replication_devices_;
std::unordered_map<std::string, std::unique_ptr<std::shared_mutex>>
device_locks_;
OperationManager operation_manager_;
tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(
tsl::Env::Default(), "ifrt", std::thread::hardware_concurrency());

xla::PjRtDevice* StringToPjRtDevice(const std::string& device);
std::shared_lock<std::shared_mutex> lock_device_shared(
const std::string& device);
std::unique_lock<std::shared_mutex> lock_device(const std::string& device);

std::string PjRtDeviceToString(xla::PjRtDevice* const device) const;
std::vector<std::string> PjRtDevicesToString(
Expand Down
Loading

0 comments on commit 40ab61d

Please sign in to comment.