diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index e3ed7539d31..61d15f6c2f5 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -236,6 +236,11 @@ def _aten_index_select(x, dim, index): return jnp.take(x, index, dim) +@op(torch.ops.aten.cholesky) +def _aten_cholesky(input, upper=False): + return jax.scipy.linalg.cholesky(input, lower=(not upper)) + + # aten.igammac @op(torch.ops.aten.igammac) def _aten_igammac(input, other):