Skip to content

Commit

Permalink
Replace assert with fatal in Shape::get_normalized_index (#12352)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt authored Sep 7, 2024
1 parent aa866b6 commit 3a010a7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,33 @@ namespace ttnn::operations::data_movement {

void TilizeWithValPadding::validate(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor_a = input_tensors.at(0);
const auto& input_shape = input_tensor_a.get_legacy_shape();
TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands need to be on device!");
TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands need to be allocated in buffers on device!");
TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Can only tilize row major data");
TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16);
TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16, "Can only tilize bfloat16 tensors");
TT_FATAL(input_shape.rank() >= 2, "Input tensor must be of rank >2, but its shape is {}", input_shape);

TT_FATAL(input_tensor_a.get_legacy_shape()[0] <= this->output_tensor_shape[0]);
TT_FATAL(input_tensor_a.get_legacy_shape()[1] <= this->output_tensor_shape[1]);
TT_FATAL(input_tensor_a.get_legacy_shape()[2] <= this->output_tensor_shape[2]);
TT_FATAL(input_tensor_a.get_legacy_shape()[3] <= this->output_tensor_shape[3]);

uint32_t num_rows = this->output_tensor_shape[2];
uint32_t inner_dim = this->output_tensor_shape[3];
TT_FATAL(num_rows % TILE_HEIGHT == 0, "Output shape must be tilizable");
TT_FATAL(inner_dim % TILE_WIDTH == 0, "Output shape must be tilizable");
for (auto i = 0; i < input_shape.rank(); i++) {
TT_FATAL(input_shape[i] <= this->output_tensor_shape[i],
"Output tensor shape {} must be greater than or equal to input shape {} in each dimension, but is smaller in dimension {}",
this->output_tensor_shape, input_shape, i);
}

uint32_t num_rows = this->output_tensor_shape[-1];
uint32_t inner_dim = this->output_tensor_shape[-2];
TT_FATAL(inner_dim % TILE_WIDTH == 0 && num_rows % TILE_HEIGHT == 0,
"To be tilizable output tensor shape {} must be divisible by tile size ({}, {})",
output_tensor_shape, TILE_WIDTH, TILE_HEIGHT);


if (input_tensor_a.memory_config().is_sharded()) {
TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED);
TT_FATAL(this->output_mem_config.memory_layout == input_tensor_a.memory_config().memory_layout);
TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED, "Input tensor must be width sharded");
TT_FATAL(this->output_mem_config.memory_layout == input_tensor_a.memory_config().memory_layout, "Output tensor must have the same memory layout as input tensor");
for (uint32_t i = 0; i < input_tensor_a.get_legacy_shape().rank(); i++) {
if (i != input_tensor_a.get_legacy_shape().rank() - 2) {
TT_FATAL(input_tensor_a.get_legacy_shape()[i] == this->output_tensor_shape[i]);
if (i != input_shape.rank() - 2) {
TT_FATAL(input_shape[i] == this->output_tensor_shape[i]);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/tensor/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Padding::Padding(const std::vector<PadDimension>& pad_dimensions, PadValue pad_v
const uint32_t Padding::get_normalized_index(std::int64_t index) const {
std::int64_t rank = static_cast<std::int64_t>(this->rank_);
std::uint64_t normalized_index = index >= 0 ? index : rank + index;
TT_ASSERT(
TT_FATAL(
normalized_index >= 0 and normalized_index < rank,
fmt::format(
"Index is out of bounds for the rank, should be between 0 and {} however is {}",
Expand Down Expand Up @@ -164,7 +164,7 @@ const Shape Shape::without_padding() const {
const uint32_t Shape::get_normalized_index(std::int64_t index) const {
std::int64_t rank = static_cast<std::int64_t>(this->rank_);
std::uint64_t normalized_index = index >= 0 ? index : rank + index;
TT_ASSERT(
TT_FATAL(
normalized_index >= 0 and normalized_index < rank,
fmt::format(
"Index is out of bounds for the rank, should be between 0 and {} however is {}",
Expand Down

0 comments on commit 3a010a7

Please sign in to comment.