Skip to content

Commit

Permalink
Revert "Fix ttnn.from_torch for 0D/1D tensors with tile layout (#15882)…
Browse files Browse the repository at this point in the history
…" (#16222)

### Ticket
Link to Github Issue

### Problem description
Recently merged PR caused a lot of strange failures for T3K Resnet,
Stable Diffusion, and Forge.
Reverting PR for now and will investigate the issues more closely.

### What's changed
This reverts commit 444b0dc.

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12426434675)
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Dec 20, 2024
1 parent 2336024 commit 2edaca7
Show file tree
Hide file tree
Showing 26 changed files with 396 additions and 411 deletions.
2 changes: 1 addition & 1 deletion tests/tt_eager/ops/test_tilize_zero_padding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ int main(int argc, char** argv) {
log_debug(LogTest, "Moving src data to host to validate");
Tensor host_a = a.cpu(); // Move tensor a to host to validate
// TODO: Update when tensor.pad_to_tile() function is added
auto padded_shape = a.get_padded_shape();
auto padded_shape = a.get_legacy_shape();
padded_shape[2] = round_up(padded_shape[2], TILE_HEIGHT);
padded_shape[3] = round_up(padded_shape[3], TILE_WIDTH);
Tensor padded_host_a = host_a.pad(padded_shape, ttnn::SimpleShape{0, 0, 0, 0}, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int main(int argc, char** argv) {
Tensor host_a = a.cpu(); // Move tensor a to host to validate
Tensor g = Tensor(host_a.get_storage(), shape, DataType::BFLOAT16, Layout::ROW_MAJOR);
// TODO: Update when tensor.pad_to_tile() function is added
auto padded_shape = g.get_padded_shape();
auto padded_shape = g.get_legacy_shape();
padded_shape[2] = round_up(padded_shape[2], TILE_HEIGHT);
padded_shape[3] = round_up(padded_shape[3], TILE_WIDTH);
Tensor padded_g = g.pad(padded_shape, ttnn::SimpleShape{0, 0, 0, 0}, 0);
Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,14 +1230,14 @@ def test_matmul_with_matched_width_height(device, m_size, k_size, n_size):
def test_matmul_with_matched_width_height_from_1D(device, k_size, n_size):
torch.manual_seed(0)

torch_input_tensor_a = torch.rand((1, k_size), dtype=torch.bfloat16)
torch_input_tensor_a = torch.rand((k_size), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((k_size, n_size), dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
output = input_tensor_a @ input_tensor_b
output = ttnn.to_torch(output)
output = ttnn.to_torch(output, torch_rank=1)

assert len(output.shape) == len(torch_output_tensor.shape)
assert output.shape == torch_output_tensor.shape
Expand Down
19 changes: 0 additions & 19 deletions tests/ttnn/unit_tests/test_to_and_from_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,3 @@ def test_from_torch_large(device):
x_tensor = ttnn.from_torch(torch_x, layout=ttnn.TILE_LAYOUT)
x_tensor = ttnn.to_torch(x_tensor)
assert torch.allclose(torch_x, x_tensor)


@pytest.mark.parametrize(
"shape",
[
(),
(1),
(2),
(0),
],
)
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_to_for_01_rank(shape, layout, dtype):
torch_input_tensor = torch.rand(shape, dtype=dtype)
tensor = ttnn.from_torch(torch_input_tensor, layout=layout)
# Conversion in the opposite direction is not yet supported
# torch_output_tensor = ttnn.to_torch(tensor)
# assert torch.allclose(torch_input_tensor, torch_output_tensor)
22 changes: 11 additions & 11 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,30 +115,31 @@ Tensor convert_float_vector_to_tt_tensor(
Layout::TILE,
layout);
}
auto result_cpu_spec = TensorSpec(
ttnn::SimpleShape(shape), TensorLayout(data_type, PageConfig(Layout::TILE, tile), MemoryConfig{}));
auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), DataType::FLOAT32);
auto float_tensor = Tensor(OwnedStorage{owned_buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, tile);
if (result_cpu_spec.logical_shape() != result_cpu_spec.padded_shape()) {
float_tensor =
tensor_ops::tensor_pad(float_tensor, result_cpu_spec.padded_shape(), ttnn::SimpleShape{0, 0, 0, 0}, 0);
auto tile_val = tile.value_or(Tile());
if (shape[2] % tile_val.get_height() != 0 || shape[3] % tile_val.get_width() != 0) {
auto padded_shape = shape;
padded_shape[2] = tt::round_up(shape[2], tile_val.get_height());
padded_shape[3] = tt::round_up(shape[3], tile_val.get_width());

float_tensor = tensor_ops::tensor_pad(
float_tensor, LegacyShape(shape, padded_shape), ttnn::SimpleShape{0, 0, 0, 0}, 0);
}
auto output_float_data = owned_buffer::get_as<float>(float_tensor.to(Layout::TILE)).get();
auto output_packed_data =
data_type == DataType::BFLOAT8_B
? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile)
: pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile);
auto output_buffer = owned_buffer::create<uint32_t>(std::move(output_packed_data));
auto tensor = Tensor(std::move(OwnedStorage{std::move(output_buffer)}), result_cpu_spec);
auto tensor = Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile);
if (device) {
return tensor.to(device, memory_config.value_or(MemoryConfig{}));
}
return tensor;
}
auto result_cpu_spec = TensorSpec(
ttnn::SimpleShape(shape), TensorLayout(data_type, PageConfig(Layout::ROW_MAJOR, tile), MemoryConfig{}));
auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), data_type);
auto tensor = Tensor(OwnedStorage{owned_buffer}, result_cpu_spec).to(layout);
auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, Layout::ROW_MAJOR, tile).to(layout);
if (device) {
return tensor.to(device, memory_config.value_or(MemoryConfig{}));
}
Expand Down Expand Up @@ -1211,8 +1212,7 @@ void pytensor_module(py::module& m_tensor) {
const std::array<uint32_t, 4>& output_tensor_shape,
const std::array<uint32_t, 4>& input_tensor_start,
float pad_value) {
return self.pad(
ttnn::SimpleShape(output_tensor_shape), ttnn::SimpleShape(input_tensor_start), pad_value);
return self.pad(output_tensor_shape, ttnn::SimpleShape(input_tensor_start), pad_value);
},
R"doc(
Pad TT Tensor with given pad value ``arg2``.
Expand Down
Loading

0 comments on commit 2edaca7

Please sign in to comment.