Skip to content

Commit

Permalink
#13707: Replace LegacyShape with SimpleShape
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 16, 2024
1 parent aa97cfb commit c0c23ba
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,10 @@ void AllGather::validate(const std::vector<Tensor> &input_tensors) const {
}
}

std::vector<tt::tt_metal::LegacyShape> AllGather::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
auto shape = input_tensors[0].get_legacy_shape();
std::vector<ttnn::SimpleShape> AllGather::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
auto shape = input_tensors[0].get_padded_shape(); // TODO: Replace with get_logical_shape()
shape[this->dim] *= this->ring_size;
return std::vector<tt::tt_metal::LegacyShape>(input_tensors.size(), shape);
return std::vector<ttnn::SimpleShape>(input_tensors.size(), shape);
}

std::vector<Tensor> AllGather::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct AllGather {
const ccl::Topology topology;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ void AllGatherMatmul::validate(const std::vector<Tensor> &input_tensors, const s
}
}

std::vector<tt::tt_metal::LegacyShape> AllGatherMatmul::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
std::vector<ttnn::SimpleShape> AllGatherMatmul::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {

// All Gather shape
tt::tt_metal::LegacyShape all_gather_output_shape = this->all_gather_struct.compute_output_shapes({input_tensors[0]})[0];
tt::tt_metal::LegacyShape datacopy_output_shape = all_gather_output_shape;
ttnn::SimpleShape all_gather_output_shape = this->all_gather_struct.compute_output_shapes({input_tensors[0]})[0];
ttnn::SimpleShape datacopy_output_shape = all_gather_output_shape;


// Matmul shape
tt::tt_metal::LegacyShape matmul_output_shapes = this->matmul_struct.compute_output_shapes({input_tensors[1], input_tensors[2]})[0];
ttnn::SimpleShape matmul_output_shapes = this->matmul_struct.compute_output_shapes({input_tensors[1], input_tensors[2]})[0];

return {all_gather_output_shape, matmul_output_shapes, datacopy_output_shape};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct AllGatherMatmul {

/* General */
void validate(const std::vector<Tensor> &input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors,
Expand Down
16 changes: 5 additions & 11 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
using namespace tt;
using namespace tt::constants;
using namespace tt::tt_metal;
using tt::tt_metal::LegacyShape;
using ttnn::operations::unary::UnaryWithParam;

namespace {
Expand Down Expand Up @@ -1325,29 +1324,24 @@ void Matmul::validate(
chosen_program_config);
}

std::vector<tt::tt_metal::LegacyShape> Matmul::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
const tt::tt_metal::LegacyShape& input_shape_a = input_tensors.at(0).get_legacy_shape();
const tt::tt_metal::LegacyShape& input_shape_b = input_tensors.at(1).get_legacy_shape();
std::vector<ttnn::SimpleShape> Matmul::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
const ttnn::SimpleShape input_shape_a = input_tensors.at(0).get_logical_shape();
const ttnn::SimpleShape input_shape_b = input_tensors.at(1).get_logical_shape();
const uint32_t a_rank = input_shape_a.rank();
const uint32_t b_rank = input_shape_b.rank();
const uint32_t out_rank = std::max(a_rank, b_rank);
const uint32_t rank_difference = out_rank - a_rank;
tt::tt_metal::LegacyShape output_shape = (b_rank > a_rank) ? input_shape_b : input_shape_a;
auto dimensions_pads = std::vector<Padding::PadDimension>();
ttnn::SimpleShape output_shape = (b_rank > a_rank) ? input_shape_b : input_shape_a;

for (auto index = 0; index < rank_difference; index++) {
TT_FATAL(input_shape_b[index] == 1, "When in1 rank greater than in0 rank front dimensions need to be 1");
output_shape[index] = input_shape_b[index];
dimensions_pads.push_back(input_shape_b.padding()[index]);
}
for (auto index = 0; index < a_rank - 1; index++) {
output_shape[rank_difference + index] = input_shape_a[index];
dimensions_pads.push_back(input_shape_a.padding()[index]);
}
output_shape[-1] = input_shape_b[-1];
dimensions_pads.push_back(input_shape_b.padding()[b_rank - 1]);
const auto padding = Padding(dimensions_pads, Padding::PadValue::Any);
return {tt::tt_metal::LegacyShape(output_shape, padding)};
return {std::move(output_shape)};
}

std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ struct Matmul {
void validate(
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes_dram_sharded(
const std::vector<Tensor> &input_tensors, uint32_t N_unpadded) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
Expand Down

0 comments on commit c0c23ba

Please sign in to comment.