Skip to content

Commit

Permalink
Downcast at::Tensor if required
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 7, 2023
1 parent 3f04304 commit 4105f45
Showing 1 changed file with 49 additions and 33 deletions.
82 changes: 49 additions & 33 deletions torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define XLA_CLIENT_TENSOR_SOURCE_H_

#include <ATen/Tensor.h>
#include <torch/csrc/lazy/core/metrics.h>

#include <vector>

Expand All @@ -13,6 +14,45 @@
namespace torch_xla {
namespace runtime {

namespace {

// TODO: consolidate
at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) {
switch (xla_type) {
case xla::PrimitiveType::BF16:
return at::ScalarType::BFloat16;
case xla::PrimitiveType::F16:
return at::ScalarType::Half;
case xla::PrimitiveType::F32:
return at::ScalarType::Float;
case xla::PrimitiveType::F64:
return at::ScalarType::Double;
case xla::PrimitiveType::PRED:
return at::ScalarType::Bool;
case xla::PrimitiveType::U8:
return at::ScalarType::Byte;
case xla::PrimitiveType::S8:
return at::ScalarType::Char;
case xla::PrimitiveType::S16:
case xla::PrimitiveType::U16:
return at::ScalarType::Short;
case xla::PrimitiveType::S32:
case xla::PrimitiveType::U32:
return at::ScalarType::Int;
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U64:
return at::ScalarType::Long;
case xla::PrimitiveType::C64:
return at::ScalarType::ComplexFloat;
case xla::PrimitiveType::C128:
return at::ScalarType::ComplexDouble;
default:
XLA_ERROR() << "XLA type not supported: " << xla_type;
}
}

}

// Owns a contiguous block of data with the shape and layout matching `shape()`.
class TensorSource {
public:
Expand Down Expand Up @@ -48,8 +88,15 @@ class AtenSource : public TensorSource {
public:
AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device)
: TensorSource(std::move(device)),
tensor_(std::move(tensor.contiguous())),
shape_(std::move(shape)) {}
shape_(std::move(shape)) {
at::ScalarType target_torch_type = TensorTypeFromXlaType(primitive_type());
if (target_torch_type != tensor.type().scalarType()) {
TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1);
tensor_ = std::move(tensor.to(TensorTypeFromXlaType(primitive_type())).contiguous());
} else {
tensor_ = std::move(tensor.contiguous());
}
}

const void* data() const override { return tensor_.const_data_ptr(); }

Expand All @@ -68,37 +115,6 @@ class AtenSource : public TensorSource {
return {sizes.begin(), sizes.end()};
}

// xla::PrimitiveType primitive_type() const override {
// switch (tensor_.type().scalarType()) {
// case at::ScalarType::Double:
// return xla::PrimitiveType::F64;
// case at::ScalarType::Float:
// return xla::PrimitiveType::F32;
// case at::ScalarType::BFloat16:
// return xla::PrimitiveType::BF16;
// case at::ScalarType::Half:
// return xla::PrimitiveType::F16;
// case at::ScalarType::Bool:
// return xla::PrimitiveType::PRED;
// case at::ScalarType::Byte:
// return xla::PrimitiveType::U8;
// case at::ScalarType::Char:
// return xla::PrimitiveType::S8;
// case at::ScalarType::Short:
// return xla::PrimitiveType::S16;
// case at::ScalarType::Int:
// return xla::PrimitiveType::S32;
// case at::ScalarType::Long:
// return xla::PrimitiveType::S64;
// case at::ScalarType::ComplexFloat:
// return xla::PrimitiveType::C64;
// case at::ScalarType::ComplexDouble:
// return xla::PrimitiveType::C128;
// default:
// XLA_ERROR() << "Type not supported: " << tensor_.type().scalarType();
// }
// }

private:
at::Tensor tensor_;
xla::Shape shape_;
Expand Down

0 comments on commit 4105f45

Please sign in to comment.