Skip to content

Commit

Permalink
tenstorrent#13084: fix return vector optional tensor with launch_op (t…
Browse files Browse the repository at this point in the history
…enstorrent#13085)

tenstorrent#13084: fix return vector optional tensor with launch_op
  • Loading branch information
hschoi4448 authored Sep 30, 2024
1 parent dd3810e commit 517ff55
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 60 deletions.
60 changes: 44 additions & 16 deletions tests/ttnn/unit_tests/operations/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from tests.ttnn.utils_for_testing import assert_equal


@pytest.mark.parametrize("height", [64])
Expand All @@ -23,7 +23,7 @@ def test_example(device, height, width):
output_tensor = ttnn.prim.example(input_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, 0.99)
assert_equal(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("height", [64])
Expand All @@ -38,38 +38,66 @@ def test_composite_example(device, height, width):
output_tensor = ttnn.composite_example(input_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, 0.99)
assert_equal(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("height", [64])
@pytest.mark.parametrize("width", [128])
def test_example_multiple_return(device, height, width):
@pytest.mark.parametrize("return_outputs", [[False, True], [True, False], [True, True]])
def test_example_multiple_return(device, height, width, return_outputs):
torch.manual_seed(0)

return_output1, return_output2 = return_outputs

# run torch
torch_input_tensor = torch.rand((height, width), dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor

# run TT
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output1, output2 = ttnn.prim.example_multiple_return(input_tensor)
output_tensor1 = ttnn.to_torch(output1)
output_tensor2 = ttnn.to_torch(output2)
output1, output2 = ttnn.prim.example_multiple_return(
input_tensor, return_output1=return_output1, return_output2=return_output2
)

if return_output1:
output_tensor1 = ttnn.to_torch(output1)
assert_equal(torch_output_tensor, output_tensor1)
else:
assert output1 == None

assert_with_pcc(torch_output_tensor, output_tensor1, 0.99)
assert_with_pcc(torch_output_tensor, output_tensor2, 0.99)
if return_output2:
output_tensor2 = ttnn.to_torch(output2)
assert_equal(torch_output_tensor, output_tensor2)
else:
assert output2 == None


@pytest.mark.parametrize("height", [64])
@pytest.mark.parametrize("width", [128])
def test_composite_example_multiple_return(device, height, width):
@pytest.mark.parametrize("return_outputs", [[False, True], [True, False], [True, True]])
def test_composite_example_multiple_return(device, height, width, return_outputs):
torch.manual_seed(0)

return_output1, return_output2 = return_outputs

# run torch
torch_input_tensor = torch.rand((height, width), dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor

# run TT
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output1, output2 = ttnn.composite_example_multiple_return(input_tensor)
output_tensor1 = ttnn.to_torch(output1)
output_tensor2 = ttnn.to_torch(output2)

assert_with_pcc(torch_output_tensor, output_tensor1, 0.99)
assert_with_pcc(torch_output_tensor, output_tensor2, 0.99)
output1, output2 = ttnn.composite_example_multiple_return(
input_tensor, return_output1=return_output1, return_output2=return_output2
)

if return_output1:
output_tensor1 = ttnn.to_torch(output1)
assert_equal(torch_output_tensor, output_tensor1)
else:
assert output1 == None

if return_output2:
output_tensor2 = ttnn.to_torch(output2)
assert_equal(torch_output_tensor, output_tensor2)
else:
assert output2 == None
11 changes: 7 additions & 4 deletions ttnn/cpp/ttnn/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,15 @@ struct registered_operation_t {
return output_tensors;
} else if constexpr (std::is_same_v<execute_on_worker_thread_return_t, OptionalTensors>) {
// convert tensor to optional tensor
std::vector<std::optional<Tensor>> ret;

auto size = output_tensors.size();
ret.reserve(size);
std::vector<std::optional<Tensor>> ret(size);

auto return_flags = operation_t::create_async_return_flag(std::forward<decltype(args)>(args)...);

for (uint32_t i = 0 ; i < size; i++) {
ret.push_back(output_tensors.at(i));
if (return_flags.at(i)) {
ret[i] = output_tensors.at(i);
}
}
return ret;
} else if constexpr (detail::is_homogenous_tuple<execute_on_worker_thread_return_t, Tensor>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,17 @@ ExampleMultipleReturnDeviceOperation::program_factory_t ExampleMultipleReturnDev
}

void ExampleMultipleReturnDeviceOperation::validate_on_program_cache_miss(
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {}
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {
validate_on_program_cache_hit(attributes, tensor_args);
}

void ExampleMultipleReturnDeviceOperation::validate_on_program_cache_hit(
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {}
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {
TT_FATAL(attributes.return_output1 || attributes.return_output2,
"At least one output must be returned. return_output1 = {}, return_output2 = {} ",
attributes.return_output1,
attributes.return_output2);
}

ExampleMultipleReturnDeviceOperation::shape_return_value_t ExampleMultipleReturnDeviceOperation::compute_output_shapes(
const operation_attributes_t&, const tensor_args_t& tensor_args) {
Expand All @@ -29,6 +36,9 @@ ExampleMultipleReturnDeviceOperation::tensor_return_value_t ExampleMultipleRetur
auto output1_shape = output1_shape_opt.value();
auto output2_shape = output2_shape_opt.value();

auto return_output1 = operation_attributes.return_output1;
auto return_output2 = operation_attributes.return_output2;

const auto& input_tensor = tensor_args.input_tensor;
auto output1 = create_device_tensor(
output1_shape,
Expand All @@ -42,14 +52,20 @@ ExampleMultipleReturnDeviceOperation::tensor_return_value_t ExampleMultipleRetur
input_tensor.tensor_attributes->layout,
input_tensor.device());

return {output1, output2};

std::vector<std::optional<Tensor>> ret(2);

if (return_output1) ret[0] = output1;
if (return_output2) ret[1] = output2;

return ret;
}


std::tuple<ExampleMultipleReturnDeviceOperation::operation_attributes_t, ExampleMultipleReturnDeviceOperation::tensor_args_t>
ExampleMultipleReturnDeviceOperation::invoke(const Tensor& input_tensor) {
ExampleMultipleReturnDeviceOperation::invoke(const Tensor& input_tensor, bool return_output1, bool return_output2) {
return {
operation_attributes_t{true, 42},
operation_attributes_t{true, 42, return_output1, return_output2},
tensor_args_t{input_tensor}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ namespace ttnn::operations::examples {
struct ExampleMultipleReturnDeviceOperation {
// Define the operation attributes. This is it to store all variables needed by operations that aren't tensors
struct operation_attributes_t {
bool attribute;
int some_other_attribute;
bool attribute = true;
int some_other_attribute = 42;
uint32_t return_output1 = true;
uint32_t return_output2 = true;
};

// Define the tensor arguments. This is it to store all tensors passed in and/or out of the operation
Expand Down Expand Up @@ -106,7 +108,7 @@ struct ExampleMultipleReturnDeviceOperation {
// The user will be able to call the operation using `tensor_return_value_t output = ttnn::prim::example(input_tensor)` after the op is registered
// Keep in mind that the the overload with `queue_id` argument will be added automatically for primitive operations
// So, the user can also call this operation using `tensor_return_value_t output = ttnn::prim::example(queue_id, input_tensor)`
static std::tuple<operation_attributes_t, tensor_args_t> invoke(const Tensor& input_tensor);
static std::tuple<operation_attributes_t, tensor_args_t> invoke(const Tensor& input_tensor, bool return_output1, bool return_output2);

// Optional methods

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ void kernel_main() {
constexpr bool dst_is_dram1 = get_compile_time_arg_val(1) == 1;
constexpr bool dst_is_dram2 = get_compile_time_arg_val(2) == 1;

#ifdef OUT_SHARDED
cb_wait_front(cb_id_out, num_tiles);
#else

// single-tile ublocks
constexpr uint32_t onetile = 1;
const uint32_t tile_bytes = get_tile_size(cb_id_out);
Expand All @@ -36,23 +32,21 @@ void kernel_main() {
.data_format = data_format
};

#ifdef BACKWARDS
uint32_t end_id = start_id - num_tiles;
for (uint32_t i = start_id; i != end_id; -- i) {
#else
uint32_t end_id = start_id + num_tiles;
for (uint32_t i = start_id; i < end_id; ++ i) {
#endif
cb_wait_front(cb_id_out, onetile);

uint32_t l1_read_addr = get_read_ptr(cb_id_out);
noc_async_write_tile(i, s1, l1_read_addr);
noc_async_write_barrier();
if (dst_addr1 != 0) {
noc_async_write_tile(i, s1, l1_read_addr);
noc_async_write_barrier();
}

noc_async_write_tile(i, s2, l1_read_addr);
noc_async_write_barrier();
if (dst_addr2 != 0) {
noc_async_write_tile(i, s2, l1_read_addr);
noc_async_write_barrier();
}

cb_pop_front(cb_id_out, onetile);
}
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ ExampleMultipleReturnDeviceOperation::SingleCore::cached_program_t ExampleMultip

const auto& input_tensor = tensor_args.input_tensor;

auto output_tensor1 = tensor_return_value.at(0).value();
auto output_tensor2 = tensor_return_value.at(1).value();
auto output_tensor1 = tensor_return_value.at(0);
auto output_tensor2 = tensor_return_value.at(1);

auto src_buffer = input_tensor.buffer();
auto dst_buffer1 = output_tensor1.buffer();
auto dst_buffer2 = output_tensor2.buffer();

tt::tt_metal::Program program{};

tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype());
uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format);
tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor1.get_dtype());

auto output_dtype = output_tensor1.has_value() ? output_tensor1.value().get_dtype() : output_tensor2.value().get_dtype();
tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_dtype);
uint32_t single_tile_size_output = tt::tt_metal::detail::TileSize(cb_data_format_output);

uint32_t num_tiles = input_tensor.volume() / tt::constants::TILE_HW;
Expand Down Expand Up @@ -56,8 +56,9 @@ ExampleMultipleReturnDeviceOperation::SingleCore::cached_program_t ExampleMultip

bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;
std::vector<uint32_t> reader_compile_time_args = {(uint32_t)src_is_dram};
bool dst_is_dram1 = dst_buffer1->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;
bool dst_is_dram2 = dst_buffer2->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;

bool dst_is_dram1 = output_tensor1.has_value() ? (output_tensor1.value().buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0) : 0;
bool dst_is_dram2 = output_tensor2.has_value() ? (output_tensor2.value().buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0) : 0;
std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram1, (std::uint32_t)dst_is_dram2};

tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
Expand Down Expand Up @@ -117,8 +118,10 @@ ExampleMultipleReturnDeviceOperation::SingleCore::cached_program_t ExampleMultip
tt::tt_metal::SetRuntimeArgs(
program, unary_reader_kernel_id, core, {src_buffer->address(), num_tiles_per_core, num_tiles_written});

auto dst_buffer1_address = output_tensor1.has_value() ? output_tensor1.value().buffer()->address() : 0;
auto dst_buffer2_address = output_tensor2.has_value() ? output_tensor2.value().buffer()->address() : 0;
tt::tt_metal::SetRuntimeArgs(
program, unary_writer_kernel_id, core, {dst_buffer1->address(), dst_buffer2->address(), num_tiles_per_core, num_tiles_written});
program, unary_writer_kernel_id, core, {dst_buffer1_address, dst_buffer2_address, num_tiles_per_core, num_tiles_written});
num_tiles_written += num_tiles_per_core;
}

Expand All @@ -137,11 +140,12 @@ void ExampleMultipleReturnDeviceOperation::SingleCore::override_runtime_argument
auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id;

const auto& input_tensor = tensor_args.input_tensor;
auto output_tensor1 = tensor_return_value.at(0).value();
auto output_tensor2 = tensor_return_value.at(0);
auto output_tensor1 = tensor_return_value.at(0);
auto output_tensor2 = tensor_return_value.at(1);

auto src_buffer = input_tensor.buffer();
auto dst_buffer = output_tensor1.buffer();
auto dst_buffer1 = output_tensor1.has_value() ? output_tensor1.value().buffer() : 0;
auto dst_buffer2 = output_tensor2.has_value() ? output_tensor2.value().buffer() : 0;

{
auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0});
Expand All @@ -150,9 +154,11 @@ void ExampleMultipleReturnDeviceOperation::SingleCore::override_runtime_argument

{
auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0});
runtime_args[0] = dst_buffer->address();
if (output_tensor1.has_value()) {
runtime_args[0] = dst_buffer1->address();
}
if (output_tensor2.has_value()) {
// do something
runtime_args[1] = dst_buffer2->address();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

namespace ttnn::operations::examples {

std::vector<std::optional<Tensor>> CompositeExampleMutipleReturnOperation::invoke(const Tensor& input_tensor) {
return prim::example_multiple_return(input_tensor);
std::vector<std::optional<Tensor>> CompositeExampleMutipleReturnOperation::invoke(const Tensor& input_tensor, bool return_output1, bool return_output2) {
return prim::example_multiple_return(input_tensor, return_output1, return_output2);
}

std::vector<Tensor> CompositeExampleMutipleReturnOperation::create_async_output_tensors(
Expand All @@ -19,4 +19,9 @@ std::vector<Tensor> CompositeExampleMutipleReturnOperation::create_async_output_
Tensor(operation::get_workers_for_op_output({input_tensor}))};
}

std::vector<bool> CompositeExampleMutipleReturnOperation::create_async_return_flag(const Tensor& input_tensor, bool return_output1, bool return_output2) {

return {return_output1, return_output2};
}

} // namespace ttnn::operations::examples
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ namespace ttnn::operations::examples {
struct CompositeExampleMutipleReturnOperation {
// The user will be able to call this method as `Tensor output = ttnn::composite_example(input_tensor)` after the op
// is registered
static std::vector<std::optional<Tensor>> invoke(const Tensor& input_tensor);
static std::vector<std::optional<Tensor>> invoke(const Tensor& input_tensor, bool return_output1, bool return_output2);

static std::vector<Tensor> create_async_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs);

// The parameters of this function must be identical to those of invoke.
static std::vector<bool> create_async_return_flag(
const Tensor& input_tensor, bool return_output1, bool return_output2
);
};

} // namespace ttnn::operations::examples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@ void bind_example_multiple_return_operation(py::module& module) {
module,
ttnn::prim::example_multiple_return,
R"doc(example_multiple_return(input_tensor: ttnn.Tensor) -> std::vector<std::optional<ttnn.Tensor>>)doc",
ttnn::pybind_arguments_t{py::arg("input_tensor")});
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("return_output1"),
py::arg("return_output2")
});

bind_registered_operation(
module,
ttnn::composite_example_multiple_return,
R"doc(composite_example_multiple_return(input_tensor: ttnn.Tensor) -> std::vector<std::optional<Tensor>>)doc",
ttnn::pybind_arguments_t{py::arg("input_tensor")});
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("return_output1"),
py::arg("return_output2")
});
}

} // namespace ttnn::operations::examples
6 changes: 6 additions & 0 deletions ttnn/cpp/ttnn/run_operation_inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ void launch_op(
if (!output_tensor || !local_tensor) {
continue;
}

// The return type is vector<optional<Tensor>>, and this refers to the case where the i-th value is nullopt.
if (output_tensor->tensor_attributes.use_count() != 0 && local_tensor->tensor_attributes.use_count() == 0) {
continue;
}

if (std::holds_alternative<OwnedStorage>(local_tensor->tensor_attributes->storage)) {
TT_ASSERT(
output_tensor->tensor_attributes->dynamic_storage,
Expand Down

0 comments on commit 517ff55

Please sign in to comment.