diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index f63f018f9de..dcae8b84e14 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -29,7 +29,6 @@ "linalg.cholesky_ex", "linalg.det", "linalg.ldl_solve", - "linalg.lstsq", "linalg.lu_solve", "linalg.matrix_norm", "linalg.matrix_power", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index d3cdb23a56c..ce7e90d0d8f 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2415,6 +2415,65 @@ def _aten_linalg_eigh(A, UPLO='L'): return jnp.linalg.eigh(A, UPLO) +@op(torch.ops.aten.linalg_lstsq) +def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'): + input_dtype = A.dtype + + m = A.shape[-2] + n = A.shape[-1] + + is_batched = A.ndim > 2 + + if is_batched: + + batch_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2]) + batch_size = int(np.prod(batch_shape)) + A_reshaped = A.reshape((batch_size,) + A.shape[-2:]) + B_reshaped = B.reshape((batch_size,) + B.shape[-2:]) + + X, residuals, rank, singular_values = jax.vmap(jnp.linalg.lstsq, in_axes=(0, 0))(A_reshaped, B_reshaped, rcond=rcond) + + X = X.reshape(batch_shape + X.shape[-2:]) + + if driver in ['gelsd', 'gelsy', 'gelss']: + rank = rank.reshape(batch_shape) + else: + rank = jnp.array([], dtype=jnp.int64) + + full_rank = jnp.all(rank == n) + if driver == 'gelsy' or m <= n or (not full_rank): + residuals = jnp.array([], dtype=input_dtype) + else: + residuals = residuals.reshape(batch_shape + residuals.shape[-1:]) + + if driver in ['gelsd', 'gelss']: + singular_values = singular_values.reshape(batch_shape + singular_values.shape[-1:]) + else: + singular_values = jnp.array([], dtype=input_dtype) + + else: + + X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond) + + if driver not in ['gelsd', 'gelsy', 'gelss']: + rank = jnp.array([], dtype=jnp.int64) + + rank_value = None + if rank.size > 0: + rank_value = int(rank.item()) + rank = jnp.array(rank_value, dtype=jnp.int64) + + # When driver is ‘gels’, assume that A is full-rank. + full_rank = driver == 'gels' or rank_value == n + if driver == 'gelsy' or m <= n or (not full_rank): + residuals = jnp.array([], dtype=input_dtype) + + if driver not in ['gelsd', 'gelss']: + singular_values = jnp.array([], dtype=input_dtype) + + return X, residuals, rank, singular_values + + @op(torch.ops.aten.linalg_ldl_factor_ex) def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False): # TODO: Replace with native LDL when available: