-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Transfer data directly to the device (#5772)
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
- Loading branch information
1 parent
e60428d
commit 05a3cdd
Showing
10 changed files
with
222 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
#ifndef XLA_CLIENT_TENSOR_SOURCE_H_ | ||
#define XLA_CLIENT_TENSOR_SOURCE_H_ | ||
|
||
#include <ATen/Tensor.h> | ||
#include <torch/csrc/lazy/core/metrics.h> | ||
|
||
#include <vector> | ||
|
||
#include "torch_xla/csrc/dtype.h" | ||
#include "torch_xla/csrc/runtime/debug_macros.h" | ||
#include "xla/literal.h" | ||
#include "xla/shape.h" | ||
#include "xla/shape_util.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
// Owns a contiguous block of data with the shape and layout matching `shape()`. | ||
class TensorSource { | ||
public: | ||
TensorSource(std::string device) : device_(std::move(device)){}; | ||
|
||
virtual const void* data() const = 0; | ||
|
||
virtual const xla::Shape& shape() const = 0; | ||
|
||
const std::string& device() const { return device_; } | ||
|
||
virtual std::vector<int64_t> byte_strides() const { | ||
std::vector<int64_t> byte_strides(shape().dimensions_size()); | ||
XLA_CHECK_OK( | ||
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); | ||
return byte_strides; | ||
} | ||
|
||
virtual std::vector<int64_t> dimensions() const { | ||
auto dimensions = shape().dimensions(); | ||
return {dimensions.begin(), dimensions.end()}; | ||
} | ||
|
||
virtual xla::PrimitiveType primitive_type() const { | ||
return shape().element_type(); | ||
} | ||
|
||
private: | ||
std::string device_; | ||
}; | ||
|
||
class AtenSource : public TensorSource { | ||
public: | ||
AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device) | ||
: TensorSource(std::move(device)), shape_(std::move(shape)) { | ||
at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type()); | ||
if (target_torch_type != tensor.type().scalarType()) { | ||
TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); | ||
tensor_ = std::move(tensor.to(target_torch_type).contiguous()); | ||
} else { | ||
tensor_ = std::move(tensor.contiguous()); | ||
} | ||
} | ||
|
||
const void* data() const override { return tensor_.const_data_ptr(); } | ||
|
||
const xla::Shape& shape() const override { return shape_; } | ||
|
||
std::vector<int64_t> byte_strides() const override { | ||
std::vector<int64_t> strides; | ||
for (auto& stride : tensor_.strides()) { | ||
strides.push_back(stride * tensor_.itemsize()); | ||
} | ||
return strides; | ||
} | ||
|
||
std::vector<int64_t> dimensions() const override { | ||
auto sizes = tensor_.sizes(); | ||
return {sizes.begin(), sizes.end()}; | ||
} | ||
|
||
private: | ||
at::Tensor tensor_; | ||
xla::Shape shape_; | ||
}; | ||
|
||
class LiteralSource : public TensorSource { | ||
public: | ||
LiteralSource(xla::Literal literal, std::string device) | ||
: TensorSource(std::move(device)), literal_(std::move(literal)) {} | ||
|
||
const void* data() const override { return literal_.untyped_data(); } | ||
|
||
const xla::Shape& shape() const override { return literal_.shape(); } | ||
|
||
private: | ||
xla::Literal literal_; | ||
}; | ||
|
||
} // namespace runtime | ||
} // namespace torch_xla | ||
|
||
#endif // XLA_CLIENT_COMPUTATION_CLIENT_H_ |
Oops, something went wrong.