Skip to content

Commit

Permalink
Add support for linalg.lstsq (#8261)
Browse files Browse the repository at this point in the history
  • Loading branch information
matinehAkhlaghinia authored Oct 16, 2024
1 parent 349958f commit 32afdbb
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"linalg.cholesky_ex",
"linalg.det",
"linalg.ldl_solve",
"linalg.lstsq",
"linalg.lu_solve",
"linalg.matrix_norm",
"linalg.matrix_power",
Expand Down
59 changes: 59 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2415,6 +2415,65 @@ def _aten_linalg_eigh(A, UPLO='L'):
return jnp.linalg.eigh(A, UPLO)


@op(torch.ops.aten.linalg_lstsq)
def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
input_dtype = A.dtype

m = A.shape[-2]
n = A.shape[-1]

is_batched = A.ndim > 2

if is_batched:

batch_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2])
batch_size = int(np.prod(batch_shape))
A_reshaped = A.reshape((batch_size,) + A.shape[-2:])
B_reshaped = B.reshape((batch_size,) + B.shape[-2:])

X, residuals, rank, singular_values = jax.vmap(jnp.linalg.lstsq, in_axes=(0, 0))(A_reshaped, B_reshaped, rcond=rcond)

X = X.reshape(batch_shape + X.shape[-2:])

if driver in ['gelsd', 'gelsy', 'gelss']:
rank = rank.reshape(batch_shape)
else:
rank = jnp.array([], dtype=jnp.int64)

full_rank = jnp.all(rank == n)
if driver == 'gelsy' or m <= n or (not full_rank):
residuals = jnp.array([], dtype=input_dtype)
else:
residuals = residuals.reshape(batch_shape + residuals.shape[-1:])

if driver in ['gelsd', 'gelss']:
singular_values = singular_values.reshape(batch_shape + singular_values.shape[-1:])
else:
singular_values = jnp.array([], dtype=input_dtype)

else:

X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond)

if driver not in ['gelsd', 'gelsy', 'gelss']:
rank = jnp.array([], dtype=jnp.int64)

rank_value = None
if rank.size > 0:
rank_value = int(rank.item())
rank = jnp.array(rank_value, dtype=jnp.int64)

# When driver is ‘gels’, assume that A is full-rank.
full_rank = driver == 'gels' or rank_value == n
if driver == 'gelsy' or m <= n or (not full_rank):
residuals = jnp.array([], dtype=input_dtype)

if driver not in ['gelsd', 'gelss']:
singular_values = jnp.array([], dtype=input_dtype)

return X, residuals, rank, singular_values


@op(torch.ops.aten.linalg_ldl_factor_ex)
def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
# TODO: Replace with native LDL when available:
Expand Down

0 comments on commit 32afdbb

Please sign in to comment.