Skip to content

Commit

Permalink
#0: Fix end_to_end_tests to use get_legacy_shape() instead of shape()…
Browse files Browse the repository at this point in the history
… because of @arakhmati's recent Tensor class uplift
  • Loading branch information
tt-rkim committed Mar 5, 2024
1 parent c3cbcc3 commit 9a1eb34
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/end_to_end_tests/test_unit_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,55 +20,55 @@ def test_tile_major_reshape_sweep(reset_seeds, first_grayskull_device):

xtt = tt_lib.tensor.Tensor(x, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
xtt = tt_lib.tensor.reshape(xtt, 5, 3, 96, 64)
assert xtt.shape() == [5, 3, 96, 64]
assert xtt.get_legacy_shape() == [5, 3, 96, 64]
xtt_host = xtt.cpu()
tt_got_back = xtt_host.to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
x = x.reshape([5, 3, 96, 64])
eq = torch.equal(x, tt_got_back)
assert eq

xtt = tt_lib.tensor.reshape(xtt, 3, 5, 64, 96)
assert xtt.shape() == [3, 5, 64, 96]
assert xtt.get_legacy_shape() == [3, 5, 64, 96]
xtt_host = xtt.cpu()
tt_got_back = xtt_host.to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
x = x.reshape([3, 5, 64, 96])
eq = torch.equal(x, tt_got_back)
assert eq

xtt = tt_lib.tensor.reshape(xtt, -1, 5, 96, 64)
assert xtt.shape() == [3, 5, 96, 64]
assert xtt.get_legacy_shape() == [3, 5, 96, 64]
xtt_host = xtt.cpu()
tt_got_back = xtt_host.to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
x = x.reshape([3, 5, 96, 64])
eq = torch.equal(x, tt_got_back)
assert eq

xtt = tt_lib.tensor.reshape(xtt, 3, -1, 64, 96)
assert xtt.shape() == [3, 5, 64, 96]
assert xtt.get_legacy_shape() == [3, 5, 64, 96]
xtt_host = xtt.cpu()
tt_got_back = xtt_host.to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
x = x.reshape([3, 5, 64, 96])
eq = torch.equal(x, tt_got_back)
assert eq

xtt = tt_lib.tensor.reshape(xtt, 3, 5, -1, 64)
assert xtt.shape() == [3, 5, 96, 64]
assert xtt.get_legacy_shape() == [3, 5, 96, 64]
xtt_host = xtt.cpu()
tt_got_back = xtt_host.to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
x = x.reshape([3, 5, 96, 64])
eq = torch.equal(x, tt_got_back)
assert eq

xtt = tt_lib.tensor.reshape(xtt, 3, 5, 64, -1)
assert xtt.shape() == [3, 5, 64, 96]
assert xtt.get_legacy_shape() == [3, 5, 64, 96]
xtt_host = xtt.cpu()
tt_got_back = xtt_host.to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
x = x.reshape([3, 5, 64, 96])
eq = torch.equal(x, tt_got_back)
assert eq

xtt = tt_lib.tensor.reshape(xtt, 3, 5, 32, -1)
assert xtt.shape() == [3, 5, 32, 96 * 2]
assert xtt.get_legacy_shape() == [3, 5, 32, 96 * 2]
xtt_host = xtt.cpu()
tt_got_back = xtt_host.to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
x = x.reshape([3, 5, 32, 96 * 2])
Expand Down

0 comments on commit 9a1eb34

Please sign in to comment.