Skip to content

Commit

Permalink
#0: Add support for 0-volume and 1-volume tensors for ttnn::add (#14611)
Browse files Browse the repository at this point in the history
* #0: Add support for 0-volume and 1-volume tensors for ttnn::add

* #0: Fixed an assert

* #0: Tests pass in debug

* #0: Cover more 0-volume and 1-volume cases

* #0: replace TT_ASSERT with TT_FATAL

* #0: Remove 0 size check for circular buffer
  • Loading branch information
sminakov-tt authored Nov 6, 2024
1 parent fcacbef commit 179f62b
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 98 deletions.
78 changes: 53 additions & 25 deletions tests/ttnn/unit_tests/operations/eltwise/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@


@pytest.mark.parametrize(
"shapes", [[[63, 1, 4], [1, 9, 4]], [[13600, 1, 4], [1, 9, 4]], [[1, 16, 6, 64, 64], [1, 16, 1, 64, 64]]]
"shapes",
[
[[63, 1, 4], [1, 9, 4]],
[[13600, 1, 4], [1, 9, 4]],
[[1, 16, 6, 64, 64], [1, 16, 1, 64, 64]],
[[63, 1, 4], [1, 1, 1]],
],
)
def test_non_4D_channel_bcast(device, shapes):
torch.manual_seed(0)
Expand All @@ -34,7 +40,7 @@ def test_non_4D_channel_bcast(device, shapes):


@pytest.mark.parametrize("scalar", [3])
@pytest.mark.parametrize("size", [64])
@pytest.mark.parametrize("size", [64, 1, 0])
def test_add_1D_tensor_and_scalar(device, scalar, size):
torch.manual_seed(0)

Expand All @@ -49,11 +55,10 @@ def test_add_1D_tensor_and_scalar(device, scalar, size):
assert output_tensor.shape == (size,)


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
def test_add_2D_tensors(device, h, w):
torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((h, w), dtype=torch.bfloat16)
@pytest.mark.parametrize("hw", [(32, 64), (1, 1), (0, 0)])
def test_add_2D_tensors(device, hw):
torch_input_tensor_a = torch.rand(hw, dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand(hw, dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -64,11 +69,10 @@ def test_add_2D_tensors(device, h, w):
assert_with_pcc(torch_output_tensor, output, 0.9999)


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
def test_add_2D_tensors_with_program_cache(device, h, w, use_program_cache):
torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((h, w), dtype=torch.bfloat16)
@pytest.mark.parametrize("hw", [(32, 64), (1, 1), (0, 0)])
def test_add_2D_tensors_with_program_cache(device, hw, use_program_cache):
torch_input_tensor_a = torch.rand(hw, dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand(hw, dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -79,11 +83,10 @@ def test_add_2D_tensors_with_program_cache(device, h, w, use_program_cache):
assert_with_pcc(torch_output_tensor, output, 0.9999)


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
@pytest.mark.parametrize("hw", [(32, 64), (1, 1), (0, 0)])
@pytest.mark.parametrize("scalar", [0.42])
def test_add_scalar(device, h, w, scalar):
torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16)
def test_add_scalar(device, hw, scalar):
torch_input_tensor_a = torch.rand(hw, dtype=torch.bfloat16)
torch_output_tensor = scalar + torch_input_tensor_a

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -93,11 +96,10 @@ def test_add_scalar(device, h, w, scalar):
assert_with_pcc(torch_output_tensor, output, 0.9999)


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
@pytest.mark.parametrize("hw", [(32, 64), (1, 1), (0, 0)])
@pytest.mark.parametrize("scalar", [0.42])
def test_reverse_add_scalar(device, h, w, scalar):
torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16)
def test_reverse_add_scalar(device, hw, scalar):
torch_input_tensor_a = torch.rand(hw, dtype=torch.bfloat16)
torch_output_tensor = scalar + torch_input_tensor_a

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -107,11 +109,10 @@ def test_reverse_add_scalar(device, h, w, scalar):
assert_with_pcc(torch_output_tensor, output, 0.9999)


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
def test_add_4D_tensors(device, h, w):
torch_input_tensor_a = torch.rand((5, 64, h, w), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((5, 64, h, w), dtype=torch.bfloat16)
@pytest.mark.parametrize("hw", [(32, 64), (1, 1), (0, 0)])
def test_add_4D_tensors(device, hw):
torch_input_tensor_a = torch.rand((5, 64, hw[0], hw[1]), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((5, 64, hw[0], hw[1]), dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
Expand Down Expand Up @@ -487,3 +488,30 @@ def test_add_with_block_sharding(device, input_a_sharded, input_b_sharded, out_s
output_tensor = ttnn.to_torch(output_tensor)
assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988
assert output_tensor.shape == shape


@pytest.mark.parametrize(
"data",
[
([], [], []),
([1], [2], [3]),
([1], [], []),
([], [1], []),
([1, 2], [3], [4, 5]),
([1], [2, 3], [3, 4]),
([1, 2], [3, 4], [4, 6]),
],
)
@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
def test_01_volume_tensors(device, data, memory_config):
(a, b, c_golden) = data
a = torch.BFloat16Tensor(a)
b = torch.BFloat16Tensor(b)
assert torch.add(a, b).tolist() == c_golden

ttnn_a = ttnn.from_torch(a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=memory_config)
ttnn_b = ttnn.from_torch(b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=memory_config)
ttnn_c = ttnn.add(ttnn_a, ttnn_b)
c = ttnn.to_torch(ttnn_c).reshape((-1))

assert c.tolist() == c_golden
2 changes: 0 additions & 2 deletions tt_metal/common/test_tiles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ std::vector<T> convert_to_tile_layout(
auto face_HW = face_H * face_W;
bool transpose_face = transpose_within_face.has_value() ? transpose_within_face.value() : false;
bool transpose_face_order = transpose_of_faces.has_value() ? transpose_of_faces.value() : false;
TT_ASSERT(data.size() / tile_HW > 0);
TT_ASSERT(data.size() % tile_HW == 0);
int num_tiles = data.size() / tile_HW;
for(int tile_idx = 0; tile_idx < num_tiles; tile_idx++) {
Expand Down Expand Up @@ -137,7 +136,6 @@ std::vector<T> convert_to_flat_layout(
auto num_faces_row = tile_H / face_H;
bool transpose_face = transpose_within_face.has_value() ? transpose_within_face.value() : false;
bool transpose_face_order = transpose_of_faces.has_value() ? transpose_of_faces.value() : false;
TT_ASSERT(data.size() / tile_HW > 0);
TT_ASSERT(data.size() % tile_HW == 0);
int num_tiles = data.size() / tile_HW;
for(int tile_idx = 0; tile_idx < num_tiles; tile_idx++) {
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/common/work_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_

CoreRangeSet core_group_1;
CoreRangeSet core_group_2;
uint32_t units_per_core_group_1 = units_to_divide / target_num_cores;
uint32_t units_per_core_group_1 = target_num_cores == 0 ? 0 : units_to_divide / target_num_cores;
uint32_t units_per_core_group_2 = 0;
// Evenly divided units to all target cores
if (units_to_divide % target_num_cores == 0) {
if (target_num_cores == 0 || units_to_divide % target_num_cores == 0) {
core_group_1 = all_cores;
// Uneven division of units across cores
// This case should only be hit when there are more units of work than a full grid of cores
Expand Down
3 changes: 0 additions & 3 deletions tt_metal/impl/buffers/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ CircularBuffer::CircularBuffer(const CoreRangeSet &core_ranges, const CircularBu
core_ranges_(core_ranges),
config_(config),
locally_allocated_address_(std::nullopt) {
if (this->config_.total_size() == 0) {
TT_THROW("Circular Buffer Config Error: Circular buffer size cannot be 0 B");
}

for (uint8_t buffer_index = 0; buffer_index < NUM_CIRCULAR_BUFFERS; buffer_index++) {
std::optional<DataFormat> data_format_spec = this->config_.data_formats().at(buffer_index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ struct BlockSplit {
};

inline BlockSplit split_blocks_for_tilize(CoreCoord grid_size, uint32_t nblocks) {
const uint32_t nblocks_per_core = std::ceil(static_cast<float>(nblocks) / (grid_size.x * grid_size.y));
const uint32_t ncores = std::ceil(static_cast<float>(nblocks) / nblocks_per_core);
const uint32_t nblocks_per_core_cliff = nblocks % nblocks_per_core;
size_t grid_area = grid_size.x * grid_size.y;
const uint32_t nblocks_per_core = grid_area == 0 ? 1 : std::ceil(static_cast<float>(nblocks) / grid_area);
const uint32_t ncores = nblocks_per_core == 0 ? nblocks : std::ceil(static_cast<float>(nblocks) / nblocks_per_core);
const uint32_t nblocks_per_core_cliff = nblocks_per_core == 0 ? 0 : nblocks % nblocks_per_core;
const uint32_t ncores_x = grid_size.x;
const uint32_t ncores_y = std::ceil(static_cast<float>(ncores) / ncores_x);
const uint32_t ncores_y = ncores_x == 0 ? 0 : std::ceil(static_cast<float>(ncores) / ncores_x);
const uint32_t ncores_x_cliff = ncores - (ncores_y - 1) * ncores_x;

std::set<CoreRange> core_range, cliff_core_range;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ void UntilizeWithUnpadding::validate(const std::vector<Tensor>& input_tensors) c
TT_FATAL(input_tensor_a.get_layout() == Layout::TILE, "Can only untilize tile major data");

TT_FATAL(input_tensor_a.volume() % tt::constants::TILE_HW == 0, "Error");
for (uint32_t i = 0; i < input_tensor_a.get_legacy_shape().rank(); i++) {
TT_FATAL(input_tensor_a.get_legacy_shape()[i] > 0, "Error");
TT_FATAL(this->output_tensor_end[i] < input_tensor_a.get_legacy_shape()[i], "Error");
}

TT_FATAL(((this->output_tensor_end[-1] + 1) % 2 == 0), "Can only unpad to row major tensor of even width");

if (input_tensor_a.memory_config().is_sharded()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved(
Device* device = a.device();
CoreCoord grid_size = device->compute_with_storage_grid_size();

uint32_t num_blocks = a.volume() / input_shape[-1] / TILE_HEIGHT;
uint32_t num_blocks = input_shape[-1] == 0 ? 0 : a.volume() / input_shape[-1] / TILE_HEIGHT;
uint32_t num_tiles_per_row = a.get_legacy_shape()[-1] / TILE_WIDTH;

auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ namespace ttnn::operations::binary {
BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_factory(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
ZoneScopedN("BinaryDeviceOperation::select_program_factory");
const auto& input_shape_a = tensor_args.input_tensor_a.tensor_attributes->shape;
const auto input_shape_a = tensor_args.input_tensor_a.get_logical_shape();

if (operation_attributes.scalar.has_value()) {
return BroadcastHeightAndWidthMultiCore{};
}

const auto& input_shape_b = tensor_args.input_tensor_b->tensor_attributes->shape;
const auto input_shape_b = tensor_args.input_tensor_b->get_logical_shape();

auto height_a = input_shape_a[-2];
auto width_a = input_shape_a[-1];
Expand Down Expand Up @@ -120,15 +120,15 @@ void BinaryDeviceOperation::validate_on_program_cache_hit(
const auto& input_tensor_a = tensor_args.input_tensor_a;
const auto& output_tensor = tensor_args.output_tensor;

const auto& input_shape_a = input_tensor_a.get_shape();
const auto& input_shape_a = input_tensor_a.get_logical_shape();

auto batch_size_0_a = input_shape_a.rank() >= 4 ? input_shape_a[-4] : 1;
auto batch_size_1_a = input_shape_a.rank() >= 3 ? input_shape_a[-3] : 1;
auto height_a = input_shape_a[-2];
auto width_a = input_shape_a[-1];

const auto input_shape_b =
tensor_args.input_tensor_b.has_value() ? tensor_args.input_tensor_b->get_shape() : ttnn::Shape{1, 1};
tensor_args.input_tensor_b.has_value() ? tensor_args.input_tensor_b->get_logical_shape() : ttnn::SimpleShape{1, 1};
auto batch_size_0_b = input_shape_b.rank() >= 4 ? input_shape_b[-4] : 1;
auto batch_size_1_b = input_shape_b.rank() >= 3 ? input_shape_b[-3] : 1;
auto height_b = input_shape_b[-2];
Expand All @@ -145,14 +145,9 @@ void BinaryDeviceOperation::validate_on_program_cache_hit(
batch_size_1_a > batch_size_1_b and batch_size_1_b == 1,
"ttnn::operations::binary::BinaryDeviceOperation: batch size mismatch");
}
if (height_a != height_b) {
TT_ASSERT(
height_a > height_b and height_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: height mismatch");
}
if (width_a != width_b) {
TT_ASSERT(
width_a > width_b and width_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: width mismatch");
}

TT_FATAL(height_a == height_b || height_a == 1 || height_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: height mismatch");
TT_FATAL(width_a == width_b || width_a == 1 || width_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: width mismatch");
}

BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes(
Expand All @@ -165,46 +160,31 @@ BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_outpu
const int rank_b = input_shape_b.rank();
const int larger_rank = std::max(rank_a, rank_b);

// -------------------------------------------------------------------------
// This lambda function computes the broadcasted output shape between two tensors.
// It follows the broadcasting rules to determine the shape of the result
// when performing binary operations on tensors of potentially different shapes and ranks.
//
// Broadcasting Rules Overview:
// - If the two tensors have different ranks, we virtually pad the smaller-rank tensor's shape
// with ones on the left (i.e., higher-order dimensions) until both shapes have the same length.
// - For each dimension (starting from the rightmost), the sizes are compatible if:
// - They are equal, or
// - One of them is 1 (the dimension can be broadcast to match the other size).
// - The result dimension is the maximum of the two sizes.
//
// Key Points:
// - Negative indexing simplifies dimension alignment from the right (least significant dimensions),
// thats essential for correct broadcasting.
// - By defaulting to 1 for missing dimensions, we correctly handle tensors of different ranks.
// - The use of 'std::max' ensures that when one of the dimensions is 1, the other dimension size
// is used, adhering to broadcasting rules. Important! Code assumes that shapes are validated beforehand.
// - The lambda is reused for both logical shapes and padded shapes, ensuring consistency.
// -------------------------------------------------------------------------
auto compute_broadcasted_output = [rank_a, rank_b, larger_rank](const auto& shape_a, const auto& shape_b) {
SmallVector<uint32_t> output_shape(larger_rank, 1);
for (int i = -1; i >= -larger_rank; --i) {
auto dim_a = (i >= -rank_a) ? shape_a[i] : 1;
auto dim_b = (i >= -rank_b) ? shape_b[i] : 1;
output_shape[i + larger_rank] = std::max(dim_a, dim_b);
if (dim_a != 1 && dim_b != 1) {
TT_FATAL(dim_a == dim_b, "Incompatible dimensions {} and {}", dim_a, dim_b);
output_shape[i + larger_rank] = dim_a;
} else {
// One of the dimension is one, calculating the other one
output_shape[i + larger_rank] = dim_a + dim_b - 1;
}
}
return output_shape;
};

const auto logical_shape_a = input_shape_a.logical_shape();
const auto logical_shape_b = input_shape_b.logical_shape();
const auto output_shape = compute_broadcasted_output(logical_shape_a, logical_shape_b);

const auto padded_shape_a = input_shape_a.padded_shape();
const auto padded_shape_b = input_shape_b.padded_shape();
const auto output_shape_with_tile_padding = compute_broadcasted_output(padded_shape_a, padded_shape_b);

return ttnn::Shape(output_shape, output_shape_with_tile_padding);
return ttnn::SimpleShape(compute_broadcasted_output(logical_shape_a, logical_shape_b));
}

BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct BinaryDeviceOperation {
std::optional<Tensor> input_tensor_b;
std::optional<Tensor> output_tensor;
};
using shape_return_value_t = ttnn::Shape;
using shape_return_value_t = ttnn::SimpleShape;
using tensor_return_value_t = Tensor;

struct ElementWiseMultiCore {
Expand Down
13 changes: 0 additions & 13 deletions ttnn/cpp/ttnn/operations/numpy/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,6 @@ static Tensor full(
const MemoryConfig& output_mem_config = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED},
std::optional<Tensor> optional_output_tensor = std::nullopt) {
if (layout == Layout::TILE) {
if (shape.rank() < 2) {
TT_THROW("TILE layout requires rank >= 2");
}
TT_FATAL(
shape[-1] % tt::constants::TILE_WIDTH == 0,
"TILE layout requires width dimension to be multiple of 32");

TT_FATAL(
shape[-2] % tt::constants::TILE_HEIGHT == 0,
"TILE layout requires height dimension to be multiple of 32");
}

constexpr DataType data_type = detail::get_data_type<T>();
auto owned_buffer = tt::tt_metal::owned_buffer::create<T>(tt::tt_metal::compute_volume(shape));
std::fill(std::begin(owned_buffer), std::end(owned_buffer), value);
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ namespace {
namespace CMAKE_UNIQUE_NAMESPACE {

size_t round_up(size_t value, size_t multiple) {
TT_FATAL(multiple != 0, "round_up: multiple must not be 0");
if (multiple == 0) {
return value;
}

// can be faster if multiple is power of 2
// return (value + multiple - 1) & ~(multiple - 1);
return ((value + multiple - 1) / multiple) * multiple;
};

Expand Down
8 changes: 4 additions & 4 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ def _golden_function(tensor, *, torch_rank=None, **kwargs):
if torch_rank is None:
return tensor

while len(tensor.shape) != torch_rank:
while len(tensor.shape) > torch_rank:
if tensor.shape[0] != 1:
raise RuntimeError("ttnn: Unable to squeeze to desired rank!")
tensor = tensor.squeeze()
tensor = tensor.squeeze(0)
return tensor


Expand Down Expand Up @@ -304,10 +304,10 @@ def to_torch(
tensor = tensor[slices]

if torch_rank is not None:
while len(tensor.shape) != torch_rank:
while len(tensor.shape) > torch_rank:
if tensor.shape[0] != 1:
raise RuntimeError("ttnn: Unable to squeeze to desired rank!")
tensor = tensor.squeeze()
tensor = tensor.squeeze(0)

return TorchTensor(tensor)

Expand Down

0 comments on commit 179f62b

Please sign in to comment.