From 349958fb1d2d117f1c19d8924e2a2f55d0d45a31 Mon Sep 17 00:00:00 2001 From: David Huang Date: Wed, 16 Oct 2024 13:09:21 -0700 Subject: [PATCH] Fix op info test for linalg.lu_factor and linalg.ldl_factor (#8263) --- experimental/torch_xla2/test/test_ops.py | 4 --- .../torch_xla2/torch_xla2/ops/jaten.py | 33 +++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index bcc5b05e81c..f63f018f9de 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -28,12 +28,8 @@ "linalg.cholesky", "linalg.cholesky_ex", "linalg.det", - "linalg.ldl_factor", - "linalg.ldl_factor_ex", "linalg.ldl_solve", "linalg.lstsq", - "linalg.lu_factor", - "linalg.lu_factor_ex", "linalg.lu_solve", "linalg.matrix_norm", "linalg.matrix_power", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 2535338fd2a..d3cdb23a56c 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2415,6 +2415,30 @@ def _aten_linalg_eigh(A, UPLO='L'): return jnp.linalg.eigh(A, UPLO) +@op(torch.ops.aten.linalg_ldl_factor_ex) +def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False): + # TODO: Replace with native LDL when available: + # https://github.com/jax-ml/jax/issues/12779 + # TODO: Not tested for complex inputs. Does not support hermitian=True + pivots = jnp.broadcast_to( + jnp.arange(1, A.shape[-1]+1, dtype=jnp.int32), A.shape[:-1] + ) + info = jnp.zeros(A.shape[:-2], jnp.int32) + C = jnp.linalg.cholesky(A) + if C.size == 0: + return C, pivots, info + + # Fill diagonals of stacked matrices + @functools.partial(jnp.vectorize, signature='(k,k),(k,k)->(k,k)') + def fill_diagonal_batch(x, y): + return jnp.fill_diagonal(x, jnp.diag(y), inplace=False) + + D = C * jnp.eye(C.shape[-1], dtype=A.dtype) + LD = C @ jnp.linalg.inv(D) + LD = fill_diagonal_batch(LD, D*D) + return LD, pivots, info + + @op(torch.ops.aten.linalg_lu) def _aten_linalg_lu(A, pivot=True, out=None): dtype = A.dtype @@ -2445,6 +2469,15 @@ def perm_to_P(perm): return P,L,U +@op(torch.ops.aten.linalg_lu_factor_ex) +def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False): + lu, pivots, _ = jax.lax.linalg.lu(A) + # PT pivots vector is 1-indexed + pivots = pivots + 1 + info = jnp.zeros(A.shape[:-2], jnp.int32) + return lu, pivots, info + + @op(torch.ops.aten.gcd) def _aten_gcd(input, other): return jnp.gcd(input, other)