diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index d1cb04375a7c49..ce7803978a05f4 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -2660,12 +2660,21 @@ bool primitive_inst::is_valid_fusion() const { // Check if broadcast happens more than single axis. // Current gemm_tiled_opt kernel FUSED_OP_LOAD macro cannot support broadcast on dynamic dimension. - if (_node->is_type() && can_broadcast == true && merged_shape.rank().get_length() == outer_dep_pshape.rank().get_length()) { + if (_node->is_type() && can_broadcast == true && merged_shape.rank().get_length() >= outer_dep_pshape.rank().get_length()) { uint8_t broadcast_more_than_single_axis = 0; + auto updated_outer_dep_pshape = ov::PartialShape(outer_dep_pshape); + + // Update outer_dep_pshape to merged_shape rank + if (merged_shape.rank().get_length() > outer_dep_pshape.rank().get_length()) { + updated_outer_dep_pshape.insert(updated_outer_dep_pshape.begin(), + merged_shape.rank().get_length() - outer_dep_pshape.rank().get_length(), ov::Dimension(1)); + } + for (int64_t i = 0; i < merged_shape.rank().get_length(); i++) { - if (merged_shape.get_shape().at(i) != outer_dep_pshape.get_shape().at(i)) + if (merged_shape[i] != updated_outer_dep_pshape[i]) broadcast_more_than_single_axis++; } + if (broadcast_more_than_single_axis > 1) can_broadcast = false; } diff --git a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/dynamic_unfusion.cpp b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/dynamic_unfusion.cpp index 1cc079a10b82f6..04c053c6dbd1c3 100644 --- a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/dynamic_unfusion.cpp +++ b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/dynamic_unfusion.cpp @@ -15,6 +15,8 @@ namespace { using ov::test::InputShape; using DynamicUnfusionsParams = std::tuple, // input shapes + bool, // Matmul transpose a + bool, // Matmul transpose b ov::element::Type>; // input precision class DynamicUnfusions : public testing::WithParamInterface, @@ -22,9 +24,11 @@ class DynamicUnfusions : public testing::WithParamInterface obj) { std::vector input_shapes; + bool transpose_a; + bool transpose_b; ov::element::Type input_precision; - std::tie(input_shapes, input_precision) = obj.param; + std::tie(input_shapes, transpose_a, transpose_b, input_precision) = obj.param; std::ostringstream result; result << "IS=("; @@ -42,18 +46,22 @@ class DynamicUnfusions : public testing::WithParamInterface input_shapes; + bool transpose_a; + bool transpose_b; ov::element::Type input_precision; - std::tie(input_shapes, input_precision) = GetParam(); + std::tie(input_shapes, transpose_a, transpose_b, input_precision) = GetParam(); init_input_shapes(input_shapes); inType = outType = input_precision; - function = init_subgraph(inputDynamicShapes, input_precision); + function = init_subgraph(inputDynamicShapes, transpose_a, transpose_b, input_precision); } }; @@ -83,13 +93,28 @@ TEST_P(DynamicUnfusions, Inference) { const std::vector input_precisions = {ov::element::f32}; -const std::vector> input_shapes_dyn = { +const std::vector> input_shapes_same_rank_fusing_dyn = { {{{1024, -1}, {{1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1}}}}, + {{{1024, -1}, {{1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1024}}}}, }; +const std::vector> input_shapes_diff_rank_fusing_dyn = { + {{{1024, -1}, {{1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1}}}}, + {{{-1, -1, 1024}, {{1, 1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1024}}}}, +}; + +INSTANTIATE_TEST_SUITE_P(DynamicUnfusions_basic_same_rank, + DynamicUnfusions, + ::testing::Combine(::testing::ValuesIn(input_shapes_same_rank_fusing_dyn), + ::testing::Values(false), + ::testing::Values(false), + ::testing::ValuesIn(input_precisions)), + DynamicUnfusions::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(DynamicUnfusions_basic, +INSTANTIATE_TEST_SUITE_P(DynamicUnfusions_basic_diff_rank, DynamicUnfusions, - ::testing::Combine(::testing::ValuesIn(input_shapes_dyn), + ::testing::Combine(::testing::ValuesIn(input_shapes_diff_rank_fusing_dyn), + ::testing::Values(false), + ::testing::Values(true), ::testing::ValuesIn(input_precisions)), DynamicUnfusions::getTestCaseName); } // namespace