Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Add different rank case in is_valid_fusion() for gemm - eltwise unfusing #28309

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2660,12 +2660,21 @@ bool primitive_inst::is_valid_fusion() const {

// Check if broadcast happens more than single axis.
// Current gemm_tiled_opt kernel FUSED_OP_LOAD macro cannot support broadcast on dynamic dimension.
if (_node->is_type<gemm>() && can_broadcast == true && merged_shape.rank().get_length() == outer_dep_pshape.rank().get_length()) {
if (_node->is_type<gemm>() && can_broadcast == true && merged_shape.rank().get_length() >= outer_dep_pshape.rank().get_length()) {
uint8_t broadcast_more_than_single_axis = 0;
auto updated_outer_dep_pshape = ov::PartialShape(outer_dep_pshape);

// Update outer_dep_pshape to merged_shape rank
if (merged_shape.rank().get_length() > outer_dep_pshape.rank().get_length()) {
updated_outer_dep_pshape.insert(updated_outer_dep_pshape.begin(),
merged_shape.rank().get_length() - outer_dep_pshape.rank().get_length(), ov::Dimension(1));
}

for (int64_t i = 0; i < merged_shape.rank().get_length(); i++) {
if (merged_shape.get_shape().at(i) != outer_dep_pshape.get_shape().at(i))
if (merged_shape[i] != updated_outer_dep_pshape[i])
broadcast_more_than_single_axis++;
}

if (broadcast_more_than_single_axis > 1)
can_broadcast = false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@ namespace {
using ov::test::InputShape;

using DynamicUnfusionsParams = std::tuple<std::vector<InputShape>, // input shapes
bool, // Matmul transpose a
bool, // Matmul transpose b
ov::element::Type>; // input precision

class DynamicUnfusions : public testing::WithParamInterface<DynamicUnfusionsParams>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<DynamicUnfusionsParams> obj) {
std::vector<InputShape> input_shapes;
bool transpose_a;
bool transpose_b;
ov::element::Type input_precision;

std::tie(input_shapes, input_precision) = obj.param;
std::tie(input_shapes, transpose_a, transpose_b, input_precision) = obj.param;

std::ostringstream result;
result << "IS=(";
Expand All @@ -42,18 +46,22 @@ class DynamicUnfusions : public testing::WithParamInterface<DynamicUnfusionsPara
}
result << ")_";
}
result << "transpose_a=" << transpose_a << "_";
result << "transpose_b=" << transpose_b << "_";
result << "input_precision=" << input_precision;
return result.str();
}

protected:
std::shared_ptr<ov::Model> init_subgraph(std::vector<ov::PartialShape>& input_shapes,
bool transpose_a,
bool transpose_b,
const ov::element::Type input_precision) {
auto input0 = std::make_shared<ov::op::v0::Parameter>(input_precision, input_shapes[0]);
auto input1 = std::make_shared<ov::op::v0::Parameter>(input_precision, input_shapes[1]);
auto input2 = std::make_shared<ov::op::v0::Parameter>(input_precision, input_shapes[2]);

auto matmul = std::make_shared<ov::op::v0::MatMul>(input0, input1);
auto matmul = std::make_shared<ov::op::v0::MatMul>(input0, input1, transpose_a, transpose_b);
auto mul = std::make_shared<ov::op::v1::Multiply>(matmul, input2);

matmul->set_friendly_name("MatMul");
Expand All @@ -66,14 +74,16 @@ class DynamicUnfusions : public testing::WithParamInterface<DynamicUnfusionsPara
targetDevice = ov::test::utils::DEVICE_GPU;

std::vector<InputShape> input_shapes;
bool transpose_a;
bool transpose_b;
ov::element::Type input_precision;

std::tie(input_shapes, input_precision) = GetParam();
std::tie(input_shapes, transpose_a, transpose_b, input_precision) = GetParam();

init_input_shapes(input_shapes);

inType = outType = input_precision;
function = init_subgraph(inputDynamicShapes, input_precision);
function = init_subgraph(inputDynamicShapes, transpose_a, transpose_b, input_precision);
}
};

Expand All @@ -83,13 +93,28 @@ TEST_P(DynamicUnfusions, Inference) {

const std::vector<ov::element::Type> input_precisions = {ov::element::f32};

const std::vector<std::vector<InputShape>> input_shapes_dyn = {
const std::vector<std::vector<InputShape>> input_shapes_same_rank_fusing_dyn = {
{{{1024, -1}, {{1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1}}}},
{{{1024, -1}, {{1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1024}}}},
};
const std::vector<std::vector<InputShape>> input_shapes_diff_rank_fusing_dyn = {
{{{1024, -1}, {{1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1}}}},
{{{-1, -1, 1024}, {{1, 1024, 1024}}}, {{-1, 1024}, {{1024, 1024}}}, {{1, -1}, {{1, 1024}}}},
};

INSTANTIATE_TEST_SUITE_P(DynamicUnfusions_basic_same_rank,
DynamicUnfusions,
::testing::Combine(::testing::ValuesIn(input_shapes_same_rank_fusing_dyn),
::testing::Values(false),
::testing::Values(false),
::testing::ValuesIn(input_precisions)),
DynamicUnfusions::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(DynamicUnfusions_basic,
INSTANTIATE_TEST_SUITE_P(DynamicUnfusions_basic_diff_rank,
DynamicUnfusions,
::testing::Combine(::testing::ValuesIn(input_shapes_dyn),
::testing::Combine(::testing::ValuesIn(input_shapes_diff_rank_fusing_dyn),
::testing::Values(false),
::testing::Values(true),
::testing::ValuesIn(input_precisions)),
DynamicUnfusions::getTestCaseName);
} // namespace
Loading