-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Transfer data directly to the device (#5752)
* Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516.
- Loading branch information
1 parent
b20a082
commit 5ca36cf
Showing
8 changed files
with
158 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#ifndef XLA_CLIENT_TENSOR_SOURCE_H_ | ||
#define XLA_CLIENT_TENSOR_SOURCE_H_ | ||
|
||
#include <ATen/Tensor.h> | ||
|
||
#include <vector> | ||
|
||
#include "torch_xla/csrc/runtime/debug_macros.h" | ||
#include "xla/literal.h" | ||
#include "xla/shape.h" | ||
#include "xla/shape_util.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
// Owns a contiguous block of data with the shape and layout matching `shape()`. | ||
class TensorSource { | ||
public: | ||
TensorSource(std::string device) : device_(std::move(device)){}; | ||
|
||
virtual const void* data() const = 0; | ||
|
||
virtual const xla::Shape& shape() const = 0; | ||
|
||
const std::string& device() const { return device_; } | ||
|
||
std::vector<int64_t> byte_strides() const { | ||
std::vector<int64_t> byte_strides(shape().dimensions_size()); | ||
XLA_CHECK_OK( | ||
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); | ||
return byte_strides; | ||
} | ||
|
||
private: | ||
std::string device_; | ||
}; | ||
|
||
class AtenSource : public TensorSource { | ||
public: | ||
AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device) | ||
: TensorSource(std::move(device)), | ||
tensor_(std::move(tensor.contiguous())), | ||
shape_(std::move(shape)) {} | ||
|
||
const void* data() const override { return tensor_.const_data_ptr(); } | ||
|
||
const xla::Shape& shape() const override { return shape_; } | ||
|
||
private: | ||
at::Tensor tensor_; | ||
xla::Shape shape_; | ||
}; | ||
|
||
class LiteralSource : public TensorSource { | ||
public: | ||
LiteralSource(xla::Literal literal, std::string device) | ||
: TensorSource(std::move(device)), literal_(std::move(literal)) {} | ||
|
||
const void* data() const override { return literal_.untyped_data(); } | ||
|
||
const xla::Shape& shape() const override { return literal_.shape(); } | ||
|
||
private: | ||
xla::Literal literal_; | ||
}; | ||
|
||
} // namespace runtime | ||
} // namespace torch_xla | ||
|
||
#endif // XLA_CLIENT_COMPUTATION_CLIENT_H_ |
Oops, something went wrong.