Skip to content

Commit

Permalink
Fix compilation of test_create_tensor.cpp (#13506)
Browse files Browse the repository at this point in the history
#0: Fix compilation of test_create_tensor.cpp
  • Loading branch information
ayerofieiev-tt authored Oct 5, 2024
1 parent e400377 commit 3d33e8d
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

#include "ttnn_test_fixtures.hpp"

void run_create_tensor_test(tt::tt_metal::Device* device, ttnn::Shape input_shape) {
void run_create_tensor_test(tt::tt_metal::Device* device, ttnn::SimpleShape input_shape) {
MemoryConfig mem_cfg = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED,
.buffer_type = BufferType::DRAM,
Expand All @@ -32,11 +32,12 @@ void run_create_tensor_test(tt::tt_metal::Device* device, ttnn::Shape input_shap
host_data[i] = 1;
}

auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, input_shape, dtype, Layout::TILE, mem_cfg);
ttnn::Shape shape(input_shape.as_vector());
auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, shape, dtype, Layout::TILE, mem_cfg);

auto input_storage = tt::tt_metal::DeviceStorage{input_buffer};

Tensor input_tensor = Tensor(input_storage, input_shape, dtype, Layout::TILE);
Tensor input_tensor = Tensor(input_storage, shape, dtype, Layout::TILE);
tt::log_debug("input_data: \n {}", input_tensor.write_to_string());

ttnn::write_buffer(io_cq, input_tensor, {host_data});
Expand All @@ -51,7 +52,7 @@ void run_create_tensor_test(tt::tt_metal::Device* device, ttnn::Shape input_shap
}

struct CreateTensorParams {
ttnn::Shape shape;
ttnn::SimpleShape shape;
};

class CreateTensorTest : public ttnn::TTNNFixtureWithDevice, public ::testing::WithParamInterface<CreateTensorParams> {};
Expand All @@ -65,6 +66,6 @@ INSTANTIATE_TEST_SUITE_P(
CreateTensorTestWithShape,
CreateTensorTest,
::testing::Values(
CreateTensorParams{.shape=ttnn::Shape({1, 1, 32, 32})}
CreateTensorParams{.shape=ttnn::SimpleShape({1, 1, 32, 32})}
)
);

0 comments on commit 3d33e8d

Please sign in to comment.