-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable cross entropy loss for xla autocast with FP32 precision (#7992) (
- Loading branch information
1 parent
0f645b2
commit 940bee4
Showing
2 changed files
with
30 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
import re | ||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
import unittest | ||
|
||
device = xm.xla_device() | ||
|
||
|
||
class TestAutocastXla(unittest.TestCase): | ||
|
||
def test_cross_entropy_loss(self): | ||
data = torch.randn(16, 10).to(torch.bfloat16).to(device) | ||
target = torch.randn(16, 10).to(torch.bfloat16).to(device) | ||
with torch.autocast("xla"): | ||
loss = torch.nn.CrossEntropyLoss()(data, target) | ||
hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) | ||
self.assertTrue( | ||
re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None) | ||
|
||
self.assertTrue( | ||
re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None) | ||
|
||
self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters