Skip to content

Commit

Permalink
Revert "Simplify AtenSource"
Browse files Browse the repository at this point in the history
This reverts commit 4225deb.
  • Loading branch information
will-cromar committed Nov 8, 2023
1 parent 40933b6 commit df239db
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 45 deletions.
5 changes: 3 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
for (auto& tensor : tensors) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device());

total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape());

std::shared_ptr<xla::PjRtBuffer> buffer =
std::move(client_
->BufferFromHostBuffer(
Expand All @@ -308,8 +310,7 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
.value());

ComputationClient::DataPtr data =
std::make_shared<PjRtData>(tensor->device(), buffer);
total_size += xla::ShapeUtil::ByteSizeOf(data->shape());
std::make_shared<PjRtData>(tensor->device(), tensor->shape(), buffer);
datas.push_back(data);
}
OutboundDataMetric()->AddSample(total_size);
Expand Down
57 changes: 26 additions & 31 deletions torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,34 @@ class TensorSource {

virtual const void* data() const = 0;

virtual xla::PrimitiveType primitive_type() const = 0;
virtual const xla::Shape& shape() const = 0;

virtual std::vector<int64_t> dimensions() const = 0;
const std::string& device() const { return device_; }

virtual std::vector<int64_t> byte_strides() const = 0;
virtual std::vector<int64_t> byte_strides() const {
std::vector<int64_t> byte_strides(shape().dimensions_size());
XLA_CHECK_OK(
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides)));
return byte_strides;
}

const std::string& device() const { return device_; }
virtual std::vector<int64_t> dimensions() const {
auto dimensions = shape().dimensions();
return {dimensions.begin(), dimensions.end()};
}

virtual xla::PrimitiveType primitive_type() const {
return shape().element_type();
}

private:
std::string device_;
};

class AtenSource : public TensorSource {
public:
AtenSource(const at::Tensor& tensor, xla::PrimitiveType target_type,
std::string device)
: TensorSource(std::move(device)), target_type_(target_type) {
AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device)
: TensorSource(std::move(device)), shape_(std::move(shape)) {
at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type());
if (target_torch_type != tensor.type().scalarType()) {
TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1);
Expand All @@ -50,12 +61,7 @@ class AtenSource : public TensorSource {

const void* data() const override { return tensor_.const_data_ptr(); }

xla::PrimitiveType primitive_type() const override { return target_type_; }

std::vector<int64_t> dimensions() const override {
auto sizes = tensor_.sizes();
return {sizes.begin(), sizes.end()};
}
const xla::Shape& shape() const override { return shape_; }

std::vector<int64_t> byte_strides() const override {
std::vector<int64_t> strides;
Expand All @@ -65,9 +71,14 @@ class AtenSource : public TensorSource {
return strides;
}

std::vector<int64_t> dimensions() const override {
auto sizes = tensor_.sizes();
return {sizes.begin(), sizes.end()};
}

private:
at::Tensor tensor_;
xla::PrimitiveType target_type_;
xla::Shape shape_;
};

class LiteralSource : public TensorSource {
Expand All @@ -77,23 +88,7 @@ class LiteralSource : public TensorSource {

const void* data() const override { return literal_.untyped_data(); }

const xla::Shape& shape() const { return literal_.shape(); }

xla::PrimitiveType primitive_type() const override {
return shape().element_type();
}

std::vector<int64_t> dimensions() const override {
auto dimensions = shape().dimensions();
return {dimensions.begin(), dimensions.end()};
}

std::vector<int64_t> byte_strides() const override {
std::vector<int64_t> byte_strides(shape().dimensions_size());
XLA_CHECK_OK(
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides)));
return byte_strides;
}
const xla::Shape& shape() const override { return literal_.shape(); }

private:
xla::Literal literal_;
Expand Down
14 changes: 6 additions & 8 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ torch::lazy::BackendDataPtr TensorToXlaData(
}

std::vector<std::shared_ptr<const runtime::TensorSource>> source_tensors;
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
tensor, shape.element_type(), device.toString()));
source_tensors.push_back(
std::make_shared<runtime::AtenSource>(tensor, shape, device.toString()));

auto handles =
runtime::GetComputationClient()->TransferToServer(source_tensors);
Expand Down Expand Up @@ -705,10 +705,9 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
std::vector<std::shared_ptr<const runtime::TensorSource>> source_tensors;
for (size_t i = 0; i < tensors.size(); ++i) {
torch::lazy::BackendDevice device = ParseDeviceString(devices[i]);
xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device);
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
tensors[i],
MaybeDowncastForDevice(tensors[i].type().scalarType(), device),
devices[i]));
tensors[i], std::move(shape), devices[i]));
}
return WrapXlaData(
runtime::GetComputationClient()->TransferToServer(source_tensors));
Expand All @@ -725,6 +724,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
std::vector<runtime::ComputationClient::DataPtr> handles;
for (size_t i = 0; i < tensors.size(); ++i) {
torch::lazy::BackendDevice device = ParseDeviceString(devices[i]);
xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device);

std::vector<std::shared_ptr<const runtime::TensorSource>>
source_tensors; // in
Expand All @@ -744,9 +744,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
local_shards, local_devices, shardings[i]));
} else {
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
tensors[i],
MaybeDowncastForDevice(tensors[i].type().scalarType(), device),
devices[i]));
tensors[i], std::move(shape), devices[i]));
new_handles =
runtime::GetComputationClient()->TransferToServer(source_tensors);
}
Expand Down
7 changes: 3 additions & 4 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,11 +727,10 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
}
for (int64_t j = 0; j < devices.size(); ++j) {
auto shard_device = ParseDeviceString(devices[j]);
auto shard_shape =
CreateComputationShapeFromTensor(local_shards[j], &shard_device);
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
local_shards[j],
MaybeDowncastForDevice(local_shards[j].type().scalarType(),
shard_device),
devices[j]));
local_shards[j], shard_shape, devices[j]));
}
return runtime::GetComputationClient()->TransferShardsToServer(
source_tensors, GetVirtualDevice().toString(), global_shape, sharding);
Expand Down

0 comments on commit df239db

Please sign in to comment.