Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix precision issue of pow(int, float) #6103

Merged
merged 3 commits into from
Dec 15, 2023
Merged

Fix precision issue of pow(int, float) #6103

merged 3 commits into from
Dec 15, 2023

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Dec 12, 2023

fixes #5887

@qihqi qihqi changed the title Qihqi/pow Fix precision issue of pow(int, float) Dec 12, 2023
@qihqi qihqi requested a review from JackCaoG December 12, 2023 00:08
auto* xla_node = dynamic_cast<XlaNode*>(node.get());
at::ScalarType dtype =
TorchTypeFromXlaType(xla_node->xla_shape().element_type());
return input->CreateFrom(node, dtype);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do something similar to input->CreateFrom(node, /*logical_element_type=*/nullptr) to make sure logical_lement_type is not being inherited from input, and it will by defualt use the xla_shape.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM if CI is green.

Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Unit tests CI should be fixed with a rebase.

@qihqi qihqi force-pushed the qihqi/pow branch 8 times, most recently from 2c8e384 to e2bafc2 Compare December 13, 2023 21:47
Currently we are casting float scalar to int32 tensor (to
match input1). This is incorrect as it would do
power in int and get incorrect results.

Fixed version will cast both to float and do the math in
float. The return value will be float tensor instead of
int tensor
@qihqi qihqi merged commit e500129 into master Dec 15, 2023
20 checks passed
qihqi added a commit that referenced this pull request Dec 15, 2023
qihqi added a commit that referenced this pull request Dec 15, 2023
qihqi added a commit that referenced this pull request Dec 15, 2023
@qihqi qihqi deleted the qihqi/pow branch April 29, 2024 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Core ATen Opset] Lower aten_pow_Tensor_Tensor
3 participants