Skip to content

Commit

Permalink
Unit test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Nov 25, 2024
1 parent 39e67b5 commit 72da142
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
9 changes: 6 additions & 3 deletions experimental/torch_xla2/test/test_libraries.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import unittest
import jax
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.library import Library, impl, impl_abstract
import torch_xla2
from torch_xla2 import tensor
import torch_xla2.export
from torch_xla2.ops import jaten
from torch_xla2.ops import jlibrary

Expand Down Expand Up @@ -56,6 +54,7 @@ class LibraryTest(unittest.TestCase):

def setUp(self):
torch.manual_seed(0)
torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False

def test_basic_sdpa_library(self):

Expand All @@ -78,3 +77,7 @@ def forward(self, q,k,v):
## stablehlo.composite ops.
self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
self.assertIn("call @mylib.softmax", module_str)


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion experimental/torch_xla2/test/test_tf_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
import os
import tempfile
import numpy as np
import tensorflow as tf
import torch
import torch.nn.functional as F
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
arr = jax.device_put(arr, jax_device)
else:
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
return torch_tensor.to(new_device)
return the_tensor.to(new_device)

return XLATensor2(arr, self)

Expand Down

0 comments on commit 72da142

Please sign in to comment.