Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft : [logo/pass] Refactor RemoveDeadNodeWithQueryPass #14051

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/circle2circle-dredd-recipe-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,5 @@ Add(REGRESS_ONNX_Conv_BN_Relu6_001 PASS

Add(REGRESS_ONNX_Mul_Mul_000 PASS
convert_nchw_to_nhwc)

Add(RemoveDeadNodeWithQueryPass_000 PASS substitute_pack_to_reshape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should change test title

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not sure if this title is correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think REGRESS_Issue_13863 would be ok.

10 changes: 6 additions & 4 deletions compiler/logo/src/Passes/RemoveDeadNodeWithQueryPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,17 @@ bool RemoveDeadNodeWithQueryPass::run(loco::Graph *g)
}

// Find the nodes that should not be dead node in candidates
for (auto node : candidates)
for (auto it = candidates.begin(); it != candidates.end();)
{
if (auto service = node->dialect()->service<DeadNodeQueryService>())
if (auto service = (*it)->dialect()->service<DeadNodeQueryService>())
{
if (!service->isDeadNode(node))
if (!service->isDeadNode(*it))
{
candidates.erase(node);
it = candidates.erase(it);
continue;
}
}
++it;
}

for (auto node : candidates)
Expand Down
252 changes: 252 additions & 0 deletions res/TensorFlowLiteRecipes/RemoveDeadNodeWithQueryPass_000/test.recipe
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
operand {
name: "serving_default_input:0"
type: FLOAT32
shape {
dim: 1
dim: 32
dim: 32
dim: 3
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "Const"
type: FLOAT32
shape {
}
filler {
tag: "explicit"
arg: "2"
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "Const_1"
type: FLOAT32
shape {
}
filler {
tag: "explicit"
arg: "4"
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "model/tf.split/split/split_dim"
type: INT32
shape {
}
filler {
tag: "explicit"
arg: "1"
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "model/flatten/Const"
type: INT32
shape {
dim: 2
}
filler {
tag: "explicit"
arg: "-1"
arg: "3072"
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "PartitionedCall:3"
type: FLOAT32
shape {
dim: 1
dim: 1
dim: 32
dim: 32
dim: 3
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "model/flatten/Reshape"
type: FLOAT32
shape {
dim: 1
dim: 3072
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "PartitionedCall:0"
type: FLOAT32
shape {
dim: 1
dim: 1024
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "model/tf.split/split"
type: FLOAT32
shape {
dim: 1
dim: 1024
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "model/tf.split/split1"
type: FLOAT32
shape {
dim: 1
dim: 1024
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "model/tf.compat.v1.math.scalar_mul_1/Mul"
type: FLOAT32
shape {
dim: 1
dim: 3072
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "PartitionedCall:2"
type: FLOAT32
shape {
dim: 1
dim: 1
dim: 3072
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "model/tf.compat.v1.math.scalar_mul/Mul"
type: FLOAT32
shape {
dim: 1
dim: 3072
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operand {
name: "PartitionedCall:1"
type: FLOAT32
shape {
dim: 1
dim: 1
dim: 3072
}
quant {
quantized_dimension: 0
}
is_variable: false
}
operation {
type: "Pack"
input: "serving_default_input:0"
output: "PartitionedCall:3"
pack_options {
values_count: 1
axis: 0
}
}
operation {
type: "Reshape"
input: "serving_default_input:0"
input: "model/flatten/Const"
output: "model/flatten/Reshape"
}
operation {
type: "Split"
input: "model/tf.split/split/split_dim"
input: "model/flatten/Reshape"
output: "PartitionedCall:0"
output: "model/tf.split/split"
output: "model/tf.split/split1"
split_options {
num_splits: 3
}
}
operation {
type: "Mul"
input: "model/flatten/Reshape"
input: "Const_1"
output: "model/tf.compat.v1.math.scalar_mul_1/Mul"
mul_options {
activation: NONE
}
}
operation {
type: "Pack"
input: "model/tf.compat.v1.math.scalar_mul_1/Mul"
output: "PartitionedCall:2"
pack_options {
values_count: 1
axis: 0
}
}
operation {
type: "Mul"
input: "model/flatten/Reshape"
input: "Const"
output: "model/tf.compat.v1.math.scalar_mul/Mul"
mul_options {
activation: NONE
}
}
operation {
type: "Pack"
input: "model/tf.compat.v1.math.scalar_mul/Mul"
output: "PartitionedCall:1"
pack_options {
values_count: 1
axis: 0
}
}
input: "serving_default_input:0"
output: "PartitionedCall:2"
output: "PartitionedCall:1"
output: "PartitionedCall:3"
output: "PartitionedCall:0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Verify that the pack operation has been successfully removed
# Check that the reshape operation exists (substitute_pack_to_reshape pass applied)

RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1

RULE "NO_PACK" $(op_count PACK) '=' 0
RULE "RESHAPE_EXIST" $(op_count RESHAPE) '=' 4