Skip to content

Commit

Permalink
#7443: optimize serialization/deserialization of multi-device tensors
Browse files Browse the repository at this point in the history
Previously when using native ttnn multi-chip APIs, the weight tensors
that were assigned to be replicated across multi-device required
serializing/deserializing tensor replicas. Now only a single tensor
should ever be serialized/deserialized. The same serialized tensor on
disk should work across different number of multi-devices.

To enable this, part of these changes include migrating how we
distribute tensor to multi-device into C++. This also allows
MultiDeviceHostStorage to alias a single C++ Buffer data via shared_ptr.
  • Loading branch information
cfjchu committed Apr 13, 2024
1 parent 00cb77e commit 6a41653
Show file tree
Hide file tree
Showing 16 changed files with 209 additions and 63 deletions.
4 changes: 4 additions & 0 deletions tt_eager/tensor/owned_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ struct Buffer {
shared_vector_(shared_vector),
pointer_for_faster_access_(shared_vector->data()),
size_(shared_vector->size()) {}
explicit Buffer(const std::shared_ptr<std::vector<T>>& shared_vector) :
shared_vector_(shared_vector),
pointer_for_faster_access_(shared_vector->data()),
size_(shared_vector->size()) {}

const std::size_t size() const { return this->size_; }

Expand Down
101 changes: 73 additions & 28 deletions tt_eager/tensor/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <fstream>
#include <iostream>
#include <string>
#include <type_traits>

#include "tensor/borrowed_buffer_functions.hpp"
#include "tensor/owned_buffer_functions.hpp"
Expand Down Expand Up @@ -47,18 +48,33 @@ void dump_borrowed_storage(ofstream& output_stream, const BorrowedStorage& stora
void dump_multi_device_host_storage(ofstream& output_stream, const MultiDeviceHostStorage& storage) {
std::size_t num_buffers = storage.buffers.size();
output_stream.write(reinterpret_cast<const char*>(&num_buffers), sizeof(std::size_t));
for (const auto& buffer : storage.buffers) {
output_stream.write(reinterpret_cast<const char*>(&storage.strategy), sizeof(DistributedTensorConfig));

if (std::holds_alternative<ReplicateTensor>(storage.strategy)) {
std::visit(
[&output_stream]<typename T>(const owned_buffer::Buffer<T>& generic_buffer) {
const auto buffer = owned_buffer::get_as<T>(generic_buffer);
auto size = buffer.size();
output_stream.write(reinterpret_cast<const char*>(&size), sizeof(size));
output_stream.write(reinterpret_cast<const char*>(buffer.begin()), sizeof(T) * size);
}, buffer
}, storage.buffers.at(0)
);
}
for (const auto& shape : storage.shapes) {
output_stream.write(reinterpret_cast<const char*>(&shape), sizeof(Shape));
output_stream.write(reinterpret_cast<const char*>(&storage.shapes.at(0)), sizeof(Shape));

} else {
for (const auto& buffer : storage.buffers) {
std::visit(
[&output_stream]<typename T>(const owned_buffer::Buffer<T>& generic_buffer) {
const auto buffer = owned_buffer::get_as<T>(generic_buffer);
auto size = buffer.size();
output_stream.write(reinterpret_cast<const char*>(&size), sizeof(size));
output_stream.write(reinterpret_cast<const char*>(buffer.begin()), sizeof(T) * size);
}, buffer
);
}
for (const auto& shape : storage.shapes) {
output_stream.write(reinterpret_cast<const char*>(&shape), sizeof(Shape));
}
}
}

Expand All @@ -73,29 +89,47 @@ OwnedStorage load_owned_storage(ifstream& input_stream) {
}

template<typename T>
MultiDeviceHostStorage load_multi_device_host_storage(ifstream& input_stream) {
MultiDeviceHostStorage load_multi_device_host_storage(ifstream& input_stream, DeviceMesh* device_mesh) {
std::size_t num_buffers = 0;
DistributedTensorConfig strategy;
input_stream.read(reinterpret_cast<char*>(&num_buffers), sizeof(std::size_t));
input_stream.read(reinterpret_cast<char*>(&strategy), sizeof(DistributedTensorConfig));

std::vector<OwnedBuffer> buffers;

for (std::size_t i = 0; i < num_buffers; ++i) {
std::vector<Shape> shapes;
if (std::holds_alternative<ReplicateTensor>(strategy)) {
std::size_t size = 0;
input_stream.read(reinterpret_cast<char*>(&size), sizeof(std::size_t));

auto buffer = owned_buffer::create<T>(size);
input_stream.read(reinterpret_cast<char*>(buffer.begin()), sizeof(T) * size);

buffers.push_back(std::move(buffer));
}
std::vector<Shape> shapes;
for (std::size_t i = 0; i < num_buffers; ++i) {
auto shape = Shape{};
input_stream.read(reinterpret_cast<char*>(buffer.begin()), sizeof(T) * size);
input_stream.read(reinterpret_cast<char*>(&shape), sizeof(Shape));
buffers.push_back(buffer);
shapes.push_back(shape);

for (std::size_t i = 1; i < device_mesh->num_devices(); ++i) {
buffers.push_back(owned_buffer::Buffer<T>{buffer.get_ptr()});
shapes.push_back(shape);
}

} else {
for (std::size_t i = 0; i < num_buffers; ++i) {
std::size_t size = 0;
input_stream.read(reinterpret_cast<char*>(&size), sizeof(std::size_t));

auto buffer = owned_buffer::create<T>(size);
input_stream.read(reinterpret_cast<char*>(buffer.begin()), sizeof(T) * size);

buffers.push_back(std::move(buffer));
}
for (std::size_t i = 0; i < num_buffers; ++i) {
auto shape = Shape{};
input_stream.read(reinterpret_cast<char*>(&shape), sizeof(Shape));
shapes.push_back(shape);
}
}

return {buffers, shapes};
return {strategy, buffers, shapes};
}


Expand All @@ -121,28 +155,32 @@ OwnedStorage load_owned_storage(ifstream& input_stream, DataType data_type) {
}


MultiDeviceHostStorage load_multi_device_host_storage(ifstream& input_stream, DataType data_type) {
MultiDeviceHostStorage load_multi_device_host_storage(ifstream& input_stream, DataType data_type, DeviceMesh *device_mesh) {
if (data_type == DataType::UINT32 or data_type == DataType::BFLOAT8_B) {
using T = std::uint32_t;
return load_multi_device_host_storage<T>(input_stream);
return load_multi_device_host_storage<T>(input_stream, device_mesh);
} else if (data_type == DataType::UINT16) {
using T = std::uint16_t;
return load_multi_device_host_storage<T>(input_stream);
return load_multi_device_host_storage<T>(input_stream, device_mesh);
} else if (data_type == DataType::FLOAT32) {
using T = float;
return load_multi_device_host_storage<T>(input_stream);
return load_multi_device_host_storage<T>(input_stream, device_mesh);
} else if (data_type == DataType::BFLOAT16) {
using T = bfloat16;
return load_multi_device_host_storage<T>(input_stream);
return load_multi_device_host_storage<T>(input_stream, device_mesh);
} else {
TT_THROW("Unsupported DataType");
}
}


Storage load_storage(ifstream& input_stream, DataType data_type, StorageType storage_type) {
template <typename T>
Storage load_storage(ifstream& input_stream, DataType data_type, StorageType storage_type, T device) {
if (storage_type == StorageType::MULTI_DEVICE_HOST) {
return load_multi_device_host_storage(input_stream, data_type);
if constexpr (std::is_same_v<T, DeviceMesh*>) {
return load_multi_device_host_storage(input_stream, data_type, device);
} else {
TT_THROW("DeviceMesh is required for MULTI_DEVICE_HOST storage");
}
} else {
return load_owned_storage(input_stream, data_type);
}
Expand Down Expand Up @@ -208,7 +246,8 @@ void dump_tensor(const std::string& file_name, const Tensor& tensor) {
tensor_to_dump.get_storage());
}

Tensor load_tensor(const std::string& file_name, Device* device) {
template<typename T>
Tensor load_tensor(const std::string& file_name, T device) {
ifstream input_stream(file_name, ios::in | ios::binary);
if (not input_stream) {
throw std::runtime_error(fmt::format("Cannot open \"{}\"", file_name));
Expand All @@ -219,8 +258,10 @@ Tensor load_tensor(const std::string& file_name, Device* device) {
if (read_sentinel == detail::SENTINEL_VALUE) {
std::uint8_t version_id;
input_stream.read(reinterpret_cast<char*>(&version_id), sizeof(version_id));
if (version_id != VERSION_ID) {
throw std::runtime_error(fmt::format("Unsupported version_id: {}", version_id));

// Allow only backward compatible versions
if (version_id > VERSION_ID) {
throw std::runtime_error(fmt::format("Serialized tensor with version_id: {}. Loader version: {}", version_id, VERSION_ID));
}
auto shape = Shape{};
DataType data_type;
Expand All @@ -242,7 +283,7 @@ Tensor load_tensor(const std::string& file_name, Device* device) {
}
}

auto storage = detail::load_storage(input_stream, data_type, storage_type);
auto storage = detail::load_storage(input_stream, data_type, storage_type, device);

auto tensor = Tensor(std::move(storage), shape, data_type, layout);
if (device != nullptr) {
Expand Down Expand Up @@ -271,6 +312,10 @@ Tensor load_tensor(const std::string& file_name, Device* device) {
}
}

// Explicit instantiations
template Tensor load_tensor<Device*>(const std::string&, Device*);
template Tensor load_tensor<DeviceMesh*>(const std::string&, DeviceMesh*);

} // namespace tt_metal

} // namespace tt
4 changes: 3 additions & 1 deletion tt_eager/tensor/serialization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ namespace tt {
namespace tt_metal {

void dump_tensor(const std::string& file_name, const Tensor& tensor);
Tensor load_tensor(const std::string& file_name, Device* device = nullptr);

template <typename T>
Tensor load_tensor(const std::string& file_name, T device = nullptr);

} // namespace tt_metalls

Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ Tensor Tensor::to(Device *target_device, const MemoryConfig &mem_config) const {
Tensor Tensor::to(DeviceMesh *device_mesh, const MemoryConfig &mem_config) const {
ZoneScoped;
auto all_workers = device_mesh->get_devices();
auto workers = std::vector<Device*>(all_workers.begin(), all_workers.begin() + num_buffers_in_tensor(*this));
auto workers = std::vector<Device*>(all_workers.begin(), all_workers.end());
TT_FATAL(validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)");
Tensor multi_device_tensor = Tensor(workers);
uint32_t device_tensor_ref_count = multi_device_tensor.tensor_attributes->record_main_thread_ref_count();
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ Tensor to_layout_bfloat8_b(const Tensor &tensor, Layout target_layout) {
output_buffers.push_back(output_uint32_buffer);
}
return Tensor(
std::move(MultiDeviceHostStorage{output_buffers, storage.shapes}),
std::move(MultiDeviceHostStorage{storage.strategy, output_buffers, storage.shapes}),
tensor.get_legacy_shape(),
DataType::BFLOAT8_B,
target_layout
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tensor/tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ inline Tensor to_layout(const Tensor& tensor, Layout target_layout) {
output_buffers.push_back(output_buffer);
output_shapes.push_back(storage.shapes[i]);
}
return MultiDeviceHostStorage{output_buffers, output_shapes};
return MultiDeviceHostStorage{storage.strategy, output_buffers, output_shapes};
} else if constexpr (std::is_same_v<StorageType, DeviceStorage>) {
TT_THROW("Device storage isn't supported");
} else if constexpr (std::is_same_v<StorageType, MultiDeviceStorage>) {
Expand Down
19 changes: 15 additions & 4 deletions tt_eager/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,19 @@ std::vector<Tensor> get_tensors_from_multi_device_storage(const Tensor& multi_de
return tensors;
}

DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& tensor) {
if (tensor.storage_type() == StorageType::MULTI_DEVICE) {
const auto& tensor_storage = std::get<MultiDeviceStorage>(tensor.get_storage());
return tensor_storage.strategy;
}
else if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) {
const auto& tensor_storage = std::get<MultiDeviceHostStorage>(tensor.get_storage());
return tensor_storage.strategy;
}
TT_THROW("Tensor is not a multi-device tensor");
}

Tensor create_multi_device_tensor(const std::vector<Tensor>& tensors, StorageType storage_type) {
Tensor create_multi_device_tensor(const std::vector<Tensor>& tensors, StorageType storage_type, const DistributedTensorConfig& strategy) {
if (tensors.empty()) {
TT_THROW("Cannot create multi-device tensor with empty tensor list");
}
Expand All @@ -264,7 +275,7 @@ Tensor create_multi_device_tensor(const std::vector<Tensor>& tensors, StorageTyp
shapes.insert({device->id(), tensor.get_legacy_shape()});
}
return Tensor{
MultiDeviceStorage{device_buffers, shapes},
MultiDeviceStorage{strategy, device_buffers, shapes},
tensors.at(0).get_legacy_shape(),
tensors.at(0).get_dtype(),
tensors.at(0).get_layout()
Expand All @@ -277,7 +288,7 @@ Tensor create_multi_device_tensor(const std::vector<Tensor>& tensors, StorageTyp
shapes.push_back(tensor.get_legacy_shape());
}
return Tensor{
MultiDeviceHostStorage{owned_buffers, shapes},
MultiDeviceHostStorage{strategy, owned_buffers, shapes},
tensors.at(0).get_legacy_shape(),
tensors.at(0).get_dtype(),
tensors.at(0).get_layout()
Expand All @@ -292,7 +303,7 @@ Tensor transform(const Tensor& tensor, std::function<Tensor(const Tensor&)> tran
std::vector<Tensor> output_tensors(input_tensors.size());
std::transform(input_tensors.begin(), input_tensors.end(), output_tensors.begin(),
[&](const auto& device_tensor) { return transform_func(device_tensor); });
return create_multi_device_tensor(output_tensors, tensor.storage_type());
return create_multi_device_tensor(output_tensors, tensor.storage_type(), get_distributed_tensor_config_from_tensor(tensor));
}

void apply(const Tensor& tensor, std::function<void(const Tensor&)> callable) {
Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tensor/tensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ inline bool any_tensor_on_multi_device(const std::vector<ttnn::Tensor>& tensors)
return false;
}

DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& tensor);

} // namespace tt_metal

} // namespace tt
30 changes: 30 additions & 0 deletions tt_eager/tensor/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,26 @@ namespace tt {

namespace tt_metal {

static DistributedTensorConfig create_shard_distributed_tensor_config(const std::unordered_map<std::string, std::string>& metadata) {
return ShardTensor(std::stoi(metadata.at("shard_dim")));
}
static DistributedTensorConfig create_replicate_distributed_tensor_config(const std::unordered_map<std::string, std::string>& metadata) {
return ReplicateTensor{};
}

DistributedTensorConfig get_distributed_tensor_config(const std::unordered_map<std::string, std::string>& metadata) {
if (auto it = metadata.find("strategy"); it != metadata.end()) {
const std::string& strategy = it->second;
if (strategy == "shard") {
return create_shard_distributed_tensor_config(metadata);

} else if (strategy == "replicate") {
return create_replicate_distributed_tensor_config(metadata);
}
}
TT_THROW("Unsupported DistributedTensorConfig strategy:");
}


tt::DataFormat datatype_to_dataformat_converter(tt::tt_metal::DataType datatype) {
switch (datatype) {
Expand Down Expand Up @@ -133,6 +153,16 @@ const uint32_t Shape::get_normalized_index(std::int64_t index) const {
return normalized_index;
}

bool operator==(const ReplicateTensor&, const ReplicateTensor&) {
return true; // All instances are considered equal because there are no data members.
}
bool operator==(const AllGatherTensor&, const AllGatherTensor&) {
return true; // All instances are considered equal because there are no data members.
}
bool operator==(const ShardTensor& lhs, const ShardTensor& rhs) {
return lhs.shard_dimension == rhs.shard_dimension; // Equal if they have the same shard_dimension.
}

bool operator==(const Shape& shape_a, const Shape& shape_b) {
if (shape_a.rank() != shape_b.rank()) {
return false;
Expand Down
Loading

0 comments on commit 6a41653

Please sign in to comment.