diff --git a/tests/ttnn/unit_tests/test_to_and_from_torch.py b/tests/ttnn/unit_tests/test_to_and_from_torch.py index 8bd62c6a5fd..4b84f8ea120 100644 --- a/tests/ttnn/unit_tests/test_to_and_from_torch.py +++ b/tests/ttnn/unit_tests/test_to_and_from_torch.py @@ -76,3 +76,10 @@ def test_to_and_from_2D(height, width, dtype, layout): if dtype == ttnn.bfloat8_b: allclose_kwargs["atol"] = 1e-2 assert torch.allclose(torch_input_tensor, torch_output_tensor, **allclose_kwargs) + + +def test_from_torch_large(device): + torch_x = torch.rand((2048, 1024, 32, 32), dtype=torch.bfloat16) + x_tensor = ttnn.from_torch(torch_x, layout=ttnn.TILE_LAYOUT) + x_tensor = ttnn.to_torch(x_tensor) + assert torch.allclose(torch_x, x_tensor) diff --git a/tt_metal/common/test_tiles.hpp b/tt_metal/common/test_tiles.hpp index 8fac4cef897..50674abc39d 100644 --- a/tt_metal/common/test_tiles.hpp +++ b/tt_metal/common/test_tiles.hpp @@ -207,18 +207,18 @@ inline std::vector untilize_nchw(const BufferType& in, tt::stl::Span tilize_nchw(const BufferType& in_rowmajor, tt::stl::Spa return tilized_result; } - int H = shape[shape.size() - 2], W = shape[shape.size() - 1]; - auto batch_size = 1; - for (int i = 0; i < shape.size() - 2; i++) { + uint32_t H = shape[shape.size() - 2], W = shape[shape.size() - 1]; + uint64_t batch_size = 1; + for (uint32_t i = 0; i < shape.size() - 2; i++) { batch_size *= shape[i]; } - int input_volume = batch_size * H * W; + uint64_t input_volume = batch_size * H * W; auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT; auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH; - int OH = round_up_to_tile(H, tile_H); - int OW = round_up_to_tile(W, tile_W); + uint32_t OH = round_up_to_tile(H, tile_H); + uint32_t OW = round_up_to_tile(W, tile_W); tilized_result.resize(batch_size * OH * OW); std::fill(tilized_result.begin(), tilized_result.end(), 0); - int out_index = 0; + uint64_t out_index = 0; for (auto batch_index = 0; batch_index < batch_size; batch_index++) { - for (int hs = 0; hs < H; hs += tile_H) { - for (int ws = 0; ws < W; ws += tile_W) { - for (int ht = 0; ht < tile_H; ht++) { - for (int wt = 0; wt < tile_W; wt++) { + for (auto hs = 0; hs < H; hs += tile_H) { + for (auto ws = 0; ws < W; ws += tile_W) { + for (auto ht = 0; ht < tile_H; ht++) { + for (auto wt = 0; wt < tile_W; wt++) { auto w = wt + ws; auto h = ht + hs; auto in_offs = w + h * W + batch_index * H * W; auto val = (w >= W || h >= H || in_offs >= input_volume) ? 0 : in_rowmajor[in_offs]; - int out_w = (out_index % OW); - int out_h = (out_index / OW) % OH; + auto out_w = (out_index % OW); + auto out_h = (out_index / OW) % OH; TT_ASSERT(w < OW); TT_ASSERT(h < OH); - int out_offs = out_w + out_h * OW + batch_index * OH * OW; + auto out_offs = out_w + out_h * OW + batch_index * OH * OW; tilized_result[out_offs] = val; out_index++; }