Skip to content

Commit

Permalink
remove debug prints in lu_unpack implementation (#8282)
Browse files Browse the repository at this point in the history
  • Loading branch information
barney-s authored Oct 18, 2024
1 parent 2d73a5f commit cec5052
Showing 1 changed file with 0 additions and 8 deletions.
8 changes: 0 additions & 8 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4844,22 +4844,17 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
tile_shape[-1] = 1
tile_shape[-2] = 1
P = jnp.tile(identity2d, tile_shape)
#print("debug: start permutation matrix:", P)

# closure to be called for each input 2D matrix.
def _lu_unpack_2d(p, pivot):
jax.debug.print("unpack2d: {} {} {}", p , pivot, pivot.size)
_pivot = pivot - 1 # pivots are offset by 1 in jax
indices = jnp.array([*range(n)], dtype=jnp.int32)
def update_indices(i, _indices):
#jax.debug.print("fori <<: {} {} {} {}", i, _indices, _pivot, p)
tmp = _indices[i]
_indices = _indices.at[i].set(_indices[_pivot[i]])
_indices = _indices.at[_pivot[i]].set(tmp)
#jax.debug.print("fori >>: {} {} {} {}", i, _indices, _pivot, p)
return _indices
indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices)
#jax.debug.print("indices {}", indices)
p = p[jnp.array(indices)]
p = jnp.transpose(p)
return p
Expand Down Expand Up @@ -4888,12 +4883,9 @@ def update_indices(i, _indices):

# reshape result back to P's shape
newRetshape = (*P.shape[:-2], unpackedP.shape[-2], unpackedP.shape[-1])
#print("newshape: {} {}", newRetshape, unpackedP.shape, dim)
P = unpackedP.reshape(newRetshape)
#print("permutation after: ", P)
else:
# emulate pytroch behavior: return empty tensors
P = torch.empty(torch.Size([0]))

#print("debug output:", P, L, U)
return P, L, U

0 comments on commit cec5052

Please sign in to comment.