Skip to content

Commit

Permalink
Fix type promotion for pow. (pytorch#6745)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Mar 18, 2024
1 parent 790e5c8 commit 902aa50
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 39 deletions.
28 changes: 28 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,34 @@ def test_patched_linear_1D_bias(self):
self.assertTrue(
torch.allclose(linear.bias.grad.cpu(), linear_cpu.bias.grad))

def test_pow_dtype_promotion(self):

def test(dtype):

def foo(x):
return torch.pow(x, 3.0)

x = torch.arange(10).to(dtype)
r = foo(x)

device = xm.xla_device()
Xx = x.to(device)
Xr = foo(Xx)

self.assertEqual(r, Xr.cpu())

test_dtypes = [
torch.bfloat16,
torch.float16,
torch.float32,
torch.float64,
torch.cfloat,
torch.cdouble,
]

for dtype in test_dtypes:
test(dtype)


class MNISTComparator(nn.Module):

Expand Down
27 changes: 21 additions & 6 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ at::Tensor DoBinaryOp(const at::Tensor& self, const at::Scalar& other,
return bridge::AtenFromXlaTensor(result);
}

template <typename B>
at::Tensor DoBinaryOp(const at::Scalar& self, const at::Tensor& other,
const B& bin_op) {
at::ScalarType dtype = at::result_type(self, other);
XLATensorPtr other_tensor = bridge::GetXlaTensor(other);
XLATensorPtr result = bin_op(self, other_tensor, dtype);
return bridge::AtenFromXlaTensor(result);
}

template <typename B>
at::Tensor DoBinaryOpWithoutPromo(const at::Tensor& self,
const at::Tensor& other, const B& bin_op) {
Expand Down Expand Up @@ -2302,22 +2311,28 @@ at::Tensor XLANativeFunctions::permute_copy(const at::Tensor& self,
at::Tensor XLANativeFunctions::pow(const at::Tensor& self,
const at::Scalar& exponent) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::pow(bridge::GetXlaTensor(self), exponent));
XLATensorPtr (*method_pow)(const XLATensorPtr&, const at::Scalar&,
c10::optional<at::ScalarType>) =
tensor_methods::pow;
return DoBinaryOp(self, exponent, method_pow);
}

at::Tensor XLANativeFunctions::pow(const at::Tensor& self,
const at::Tensor& exponent) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::pow(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(exponent)));
XLATensorPtr (*method_pow)(const XLATensorPtr&, const XLATensorPtr&,
c10::optional<at::ScalarType>) =
tensor_methods::pow;
return DoBinaryOp(self, exponent, method_pow);
}

at::Tensor XLANativeFunctions::pow(const at::Scalar& self,
const at::Tensor& exponent) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(
tensor_methods::pow(self, bridge::GetXlaTensor(exponent)));
XLATensorPtr (*method_pow)(const at::Scalar&, const XLATensorPtr&,
c10::optional<at::ScalarType>) =
tensor_methods::pow;
return DoBinaryOp(self, exponent, method_pow);
}

at::Tensor XLANativeFunctions::_prelu_kernel(const at::Tensor& self,
Expand Down
64 changes: 34 additions & 30 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2179,37 +2179,41 @@ XLATensorPtr permute(const XLATensorPtr& input,
torch::lazy::MakeNode<Permute>(input->GetIrValue(), dimensions));
}

XLATensorPtr pow(const XLATensorPtr& input, const at::Scalar& exponent) {
XLATensorPtr pow(const XLATensorPtr& input, const at::Scalar& exponent,
c10::optional<at::ScalarType> logical_element_type) {
// We want to pass exponent_node as a constant to give XLA more room to
// optimize
const torch::lazy::BackendDevice& device = input->GetDevice();
auto xla_type = MakeXlaPrimitiveType(GetScalarType(exponent), &device);
// Float scalar literal in Python defaults to F64. But we want to produce
// F32 as this is the default Pytorch behavior.
if (xla_type == xla::PrimitiveType::F64) {
xla_type = xla::PrimitiveType::F32;
}
torch::lazy::Value exp_node = ScalarOp(exponent, xla_type);
torch::lazy::NodePtr node = Pow(input->GetIrValue(), exp_node);
return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt);
}

XLATensorPtr pow(const XLATensorPtr& input, const XLATensorPtr& exponent) {
torch::lazy::NodePtr node = Pow(input->GetIrValue(), exponent->GetIrValue());
return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt);
}

XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent) {
const torch::lazy::BackendDevice& device = exponent->GetDevice();
torch::lazy::Value input_node = XLAGraphExecutor::Get()->GetIrValueForScalar(
input, MakeXlaPrimitiveType(GetScalarType(input), &device), device);
torch::lazy::NodePtr pow_node = Pow(input_node, exponent->GetIrValue());
at::ScalarType input_dtype = GetScalarType(input);
at::ScalarType exp_dtype = exponent->dtype();
at::ScalarType promoted_dtype =
MaybeUpcastToHostTorchType(XlaHelpers::PromoteType(
XlaTypeFromTorchType(input_dtype), XlaTypeFromTorchType(exp_dtype)));
return exponent->CreateFrom(pow_node, promoted_dtype);
// optimize.
at::ScalarType type =
logical_element_type
? *logical_element_type
: at::result_type(bridge::AtenFromXlaTensor(input), exponent);
return input->CreateFrom(
Pow(input->GetIrValue(),
ScalarOp(exponent, MakeXlaPrimitiveType(type, &input->GetDevice()))),
type);
}

XLATensorPtr pow(const XLATensorPtr& input, const XLATensorPtr& exponent,
c10::optional<at::ScalarType> logical_element_type) {
at::ScalarType type =
logical_element_type
? *logical_element_type
: at::result_type(bridge::AtenFromXlaTensor(input),
bridge::AtenFromXlaTensor(exponent));
return input->CreateFrom(Pow(input->GetIrValue(), exponent->GetIrValue()),
type);
}

XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent,
c10::optional<at::ScalarType> logical_element_type) {
at::ScalarType type =
logical_element_type
? *logical_element_type
: at::result_type(input, bridge::AtenFromXlaTensor(exponent));
return exponent->CreateFrom(
Pow(ScalarOp(input, MakeXlaPrimitiveType(type, &exponent->GetDevice())),
exponent->GetIrValue()),
type);
}

XLATensorPtr prelu(const XLATensorPtr& input, const XLATensorPtr& weight) {
Expand Down
12 changes: 9 additions & 3 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,15 @@ void optimization_barrier_(std::vector<XLATensorPtr>& tensors);
// Permute the dimensions of this tensor according to the given permutation.
XLATensorPtr permute(const XLATensorPtr& input, absl::Span<const int64_t> dims);

XLATensorPtr pow(const XLATensorPtr& input, const at::Scalar& exponent);
XLATensorPtr pow(const XLATensorPtr& input, const XLATensorPtr& exponent);
XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent);
XLATensorPtr pow(
const XLATensorPtr& input, const at::Scalar& exponent,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
XLATensorPtr pow(
const XLATensorPtr& input, const XLATensorPtr& exponent,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
XLATensorPtr pow(
const at::Scalar& input, const XLATensorPtr& exponent,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);

XLATensorPtr prelu(const XLATensorPtr& input, const XLATensorPtr& weight);

Expand Down

0 comments on commit 902aa50

Please sign in to comment.