Skip to content

Commit

Permalink
#0: Treat empty shape as scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Oct 22, 2024
1 parent 0efa98a commit 02a7f64
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 12 deletions.
1 change: 0 additions & 1 deletion tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,5 @@ INSTANTIATE_TEST_SUITE_P(
CreateTensorParams{.shape=ttnn::SimpleShape({0, 0, 0, 0})},
CreateTensorParams{.shape=ttnn::SimpleShape({0, 1, 32, 32})},
CreateTensorParams{.shape=ttnn::SimpleShape({0})},
CreateTensorParams{.shape=ttnn::SimpleShape({})}
)
);
6 changes: 1 addition & 5 deletions ttnn/cpp/ttnn/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,14 @@ uint32_t element_size_bytes(DataType dtype) {
}

uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const ttnn::SimpleShape& shape, const std::optional<Tile>& tile) {
if (shape.rank() == 0) {
return 1;
}

uint32_t W = shape[-1];
uint32_t page_size = 0;
const auto tile_HW = tile.has_value() ? tile->get_tile_hw() : constants::TILE_HW;
const auto bfloat8b_tile_HW = tile.has_value() ? tile_HW + 64 : constants::BFLOAT8_B_TILE_HW;
const auto bfloat4b_tile_HW = tile.has_value() ? tile_HW / 2 + 64 : constants::BFLOAT4_B_TILE_HW;
switch (layout) {
case Layout::ROW_MAJOR: {
uint32_t size_of_element = element_size_bytes(dtype);
uint32_t W = shape.rank() == 0 ? 1 : shape[-1];
page_size = W * size_of_element;
} break;
case Layout::TILE: {
Expand Down
3 changes: 0 additions & 3 deletions ttnn/cpp/ttnn/tensor/tensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, const std::

// TODO: Remove this once we switch to SimpleShape .volume()
static std::size_t compute_volume(const tt::tt_metal::LegacyShape& shape) {
if (shape.rank() == 0) {
return 0;
}
size_t volume = 1;
for (auto index = 0; index < shape.rank(); index++) {
volume *= shape[index];
Expand Down
3 changes: 0 additions & 3 deletions ttnn/cpp/ttnn/tensor/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,6 @@ uint32_t& SimpleShape::operator[](int32_t index) {
}

uint64_t SimpleShape::volume() const {
if (value.empty()) {
return 0;
}
return std::accumulate(this->value.begin(), this->value.end(),
uint64_t{1}, std::multiplies<uint64_t>());
}
Expand Down

0 comments on commit 02a7f64

Please sign in to comment.