Skip to content

Commit

Permalink
#0: switch from Span to SmallVector to stop UB
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Dec 16, 2024
1 parent 14829c6 commit 0cdec3d
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,13 @@ inline bool is_on_device(const Tensor& t) {

ttnn::Tensor permute_impl(
const ttnn::Tensor& a,
const tt::stl::Span<const uint32_t>& dims,
const ttnn::SmallVector<uint32_t>& dims,
const MemoryConfig& output_mem_config,
const std::optional<float>& pad_value) {
using ttnn::operations::experimental::auto_format::AutoFormat;
Device* device;

// Get the device
if (a.storage_type() != StorageType::DEVICE) {
device = AutoFormat::GetDefaultDevice();
TT_ASSERT(device != nullptr, "Requires setting default device if no inputs to op are on device");
} else {
device = a.device();
}
Device* device = a.device();

if (a.get_shape().rank() > 4) {
auto input = a.get_layout() == Layout::TILE
Expand All @@ -56,9 +50,6 @@ ttnn::Tensor permute_impl(
TT_FATAL(dims.size() == 4, "Only 4D tensor are supported for permute.");
uint32_t N = dims[0], C = dims[1], H = dims[2], W = dims[3];

// Convert tensor back to original
auto input_shape = a.get_logical_shape();

auto formatted_input_tensor = a;
// WH and CN should be supported without typecast
bool wh = N == 0 && C == 1 && H == 3 && W == 2;
Expand Down Expand Up @@ -134,13 +125,14 @@ ttnn::Tensor permute_impl(
} else {
TT_ASSERT(false, "Illegal permute args");
}
// Convert tensor back to original dtype if typecast was performed
output = typecast ? ttnn::typecast(output, DataType::BFLOAT8_B) : output;
return output;
}

ttnn::Tensor permute_launch(
const ttnn::Tensor& a,
tt::stl::Span<const uint32_t> dims,
const ttnn::SmallVector<uint32_t>& dims,
const MemoryConfig& output_mem_config,
const std::optional<float>& pad_value) {
std::vector<ttnn::Tensor> output_tensors = {ttnn::Tensor(operation::get_workers_for_op_output({a}))};
Expand Down Expand Up @@ -190,7 +182,6 @@ ttnn::Tensor ExecutePermute::invoke(
return ttnn::to_memory_config(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}

const auto input_layout = input_tensor.get_layout();
auto adjust_order = [](tt::stl::Span<const uint32_t> dims) {
ttnn::SmallVector<uint32_t> new_order;
TT_FATAL(dims.size() <= 4, "Minimum rank of tensor required is 4");
Expand All @@ -206,6 +197,7 @@ ttnn::Tensor ExecutePermute::invoke(
auto itensor = (input_tensor.get_logical_shape().rank() < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor;
auto iorder = normalized_dims.size() < 4 ? adjust_order(normalized_dims) : normalized_dims;

const auto input_layout = input_tensor.get_layout();
auto output_tensor =
detail::permute_launch(itensor, iorder, memory_config.value_or(input_tensor.memory_config()), pad_value);
output_tensor = ttnn::to_layout(output_tensor, input_layout, std::nullopt, std::nullopt, (Device*)nullptr);
Expand Down

0 comments on commit 0cdec3d

Please sign in to comment.