Skip to content

Commit

Permalink
Remove the tt-train stop gap, add more mesh device creation tests, dr…
Browse files Browse the repository at this point in the history
…op auto launch for full / full like
  • Loading branch information
omilyutin-tt committed Nov 26, 2024
1 parent 427b1e8 commit 2426a8d
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <optional>
#include <variant>

#include "buffers/buffer_constants.hpp"
Expand All @@ -24,11 +25,11 @@ using ::tt::tt_metal::TensorMemoryLayout;

class MultiDeviceTensorCreationTest : public T3kMultiDeviceFixture, public ::testing::WithParamInterface<bool> {};

TEST_P(MultiDeviceTensorCreationTest, CreateEmpty) {
TEST_P(MultiDeviceTensorCreationTest, Empty) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

const auto mesh_replicated_tensor = ttnn::empty(
const Tensor mesh_replicated_tensor = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
Expand All @@ -39,7 +40,133 @@ TEST_P(MultiDeviceTensorCreationTest, CreateEmpty) {
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

TEST_P(MultiDeviceTensorCreationTest, EmptyLike) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

ASSERT_FALSE(mesh_device->get_devices().empty());

const Tensor tensor = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
mesh_device->get_devices().at(0),
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE);
EXPECT_EQ(tensor.get_workers().size(), 1);

const Tensor mesh_replicated_tensor = ttnn::empty_like(
tensor,
DataType::BFLOAT16,
Layout::ROW_MAJOR,
*mesh_device,
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE);
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

TEST_P(MultiDeviceTensorCreationTest, Full) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

const Tensor mesh_replicated_tensor = ttnn::full(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
/*fill_value=*/42,
DataType::BFLOAT16,
Layout::ROW_MAJOR,
std::ref(*mesh_device),
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE);
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());
EXPECT_EQ(mesh_replicated_tensor.shape(), ttnn::SimpleShape({32, 32}));
EXPECT_EQ(mesh_replicated_tensor.dtype(), DataType::BFLOAT16);
EXPECT_EQ(mesh_replicated_tensor.layout(), Layout::ROW_MAJOR);

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

TEST_P(MultiDeviceTensorCreationTest, FullLike) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

ASSERT_FALSE(mesh_device->get_devices().empty());

Tensor tensor = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
mesh_device->get_devices().at(0),
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE);
EXPECT_EQ(tensor.get_workers().size(), 1);

Tensor mesh_replicated_tensor = ttnn::full_like(
tensor,
/*fill_value=*/42,
/*dtype=*/std::nullopt,
/*layout=*/std::nullopt,
std::ref(*mesh_device));

EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE);
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());
EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape());
EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype());
EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout());

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

TEST_P(MultiDeviceTensorCreationTest, FullLikeWithOptTensor) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

ASSERT_FALSE(mesh_device->get_devices().empty());

Tensor tensor = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
mesh_device->get_devices().at(0),
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE);
EXPECT_EQ(tensor.get_workers().size(), 1);

Tensor opt_output = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
mesh_device,
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

Tensor mesh_replicated_tensor = ttnn::full_like(
tensor,
/*fill_value=*/42,
/*dtype=*/std::nullopt,
/*layout=*/std::nullopt,
/*device=*/std::nullopt,
/*memory_config=*/std::nullopt,
opt_output);

EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE);
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());
EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape());
EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype());
EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout());

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

Expand Down
140 changes: 57 additions & 83 deletions tests/ttnn/unit_tests/operations/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,62 +32,6 @@ def test_zeros_like(device, input_shape):
assert torch.allclose(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32],
[5, 96, 64],
],
)
def test_zeros_like_bf8b(device, input_shape):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
torch_output_tensor = torch.zeros_like(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.zeros_like(input_tensor)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor).to(torch.bfloat16)

assert_with_pcc(torch_output_tensor, output_tensor, 0.9999)
assert torch.allclose(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32],
[5, 96, 64],
],
)
@pytest.mark.parametrize(
"layout",
[ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE],
)
def test_zeros_like_opt(device, layout, input_shape):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
torch_output_tensor = torch.zeros_like(torch_input_tensor)
opt_tensor = torch.ones(input_shape, dtype=torch.bfloat16)
opt_tensor = ttnn.from_torch(
opt_tensor, ttnn.bfloat16, layout=layout, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout)
input_tensor = ttnn.to_device(input_tensor, device)

cq_id = 0
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.zeros_like(input_tensor, optional_tensor=opt_tensor, queue_id=cq_id)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

assert ttnn.is_tensor_storage_on_device(opt_tensor)
opt_tensor = ttnn.from_device(opt_tensor)
opt_tensor = ttnn.to_torch(opt_tensor)

assert_with_pcc(torch_output_tensor, opt_tensor, 0.9999)
assert torch.allclose(torch_output_tensor, opt_tensor)


@pytest.mark.parametrize(
"input_shape",
[
Expand All @@ -110,35 +54,13 @@ def test_ones_like(device, input_shape):
assert torch.allclose(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32],
[5, 96, 64],
],
)
def test_ones_like_bf8b(device, input_shape):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
torch_output_tensor = torch.ones_like(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor = ttnn.to_device(input_tensor, device)
output_tensor = ttnn.ones_like(input_tensor)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor).to(torch.bfloat16)

assert_with_pcc(torch_output_tensor, output_tensor, 0.9999)
assert torch.allclose(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[[32, 32], [5, 96, 64], [1, 2, 64, 64], [1, 2, 4, 64, 64]],
)
@pytest.mark.parametrize(
"fill_value",
[-5, 3, 15, 25],
[-5.25, 0, 1.0],
)
def test_full_like(device, input_shape, fill_value):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
Expand All @@ -161,7 +83,7 @@ def test_full_like(device, input_shape, fill_value):
)
@pytest.mark.parametrize(
"fill_value",
[-5, 3, 15, 25],
[-5.25, 0, 1.0],
)
def test_full_like_bf8b(device, input_shape, fill_value):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
Expand All @@ -187,7 +109,7 @@ def test_full_like_bf8b(device, input_shape, fill_value):
)
@pytest.mark.parametrize(
"fill_value",
[-5, 3, 15, 25],
[-5.25, 0, 1.0],
)
@pytest.mark.parametrize(
"layout",
Expand Down Expand Up @@ -286,6 +208,7 @@ def test_full(device, input_shape, fill_value, layout):
[
[32, 32],
[5, 96, 64],
[1, 50257],
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -314,6 +237,34 @@ def test_full_with_opt_tensor(device, input_shape, layout, fill_value):
assert torch.allclose(torch_tensor, opt_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32],
[5, 96, 64],
[1, 50257],
],
)
@pytest.mark.parametrize(
"fill_value",
[-5.25, 0, 1.0],
)
@pytest.mark.parametrize(
"layout",
[ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE],
)
def test_full_multi_device(mesh_device, input_shape, fill_value, layout):
torch_tensor = torch.full(input_shape, dtype=torch.bfloat16, fill_value=fill_value)

tensor = ttnn.full(input_shape, device=mesh_device, fill_value=fill_value, layout=layout)
assert ttnn.is_tensor_storage_on_device(tensor)
output_tensors = ttnn.to_torch(tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device))

for output_tensor in output_tensors:
assert_with_pcc(torch_tensor, output_tensor, 0.9999)
assert torch.allclose(torch_tensor, output_tensor)


@pytest.mark.parametrize(
"start",
[4, 8, 16, 32],
Expand Down Expand Up @@ -403,7 +354,6 @@ def test_empty_multi_device(mesh_device, input_shapes):
)
def test_empty_like(device, input_shapes):
torch_input_tensor = torch.ones((input_shapes), dtype=torch.bfloat16)
torch_output_tensor = torch.empty(torch_input_tensor.shape, dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT)
input_tensor = ttnn.to_device(input_tensor, device)
Expand All @@ -412,4 +362,28 @@ def test_empty_like(device, input_shapes):
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert list(torch_output_tensor.shape) == list(output_tensor.shape)
assert list(torch_input_tensor.shape) == list(output_tensor.shape)


@pytest.mark.parametrize(
"input_shapes",
[
[2, 1, 4, 4], # 256x256
[2, 1280, 8, 8],
[2, 640, 16, 16],
[2, 1280, 8, 8], # 512x512
[2, 1280, 16, 16],
[2, 1280, 16, 16],
],
)
def test_empty_like_multi_device(mesh_device, input_shapes):
torch_input_tensor = torch.empty((input_shapes), dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT)
input_tensor = ttnn.to_device(input_tensor, mesh_device)
output_tensor = ttnn.empty_like(input_tensor, layout=ttnn.TILE_LAYOUT)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensors = ttnn.to_torch(output_tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device))
for output_tensor in output_tensors:
assert list(torch_input_tensor.shape) == list(output_tensor.shape)
5 changes: 2 additions & 3 deletions tt-train/sources/ttml/core/tt_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ tt::tt_metal::Tensor full(
(padded[2] + additional_padding_h),
(padded[3] + additional_padding_w),
});
// temporary solution to avoid using the device, and use only MeshDevice in highlevel api
return ttnn::full(padded_shape, value, dtype, Layout::TILE, std::ref(*device->get_device(0)));
return ttnn::full(padded_shape, value, dtype, Layout::TILE, std::ref(*device));
}
// if not padding available, we can just create a tensor with the given shape
return ttnn::full(shape, value, dtype, Layout::TILE, std::ref(*device->get_device(0)));
return ttnn::full(shape, value, dtype, Layout::TILE, std::ref(*device));
}

tt::tt_metal::Tensor zeros(const ttnn::Shape& shape, ttnn::distributed::MeshDevice* device, DataType dtype) {
Expand Down
8 changes: 3 additions & 5 deletions ttnn/cpp/ttnn/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,22 +424,20 @@ struct Arange {
} // namespace creation
} // namespace operations

constexpr auto full =
ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::full", ttnn::operations::creation::Full>();
constexpr auto full = ttnn::decorators::register_operation<"ttnn::full", ttnn::operations::creation::Full>();
constexpr auto zeros = ttnn::decorators::register_operation<"ttnn::zeros", ttnn::operations::creation::Zeros>();
constexpr auto ones = ttnn::decorators::register_operation<"ttnn::ones", ttnn::operations::creation::Ones>();
constexpr auto empty = ttnn::decorators::register_operation<"ttnn::empty", ttnn::operations::creation::Empty>();

constexpr auto full_like =
ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::full_like", ttnn::operations::creation::FullLike>();
ttnn::decorators::register_operation<"ttnn::full_like", ttnn::operations::creation::FullLike>();
constexpr auto zeros_like =
ttnn::decorators::register_operation<"ttnn::zeros_like", ttnn::operations::creation::ZerosLike>();
constexpr auto ones_like =
ttnn::decorators::register_operation<"ttnn::ones_like", ttnn::operations::creation::OnesLike>();
constexpr auto empty_like =
ttnn::decorators::register_operation<"ttnn::empty_like", ttnn::operations::creation::EmptyLike>();

constexpr auto arange =
ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::arange", ttnn::operations::creation::Arange>();
constexpr auto arange = ttnn::decorators::register_operation<"ttnn::arange", ttnn::operations::creation::Arange>();

} // namespace ttnn
Loading

0 comments on commit 2426a8d

Please sign in to comment.