-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix precision issue of pow(int, float) #6103
Conversation
torch_xla/csrc/tensor_methods.cpp
Outdated
auto* xla_node = dynamic_cast<XlaNode*>(node.get()); | ||
at::ScalarType dtype = | ||
TorchTypeFromXlaType(xla_node->xla_shape().element_type()); | ||
return input->CreateFrom(node, dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can do something similar to input->CreateFrom(node, /*logical_element_type=*/nullptr)
to make sure logical_lement_type
is not being inherited from input, and it will by defualt use the xla_shape
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if CI is green.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Unit tests CI should be fixed with a rebase.
2c8e384
to
e2bafc2
Compare
Currently we are casting float scalar to int32 tensor (to match input1). This is incorrect as it would do power in int and get incorrect results. Fixed version will cast both to float and do the math in float. The return value will be float tensor instead of int tensor
fixes #5887