From 1e25da99bded8b2f0598f3f447d57850ec20e63e Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Mon, 18 Nov 2024 21:03:11 +0000 Subject: [PATCH] #0: Clang format ttnn/cpp/pybind11/pytensor.cpp --- ttnn/cpp/pybind11/pytensor.cpp | 743 +++++++++++++++++---------------- 1 file changed, 379 insertions(+), 364 deletions(-) diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 1bd5e3ea56e..c0964d46861 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -2,23 +2,22 @@ // // SPDX-License-Identifier: Apache-2.0 -#include -#include - - +#include #include #include -#include + +#include +#include #include "tensor.hpp" -#include "ttnn/tensor/host_buffer/types.hpp" -#include "ttnn/tensor/tensor_impl.hpp" -#include "ttnn/run_operation.hpp" -#include "tt_metal/tools/profiler/op_profiler.hpp" #include "tt_metal/graph/graph_tracking.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/core.hpp" +#include "ttnn/run_operation.hpp" +#include "ttnn/tensor/host_buffer/types.hpp" #include "ttnn/tensor/tensor.hpp" -#include "tt_metal/host_api.hpp" +#include "ttnn/tensor/tensor_impl.hpp" namespace py = pybind11; @@ -33,21 +32,21 @@ namespace detail { void log_external_operation( std::size_t operation_id, std::size_t device_operation_id, - const operation::ExternalOperation& operation, - const std::vector& input_tensors) { + const operation::ExternalOperation &operation, + const std::vector &input_tensors) { tt::log_debug(tt::LogOp, "Launching External Operation: \"{}\"", operation.get_type_name()); auto attributes = operation.attributes(); if (not attributes.empty()) { tt::log_debug(tt::LogOp, "Attributes:"); - for (auto&& [name, value] : attributes) { + for (auto &&[name, value] : attributes) { tt::log_debug(tt::LogOp, "\t{} = {}", name, value); } } tt::log_debug(tt::LogOp, "Input std::vector:"); for (auto index = 0; index < input_tensors.size(); index++) { - const auto& tensor = input_tensors[index]; + const auto &tensor = input_tensors[index]; tt::log_debug(tt::LogOp, "\t{}: {}", index, tensor); } @@ -58,14 +57,19 @@ void log_external_operation( void log_external_operation( std::size_t operation_id, std::size_t device_operation_id, - const operation::ExternalOperation& operation, - const std::vector& input_tensors) {} + const operation::ExternalOperation &operation, + const std::vector &input_tensors) {} #endif template -Tensor create_owned_tensor(T* data_ptr, size_t num_elements, tt::stl::Span shape, DataType data_type, Layout layout, const std::optional& optional_tile = std::nullopt) -{ +Tensor create_owned_tensor( + T *data_ptr, + size_t num_elements, + tt::stl::Span shape, + DataType data_type, + Layout layout, + const std::optional &optional_tile = std::nullopt) { auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); auto storage = OwnedStorage{std::move(buffer)}; @@ -73,7 +77,10 @@ Tensor create_owned_tensor(T* data_ptr, size_t num_elements, tt::stl::Span optional_data_type = std::nullopt, const std::optional& optional_tile = std::nullopt, bool enable_borrow = true) { + const py::handle &torch_tensor, + std::optional optional_data_type = std::nullopt, + const std::optional &optional_tile = std::nullopt, + bool enable_borrow = true) { py::object torch = py::module_::import("torch"); if (not py::isinstance(torch_tensor, torch.attr("Tensor"))) { TT_THROW("The argument must be of type torch.Tensor!"); @@ -103,10 +110,10 @@ Tensor convert_torch_tensor_to_tt_tensor( data_type = DataType::UINT32; } else if (torch_dtype.equal(torch.attr("int32"))) { data_type = DataType::INT32; - } else if (torch_dtype.equal(torch.attr("int16"))) { + } else if (torch_dtype.equal(torch.attr("int16"))) { // TODO(arakhmati): add DataType::INT16? data_type = DataType::UINT16; - } else if (torch_dtype.equal(torch.attr("uint8"))) { + } else if (torch_dtype.equal(torch.attr("uint8"))) { data_type = DataType::UINT8; } else { TT_THROW("Unsupported DataType: {}", std::string(py::repr(torch_dtype))); @@ -222,45 +229,27 @@ Tensor convert_torch_tensor_to_tt_tensor( auto data_ptr = reinterpret_cast(torch_data_ptr); auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); - auto tensor = Tensor( - OwnedStorage{buffer}, - shape, - DataType::FLOAT32, - Layout::ROW_MAJOR, - optional_tile) - .to(Layout::TILE); + auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) + .to(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp8_tiles( output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tensor.get_tile()); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); return Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - shape, - data_type, - Layout::TILE, - tensor.get_tile()); + std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tensor.get_tile()); } case DataType::BFLOAT4_B: { auto data_ptr = reinterpret_cast(torch_data_ptr); auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); - auto tensor = Tensor( - OwnedStorage{buffer}, - shape, - DataType::FLOAT32, - Layout::ROW_MAJOR, - optional_tile) - .to(Layout::TILE); + auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) + .to(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp4_tiles( output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tensor.get_tile()); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); return Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - shape, - data_type, - Layout::TILE, - tensor.get_tile()); + std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tensor.get_tile()); } default: { TT_THROW("Unsupported DataType: {}", data_type); @@ -270,7 +259,9 @@ Tensor convert_torch_tensor_to_tt_tensor( } Tensor convert_numpy_tensor_to_tt_tensor( - const py::handle &np_tensor, std::optional optional_data_type = std::nullopt, const std::optional& optional_tile = std::nullopt) { + const py::handle &np_tensor, + std::optional optional_data_type = std::nullopt, + const std::optional &optional_tile = std::nullopt) { py::object np = py::module_::import("numpy"); if (not py::isinstance(np_tensor, np.attr("ndarray"))) { TT_THROW("The tensor must be of type numpy.ndarray!"); @@ -397,45 +388,27 @@ Tensor convert_numpy_tensor_to_tt_tensor( auto data_ptr = reinterpret_cast(np_data_ptr); auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); - auto tensor = Tensor( - OwnedStorage{buffer}, - shape, - DataType::FLOAT32, - Layout::ROW_MAJOR, - optional_tile) - .to(Layout::TILE); + auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) + .to(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp8_tiles( output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tensor.get_tile()); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); return Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - shape, - data_type, - Layout::TILE, - tensor.get_tile()); + std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tensor.get_tile()); } case DataType::BFLOAT4_B: { auto data_ptr = reinterpret_cast(np_data_ptr); auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); - auto tensor = Tensor( - OwnedStorage{buffer}, - shape, - DataType::FLOAT32, - Layout::ROW_MAJOR, - optional_tile) - .to(Layout::TILE); + auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) + .to(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp4_tiles( output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tensor.get_tile()); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); return Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - shape, - data_type, - Layout::TILE, - tensor.get_tile()); + std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tensor.get_tile()); } default: { TT_THROW("Unsupported DataType: {}", data_type); @@ -445,8 +418,12 @@ Tensor convert_numpy_tensor_to_tt_tensor( } Tensor convert_python_tensor_to_tt_tensor( - const py::handle &tensor, std::optional optional_data_type = std::nullopt, const std::optional& optional_tile = std::nullopt, bool enable_borrow = true) { - GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_python_tensor_to_tt_tensor", tensor, optional_data_type, enable_borrow); + const py::handle &tensor, + std::optional optional_data_type = std::nullopt, + const std::optional &optional_tile = std::nullopt, + bool enable_borrow = true) { + GraphTracker::instance().track_function_start( + "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor", tensor, optional_data_type, enable_borrow); py::object torch = py::module_::import("torch"); py::object np = py::module_::import("numpy"); if (py::isinstance(tensor, torch.attr("Tensor"))) { @@ -464,8 +441,13 @@ Tensor convert_python_tensor_to_tt_tensor( } } -Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optional data_type, const std::optional tile, const std::unordered_map& strategy) { - GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_python_tensors_to_tt_tensors", tensor_shards, data_type, strategy); +Tensor convert_python_tensors_to_tt_tensors( + py::list tensor_shards, + std::optional data_type, + const std::optional tile, + const std::unordered_map &strategy) { + GraphTracker::instance().track_function_start( + "tt::tt_metal::detail::convert_python_tensors_to_tt_tensors", tensor_shards, data_type, strategy); std::vector tt_shards; for (const auto &shard : tensor_shards) { tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor(shard, data_type, tile, false)); @@ -473,283 +455,304 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona std::vector host_owned_buffers; std::vector host_owned_shapes; for (const auto &shard : tt_shards) { - TT_ASSERT(std::holds_alternative(shard.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(shard.get_storage())); + TT_ASSERT( + std::holds_alternative(shard.get_storage()), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(shard.get_storage())); host_owned_buffers.push_back(std::get(shard.get_storage()).buffer); host_owned_shapes.push_back(shard.shape()); } auto distributed_tensor_config = get_distributed_tensor_config(strategy); auto storage = MultiDeviceHostStorage{distributed_tensor_config, std::move(host_owned_buffers), host_owned_shapes}; - auto output = Tensor(std::move(storage), tt_shards.at(0).get_legacy_shape(), tt_shards.at(0).get_dtype(), tt_shards.at(0).get_layout(), tt_shards.at(0).get_tile()); + auto output = Tensor( + std::move(storage), + tt_shards.at(0).get_legacy_shape(), + tt_shards.at(0).get_dtype(), + tt_shards.at(0).get_layout(), + tt_shards.at(0).get_tile()); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; } - OwnedBuffer create_owned_buffer_from_vector_of_floats(std::vector&& data, DataType data_type) { - switch (data_type) { - case DataType::BFLOAT8_B: { - auto uint32_vector = pack_fp32_vec_as_bfp8_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); - return owned_buffer::create(std::move(uint32_vector)); - } - case DataType::BFLOAT4_B: { - auto uint32_vector = pack_fp32_vec_as_bfp4_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); - return owned_buffer::create(std::move(uint32_vector)); - } - case DataType::FLOAT32: { - return owned_buffer::create(std::move(data)); - } - case DataType::BFLOAT16: { - std::vector<::bfloat16> bfloat16_data(data.size()); - std::transform( - std::begin(data), std::end(data), - std::begin(bfloat16_data), - [](float value) { return ::bfloat16(value); } - ); - return owned_buffer::create<::bfloat16>(std::move(bfloat16_data)); - } - default: { - TT_THROW("Cannot create a host buffer!"); - } +OwnedBuffer create_owned_buffer_from_vector_of_floats(std::vector &&data, DataType data_type) { + switch (data_type) { + case DataType::BFLOAT8_B: { + auto uint32_vector = pack_fp32_vec_as_bfp8_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); + return owned_buffer::create(std::move(uint32_vector)); } - } - - py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor) { - GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_torch_tensor", tt_tensor); - TT_ASSERT(tt_tensor.storage_type() == StorageType::OWNED or tt_tensor.storage_type() == StorageType::BORROWED); - - using namespace pybind11::literals; - py::object torch = py::module_::import("torch"); - auto frombuffer = torch.attr("frombuffer"); - auto buffer = std::visit( - [](auto &&storage) -> std::variant { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return storage.buffer; - } - else if constexpr (std::is_same_v) { - TT_THROW("Device tensor cannot be converted to torch"); - } else if constexpr (std::is_same_v) { - return storage.buffer; - } else if constexpr (std::is_same_v) { - TT_THROW("Tensor with MultiDeviceStorage cannot be converted to torch"); - } else if constexpr (std::is_same_v) { - TT_THROW("Tensor MultiDeviceHostStorage cannot be converted to torch directly. Use composer(..) functionality."); - } else { - raise_unsupported_storage(); - } - }, - tt_tensor.get_storage()); - - const auto& tile = tt_tensor.get_tile(); - auto tt_dtype = tt_tensor.get_dtype(); - if (tt_dtype == DataType::BFLOAT8_B) { - TT_ASSERT(std::holds_alternative(buffer), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(buffer)); - auto uint32_data = std::get>(std::get(buffer)).get(); - auto float_unpacked_data = unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); - auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); - auto float_tensor = Tensor( - OwnedStorage{input_float_buffer}, - tt_tensor.get_shape(), - DataType::FLOAT32, - tt_tensor.get_layout(), - tt_tensor.get_tile()) - .to(Layout::ROW_MAJOR); - auto output_float_data = owned_buffer::get_as(float_tensor).get(); - buffer = owned_buffer::create(std::move(output_float_data)); - tt_dtype = DataType::FLOAT32; + case DataType::BFLOAT4_B: { + auto uint32_vector = pack_fp32_vec_as_bfp4_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); + return owned_buffer::create(std::move(uint32_vector)); } - if (tt_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(std::holds_alternative(buffer), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(buffer)); - auto uint32_data = std::get>(std::get(buffer)).get(); - auto float_unpacked_data = unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); - auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); - auto float_tensor = Tensor( - OwnedStorage{input_float_buffer}, - tt_tensor.get_shape(), - DataType::FLOAT32, - tt_tensor.get_layout(), - tt_tensor.get_tile()) - .to(Layout::ROW_MAJOR); - auto output_float_data = owned_buffer::get_as(float_tensor).get(); - buffer = owned_buffer::create(std::move(output_float_data)); - tt_dtype = DataType::FLOAT32; + case DataType::FLOAT32: { + return owned_buffer::create(std::move(data)); } - - const auto tt_dtype_to_torch_dtype = std::map { - {DataType::UINT8, torch.attr("uint8")}, - {DataType::UINT16, torch.attr("int16")}, // TODO(arakhmati): add DataType::INT16 - {DataType::INT32, torch.attr("int32")}, - {DataType::UINT32, torch.attr("int32")}, // TODO(arakhmati): add DataType::INT32 - {DataType::FLOAT32, torch.attr("float32")}, - {DataType::BFLOAT16, torch.attr("bfloat16")}, - }; - auto torch_dtype = tt_dtype_to_torch_dtype.at(tt_dtype); - - auto shape = tt_tensor.get_legacy_shape(); - auto torch_shape = std::vector(std::begin(shape), std::end(shape)); - auto tensor = [&](){ - if(tt_tensor.volume() == 0) { - auto pytorch_empty = torch.attr("empty"); - auto logical_shape = tt_tensor.get_logical_shape(); - auto view = logical_shape.view(); - std::vector shape_vector(view.begin(), view.end()); - return pytorch_empty(shape_vector, "dtype"_a=torch_dtype); - } - return frombuffer(buffer, "dtype"_a=torch_dtype); - }(); - tensor = tensor.attr("reshape")(torch_shape); - tensor = tensor.attr("contiguous")(); - if (tt_tensor.storage_type() == StorageType::BORROWED) { - tensor = tensor.attr("clone")(); + case DataType::BFLOAT16: { + std::vector<::bfloat16> bfloat16_data(data.size()); + std::transform(std::begin(data), std::end(data), std::begin(bfloat16_data), [](float value) { + return ::bfloat16(value); + }); + return owned_buffer::create<::bfloat16>(std::move(bfloat16_data)); + } + default: { + TT_THROW("Cannot create a host buffer!"); } - GraphTracker::instance().track_function_end(tensor); - return tensor; } +} - py::object convert_tt_tensor_to_numpy_tensor(const Tensor &tt_tensor) { - GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_torch_tensor", tt_tensor); - TT_ASSERT(tt_tensor.storage_type() == StorageType::OWNED or tt_tensor.storage_type() == StorageType::BORROWED); - - using namespace pybind11::literals; - py::object np = py::module_::import("numpy"); - auto frombuffer = np.attr("frombuffer"); - - auto buffer = std::visit( - [](auto &&storage) -> std::variant { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return storage.buffer; - } else if constexpr (std::is_same_v) { - TT_THROW("Device tensor cannot be converted to numpy"); - } else if constexpr (std::is_same_v) { - return storage.buffer; - } else if constexpr (std::is_same_v) { - TT_THROW("Device tensor cannot be converted to numpy"); - } else if constexpr (std::is_same_v) { - TT_THROW("Device tensor cannot be converted to torch"); - } else { - raise_unsupported_storage(); - } - }, - tt_tensor.get_storage()); - - const auto& tile = tt_tensor.get_tile(); - auto tt_dtype = tt_tensor.get_dtype(); - if (tt_dtype == DataType::BFLOAT8_B) { - TT_ASSERT(std::holds_alternative(buffer), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(buffer)); - auto uint32_data = std::get>(std::get(buffer)).get(); - auto float_unpacked_data = - unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); - auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); - auto float_tensor = Tensor( - OwnedStorage{input_float_buffer}, - tt_tensor.get_shape(), - DataType::FLOAT32, - tt_tensor.get_layout(), - tt_tensor.get_tile()) - .to(Layout::ROW_MAJOR); - auto output_float_data = owned_buffer::get_as(float_tensor).get(); - buffer = owned_buffer::create(std::move(output_float_data)); - tt_dtype = DataType::FLOAT32; - } - if (tt_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(std::holds_alternative(buffer), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(buffer)); - auto uint32_data = std::get>(std::get(buffer)).get(); - auto float_unpacked_data = - unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); - auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); - auto float_tensor = Tensor( - OwnedStorage{input_float_buffer}, - tt_tensor.get_shape(), - DataType::FLOAT32, - tt_tensor.get_layout(), - tt_tensor.get_tile()) - .to(Layout::ROW_MAJOR); - auto output_float_data = owned_buffer::get_as(float_tensor).get(); - buffer = owned_buffer::create(std::move(output_float_data)); - tt_dtype = DataType::FLOAT32; - } +py::object convert_tt_tensor_to_torch_tensor(const Tensor &tt_tensor) { + GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_torch_tensor", tt_tensor); + TT_ASSERT(tt_tensor.storage_type() == StorageType::OWNED or tt_tensor.storage_type() == StorageType::BORROWED); - const auto tt_dtype_to_np_dtype = std::map{ - {DataType::UINT8, np.attr("ubyte")}, - {DataType::UINT16, np.attr("int16")}, // TODO(arakhmati): add DataType::INT16 - {DataType::INT32, np.attr("int32")}, - {DataType::UINT32, np.attr("int32")}, // TODO(arakhmati): add DataType::INT32 - {DataType::FLOAT32, np.attr("float32")}, - }; - auto np_dtype = tt_dtype_to_np_dtype.at(tt_dtype); - - auto shape = tt_tensor.get_legacy_shape(); - auto np_shape = std::vector(std::begin(shape), std::end(shape)); - auto tensor = frombuffer(buffer, "dtype"_a = np_dtype); - tensor = tensor.attr("reshape")(np_shape); - tensor = np.attr("ascontiguousarray")(tensor); - GraphTracker::instance().track_function_end(tensor); - return tensor; + using namespace pybind11::literals; + py::object torch = py::module_::import("torch"); + auto frombuffer = torch.attr("frombuffer"); + auto buffer = std::visit( + [](auto &&storage) -> std::variant { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return storage.buffer; + } else if constexpr (std::is_same_v) { + TT_THROW("Device tensor cannot be converted to torch"); + } else if constexpr (std::is_same_v) { + return storage.buffer; + } else if constexpr (std::is_same_v) { + TT_THROW("Tensor with MultiDeviceStorage cannot be converted to torch"); + } else if constexpr (std::is_same_v) { + TT_THROW( + "Tensor MultiDeviceHostStorage cannot be converted to torch directly. Use composer(..) " + "functionality."); + } else { + raise_unsupported_storage(); + } + }, + tt_tensor.get_storage()); + + const auto &tile = tt_tensor.get_tile(); + auto tt_dtype = tt_tensor.get_dtype(); + if (tt_dtype == DataType::BFLOAT8_B) { + TT_ASSERT( + std::holds_alternative(buffer), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(buffer)); + auto uint32_data = std::get>(std::get(buffer)).get(); + auto float_unpacked_data = + unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); + auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); + auto float_tensor = Tensor( + OwnedStorage{input_float_buffer}, + tt_tensor.get_shape(), + DataType::FLOAT32, + tt_tensor.get_layout(), + tt_tensor.get_tile()) + .to(Layout::ROW_MAJOR); + auto output_float_data = owned_buffer::get_as(float_tensor).get(); + buffer = owned_buffer::create(std::move(output_float_data)); + tt_dtype = DataType::FLOAT32; + } + if (tt_dtype == DataType::BFLOAT4_B) { + TT_ASSERT( + std::holds_alternative(buffer), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(buffer)); + auto uint32_data = std::get>(std::get(buffer)).get(); + auto float_unpacked_data = + unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); + auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); + auto float_tensor = Tensor( + OwnedStorage{input_float_buffer}, + tt_tensor.get_shape(), + DataType::FLOAT32, + tt_tensor.get_layout(), + tt_tensor.get_tile()) + .to(Layout::ROW_MAJOR); + auto output_float_data = owned_buffer::get_as(float_tensor).get(); + buffer = owned_buffer::create(std::move(output_float_data)); + tt_dtype = DataType::FLOAT32; } - auto parse_external_operation( - const py::function &external_operation, - const py::args &args, - const py::kwargs &kwargs, - std::optional function_name_override = std::nullopt) { - std::string function_name; - if (function_name_override.has_value()) { - function_name = function_name_override.value(); - } else { - function_name = py::cast(external_operation.attr("__qualname__")); + const auto tt_dtype_to_torch_dtype = std::map{ + {DataType::UINT8, torch.attr("uint8")}, + {DataType::UINT16, torch.attr("int16")}, // TODO(arakhmati): add DataType::INT16 + {DataType::INT32, torch.attr("int32")}, + {DataType::UINT32, torch.attr("int32")}, // TODO(arakhmati): add DataType::INT32 + {DataType::FLOAT32, torch.attr("float32")}, + {DataType::BFLOAT16, torch.attr("bfloat16")}, + }; + auto torch_dtype = tt_dtype_to_torch_dtype.at(tt_dtype); + + auto shape = tt_tensor.get_legacy_shape(); + auto torch_shape = std::vector(std::begin(shape), std::end(shape)); + auto tensor = [&]() { + if (tt_tensor.volume() == 0) { + auto pytorch_empty = torch.attr("empty"); + auto logical_shape = tt_tensor.get_logical_shape(); + auto view = logical_shape.view(); + std::vector shape_vector(view.begin(), view.end()); + return pytorch_empty(shape_vector, "dtype"_a = torch_dtype); } + return frombuffer(buffer, "dtype"_a = torch_dtype); + }(); + tensor = tensor.attr("reshape")(torch_shape); + tensor = tensor.attr("contiguous")(); + if (tt_tensor.storage_type() == StorageType::BORROWED) { + tensor = tensor.attr("clone")(); + } + GraphTracker::instance().track_function_end(tensor); + return tensor; +} + +py::object convert_tt_tensor_to_numpy_tensor(const Tensor &tt_tensor) { + GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_torch_tensor", tt_tensor); + TT_ASSERT(tt_tensor.storage_type() == StorageType::OWNED or tt_tensor.storage_type() == StorageType::BORROWED); - std::vector input_tensors; - tt::stl::reflection::Attributes attributes; - - auto process_name_and_value = [&function_name, &input_tensors, &attributes]( - const auto &name, const auto &value) { - py::object torch = py::module_::import("torch"); - py::object ttnn = py::module_::import("ttnn"); - if (py::isinstance(value)) { - // TODO(arakhmati): figure out how to handle this without causing extra memory usage - // auto tensor = py::cast(value); - // input_tensors.push_back(tensor); - } else if (py::isinstance(value, ttnn.attr("Tensor"))) { - // TODO(arakhmati): figure out how to handle this without causing extra memory usage - // auto tensor = py::cast(value.attr("value")); - // input_tensors.push_back(tensor); - } else if (py::isinstance(value, torch.attr("nn").attr("Module"))) { - // do nothing - } else if (py::isinstance(value, torch.attr("Tensor"))) { - // TODO(arakhmati): figure out how to handle this without causing extra memory usage - // auto tensor = detail::convert_torch_tensor_to_tt_tensor(value); - // input_tensors.push_back(tensor); + using namespace pybind11::literals; + py::object np = py::module_::import("numpy"); + auto frombuffer = np.attr("frombuffer"); + + auto buffer = std::visit( + [](auto &&storage) -> std::variant { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return storage.buffer; + } else if constexpr (std::is_same_v) { + TT_THROW("Device tensor cannot be converted to numpy"); + } else if constexpr (std::is_same_v) { + return storage.buffer; + } else if constexpr (std::is_same_v) { + TT_THROW("Device tensor cannot be converted to numpy"); + } else if constexpr (std::is_same_v) { + TT_THROW("Device tensor cannot be converted to torch"); } else { - // TODO(MO): Exclude tensor data as it is not an attribute - //attributes.push_back({name, fmt::format("{}", value)}); + raise_unsupported_storage(); } - }; + }, + tt_tensor.get_storage()); + + const auto &tile = tt_tensor.get_tile(); + auto tt_dtype = tt_tensor.get_dtype(); + if (tt_dtype == DataType::BFLOAT8_B) { + TT_ASSERT( + std::holds_alternative(buffer), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(buffer)); + auto uint32_data = std::get>(std::get(buffer)).get(); + auto float_unpacked_data = + unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); + auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); + auto float_tensor = Tensor( + OwnedStorage{input_float_buffer}, + tt_tensor.get_shape(), + DataType::FLOAT32, + tt_tensor.get_layout(), + tt_tensor.get_tile()) + .to(Layout::ROW_MAJOR); + auto output_float_data = owned_buffer::get_as(float_tensor).get(); + buffer = owned_buffer::create(std::move(output_float_data)); + tt_dtype = DataType::FLOAT32; + } + if (tt_dtype == DataType::BFLOAT4_B) { + TT_ASSERT( + std::holds_alternative(buffer), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(buffer)); + auto uint32_data = std::get>(std::get(buffer)).get(); + auto float_unpacked_data = + unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); + auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); + auto float_tensor = Tensor( + OwnedStorage{input_float_buffer}, + tt_tensor.get_shape(), + DataType::FLOAT32, + tt_tensor.get_layout(), + tt_tensor.get_tile()) + .to(Layout::ROW_MAJOR); + auto output_float_data = owned_buffer::get_as(float_tensor).get(); + buffer = owned_buffer::create(std::move(output_float_data)); + tt_dtype = DataType::FLOAT32; + } - auto arg_index = 0; - for (const auto &value : args) { - auto name = fmt::format("arg_{}", arg_index++); - process_name_and_value(name, value); - } + const auto tt_dtype_to_np_dtype = std::map{ + {DataType::UINT8, np.attr("ubyte")}, + {DataType::UINT16, np.attr("int16")}, // TODO(arakhmati): add DataType::INT16 + {DataType::INT32, np.attr("int32")}, + {DataType::UINT32, np.attr("int32")}, // TODO(arakhmati): add DataType::INT32 + {DataType::FLOAT32, np.attr("float32")}, + }; + auto np_dtype = tt_dtype_to_np_dtype.at(tt_dtype); + + auto shape = tt_tensor.get_legacy_shape(); + auto np_shape = std::vector(std::begin(shape), std::end(shape)); + auto tensor = frombuffer(buffer, "dtype"_a = np_dtype); + tensor = tensor.attr("reshape")(np_shape); + tensor = np.attr("ascontiguousarray")(tensor); + GraphTracker::instance().track_function_end(tensor); + return tensor; +} + +auto parse_external_operation( + const py::function &external_operation, + const py::args &args, + const py::kwargs &kwargs, + std::optional function_name_override = std::nullopt) { + std::string function_name; + if (function_name_override.has_value()) { + function_name = function_name_override.value(); + } else { + function_name = py::cast(external_operation.attr("__qualname__")); + } + + std::vector input_tensors; + tt::stl::reflection::Attributes attributes; - for (const auto &[name, value] : kwargs) { - process_name_and_value(py::cast(name), value); + auto process_name_and_value = [&function_name, &input_tensors, &attributes](const auto &name, const auto &value) { + py::object torch = py::module_::import("torch"); + py::object ttnn = py::module_::import("ttnn"); + if (py::isinstance(value)) { + // TODO(arakhmati): figure out how to handle this without causing extra memory usage + // auto tensor = py::cast(value); + // input_tensors.push_back(tensor); + } else if (py::isinstance(value, ttnn.attr("Tensor"))) { + // TODO(arakhmati): figure out how to handle this without causing extra memory usage + // auto tensor = py::cast(value.attr("value")); + // input_tensors.push_back(tensor); + } else if (py::isinstance(value, torch.attr("nn").attr("Module"))) { + // do nothing + } else if (py::isinstance(value, torch.attr("Tensor"))) { + // TODO(arakhmati): figure out how to handle this without causing extra memory usage + // auto tensor = detail::convert_torch_tensor_to_tt_tensor(value); + // input_tensors.push_back(tensor); + } else { + // TODO(MO): Exclude tensor data as it is not an attribute + // attributes.push_back({name, fmt::format("{}", value)}); } + }; - auto operation = tt::tt_metal::operation::ExternalOperation{function_name, attributes}; - return std::make_tuple(operation, input_tensors); + auto arg_index = 0; + for (const auto &value : args) { + auto name = fmt::format("arg_{}", arg_index++); + process_name_and_value(name, value); } -} // namespace detail + for (const auto &[name, value] : kwargs) { + process_name_and_value(py::cast(name), value); + } + + auto operation = tt::tt_metal::operation::ExternalOperation{function_name, attributes}; + return std::make_tuple(operation, input_tensors); +} + +} // namespace detail void pytensor_module_types(py::module &m_tensor) { using tt::tt_metal::LegacyShape; - // Tensor constructors that accept device and .to(device) function use keep alive call policy to communicate that Device needs to outlive Tensor. - // This is because when tensors on device are destroyed they need to deallocate their buffers via device. - // keep_alive increases the ref count of the Device object being passed into the constructor and .to() function. - // For additional info see: https://pybind11.readthedocs.io/en/stable/advanced/functions.html#keep-alive + // Tensor constructors that accept device and .to(device) function use keep alive call policy to communicate that + // Device needs to outlive Tensor. This is because when tensors on device are destroyed they need to deallocate + // their buffers via device. keep_alive increases the ref count of the Device object being passed into the + // constructor and .to() function. For additional info see: + // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#keep-alive auto pyTensor = py::class_(m_tensor, "Tensor", R"doc( Class constructor supports tensors of rank 4. @@ -789,12 +792,14 @@ void pytensor_module(py::module &m_tensor) { "decorate_external_operation", [](const py::function &function, std::optional function_name) -> py::function { return py::cpp_function(std::function([function, function_name]( - const py::args &args, const py::kwargs &kwargs) { + const py::args &args, const py::kwargs &kwargs) { ZoneScopedN("TT_DNN_FALLBACK_OP"); uint32_t device_operation_id = ttnn::CoreIDs::instance().fetch_and_increment_device_operation_id(); - auto [operation, input_tensors] = detail::parse_external_operation(function, args, kwargs, function_name); + auto [operation, input_tensors] = + detail::parse_external_operation(function, args, kwargs, function_name); GraphTracker::instance().track_function_start(operation.get_type_name(), args, kwargs); - detail::log_external_operation(ttnn::CoreIDs::instance().get_python_operation_id(), device_operation_id, operation, input_tensors); + detail::log_external_operation( + ttnn::CoreIDs::instance().get_python_operation_id(), device_operation_id, operation, input_tensors); auto output = function(*args, **kwargs); @@ -823,10 +828,10 @@ void pytensor_module(py::module &m_tensor) { pyTensor.def(py::init()) .def( py::init<>([](std::vector &&data, - const std::array &shape, - DataType data_type, - Layout layout, - const std::optional &tile) { + const std::array &shape, + DataType data_type, + Layout layout, + const std::optional &tile) { auto owned_buffer = detail::create_owned_buffer_from_vector_of_floats(std::move(data), data_type); return Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); }), @@ -863,11 +868,11 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( py::init<>([](std::vector &&data, - const std::array &shape, - DataType data_type, - Layout layout, - Device *device, - const std::optional &tile) { + const std::array &shape, + DataType data_type, + Layout layout, + Device *device, + const std::optional &tile) { auto owned_buffer = detail::create_owned_buffer_from_vector_of_floats(std::move(data), data_type); auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); return tensor.to(device, MemoryConfig{}); @@ -916,12 +921,12 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( py::init<>([](std::vector &&data, - const std::array &shape, - DataType data_type, - Layout layout, - Device *device, - const MemoryConfig &memory_config, - const std::optional &tile) { + const std::array &shape, + DataType data_type, + Layout layout, + Device *device, + const MemoryConfig &memory_config, + const std::optional &tile) { auto owned_buffer = detail::create_owned_buffer_from_vector_of_floats(std::move(data), data_type); auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); return tensor.to(device, memory_config); @@ -975,9 +980,9 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( py::init<>([](const py::object &tensor, - std::optional data_type, - const std::unordered_map &strategy, - const std::optional &tile) { + std::optional data_type, + const std::unordered_map &strategy, + const std::optional &tile) { if (py::isinstance(tensor)) { return detail::convert_python_tensors_to_tt_tensors(tensor, data_type, tile, strategy); } @@ -1006,11 +1011,11 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( py::init<>([](const py::object &python_tensor, - std::optional data_type, - Device *device, - Layout layout, - const MemoryConfig &mem_config, - const std::optional &tile) { + std::optional data_type, + Device *device, + Layout layout, + const MemoryConfig &mem_config, + const std::optional &tile) { auto tensor = detail::convert_python_tensor_to_tt_tensor(python_tensor, data_type, tile); auto layout_tensor = tensor.to(layout); return layout_tensor.to(device, mem_config); @@ -1204,7 +1209,7 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( "to", - py::overload_cast(&Tensor::to, py::const_), + py::overload_cast(&Tensor::to, py::const_), py::arg("target_layout").noconvert(), py::arg("mesh_device") = nullptr, R"doc( @@ -1229,9 +1234,11 @@ void pytensor_module(py::module &m_tensor) { .def( "pad", [](const Tensor &self, - const std::array &output_tensor_shape, - const std::array &input_tensor_start, - float pad_value) { return self.pad(output_tensor_shape, ttnn::SimpleShape(input_tensor_start), pad_value); }, + const std::array &output_tensor_shape, + const std::array &input_tensor_start, + float pad_value) { + return self.pad(output_tensor_shape, ttnn::SimpleShape(input_tensor_start), pad_value); + }, R"doc( Pad TT Tensor with given pad value ``arg2``. @@ -1301,8 +1308,8 @@ void pytensor_module(py::module &m_tensor) { .def( "unpad", [](const Tensor &self, - const std::array &output_tensor_start, - const std::array &output_tensor_end) { + const std::array &output_tensor_start, + const std::array &output_tensor_end) { return self.unpad(ttnn::SimpleShape(output_tensor_start), ttnn::SimpleShape(output_tensor_end)); }, R"doc( @@ -1506,7 +1513,9 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( // TODO: Rename to physical_volume - "volume", [](const Tensor &self) { return self.volume(); }, R"doc( + "volume", + [](const Tensor &self) { return self.volume(); }, + R"doc( Get the volume of the tensor. .. code-block:: python @@ -1516,7 +1525,9 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( // TODO: Rename to volume - "logical_volume", [](const Tensor &self) { return self.get_logical_volume(); }, R"doc( + "logical_volume", + [](const Tensor &self) { return self.get_logical_volume(); }, + R"doc( Get the logical volume of the tensor. .. code-block:: python @@ -1533,10 +1544,10 @@ void pytensor_module(py::module &m_tensor) { storage_type = tt_tensor.storage_type() )doc") - .def( - "device", - [](const Tensor &self) { return self.device(); }, - R"doc( + .def( + "device", + [](const Tensor &self) { return self.device(); }, + R"doc( Get the device of the tensor. .. code-block:: python @@ -1544,11 +1555,11 @@ void pytensor_module(py::module &m_tensor) { device = tt_tensor.device() )doc", - py::return_value_policy::reference) - .def( - "devices", - [](const Tensor &self) { return self.get_workers(); }, - R"doc( + py::return_value_policy::reference) + .def( + "devices", + [](const Tensor &self) { return self.get_workers(); }, + R"doc( Get devices tensor is mapped on to. .. code-block:: python @@ -1556,7 +1567,7 @@ void pytensor_module(py::module &m_tensor) { devices = tt_tensor.devices() )doc", - py::return_value_policy::reference) + py::return_value_policy::reference) .def( "to_torch", [](const Tensor &self) -> py::object { return detail::convert_tt_tensor_to_torch_tensor(self); }, @@ -1703,7 +1714,9 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( "reshape", - [](Tensor &self, int N, int C, int H, int W) { return self.reshape(infer_dims_for_reshape(self, ttnn::SmallVector{N, C, H, W})); }, + [](Tensor &self, int N, int C, int H, int W) { + return self.reshape(infer_dims_for_reshape(self, ttnn::SmallVector{N, C, H, W})); + }, R"doc( Reshapes TT tensor @@ -1723,7 +1736,9 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( "reshape", - [](Tensor &self, const ttnn::SmallVector &shape) -> Tensor { return self.reshape(infer_dims_for_reshape(self, shape)); }, + [](Tensor &self, const ttnn::SmallVector &shape) -> Tensor { + return self.reshape(infer_dims_for_reshape(self, shape)); + }, R"doc( Reshapes TT tensor