Skip to content

Commit

Permalink
#9302: Add additional composite ops
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Jun 12, 2024
1 parent 2a4a22c commit 80800fc
Show file tree
Hide file tree
Showing 14 changed files with 158 additions and 80 deletions.
3 changes: 3 additions & 0 deletions tests/ttnn/unit_tests/test_model_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def forward(self, x):
output_tensor = ttnn.to_device(output_tensor, device=device)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.relu(output_tensor)
print(f"Output shape is {output_tensor.shape}")
output_tensor = ttnn.permute(output_tensor, (0, 3, 1, 2))
print(f"Expected shape after permute is ttnn.Shape([1, 128, 28, 28[32]]")
print(f"Actual shape is {output_tensor.shape}")
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.get_fallback_function(ttnn.reshape)(output_tensor, (-1, num_output_channels))
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,8 @@ const Shape Tensor::strides() const { return detail::compute_strides(this->get_l

uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_legacy_shape()); }

uint32_t Tensor::intended_volume() const { return tt::tt_metal::compute_volume(this->get_shape()); }

Tensor create_device_tensor(
const Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) {
ZoneScoped;
Expand Down
1 change: 1 addition & 0 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ struct Tensor {
StorageType storage_type() const;
const Shape strides() const;
uint32_t volume() const;
uint32_t intended_volume() const;

bool is_allocated() const;

Expand Down
18 changes: 18 additions & 0 deletions tt_eager/tt_dnn/op_library/auto_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ Tensor AutoFormat::move_tensor_to_mem_config(const Tensor& input, const MemoryCo
}
}

// This code is a workaround for cases where we need to remove autoformat but other dependent ops
// are not quite ready. So here we basically just put the tensor back on device.
// Used in backward_ops.cpp
// See: Remove auto format within permute_op.cpp #9404
Tensor AutoFormat::move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional<MemoryConfig> target_mem_config){
const auto intended_shape = input.get_shape();
const auto device_shape = input.get_legacy_shape();
const auto new_intended_shape = std::array<std::uint32_t, 4>{intended_shape[0], intended_shape[1], intended_shape[-2], intended_shape[-1]};
const auto new_device_shape = std::array<std::uint32_t, 4>{
device_shape[0],
device_shape[1],
(device_shape[-2] % TILE_HEIGHT != 0 ? (device_shape[-2] / TILE_HEIGHT + 1) * TILE_HEIGHT : device_shape[-2]),
(device_shape[-1] % TILE_WIDTH != 0 ? (device_shape[-1] / TILE_WIDTH + 1) * TILE_WIDTH : device_shape[-1])
};
const auto new_shape = tt_metal::Shape(new_intended_shape, new_device_shape);
return AutoFormat::format_input_tensor(input, device, new_shape, 0.0, target_layout, target_mem_config);
}

Tensor AutoFormat::format_input_tensor(
const Tensor& input,
Device* device,
Expand Down
14 changes: 10 additions & 4 deletions tt_eager/tt_dnn/op_library/auto_format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ class AutoFormat {


static Shape pad_to_tile_shape(const Shape& unpadded_shape, bool pad_c=false, bool pad_n=false, bool pad_h=true, bool pad_w=true) {
auto n = pad_n ? round_up(unpadded_shape[0], TILE_HEIGHT) : unpadded_shape[0];
auto c = pad_c ? round_up(unpadded_shape[1], TILE_WIDTH) : unpadded_shape[1];
auto h = pad_h ? round_up(unpadded_shape[2], TILE_HEIGHT) : unpadded_shape[2];
auto w = pad_w ? round_up(unpadded_shape[3], TILE_WIDTH) : unpadded_shape[3];
auto n = pad_n ? round_up(unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1, TILE_HEIGHT) : unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1;
auto c = pad_c ? round_up(unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1, TILE_WIDTH) : unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1;
auto h = pad_h ? round_up(unpadded_shape[-2], TILE_HEIGHT) : unpadded_shape[-2];
auto w = pad_w ? round_up(unpadded_shape[-1], TILE_WIDTH) : unpadded_shape[-1];
Shape padded_shape = {n, c, h, w};
return padded_shape;
}
Expand Down Expand Up @@ -83,6 +83,12 @@ class AutoFormat {
return false;
}

// This code is a workaround for cases where we need to remove autoformat but other dependent ops
// are not quite ready. So here we basically just put the tensor back on device.
// Used in backward_ops.cpp
// See: Remove auto format within permute_op.cpp #9404
static Tensor move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional<MemoryConfig> target_mem_config);

static Tensor move_tensor_to_device(const Tensor &input, Device * device, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

static Tensor move_tensor_to_mem_config(const Tensor &input, const MemoryConfig& mem_config);
Expand Down
10 changes: 10 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2170,6 +2170,11 @@ std::vector<Tensor> _prod_bw(
std::vector<int64_t> after_permute_dims = {0, 2, 3, 1};
Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config);
Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config);

// put the tensor back on device because permute throws it off device
// See: Remove auto format within permute_op.cpp #9404
tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config());

after_permute_dims = {0, 3, 1, 2};
Tensor result = permute(
bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config),
Expand Down Expand Up @@ -2202,6 +2207,11 @@ std::vector<Tensor> _prod_bw(
std::vector<int64_t> after_permute_dims = {3, 1, 2, 0};
Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config);
Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config);

// put the tensor back on device because permute throws it off device
// See: Remove auto format within permute_op.cpp #9404
tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config());

Tensor result = permute(
bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config),
after_permute_dims,
Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ inline Tensor bcast(
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))};
operation::launch_with_autoformat(
operation::launch_op(
[bcast_op, bcast_dim, output_mem_config] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
using tt::constants::TILE_HEIGHT;
using tt::constants::TILE_WIDTH;
Expand Down Expand Up @@ -109,7 +109,7 @@ inline Tensor bcast(
input_tensor_b.get_legacy_shape()[-1] == TILE_WIDTH));
}
}
return operation::run_with_autoformat(
return operation::run(
EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config}, {input_tensor_a, input_tensor_b});
}, {input_tensor_a, input_tensor_b}, output_tensors);
return output_tensors.at(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace tt_metal {
operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math/*, BcastOpDim bcast_dim*/){
const auto ashape = a.get_legacy_shape();
const auto bshape = b.get_legacy_shape();
uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3];
uint32_t bN = bshape[0], bC = bshape[1], bH = bshape[2], bW = bshape[3];
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1, C = ashape.rank() >= 3 ? ashape[-3] : 1, H = ashape[-2], W = ashape[-1];
uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1, bC = bshape.rank() >= 3 ? bshape[-3] : 1, bH = bshape[-2], bW = bshape[-1];
uint32_t NC = N*C;


Expand Down
Loading

0 comments on commit 80800fc

Please sign in to comment.