Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8815: Updating bcast to no longer use run with auto format #9360

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/ttnn/unit_tests/test_model_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def forward(self, x):
output_tensor = ttnn.to_device(output_tensor, device=device)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.relu(output_tensor)
print(f"Output shape is {output_tensor.shape}")
output_tensor = ttnn.permute(output_tensor, (0, 3, 1, 2))
print(f"Expected shape after permute is ttnn.Shape([1, 128, 28, 28[32]]")
print(f"Actual shape is {output_tensor.shape}")
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.get_fallback_function(ttnn.reshape)(output_tensor, (-1, num_output_channels))
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,8 @@ const Shape Tensor::strides() const { return detail::compute_strides(this->get_l

uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_legacy_shape()); }

uint32_t Tensor::intended_volume() const { return tt::tt_metal::compute_volume(this->get_shape()); }

Tensor create_device_tensor(
const Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) {
ZoneScoped;
Expand Down
1 change: 1 addition & 0 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ struct Tensor {
StorageType storage_type() const;
const Shape strides() const;
uint32_t volume() const;
uint32_t intended_volume() const;

bool is_allocated() const;

Expand Down
18 changes: 18 additions & 0 deletions tt_eager/tt_dnn/op_library/auto_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ Tensor AutoFormat::move_tensor_to_mem_config(const Tensor& input, const MemoryCo
}
}

// This code is a workaround for cases where we need to remove autoformat but other dependent ops
// are not quite ready. So here we basically just put the tensor back on device.
// Used in backward_ops.cpp
// See: Remove auto format within permute_op.cpp #9404
Tensor AutoFormat::move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional<MemoryConfig> target_mem_config){
const auto intended_shape = input.get_shape();
const auto device_shape = input.get_legacy_shape();
const auto new_intended_shape = std::array<std::uint32_t, 4>{intended_shape[0], intended_shape[1], intended_shape[-2], intended_shape[-1]};
const auto new_device_shape = std::array<std::uint32_t, 4>{
device_shape[0],
device_shape[1],
(device_shape[-2] % TILE_HEIGHT != 0 ? (device_shape[-2] / TILE_HEIGHT + 1) * TILE_HEIGHT : device_shape[-2]),
(device_shape[-1] % TILE_WIDTH != 0 ? (device_shape[-1] / TILE_WIDTH + 1) * TILE_WIDTH : device_shape[-1])
};
const auto new_shape = tt_metal::Shape(new_intended_shape, new_device_shape);
return AutoFormat::format_input_tensor(input, device, new_shape, 0.0, target_layout, target_mem_config);
}

Tensor AutoFormat::format_input_tensor(
const Tensor& input,
Device* device,
Expand Down
14 changes: 10 additions & 4 deletions tt_eager/tt_dnn/op_library/auto_format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ class AutoFormat {


static Shape pad_to_tile_shape(const Shape& unpadded_shape, bool pad_c=false, bool pad_n=false, bool pad_h=true, bool pad_w=true) {
auto n = pad_n ? round_up(unpadded_shape[0], TILE_HEIGHT) : unpadded_shape[0];
auto c = pad_c ? round_up(unpadded_shape[1], TILE_WIDTH) : unpadded_shape[1];
auto h = pad_h ? round_up(unpadded_shape[2], TILE_HEIGHT) : unpadded_shape[2];
auto w = pad_w ? round_up(unpadded_shape[3], TILE_WIDTH) : unpadded_shape[3];
auto n = pad_n ? round_up(unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1, TILE_HEIGHT) : unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1;
auto c = pad_c ? round_up(unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1, TILE_WIDTH) : unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1;
auto h = pad_h ? round_up(unpadded_shape[-2], TILE_HEIGHT) : unpadded_shape[-2];
auto w = pad_w ? round_up(unpadded_shape[-1], TILE_WIDTH) : unpadded_shape[-1];
Shape padded_shape = {n, c, h, w};
return padded_shape;
}
Expand Down Expand Up @@ -83,6 +83,12 @@ class AutoFormat {
return false;
}

// This code is a workaround for cases where we need to remove autoformat but other dependent ops
// are not quite ready. So here we basically just put the tensor back on device.
// Used in backward_ops.cpp
// See: Remove auto format within permute_op.cpp #9404
static Tensor move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional<MemoryConfig> target_mem_config);

static Tensor move_tensor_to_device(const Tensor &input, Device * device, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

static Tensor move_tensor_to_mem_config(const Tensor &input, const MemoryConfig& mem_config);
Expand Down
10 changes: 10 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2170,6 +2170,11 @@ std::vector<Tensor> _prod_bw(
std::vector<int64_t> after_permute_dims = {0, 2, 3, 1};
Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config);
Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config);

// put the tensor back on device because permute throws it off device
// See: Remove auto format within permute_op.cpp #9404
tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config());

after_permute_dims = {0, 3, 1, 2};
Tensor result = permute(
bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config),
Expand Down Expand Up @@ -2202,6 +2207,11 @@ std::vector<Tensor> _prod_bw(
std::vector<int64_t> after_permute_dims = {3, 1, 2, 0};
Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config);
Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config);

// put the tensor back on device because permute throws it off device
// See: Remove auto format within permute_op.cpp #9404
tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config());

Tensor result = permute(
bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config),
after_permute_dims,
Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ inline Tensor bcast(
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))};
operation::launch_with_autoformat(
operation::launch_op(
[bcast_op, bcast_dim, output_mem_config] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
using tt::constants::TILE_HEIGHT;
using tt::constants::TILE_WIDTH;
Expand Down Expand Up @@ -109,7 +109,7 @@ inline Tensor bcast(
input_tensor_b.get_legacy_shape()[-1] == TILE_WIDTH));
}
}
return operation::run_with_autoformat(
return operation::run(
EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config}, {input_tensor_a, input_tensor_b});
}, {input_tensor_a, input_tensor_b}, output_tensors);
return output_tensors.at(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace tt_metal {
operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math/*, BcastOpDim bcast_dim*/){
const auto ashape = a.get_legacy_shape();
const auto bshape = b.get_legacy_shape();
uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3];
uint32_t bN = bshape[0], bC = bshape[1], bH = bshape[2], bW = bshape[3];
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1, C = ashape.rank() >= 3 ? ashape[-3] : 1, H = ashape[-2], W = ashape[-1];
uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1, bC = bshape.rank() >= 3 ? bshape[-3] : 1, bH = bshape[-2], bW = bshape[-1];
uint32_t NC = N*C;


Expand Down
Loading
Loading