From 6a416530da2401188956efd2d96ba0aac01ec3ae Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Sat, 13 Apr 2024 01:03:35 +0000 Subject: [PATCH] #7443: optimize serialization/deserialization of multi-device tensors Previously when using native ttnn multi-chip APIs, the weight tensors that were assigned to be replicated across multi-device required serializing/deserializing tensor replicas. Now only a single tensor should ever be serialized/deserialized. The same serialized tensor on disk should work across different number of multi-devices. To enable this, part of these changes include migrating how we distribute tensor to multi-device into C++. This also allows MultiDeviceHostStorage to alias a single C++ Buffer data via shared_ptr. --- tt_eager/tensor/owned_buffer.hpp | 4 + tt_eager/tensor/serialization.cpp | 101 +++++++++++++----- tt_eager/tensor/serialization.hpp | 4 +- tt_eager/tensor/tensor.cpp | 2 +- tt_eager/tensor/tensor_impl.cpp | 2 +- tt_eager/tensor/tensor_impl.hpp | 2 +- tt_eager/tensor/tensor_utils.cpp | 19 +++- tt_eager/tensor/tensor_utils.hpp | 2 + tt_eager/tensor/types.cpp | 30 ++++++ tt_eager/tensor/types.hpp | 38 +++++-- tt_eager/tt_dnn/op_library/run_operation.cpp | 4 +- .../tt_lib/csrc/tt_lib_bindings_tensor.cpp | 14 ++- .../csrc/tt_lib_bindings_tensor_pytensor.cpp | 10 +- ttnn/cpp/ttnn/multi_device.hpp | 4 +- ttnn/ttnn/multi_device.py | 14 +++ ttnn/ttnn/operations/core.py | 22 +++- 16 files changed, 209 insertions(+), 63 deletions(-) diff --git a/tt_eager/tensor/owned_buffer.hpp b/tt_eager/tensor/owned_buffer.hpp index 3e59961c0fe..076e16f4d2e 100644 --- a/tt_eager/tensor/owned_buffer.hpp +++ b/tt_eager/tensor/owned_buffer.hpp @@ -20,6 +20,10 @@ struct Buffer { shared_vector_(shared_vector), pointer_for_faster_access_(shared_vector->data()), size_(shared_vector->size()) {} + explicit Buffer(const std::shared_ptr>& shared_vector) : + shared_vector_(shared_vector), + pointer_for_faster_access_(shared_vector->data()), + size_(shared_vector->size()) {} const std::size_t size() const { return this->size_; } diff --git a/tt_eager/tensor/serialization.cpp b/tt_eager/tensor/serialization.cpp index c358ae863a0..cb3241f0bac 100644 --- a/tt_eager/tensor/serialization.cpp +++ b/tt_eager/tensor/serialization.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "tensor/borrowed_buffer_functions.hpp" #include "tensor/owned_buffer_functions.hpp" @@ -47,18 +48,33 @@ void dump_borrowed_storage(ofstream& output_stream, const BorrowedStorage& stora void dump_multi_device_host_storage(ofstream& output_stream, const MultiDeviceHostStorage& storage) { std::size_t num_buffers = storage.buffers.size(); output_stream.write(reinterpret_cast(&num_buffers), sizeof(std::size_t)); - for (const auto& buffer : storage.buffers) { + output_stream.write(reinterpret_cast(&storage.strategy), sizeof(DistributedTensorConfig)); + + if (std::holds_alternative(storage.strategy)) { std::visit( [&output_stream](const owned_buffer::Buffer& generic_buffer) { const auto buffer = owned_buffer::get_as(generic_buffer); auto size = buffer.size(); output_stream.write(reinterpret_cast(&size), sizeof(size)); output_stream.write(reinterpret_cast(buffer.begin()), sizeof(T) * size); - }, buffer + }, storage.buffers.at(0) ); - } - for (const auto& shape : storage.shapes) { - output_stream.write(reinterpret_cast(&shape), sizeof(Shape)); + output_stream.write(reinterpret_cast(&storage.shapes.at(0)), sizeof(Shape)); + + } else { + for (const auto& buffer : storage.buffers) { + std::visit( + [&output_stream](const owned_buffer::Buffer& generic_buffer) { + const auto buffer = owned_buffer::get_as(generic_buffer); + auto size = buffer.size(); + output_stream.write(reinterpret_cast(&size), sizeof(size)); + output_stream.write(reinterpret_cast(buffer.begin()), sizeof(T) * size); + }, buffer + ); + } + for (const auto& shape : storage.shapes) { + output_stream.write(reinterpret_cast(&shape), sizeof(Shape)); + } } } @@ -73,29 +89,47 @@ OwnedStorage load_owned_storage(ifstream& input_stream) { } template -MultiDeviceHostStorage load_multi_device_host_storage(ifstream& input_stream) { +MultiDeviceHostStorage load_multi_device_host_storage(ifstream& input_stream, DeviceMesh* device_mesh) { std::size_t num_buffers = 0; + DistributedTensorConfig strategy; input_stream.read(reinterpret_cast(&num_buffers), sizeof(std::size_t)); + input_stream.read(reinterpret_cast(&strategy), sizeof(DistributedTensorConfig)); std::vector buffers; - - for (std::size_t i = 0; i < num_buffers; ++i) { + std::vector shapes; + if (std::holds_alternative(strategy)) { std::size_t size = 0; input_stream.read(reinterpret_cast(&size), sizeof(std::size_t)); - auto buffer = owned_buffer::create(size); - input_stream.read(reinterpret_cast(buffer.begin()), sizeof(T) * size); - - buffers.push_back(std::move(buffer)); - } - std::vector shapes; - for (std::size_t i = 0; i < num_buffers; ++i) { auto shape = Shape{}; + input_stream.read(reinterpret_cast(buffer.begin()), sizeof(T) * size); input_stream.read(reinterpret_cast(&shape), sizeof(Shape)); + buffers.push_back(buffer); shapes.push_back(shape); + + for (std::size_t i = 1; i < device_mesh->num_devices(); ++i) { + buffers.push_back(owned_buffer::Buffer{buffer.get_ptr()}); + shapes.push_back(shape); + } + + } else { + for (std::size_t i = 0; i < num_buffers; ++i) { + std::size_t size = 0; + input_stream.read(reinterpret_cast(&size), sizeof(std::size_t)); + + auto buffer = owned_buffer::create(size); + input_stream.read(reinterpret_cast(buffer.begin()), sizeof(T) * size); + + buffers.push_back(std::move(buffer)); + } + for (std::size_t i = 0; i < num_buffers; ++i) { + auto shape = Shape{}; + input_stream.read(reinterpret_cast(&shape), sizeof(Shape)); + shapes.push_back(shape); + } } - return {buffers, shapes}; + return {strategy, buffers, shapes}; } @@ -121,28 +155,32 @@ OwnedStorage load_owned_storage(ifstream& input_stream, DataType data_type) { } -MultiDeviceHostStorage load_multi_device_host_storage(ifstream& input_stream, DataType data_type) { +MultiDeviceHostStorage load_multi_device_host_storage(ifstream& input_stream, DataType data_type, DeviceMesh *device_mesh) { if (data_type == DataType::UINT32 or data_type == DataType::BFLOAT8_B) { using T = std::uint32_t; - return load_multi_device_host_storage(input_stream); + return load_multi_device_host_storage(input_stream, device_mesh); } else if (data_type == DataType::UINT16) { using T = std::uint16_t; - return load_multi_device_host_storage(input_stream); + return load_multi_device_host_storage(input_stream, device_mesh); } else if (data_type == DataType::FLOAT32) { using T = float; - return load_multi_device_host_storage(input_stream); + return load_multi_device_host_storage(input_stream, device_mesh); } else if (data_type == DataType::BFLOAT16) { using T = bfloat16; - return load_multi_device_host_storage(input_stream); + return load_multi_device_host_storage(input_stream, device_mesh); } else { TT_THROW("Unsupported DataType"); } } - -Storage load_storage(ifstream& input_stream, DataType data_type, StorageType storage_type) { +template +Storage load_storage(ifstream& input_stream, DataType data_type, StorageType storage_type, T device) { if (storage_type == StorageType::MULTI_DEVICE_HOST) { - return load_multi_device_host_storage(input_stream, data_type); + if constexpr (std::is_same_v) { + return load_multi_device_host_storage(input_stream, data_type, device); + } else { + TT_THROW("DeviceMesh is required for MULTI_DEVICE_HOST storage"); + } } else { return load_owned_storage(input_stream, data_type); } @@ -208,7 +246,8 @@ void dump_tensor(const std::string& file_name, const Tensor& tensor) { tensor_to_dump.get_storage()); } -Tensor load_tensor(const std::string& file_name, Device* device) { +template +Tensor load_tensor(const std::string& file_name, T device) { ifstream input_stream(file_name, ios::in | ios::binary); if (not input_stream) { throw std::runtime_error(fmt::format("Cannot open \"{}\"", file_name)); @@ -219,8 +258,10 @@ Tensor load_tensor(const std::string& file_name, Device* device) { if (read_sentinel == detail::SENTINEL_VALUE) { std::uint8_t version_id; input_stream.read(reinterpret_cast(&version_id), sizeof(version_id)); - if (version_id != VERSION_ID) { - throw std::runtime_error(fmt::format("Unsupported version_id: {}", version_id)); + + // Allow only backward compatible versions + if (version_id > VERSION_ID) { + throw std::runtime_error(fmt::format("Serialized tensor with version_id: {}. Loader version: {}", version_id, VERSION_ID)); } auto shape = Shape{}; DataType data_type; @@ -242,7 +283,7 @@ Tensor load_tensor(const std::string& file_name, Device* device) { } } - auto storage = detail::load_storage(input_stream, data_type, storage_type); + auto storage = detail::load_storage(input_stream, data_type, storage_type, device); auto tensor = Tensor(std::move(storage), shape, data_type, layout); if (device != nullptr) { @@ -271,6 +312,10 @@ Tensor load_tensor(const std::string& file_name, Device* device) { } } +// Explicit instantiations +template Tensor load_tensor(const std::string&, Device*); +template Tensor load_tensor(const std::string&, DeviceMesh*); + } // namespace tt_metal } // namespace tt diff --git a/tt_eager/tensor/serialization.hpp b/tt_eager/tensor/serialization.hpp index eb5cd0802e3..1f0138f9e85 100644 --- a/tt_eager/tensor/serialization.hpp +++ b/tt_eager/tensor/serialization.hpp @@ -13,7 +13,9 @@ namespace tt { namespace tt_metal { void dump_tensor(const std::string& file_name, const Tensor& tensor); -Tensor load_tensor(const std::string& file_name, Device* device = nullptr); + +template +Tensor load_tensor(const std::string& file_name, T device = nullptr); } // namespace tt_metalls diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 201064b9573..48be3a46317 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -338,7 +338,7 @@ Tensor Tensor::to(Device *target_device, const MemoryConfig &mem_config) const { Tensor Tensor::to(DeviceMesh *device_mesh, const MemoryConfig &mem_config) const { ZoneScoped; auto all_workers = device_mesh->get_devices(); - auto workers = std::vector(all_workers.begin(), all_workers.begin() + num_buffers_in_tensor(*this)); + auto workers = std::vector(all_workers.begin(), all_workers.end()); TT_FATAL(validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); Tensor multi_device_tensor = Tensor(workers); uint32_t device_tensor_ref_count = multi_device_tensor.tensor_attributes->record_main_thread_ref_count(); diff --git a/tt_eager/tensor/tensor_impl.cpp b/tt_eager/tensor/tensor_impl.cpp index adedfe6530d..163de3b219a 100644 --- a/tt_eager/tensor/tensor_impl.cpp +++ b/tt_eager/tensor/tensor_impl.cpp @@ -204,7 +204,7 @@ Tensor to_layout_bfloat8_b(const Tensor &tensor, Layout target_layout) { output_buffers.push_back(output_uint32_buffer); } return Tensor( - std::move(MultiDeviceHostStorage{output_buffers, storage.shapes}), + std::move(MultiDeviceHostStorage{storage.strategy, output_buffers, storage.shapes}), tensor.get_legacy_shape(), DataType::BFLOAT8_B, target_layout diff --git a/tt_eager/tensor/tensor_impl.hpp b/tt_eager/tensor/tensor_impl.hpp index 54b9d367196..0e1bfd9699a 100644 --- a/tt_eager/tensor/tensor_impl.hpp +++ b/tt_eager/tensor/tensor_impl.hpp @@ -502,7 +502,7 @@ inline Tensor to_layout(const Tensor& tensor, Layout target_layout) { output_buffers.push_back(output_buffer); output_shapes.push_back(storage.shapes[i]); } - return MultiDeviceHostStorage{output_buffers, output_shapes}; + return MultiDeviceHostStorage{storage.strategy, output_buffers, output_shapes}; } else if constexpr (std::is_same_v) { TT_THROW("Device storage isn't supported"); } else if constexpr (std::is_same_v) { diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index 3c30f2e8aa2..2b6130421ea 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -249,8 +249,19 @@ std::vector get_tensors_from_multi_device_storage(const Tensor& multi_de return tensors; } +DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& tensor) { + if (tensor.storage_type() == StorageType::MULTI_DEVICE) { + const auto& tensor_storage = std::get(tensor.get_storage()); + return tensor_storage.strategy; + } + else if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + const auto& tensor_storage = std::get(tensor.get_storage()); + return tensor_storage.strategy; + } + TT_THROW("Tensor is not a multi-device tensor"); +} -Tensor create_multi_device_tensor(const std::vector& tensors, StorageType storage_type) { +Tensor create_multi_device_tensor(const std::vector& tensors, StorageType storage_type, const DistributedTensorConfig& strategy) { if (tensors.empty()) { TT_THROW("Cannot create multi-device tensor with empty tensor list"); } @@ -264,7 +275,7 @@ Tensor create_multi_device_tensor(const std::vector& tensors, StorageTyp shapes.insert({device->id(), tensor.get_legacy_shape()}); } return Tensor{ - MultiDeviceStorage{device_buffers, shapes}, + MultiDeviceStorage{strategy, device_buffers, shapes}, tensors.at(0).get_legacy_shape(), tensors.at(0).get_dtype(), tensors.at(0).get_layout() @@ -277,7 +288,7 @@ Tensor create_multi_device_tensor(const std::vector& tensors, StorageTyp shapes.push_back(tensor.get_legacy_shape()); } return Tensor{ - MultiDeviceHostStorage{owned_buffers, shapes}, + MultiDeviceHostStorage{strategy, owned_buffers, shapes}, tensors.at(0).get_legacy_shape(), tensors.at(0).get_dtype(), tensors.at(0).get_layout() @@ -292,7 +303,7 @@ Tensor transform(const Tensor& tensor, std::function tran std::vector output_tensors(input_tensors.size()); std::transform(input_tensors.begin(), input_tensors.end(), output_tensors.begin(), [&](const auto& device_tensor) { return transform_func(device_tensor); }); - return create_multi_device_tensor(output_tensors, tensor.storage_type()); + return create_multi_device_tensor(output_tensors, tensor.storage_type(), get_distributed_tensor_config_from_tensor(tensor)); } void apply(const Tensor& tensor, std::function callable) { diff --git a/tt_eager/tensor/tensor_utils.hpp b/tt_eager/tensor/tensor_utils.hpp index 1db262810ce..d9150ec5d37 100644 --- a/tt_eager/tensor/tensor_utils.hpp +++ b/tt_eager/tensor/tensor_utils.hpp @@ -132,6 +132,8 @@ inline bool any_tensor_on_multi_device(const std::vector& tensors) return false; } +DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& tensor); + } // namespace tt_metal } // namespace tt diff --git a/tt_eager/tensor/types.cpp b/tt_eager/tensor/types.cpp index 007daf55052..dab116397e2 100644 --- a/tt_eager/tensor/types.cpp +++ b/tt_eager/tensor/types.cpp @@ -12,6 +12,26 @@ namespace tt { namespace tt_metal { +static DistributedTensorConfig create_shard_distributed_tensor_config(const std::unordered_map& metadata) { + return ShardTensor(std::stoi(metadata.at("shard_dim"))); +} +static DistributedTensorConfig create_replicate_distributed_tensor_config(const std::unordered_map& metadata) { + return ReplicateTensor{}; +} + +DistributedTensorConfig get_distributed_tensor_config(const std::unordered_map& metadata) { + if (auto it = metadata.find("strategy"); it != metadata.end()) { + const std::string& strategy = it->second; + if (strategy == "shard") { + return create_shard_distributed_tensor_config(metadata); + + } else if (strategy == "replicate") { + return create_replicate_distributed_tensor_config(metadata); + } + } + TT_THROW("Unsupported DistributedTensorConfig strategy:"); +} + tt::DataFormat datatype_to_dataformat_converter(tt::tt_metal::DataType datatype) { switch (datatype) { @@ -133,6 +153,16 @@ const uint32_t Shape::get_normalized_index(std::int64_t index) const { return normalized_index; } +bool operator==(const ReplicateTensor&, const ReplicateTensor&) { + return true; // All instances are considered equal because there are no data members. +} +bool operator==(const AllGatherTensor&, const AllGatherTensor&) { + return true; // All instances are considered equal because there are no data members. +} +bool operator==(const ShardTensor& lhs, const ShardTensor& rhs) { + return lhs.shard_dimension == rhs.shard_dimension; // Equal if they have the same shard_dimension. +} + bool operator==(const Shape& shape_a, const Shape& shape_b) { if (shape_a.rank() != shape_b.rank()) { return false; diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index 61b59c530d1..8cdf9ca3e4d 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -21,7 +21,7 @@ namespace tt { namespace tt_metal { -static constexpr std::uint8_t VERSION_ID = 2; +static constexpr std::uint8_t VERSION_ID = 3; enum class Layout { ROW_MAJOR = 0, TILE = 1, INVALID = 2 }; @@ -44,6 +44,21 @@ enum class StorageType { MULTI_DEVICE_HOST, // host storage for multi-device context }; +struct AllGatherTensor{}; +bool operator==(const AllGatherTensor&, const AllGatherTensor&); +struct ReplicateTensor {}; +bool operator==(const ReplicateTensor&, const ReplicateTensor&); +struct ShardTensor { + int shard_dimension; + ShardTensor(int shard_dimension) : shard_dimension(shard_dimension) {} +}; +bool operator==(const ShardTensor& lhs, const ShardTensor& rhs); + +// DistributedTensorConfig is a variant of different ways in which a tensor can be distributed across devices. +using DistributedTensorConfig = std::variant; +DistributedTensorConfig get_distributed_tensor_config(const std::unordered_map& metadata); + + tt::DataFormat datatype_to_dataformat_converter(DataType datatype); static constexpr std::size_t MAX_NUM_DIMENSIONS = 8; @@ -318,36 +333,40 @@ struct BorrowedStorage { const auto attribute_values() const { return std::make_tuple(); } }; - struct MultiDeviceHostStorage { +struct MultiDeviceHostStorage { + DistributedTensorConfig strategy; std::vector buffers; std::vector shapes; std::mutex mtx; MultiDeviceHostStorage() = default; - MultiDeviceHostStorage(std::vector buffers_, std::vector shapes_) : buffers(buffers_), shapes(shapes_) {} + MultiDeviceHostStorage(DistributedTensorConfig strategy_, std::vector buffers_, std::vector shapes_) : strategy(strategy_), buffers(buffers_), shapes(shapes_) {} MultiDeviceHostStorage(MultiDeviceHostStorage &&other) { buffers = other.buffers; shapes = other.shapes; } MultiDeviceHostStorage(const MultiDeviceHostStorage &other) { + strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; } MultiDeviceHostStorage &operator=(const MultiDeviceHostStorage &other) { + strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; return *this; } MultiDeviceHostStorage &operator=( MultiDeviceHostStorage &&other) { + strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; return *this; } bool operator == (const MultiDeviceHostStorage& other) { - return this->buffers == other.buffers and this->shapes == other.shapes; + return this->strategy == other.strategy and this->buffers == other.buffers and this->shapes == other.shapes; } static constexpr auto attribute_names = std::make_tuple(); @@ -372,7 +391,7 @@ struct BorrowedStorage { TT_FATAL(device->id() < shapes.size(), "Buffer not found for device " + std::to_string(device->id())); return shapes[device->id()]; } - + uint32_t num_buffers() { std::lock_guard lock(mtx); return buffers.size(); @@ -380,34 +399,39 @@ struct BorrowedStorage { }; struct MultiDeviceStorage { + DistributedTensorConfig strategy; std::unordered_map buffers; std::unordered_map shapes; mutable std::mutex mtx; MultiDeviceStorage() = default; - MultiDeviceStorage(std::unordered_map buffers_, std::unordered_map shapes_) : buffers(buffers_), shapes(shapes_) {} + MultiDeviceStorage(DistributedTensorConfig strategy_, std::unordered_map buffers_, std::unordered_map shapes_) : strategy(strategy_), buffers(buffers_), shapes(shapes_) {} MultiDeviceStorage(MultiDeviceStorage &&other) { + strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; } MultiDeviceStorage(const MultiDeviceStorage &other) { + strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; } MultiDeviceStorage &operator=(const MultiDeviceStorage &other) { + strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; return *this; } MultiDeviceStorage &operator=( MultiDeviceStorage &&other) { + strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; return *this; } bool operator == (const MultiDeviceStorage& other) { - return this->buffers == other.buffers and this->shapes == other.shapes; + return this->strategy == other.strategy and this->buffers == other.buffers and this->shapes == other.shapes; } const MemoryConfig memory_config() const { diff --git a/tt_eager/tt_dnn/op_library/run_operation.cpp b/tt_eager/tt_dnn/op_library/run_operation.cpp index b87621b8479..4f9324c615e 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.cpp +++ b/tt_eager/tt_dnn/op_library/run_operation.cpp @@ -313,7 +313,7 @@ OutputTensors run_multi_device_operation( if constexpr(std::is_same_v>){ multi_device_output_tensors.push_back( Tensor{ - MultiDeviceStorage{buffers, shapes}, + MultiDeviceStorage{get_distributed_tensor_config_from_tensor(input_tensors[0]), buffers, shapes}, per_device_output_tensors[devices[0]][i].get_legacy_shape(), per_device_output_tensors[devices[0]][i].get_dtype(), per_device_output_tensors[devices[0]][i].get_layout() @@ -323,7 +323,7 @@ OutputTensors run_multi_device_operation( else if constexpr(std::is_same_v>){ multi_device_output_tensors.push_back( Tensor{ - MultiDeviceStorage{buffers, shapes}, + MultiDeviceStorage{get_distributed_tensor_config_from_tensor(input_tensors[0]), buffers, shapes}, per_device_output_tensors[devices[0]][i].value().get_legacy_shape(), per_device_output_tensors[devices[0]][i].value().get_dtype(), per_device_output_tensors[devices[0]][i].value().get_layout() diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp index 9f194fc6020..dc06a68844b 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp @@ -772,14 +772,12 @@ void TensorModule(py::module &m_tensor) { )doc" ); - m_tensor.def( - "load_tensor", - &load_tensor, - py::arg("file_name"), - py::arg("device") = nullptr, - R"doc( - Load tensor to file - )doc"); + m_tensor.def("load_tensor", + static_cast(&load_tensor), + py::arg("file_name"), py::arg("device") = nullptr, R"doc(Load tensor to file)doc"); + m_tensor.def("load_tensor", + static_cast(&load_tensor), + py::arg("file_name"), py::arg("device") = nullptr, R"doc(Load tensor to file)doc"); m_tensor.def( "num_cores_to_corerange_set", diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp index cd36062247a..dd5b58311ae 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp @@ -324,7 +324,7 @@ Tensor convert_python_tensor_to_tt_tensor( } } -Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optional data_type) { +Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optional data_type, const std::unordered_map& strategy) { std::vector tt_shards; for (const auto &shard : tensor_shards) { tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor(shard, data_type, false)); @@ -335,7 +335,8 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona host_owned_buffers.push_back(std::get(shard.get_storage()).buffer); host_owned_shapes.push_back(shard.get_legacy_shape()); } - auto storage = MultiDeviceHostStorage(std::move(host_owned_buffers), host_owned_shapes); + auto distributed_tensor_config = get_distributed_tensor_config(strategy); + auto storage = MultiDeviceHostStorage{distributed_tensor_config, std::move(host_owned_buffers), host_owned_shapes}; return Tensor(std::move(storage), tt_shards.at(0).get_legacy_shape(), tt_shards.at(0).get_dtype(), Layout::ROW_MAJOR); } @@ -763,14 +764,15 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona ) )doc") .def( - py::init<>([](const py::object &tensor, std::optional data_type) { + py::init<>([](const py::object &tensor, std::optional data_type, const std::unordered_map& strategy) { if (py::isinstance(tensor)) { - return detail::convert_python_tensors_to_tt_tensors(tensor, data_type); + return detail::convert_python_tensors_to_tt_tensors(tensor, data_type, strategy); } return detail::convert_python_tensor_to_tt_tensor(tensor, data_type); }), py::arg("tensor"), py::arg("data_type") = std::nullopt, + py::arg("strategy") = std::unordered_map(), py::return_value_policy::move, R"doc( +--------------+------------------------+ diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/multi_device.hpp index 410910270a4..562fd6d97f0 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/multi_device.hpp @@ -73,7 +73,7 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) host_owned_buffers.push_back(std::get(shard.get_storage()).buffer); shapes.push_back(shard.get_legacy_shape()); } - auto storage = MultiDeviceHostStorage{std::move(host_owned_buffers), shapes}; + auto storage = MultiDeviceHostStorage{AllGatherTensor(), std::move(host_owned_buffers), shapes}; return Tensor(std::move(storage), tensor_shards.at(0).get_legacy_shape(), tensor_shards.at(0).get_dtype(), tensor_shards.at(0).get_layout()); } else { std::unordered_map shapes; @@ -83,7 +83,7 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) device_buffers.insert({device->id(), std::get(shard.get_storage()).buffer}); shapes.insert({device->id(), shard.get_legacy_shape()}); } - auto storage = MultiDeviceStorage{std::move(device_buffers), shapes}; + auto storage = MultiDeviceStorage{AllGatherTensor(), std::move(device_buffers), shapes}; return Tensor(std::move(storage), tensor_shards.at(0).get_legacy_shape(), tensor_shards.at(0).get_dtype(), tensor_shards.at(0).get_layout()); } } diff --git a/ttnn/ttnn/multi_device.py b/ttnn/ttnn/multi_device.py index d17ccaadaae..1096baa23f8 100644 --- a/ttnn/ttnn/multi_device.py +++ b/ttnn/ttnn/multi_device.py @@ -79,6 +79,9 @@ def __init__(self, device_mesh): def map(self, tensor: torch.tensor): raise NotImplementedError("Subclasses must implement this method") + def config(self): + raise NotImplementedError("Subclasses must implement this method") + class MeshToTensor: """ @@ -103,6 +106,12 @@ def map(self, tensor: torch.tensor) -> Dict[int, ttnn.Tensor]: self.device_id_to_tensor = {i: input_tensor for i, input_tensor in enumerate(sliced_tensors)} return self.device_id_to_tensor + def config(self): + return { + "strategy": "shard", + "shard_dim": f"{self.shard_dim}", + } + class ReplicateTensorToMesh(TensorToMesh): def __init__(self, device_mesh: DeviceMesh): @@ -112,6 +121,11 @@ def map(self, tensor: torch.tensor): self.device_id_to_tensor = {i: tensor for i in range(self.device_mesh.get_num_devices())} return self.device_id_to_tensor + def config(self): + return { + "strategy": "replicate", + } + class ConcatMeshToTensor(MeshToTensor): def __init__(self, device_mesh: DeviceMesh, dim: int): diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 4d15655a702..e668d4d6d06 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -278,7 +278,7 @@ def impl(tensor, dtype, mesh_mapper): if mesh_mapper: device_id_to_shard_ranges = mesh_mapper.map(tensor) shards = list(device_id_to_shard_ranges.values()) - return ttl.tensor.Tensor(shards, dtype) + return ttl.tensor.Tensor(shards, dtype, mesh_mapper.config()) return ttl.tensor.Tensor(tensor, dtype) tensor = ttl.tensor.decorate_external_operation(impl, function_name="(ttnn) from_torch")(tensor, dtype, mesh_mapper) @@ -1012,11 +1012,24 @@ def from_torch_and_dump(tensor, dtype, layout, cache_file_name): ttnn.dump_tensor(cache_file_name, tensor) return tensor - storage_type = f"_multi_device_{device.get_num_devices()}" if mesh_mapper else "" + def dispatch_to_device_on_load(device) -> bool: + return isinstance(device, ttnn.DeviceMesh) + + if isinstance(mesh_mapper, ttnn.ReplicateTensorToMesh): + storage_type = f"_multi_device" if mesh_mapper else "" + elif mesh_mapper: + storage_type = f"_multi_device_{device.get_num_devices()}" + else: + storage_type = "" + cache_file_name = f"{cache_file_name}{storage_type}_dtype_{dtype_name}_layout_{layout_name}.bin" try: - tensor = ttnn.load_tensor(cache_file_name) + tensor = ( + ttnn.load_tensor(cache_file_name, device=device) + if dispatch_to_device_on_load(device) + else ttnn.load_tensor(cache_file_name) + ) if tuple(tensor.shape) != tuple(tensor.shape): logger.warning( f"Cached file {cache_file_name} has shape {tensor.shape}, expected {tensor.shape}, regenerating cache" @@ -1025,7 +1038,8 @@ def from_torch_and_dump(tensor, dtype, layout, cache_file_name): logger.debug(f"Loaded cache for {cache_file_name} of shape {tensor.shape}") except (FileNotFoundError, RuntimeError): tensor = from_torch_and_dump(tensor, dtype, layout, cache_file_name) - tensor = ttnn.to_device(tensor, device, memory_config=memory_config) + if not dispatch_to_device_on_load(device): + tensor = ttnn.to_device(tensor, device, memory_config=memory_config) return tensor