From 99f99f29aa09b0789611fef869d3512846bf0f7c Mon Sep 17 00:00:00 2001 From: David Huang Date: Wed, 9 Oct 2024 06:16:05 +0000 Subject: [PATCH] fix inv and inv_ex --- experimental/torch_xla2/test/test_ops.py | 2 -- experimental/torch_xla2/torch_xla2/ops/jaten.py | 11 +++++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa07381cfda..f46940611a2 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -29,8 +29,6 @@ "linalg.cholesky", "linalg.cholesky_ex", "linalg.det", - "linalg.inv", - "linalg.inv_ex", "linalg.ldl_factor", "linalg.ldl_factor_ex", "linalg.ldl_solve", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index f671e039839..48762862f01 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2333,6 +2333,12 @@ def _aten_histc(input, bins=100, min=0, max=0): return hist +# Used by some linalg functions to raise an exception +# when check_errors == True. This is currently a no-op. +@op(torch.ops.aten._linalg_check_errors) +def _aten_linalg_check_errors(A, api_name, is_matrix): + ... + @op(torch.ops.aten.hypot) def _aten_hypot(input, other): return jnp.hypot(input, other) @@ -2358,6 +2364,11 @@ def _aten_linalg_eig(A): def _aten_linalg_eigh(A, UPLO='L'): return jnp.linalg.eigh(A, UPLO) +@op(torch.ops.aten.linalg_inv_ex) +def _aten_linalg_inv_ex(A): + + return jnp.linalg.inv(A), jnp.zeros(A.shape[:-2], jnp.int32) + @op(torch.ops.aten.linalg_lu) def _aten_linalg_lu(A, pivot=True, out=None):