Skip to content

Commit

Permalink
#0: fix memory config and kernel for blackhole
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Dec 24, 2024
1 parent 1c1c4a5 commit 6b52aea
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,13 @@ def test_transpose_hw_sharded_tiled_n_cores(device, n, c, h, w):
def test_transpose_hw_rm(shape, device):
torch_input = torch.randn(shape, dtype=torch.bfloat16)
torch_output = torch_input.transpose(2, 3)
tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
tt_input = ttnn.from_torch(
torch_input,
dtype=ttnn.DataType.BFLOAT16,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
tt_output = ttnn.transpose(tt_input, 2, 3)
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ struct ToMemoryConfig {
std::optional<ttnn::DataType> dtype = std::nullopt) {
// Temporary until we see why buffer data not being populated
const auto original_shape = tensor.get_shape();

const auto original_memory_config = ttnn::get_memory_config(tensor);
if (original_memory_config.has_value() && original_memory_config.value() == memory_config) {
return tensor;
Expand Down
7 changes: 2 additions & 5 deletions ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp"
#include "tt_metal/common/constants.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.hpp"

#include "fold.hpp"

Expand Down Expand Up @@ -63,8 +62,7 @@ std::vector<Tensor> fold_with_transpose_(
tt::log_debug("pad_output: {}", pad_output.shape());

// transpose
auto transpose_hw_output = operation::run(Transpose{TransposeOpDim::WH, L1_mem_config, 0.0f}, {pad_output})
.at(0); // ttnn::transpose(pad_output, 2, 3, L1_mem_config);
auto transpose_hw_output = ttnn::transpose(pad_output, 2, 3, L1_mem_config);

tt::log_debug("transpose_hw_output: {}", transpose_hw_output.shape());

Expand All @@ -82,8 +80,7 @@ std::vector<Tensor> fold_with_transpose_(
tt::log_debug("reshape_hc_output: {}", reshape_hc_output.shape());

// transpose
auto transpose_hw_output2 =
operation::run(Transpose{TransposeOpDim::WH, L1_mem_config, 0.0f}, {reshape_hc_output}).at(0);
auto transpose_hw_output2 = ttnn::transpose(reshape_hc_output, 2, 3, L1_mem_config);

tt::log_debug("transpose_hw_output2: {}", transpose_hw_output2.shape());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ PermuteDeviceOperation::tensor_return_value_t PermuteDeviceOperation::create_out
}
auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
const auto& input_tensor = tensor_args.input_tensor;
return create_device_tensor(output_shape, input_tensor.dtype(), input_tensor.layout(), input_tensor.device());
return create_device_tensor(
output_shape,
input_tensor.dtype(),
input_tensor.layout(),
input_tensor.device(),
operation_attributes.output_mem_config);
}

std::tuple<PermuteDeviceOperation::operation_attributes_t, PermuteDeviceOperation::tensor_args_t>
Expand Down

0 comments on commit 6b52aea

Please sign in to comment.