From 3a010a718a09963c91eda19fde80a7cdd9fe3d2d Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:19:47 -0700 Subject: [PATCH] Replace assert with fatal in Shape::get_normalized_index (#12352) --- .../device/tilize_with_val_padding_op.cpp | 32 +++++++++++-------- ttnn/cpp/ttnn/tensor/types.cpp | 4 +-- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.cpp index d6faf96336a..e041d0ed6d2 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.cpp @@ -12,27 +12,33 @@ namespace ttnn::operations::data_movement { void TilizeWithValPadding::validate(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); + const auto& input_shape = input_tensor_a.get_legacy_shape(); TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands need to be on device!"); TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands need to be allocated in buffers on device!"); TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Can only tilize row major data"); - TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16); + TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16, "Can only tilize bfloat16 tensors"); + TT_FATAL(input_shape.rank() >= 2, "Input tensor must be of rank >2, but its shape is {}", input_shape); - TT_FATAL(input_tensor_a.get_legacy_shape()[0] <= this->output_tensor_shape[0]); - TT_FATAL(input_tensor_a.get_legacy_shape()[1] <= this->output_tensor_shape[1]); - TT_FATAL(input_tensor_a.get_legacy_shape()[2] <= this->output_tensor_shape[2]); - TT_FATAL(input_tensor_a.get_legacy_shape()[3] <= this->output_tensor_shape[3]); - uint32_t num_rows = this->output_tensor_shape[2]; - uint32_t inner_dim = this->output_tensor_shape[3]; - TT_FATAL(num_rows % TILE_HEIGHT == 0, "Output shape must be tilizable"); - TT_FATAL(inner_dim % TILE_WIDTH == 0, "Output shape must be tilizable"); + for (auto i = 0; i < input_shape.rank(); i++) { + TT_FATAL(input_shape[i] <= this->output_tensor_shape[i], + "Output tensor shape {} must be greater than or equal to input shape {} in each dimension, but is smaller in dimension {}", + this->output_tensor_shape, input_shape, i); + } + + uint32_t num_rows = this->output_tensor_shape[-1]; + uint32_t inner_dim = this->output_tensor_shape[-2]; + TT_FATAL(inner_dim % TILE_WIDTH == 0 && num_rows % TILE_HEIGHT == 0, + "To be tilizable output tensor shape {} must be divisible by tile size ({}, {})", + output_tensor_shape, TILE_WIDTH, TILE_HEIGHT); + if (input_tensor_a.memory_config().is_sharded()) { - TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED); - TT_FATAL(this->output_mem_config.memory_layout == input_tensor_a.memory_config().memory_layout); + TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED, "Input tensor must be width sharded"); + TT_FATAL(this->output_mem_config.memory_layout == input_tensor_a.memory_config().memory_layout, "Output tensor must have the same memory layout as input tensor"); for (uint32_t i = 0; i < input_tensor_a.get_legacy_shape().rank(); i++) { - if (i != input_tensor_a.get_legacy_shape().rank() - 2) { - TT_FATAL(input_tensor_a.get_legacy_shape()[i] == this->output_tensor_shape[i]); + if (i != input_shape.rank() - 2) { + TT_FATAL(input_shape[i] == this->output_tensor_shape[i]); } } } diff --git a/ttnn/cpp/ttnn/tensor/types.cpp b/ttnn/cpp/ttnn/tensor/types.cpp index e0209c84df8..627cb115955 100644 --- a/ttnn/cpp/ttnn/tensor/types.cpp +++ b/ttnn/cpp/ttnn/tensor/types.cpp @@ -66,7 +66,7 @@ Padding::Padding(const std::vector& pad_dimensions, PadValue pad_v const uint32_t Padding::get_normalized_index(std::int64_t index) const { std::int64_t rank = static_cast(this->rank_); std::uint64_t normalized_index = index >= 0 ? index : rank + index; - TT_ASSERT( + TT_FATAL( normalized_index >= 0 and normalized_index < rank, fmt::format( "Index is out of bounds for the rank, should be between 0 and {} however is {}", @@ -164,7 +164,7 @@ const Shape Shape::without_padding() const { const uint32_t Shape::get_normalized_index(std::int64_t index) const { std::int64_t rank = static_cast(this->rank_); std::uint64_t normalized_index = index >= 0 ? index : rank + index; - TT_ASSERT( + TT_FATAL( normalized_index >= 0 and normalized_index < rank, fmt::format( "Index is out of bounds for the rank, should be between 0 and {} however is {}",