Skip to content

Commit

Permalink
Fix unsigned arithmetic bugs in reshape ops (#16253)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrakitaTT authored Dec 24, 2024
1 parent be6ed46 commit d30b8c2
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ ttnn::Tensor convert_tensor_to_rm_reshape_convert_back_to_orig_layout(
//This function turns ND -> MD into 2D->MD for row major and 3D->MD for tiled using a 0 cost view
const auto layout = tensor.get_layout();
const auto tensor_shape = tensor.get_shape();
TT_FATAL((tensor_shape.rank()!=0), "can't do reshape from rank 0 tensor");
TT_FATAL((tensor_shape.rank() != 0), "Can't do reshape from rank 0 tensor");
if(layout == ttnn::ROW_MAJOR_LAYOUT)
{
//Collapse into the second last dimension
uint32_t second_dim = 1;
for (int i=0; i <tensor_shape.rank()-1; i++)
for (int64_t i = 0; i < static_cast<int64_t>(tensor_shape.rank()) - 1; ++i)
{
second_dim = second_dim * tensor_shape[i];
}
Expand All @@ -97,7 +97,7 @@ ttnn::Tensor convert_tensor_to_rm_reshape_convert_back_to_orig_layout(
{
uint32_t third_dim = 1;
//Collapse into the third last dimension
for (int i=0; i <tensor_shape.rank()-2; i++)
for (int64_t i = 0; i < static_cast<int64_t>(tensor_shape.rank()) - 2; ++i)
{
third_dim = third_dim * tensor_shape[i];
}
Expand All @@ -120,7 +120,7 @@ ttnn::Tensor convert_tensor_to_rm_reshape_convert_back_to_orig_layout(
pad_value
);
}
TT_FATAL(false, "layout is neither tile nor row major");
TT_FATAL(false, "Layout is neither tile nor row major");

}

Expand All @@ -136,9 +136,9 @@ ttnn::Tensor fix_shape_and_perform_reshape_on_3D_TILE(
{
//This function turns a TILE 3D->MD into an equivalent 3D->3D conversion and then turns the 3D output back to MD using a 0 cost view
//Collapse into the third last dimension
TT_FATAL((shape.rank()!=0), "can't do reshape to rank 0 tensor");
TT_FATAL((shape.rank() != 0), "Can't do reshape to rank 0 tensor");
uint32_t third_dim = 1;
for (int i=0; i <shape.rank()-2; i++)
for (int64_t i = 0; i < static_cast<int64_t>(shape.rank()) - 2; ++i)
{
third_dim = third_dim * shape[i];
}
Expand Down Expand Up @@ -170,10 +170,10 @@ ttnn::Tensor fix_shape_and_perform_reshape_on_2D_RM(
)
{
//This function turns a RM 2D->MD into an equivalent 2D->2D conversion and then turns the 2D output back to MD using a 0 cost view
TT_FATAL((shape.rank()!=0), "can't do reshape to rank 0 tensor");
TT_FATAL((shape.rank() != 0), "Can't do reshape to rank 0 tensor");
//Collapse into the second last dimension
uint32_t second_dim = 1;
for (int i=0; i <shape.rank()-1; i++)
for (int64_t i = 0; i < static_cast<int64_t>(shape.rank()) - 1; ++i)
{
second_dim = second_dim * shape[i];
}
Expand Down Expand Up @@ -241,8 +241,8 @@ ttnn::Tensor perform_reshape_on_2D_RM(

ttnn::Shape tiling_reshape_corrector(const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim) {
//Apply the correct padding metadata to the target shape
auto padded = shape.with_tile_padding();
auto rank = shape.rank();
ttnn::Shape padded = shape.with_tile_padding();
int64_t rank = shape.rank();
const int8_t correction_1 =(tile_first_dim - (int)padded[-1] % tile_first_dim) % tile_first_dim;
if(rank == 1)
{
Expand Down

0 comments on commit d30b8c2

Please sign in to comment.