diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3396eec2c80..3ace00ea495 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -72,7 +72,6 @@ "special.scaled_modified_bessel_k1", "special.spherical_bessel_j0", "special.zeta", - "svd_lowrank", "unfold_copy", "unfold", "randint", @@ -122,6 +121,7 @@ "linalg.pinv": (8e-1, 2e0), "linalg.svd": (1e0, 1e0), "svd": (1e0, 1e0), + "svd_lowrank": (1e0, 1e0), "matrix_exp": (2e-1, 2e-4), "cdist": (5e1, 3e0)} diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 6160a61fb6d..9b24b075cec 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4559,8 +4559,8 @@ def _aten__linalg_slogdet(input): # torch.linalg.svd @op(torch.ops.aten._linalg_svd) -def _aten__linalg_svd(a, full_matrices=True): - return jnp.linalg.svd(a, full_matrices=full_matrices) +def _aten__linalg_svd(a, full_matrices=False, **kwargs): + return jnp.linalg.svd(a, full_matrices=full_matrices, **kwargs) # torch.linalg.pinv diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index c8b49c35c24..110fa05dbf7 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -355,8 +355,8 @@ def linalg_solve_ex(a, b): return res, info @register_function(torch.linalg.svd) -def linalg_svd(a, full_matrices=True, **kwargs): - return jaten._aten__linalg_svd(a, full_matrices=full_matrices, **kwargs) +def linalg_svd(a, full_matrices=True): + return jaten._aten__linalg_svd(a, full_matrices=full_matrices) @register_function(torch.linalg.matrix_power) def matrix_power(A, n, *, out=None):