Skip to content

Commit

Permalink
[GPU] Add different rank case in is_valid_fusion() for gemm - eltwise…
Browse files Browse the repository at this point in the history
… unfusing (#28309)

### Details:
- Add different rank case in is_valid_fusion() for gemm - eltwise
unfusing for below case
 

![image](https://github.com/user-attachments/assets/8d71c1a7-7cf2-46a7-b03f-920045ba9165)


### Tickets:
 - 160213
  • Loading branch information
wilson-seok authored Jan 9, 2025
1 parent 7618310 commit d59210e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
13 changes: 11 additions & 2 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2660,12 +2660,21 @@ bool primitive_inst::is_valid_fusion() const {

// Check if broadcast happens more than single axis.
// Current gemm_tiled_opt kernel FUSED_OP_LOAD macro cannot support broadcast on dynamic dimension.
if (_node->is_type<gemm>() && can_broadcast == true && merged_shape.rank().get_length() == outer_dep_pshape.rank().get_length()) {
if (_node->is_type<gemm>() && can_broadcast == true && merged_shape.rank().get_length() >= outer_dep_pshape.rank().get_length()) {
uint8_t broadcast_more_than_single_axis = 0;
auto updated_outer_dep_pshape = ov::PartialShape(outer_dep_pshape);

// Update outer_dep_pshape to merged_shape rank
if (merged_shape.rank().get_length() > outer_dep_pshape.rank().get_length()) {
updated_outer_dep_pshape.insert(updated_outer_dep_pshape.begin(),
merged_shape.rank().get_length() - outer_dep_pshape.rank().get_length(), ov::Dimension(1));
}

for (int64_t i = 0; i < merged_shape.rank().get_length(); i++) {
if (merged_shape.get_shape().at(i) != outer_dep_pshape.get_shape().at(i))
if (merged_shape[i] != updated_outer_dep_pshape[i])
broadcast_more_than_single_axis++;
}

if (broadcast_more_than_single_axis > 1)
can_broadcast = false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@ namespace {
using ov::test::InputShape;

using DynamicUnfusionsParams = std::tuple<std::vector<InputShape>, // input shapes
bool, // Matmul transpose a
bool, // Matmul transpose b
ov::element::Type>; // input precision

class DynamicUnfusions : public testing::WithParamInterface<DynamicUnfusionsParams>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<DynamicUnfusionsParams> obj) {
std::vector<InputShape> input_shapes;
bool transpose_a;
bool transpose_b;
ov::element::Type input_precision;

std::tie(input_shapes, input_precision) = obj.param;
std::tie(input_shapes, transpose_a, transpose_b, input_precision) = obj.param;

std::ostringstream result;
result << "IS=(";
Expand All @@ -42,18 +46,22 @@ class DynamicUnfusions : public testing::WithParamInterface<DynamicUnfusionsPara
}
result << ")_";
}
result << "transpose_a=" << transpose_a << "_";
result << "transpose_b=" << transpose_b << "_";
result << "input_precision=" << input_precision;
return result.str();
}

protected:
std::shared_ptr<ov::Model> init_subgraph(std::vector<ov::PartialShape>& input_shapes,
bool transpose_a,
bool transpose_b,
const ov::element::Type input_precision) {
auto input0 = std::make_shared<ov::op::v0::Parameter>(input_precision, input_shapes[0]);
auto input1 = std::make_shared<ov::op::v0::Parameter>(input_precision, input_shapes[1]);
auto input2 = std::make_shared<ov::op::v0::Parameter>(input_precision, input_shapes[2]);

auto matmul = std::make_shared<ov::op::v0::MatMul>(input0, input1);
auto matmul = std::make_shared<ov::op::v0::MatMul>(input0, input1, transpose_a, transpose_b);
auto mul = std::make_shared<ov::op::v1::Multiply>(matmul, input2);

matmul->set_friendly_name("MatMul");
Expand All @@ -66,14 +74,16 @@ class DynamicUnfusions : public testing::WithParamInterface<DynamicUnfusionsPara
targetDevice = ov::test::utils::DEVICE_GPU;

std::vector<InputShape> input_shapes;
bool transpose_a;
bool transpose_b;
ov::element::Type input_precision;

std::tie(input_shapes, input_precision) = GetParam();
std::tie(input_shapes, transpose_a, transpose_b, input_precision) = GetParam();

init_input_shapes(input_shapes);

inType = outType = input_precision;
function = init_subgraph(inputDynamicShapes, input_precision);
function = init_subgraph(inputDynamicShapes, transpose_a, transpose_b, input_precision);
}
};

Expand All @@ -83,13 +93,28 @@ TEST_P(DynamicUnfusions, Inference) {

const std::vector<ov::element::Type> input_precisions = {ov::element::f32};

const std::vector<std::vector<InputShape>> input_shapes_dyn = {
const std::vector<std::vector<InputShape>> input_shapes_same_rank_fusing_dyn = {
{{{1024, -1}, {{1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1}}}},
{{{1024, -1}, {{1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1024}}}},
};
const std::vector<std::vector<InputShape>> input_shapes_diff_rank_fusing_dyn = {
{{{1024, -1}, {{1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1}}}},
{{{-1, -1, 1024}, {{1, 1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1024}}}},
};

INSTANTIATE_TEST_SUITE_P(DynamicUnfusions_basic_same_rank,
DynamicUnfusions,
::testing::Combine(::testing::ValuesIn(input_shapes_same_rank_fusing_dyn),
::testing::Values(false),
::testing::Values(false),
::testing::ValuesIn(input_precisions)),
DynamicUnfusions::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(DynamicUnfusions_basic,
INSTANTIATE_TEST_SUITE_P(DynamicUnfusions_basic_diff_rank,
DynamicUnfusions,
::testing::Combine(::testing::ValuesIn(input_shapes_dyn),
::testing::Combine(::testing::ValuesIn(input_shapes_diff_rank_fusing_dyn),
::testing::Values(false),
::testing::Values(true),
::testing::ValuesIn(input_precisions)),
DynamicUnfusions::getTestCaseName);
} // namespace

0 comments on commit d59210e

Please sign in to comment.