diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index f948be785c3..d0757378ea6 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -62,7 +62,9 @@ Tensor to_layout_impl_on_device( if (!requires_padding_change(tensor_arg, layout)) { if (layout == ttnn::ROW_MAJOR_LAYOUT) { - TT_FATAL(!dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); + TT_FATAL( + !dtype.has_value() || dtype.value() == tensor_arg.dtype(), + "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); return ttnn::untilize(tensor_arg, output_memory_config, use_multicore_untilize); } return ttnn::tilize(tensor_arg, output_memory_config, dtype, use_multicore_tilize); @@ -71,7 +73,9 @@ Tensor to_layout_impl_on_device( auto tensor_shape = tensor_arg.get_logical_shape(); if (layout == ttnn::ROW_MAJOR_LAYOUT) { - TT_FATAL(!dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); + TT_FATAL( + !dtype.has_value() || dtype.value() == tensor_arg.dtype(), + "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); if (tensor_arg.is_sharded()) { const auto memory_config = tensor_arg.memory_config();