Skip to content

Commit

Permalink
[tt-train]Rremoved slow odd option for fp16 host->device conversion. (#…
Browse files Browse the repository at this point in the history
…15470)

### Problem description
Before @sminakov-tt fix we could not move fp16 vector with odd number of
elements in a last dimension.
### What's changed
Now always using fast path.

### Checklist
- [x] Post commit CI passes
- [x] Blackhole Post commit (if applicable)
- [x] Model regression CI testing passes (if applicable)
- [x] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
dmakoviichuk-tt authored Nov 26, 2024
1 parent 364e3fd commit 0390e0c
Showing 1 changed file with 2 additions and 15 deletions.
17 changes: 2 additions & 15 deletions tt-train/sources/ttml/core/tt_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#include <fmt/color.h>

#include <algorithm>
#include <core/ttnn_all_includes.hpp>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <optional>
#include <stdexcept>
#include <core/ttnn_all_includes.hpp>

namespace {

Expand Down Expand Up @@ -197,15 +197,6 @@ tt::tt_metal::Tensor from_vector<float, DataType::BFLOAT16>(
// remove possible paddings from the shape (it conflicts with ROW MAJOR)
auto output = tt::tt_metal::Tensor(OwnedStorage{owned_buffer}, logical_shape, data_type, Layout::ROW_MAJOR);

auto to_device_odd_slow = [&]() {
if (layout == Layout::TILE) {
output = ttnn::to_layout(output, layout, std::nullopt, output_mem_config, device);
}

output = ttnn::to_device(output, device, output_mem_config);
return output;
};

auto to_device_even_fast = [&]() {
output = ttnn::to_device(output, device, output_mem_config);
if (layout == Layout::TILE) {
Expand All @@ -215,11 +206,7 @@ tt::tt_metal::Tensor from_vector<float, DataType::BFLOAT16>(
return output;
};

if (shape[-1] % 2 == 1) {
output = to_device_odd_slow();
} else {
output = to_device_even_fast();
}
output = to_device_even_fast();

return output;
}
Expand Down

0 comments on commit 0390e0c

Please sign in to comment.