Skip to content

Commit

Permalink
fix inv and inv_ex
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg committed Oct 9, 2024
1 parent 07d0823 commit 99f99f2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 0 additions & 2 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
"linalg.cholesky",
"linalg.cholesky_ex",
"linalg.det",
"linalg.inv",
"linalg.inv_ex",
"linalg.ldl_factor",
"linalg.ldl_factor_ex",
"linalg.ldl_solve",
Expand Down
11 changes: 11 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,6 +2333,12 @@ def _aten_histc(input, bins=100, min=0, max=0):
return hist


# Used by some linalg functions to raise an exception
# when check_errors == True. This is currently a no-op.
@op(torch.ops.aten._linalg_check_errors)
def _aten_linalg_check_errors(A, api_name, is_matrix):
...

@op(torch.ops.aten.hypot)
def _aten_hypot(input, other):
return jnp.hypot(input, other)
Expand All @@ -2358,6 +2364,11 @@ def _aten_linalg_eig(A):
def _aten_linalg_eigh(A, UPLO='L'):
return jnp.linalg.eigh(A, UPLO)

@op(torch.ops.aten.linalg_inv_ex)
def _aten_linalg_inv_ex(A):

return jnp.linalg.inv(A), jnp.zeros(A.shape[:-2], jnp.int32)


@op(torch.ops.aten.linalg_lu)
def _aten_linalg_lu(A, pivot=True, out=None):
Expand Down

0 comments on commit 99f99f2

Please sign in to comment.