Skip to content

Commit

Permalink
Fix op info test for linalg.lu_factor and linalg.ldl_factor (pytorch#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg authored Oct 16, 2024
1 parent ee388f6 commit 349958f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
4 changes: 0 additions & 4 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@
"linalg.cholesky",
"linalg.cholesky_ex",
"linalg.det",
"linalg.ldl_factor",
"linalg.ldl_factor_ex",
"linalg.ldl_solve",
"linalg.lstsq",
"linalg.lu_factor",
"linalg.lu_factor_ex",
"linalg.lu_solve",
"linalg.matrix_norm",
"linalg.matrix_power",
Expand Down
33 changes: 33 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2415,6 +2415,30 @@ def _aten_linalg_eigh(A, UPLO='L'):
return jnp.linalg.eigh(A, UPLO)


@op(torch.ops.aten.linalg_ldl_factor_ex)
def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
# TODO: Replace with native LDL when available:
# https://github.com/jax-ml/jax/issues/12779
# TODO: Not tested for complex inputs. Does not support hermitian=True
pivots = jnp.broadcast_to(
jnp.arange(1, A.shape[-1]+1, dtype=jnp.int32), A.shape[:-1]
)
info = jnp.zeros(A.shape[:-2], jnp.int32)
C = jnp.linalg.cholesky(A)
if C.size == 0:
return C, pivots, info

# Fill diagonals of stacked matrices
@functools.partial(jnp.vectorize, signature='(k,k),(k,k)->(k,k)')
def fill_diagonal_batch(x, y):
return jnp.fill_diagonal(x, jnp.diag(y), inplace=False)

D = C * jnp.eye(C.shape[-1], dtype=A.dtype)
LD = C @ jnp.linalg.inv(D)
LD = fill_diagonal_batch(LD, D*D)
return LD, pivots, info


@op(torch.ops.aten.linalg_lu)
def _aten_linalg_lu(A, pivot=True, out=None):
dtype = A.dtype
Expand Down Expand Up @@ -2445,6 +2469,15 @@ def perm_to_P(perm):
return P,L,U


@op(torch.ops.aten.linalg_lu_factor_ex)
def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False):
lu, pivots, _ = jax.lax.linalg.lu(A)
# PT pivots vector is 1-indexed
pivots = pivots + 1
info = jnp.zeros(A.shape[:-2], jnp.int32)
return lu, pivots, info


@op(torch.ops.aten.gcd)
def _aten_gcd(input, other):
return jnp.gcd(input, other)
Expand Down

0 comments on commit 349958f

Please sign in to comment.