From 4105f454ee97aa209fef524b074ea5d5a42c9b63 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 7 Nov 2023 18:29:43 +0000 Subject: [PATCH] Downcast at::Tensor if required --- torch_xla/csrc/runtime/tensor_source.h | 82 +++++++++++++++----------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index d444552dcba3..35f248cae6e9 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -2,6 +2,7 @@ #define XLA_CLIENT_TENSOR_SOURCE_H_ #include +#include #include @@ -13,6 +14,45 @@ namespace torch_xla { namespace runtime { +namespace { + +// TODO: consolidate +at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) { + switch (xla_type) { + case xla::PrimitiveType::BF16: + return at::ScalarType::BFloat16; + case xla::PrimitiveType::F16: + return at::ScalarType::Half; + case xla::PrimitiveType::F32: + return at::ScalarType::Float; + case xla::PrimitiveType::F64: + return at::ScalarType::Double; + case xla::PrimitiveType::PRED: + return at::ScalarType::Bool; + case xla::PrimitiveType::U8: + return at::ScalarType::Byte; + case xla::PrimitiveType::S8: + return at::ScalarType::Char; + case xla::PrimitiveType::S16: + case xla::PrimitiveType::U16: + return at::ScalarType::Short; + case xla::PrimitiveType::S32: + case xla::PrimitiveType::U32: + return at::ScalarType::Int; + case xla::PrimitiveType::S64: + case xla::PrimitiveType::U64: + return at::ScalarType::Long; + case xla::PrimitiveType::C64: + return at::ScalarType::ComplexFloat; + case xla::PrimitiveType::C128: + return at::ScalarType::ComplexDouble; + default: + XLA_ERROR() << "XLA type not supported: " << xla_type; + } +} + +} + // Owns a contiguous block of data with the shape and layout matching `shape()`. class TensorSource { public: @@ -48,8 +88,15 @@ class AtenSource : public TensorSource { public: AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device) : TensorSource(std::move(device)), - tensor_(std::move(tensor.contiguous())), - shape_(std::move(shape)) {} + shape_(std::move(shape)) { + at::ScalarType target_torch_type = TensorTypeFromXlaType(primitive_type()); + if (target_torch_type != tensor.type().scalarType()) { + TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); + tensor_ = std::move(tensor.to(TensorTypeFromXlaType(primitive_type())).contiguous()); + } else { + tensor_ = std::move(tensor.contiguous()); + } + } const void* data() const override { return tensor_.const_data_ptr(); } @@ -68,37 +115,6 @@ class AtenSource : public TensorSource { return {sizes.begin(), sizes.end()}; } - // xla::PrimitiveType primitive_type() const override { - // switch (tensor_.type().scalarType()) { - // case at::ScalarType::Double: - // return xla::PrimitiveType::F64; - // case at::ScalarType::Float: - // return xla::PrimitiveType::F32; - // case at::ScalarType::BFloat16: - // return xla::PrimitiveType::BF16; - // case at::ScalarType::Half: - // return xla::PrimitiveType::F16; - // case at::ScalarType::Bool: - // return xla::PrimitiveType::PRED; - // case at::ScalarType::Byte: - // return xla::PrimitiveType::U8; - // case at::ScalarType::Char: - // return xla::PrimitiveType::S8; - // case at::ScalarType::Short: - // return xla::PrimitiveType::S16; - // case at::ScalarType::Int: - // return xla::PrimitiveType::S32; - // case at::ScalarType::Long: - // return xla::PrimitiveType::S64; - // case at::ScalarType::ComplexFloat: - // return xla::PrimitiveType::C64; - // case at::ScalarType::ComplexDouble: - // return xla::PrimitiveType::C128; - // default: - // XLA_ERROR() << "Type not supported: " << tensor_.type().scalarType(); - // } - // } - private: at::Tensor tensor_; xla::Shape shape_;