Skip to content

Commit

Permalink
#8042: ttnn multi-chip support for tensor cloning
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed May 7, 2024
1 parent 048c081 commit 7c23ea8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 23 deletions.
13 changes: 13 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,16 @@ def test_slicing(device_mesh):
tensor = ttnn.to_device(tensor, device_mesh)
tensor = tensor[:, :, :, :1]
assert all([device_tensor.shape == tensor.shape for device_tensor in ttnn.get_device_tensors(tensor)])


def test_clone(device_mesh):
results_11BH = ttnn.from_torch(
torch.randn(1, 1, 32, 128),
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device_mesh,
mesh_mapper=ReplicateTensorToMesh(device_mesh),
)
results_11BH = ttnn.to_device(results_11BH, device_mesh)
results_11BH = ttnn.clone(results_11BH, dtype=ttnn.bfloat8_b, memory_config=ttnn.L1_MEMORY_CONFIG)
print(results_11BH)
43 changes: 21 additions & 22 deletions tt_eager/tensor/tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,32 +915,31 @@ inline std::string to_string(const Tensor& tensor, std::optional<DataType> origi
return to_string<T>(to_host<T>(tensor));
}

if (dtype == DataType::BFLOAT8_B and original_dtype == std::nullopt) {
// Convert to FLOAT32 tensor before printing
auto input_packed_data = owned_buffer::get_as<uint32_t>(tensor).get();
auto input_float_data =
unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false);
auto input_float_buffer = owned_buffer::create<float>(std::move(input_float_data));
auto float_tensor =
Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout());
return to_string<float>(float_tensor, tensor.get_dtype());
}

if (dtype == DataType::BFLOAT4_B and original_dtype == std::nullopt) {
// Convert to FLOAT32 tensor before printing
auto input_packed_data = owned_buffer::get_as<uint32_t>(tensor).get();
auto input_float_data =
unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false);
auto input_float_buffer = owned_buffer::create<float>(std::move(input_float_data));
auto float_tensor =
Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout());
return to_string<float>(float_tensor, tensor.get_dtype());
}

return std::visit(
[&](auto&& storage) -> std::string {
using StorageType = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<StorageType, OwnedStorage>) {
if (dtype == DataType::BFLOAT8_B and original_dtype == std::nullopt) {
// Convert to FLOAT32 tensor before printing
auto input_packed_data = owned_buffer::get_as<uint32_t>(tensor).get();
auto input_float_data =
unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false);
auto input_float_buffer = owned_buffer::create<float>(std::move(input_float_data));
auto float_tensor =
Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout());
return to_string<float>(float_tensor, tensor.get_dtype());
}

if (dtype == DataType::BFLOAT4_B and original_dtype == std::nullopt) {
// Convert to FLOAT32 tensor before printing
auto input_packed_data = owned_buffer::get_as<uint32_t>(tensor).get();
auto input_float_data =
unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false);
auto input_float_buffer = owned_buffer::create<float>(std::move(input_float_data));
auto float_tensor =
Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout());
return to_string<float>(float_tensor, tensor.get_dtype());
}
const auto buffer = owned_buffer::get_as<T>(storage.buffer);
return detail::to_string(buffer, shape, dtype, layout);
} else if constexpr (std::is_same_v<StorageType, BorrowedStorage>) {
Expand Down
8 changes: 7 additions & 1 deletion tt_eager/tt_dnn/op_library/copy/copy_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ Tensor copy(const Tensor& src_tensor, const Tensor& dst_tensor) {
}

Tensor clone(const Tensor& input, const MemoryConfig& output_mem_config, std::optional<const DataType> output_dtype) {
return operation::run(Copy{output_mem_config, output_dtype.value_or(input.get_dtype())}, {input}).at(0);
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input}))};
operation::launch_op(
[output_mem_config, output_dtype] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
const auto& input = input_tensors.at(0);
return operation::run(Copy{output_mem_config, output_dtype.value_or(input.get_dtype())}, {input});
}, {input}, output_tensors);
return output_tensors.at(0);
}

Tensor typecast(const Tensor& input_tensor, const DataType& dtype, const MemoryConfig& output_mem_config ) {
Expand Down

0 comments on commit 7c23ea8

Please sign in to comment.