diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index b242636c5ca..3b4e7d98cc0 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -31,7 +31,6 @@ "linalg.lu_solve", "linalg.matrix_norm", "linalg.matrix_power", - "linalg.tensorsolve", "masked.median", "max_pool2d_with_indices_backward", "nn.functional.adaptive_avg_pool3d", diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index e82621bec3b..b6f4b1ca9d3 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -376,3 +376,25 @@ def lu_solve(b, LU_data, LU_pivots, **kwargs): _pivots = LU_pivots - 1 x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b) return x + +@register_function(torch.linalg.tensorsolve) +def linalg_tensorsolve(A, b, dims=None): + # examples: + # A = torch.randn(2, 3, 6), b = torch.randn(3, 2) + # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6]) + # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3]) + # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6]) + # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3]) + # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6]) + # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6]) + + # torch allows b to be shaped differently. + # especially when axes are moved using dims. + # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3) + # So we are handling the moveaxis and forcing b's shape to match what jax expects + if dims is not None: + A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,)) + dims = None + if A.shape[:b.ndim] != b.shape: + b = jnp.reshape(b, A.shape[:b.ndim]) + return jnp.linalg.tensorsolve(A, b, axes=dims)