diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index ee069d882e2..b9be778a908 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -112,7 +112,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::TensorShape visit(const luci::CirclePow *node) final; // loco::TensorShape visit(const luci::CirclePRelu *node) final; loco::TensorShape visit(const luci::CircleQuantize *node) final; - // loco::TensorShape visit(const luci::CircleRange *node) final; + loco::TensorShape visit(const luci::CircleRange *node) final; // loco::TensorShape visit(const luci::CircleRank *node) final; // loco::TensorShape visit(const luci::CircleReduceAny *node) final; // loco::TensorShape visit(const luci::CircleReduceMax *node) final; diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 6ae15f70494..4fa8405a5bc 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -911,49 +911,6 @@ loco::NodeShape infer_p_relu(const luci::CirclePRelu *node) return loco::NodeShape{output_shape}; } -loco::NodeShape infer_range(const luci::CircleRange *node) -{ - loco::TensorShape output_shape; - output_shape.rank(1); - - auto start_node = dynamic_cast(node->start()); - auto limit_node = dynamic_cast(node->limit()); - auto delta_node = dynamic_cast(node->delta()); - - if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr) - { - return use_own(node); - } - - double start = 0, limit = 0, delta = 0; - -#define GET_RANGE_PARAM(DT) \ - start = start_node->scalar
(); \ - limit = limit_node->scalar
(); \ - delta = delta_node->scalar
(); - - switch (start_node->dtype()) - { - case loco::DataType::FLOAT32: - GET_RANGE_PARAM(loco::DataType::FLOAT32) - break; - case loco::DataType::S32: - GET_RANGE_PARAM(loco::DataType::S32) - break; - default: - INTERNAL_EXN("Range data type not supported"); - } - -#undef GET_RANGE_PARAM - - if (delta == 0) - INTERNAL_EXN("Delta can not be zero"); - - output_shape.dim(0) = ceil((limit - start) / delta); - - return loco::NodeShape{output_shape}; -} - loco::NodeShape infer_reshape(const luci::CircleReshape *node) { LOGGER(l); @@ -2104,8 +2061,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor namespace luci { @@ -24,4 +29,66 @@ luci::CircleNode *CloneNodeLet::visit(const luci::CircleRange *) return _graph->nodes()->create(); } +namespace sinf +{ + +loco::TensorShape Algorithm::visit(const luci::CircleRange *node) +{ + loco::TensorShape output_shape; + output_shape.rank(1); + + auto start_node = dynamic_cast(node->start()); + auto limit_node = dynamic_cast(node->limit()); + auto delta_node = dynamic_cast(node->delta()); + + if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr) + { + // We use shape from the node itself + loco::TensorShape shape; + shape.rank(node->rank()); + for (uint32_t r = 0; r < node->rank(); ++r) + { + // TODO remove this copy from `use_own(node);` + // Shape inference rules in this file did not consider unknown dimension. + // If some node has unknown dimension, 0 is inserted and wrong shape + // inference was done as a result. + // To fix this, new shape inference algorithm is being implemented. + // Until new inference algorithm is fully implemented, unknown dimension + // would be represented as 1 along with TFLite expression. + shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1; + } + return shape; + } + + double start = 0, limit = 0, delta = 0; + +#define GET_RANGE_PARAM(DT) \ + start = start_node->scalar
(); \ + limit = limit_node->scalar
(); \ + delta = delta_node->scalar
(); + + switch (start_node->dtype()) + { + case loco::DataType::FLOAT32: + GET_RANGE_PARAM(loco::DataType::FLOAT32) + break; + case loco::DataType::S32: + GET_RANGE_PARAM(loco::DataType::S32) + break; + default: + INTERNAL_EXN("Range data type not supported"); + } + +#undef GET_RANGE_PARAM + + if (delta == 0) + INTERNAL_EXN("Delta can not be zero"); + + output_shape.dim(0) = ceil((limit - start) / delta); + + return output_shape; +} + +} // namespace sinf + } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRange.test.cpp b/compiler/luci/service/src/Nodes/CircleRange.test.cpp index b2fb296177a..b67d287d1ab 100644 --- a/compiler/luci/service/src/Nodes/CircleRange.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleRange.test.cpp @@ -15,6 +15,7 @@ */ #include "luci/Service/CircleNodeClone.h" +#include "luci/Service/CircleShapeInference.h" #include @@ -31,3 +32,66 @@ TEST(CloneNodeTest, clone_Range) auto cloned_range = dynamic_cast(cloned); ASSERT_NE(nullptr, cloned_range); } + +TEST(ShapeRuleTest, range_const_param) +{ + luci::CircleConst start, limit, delta; + luci::CircleRange range; + + start.dtype(loco::DataType::S32); + start.size(1); + start.at(0) = 0; + start.shape_status(luci::ShapeStatus::VALID); + + limit.dtype(loco::DataType::S32); + limit.size(1); + limit.at(0) = 10; + limit.shape_status(luci::ShapeStatus::VALID); + + delta.dtype(loco::DataType::S32); + delta.size(1); + delta.at(0) = 2; + delta.shape_status(luci::ShapeStatus::VALID); + + range.start(&start); + range.limit(&limit); + range.delta(&delta); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&range, shape)); + ASSERT_EQ(1, shape.rank()); + ASSERT_TRUE(shape.dim(0).known()); + ASSERT_EQ(5, shape.dim(0).value()); +} + +TEST(ShapeRuleTest, range_zero_delta_NEG) +{ + luci::CircleConst start, limit, delta; + luci::CircleRange range; + + start.dtype(loco::DataType::S32); + start.size(1); + start.at(0) = 0; + start.shape_status(luci::ShapeStatus::VALID); + + limit.dtype(loco::DataType::S32); + limit.size(1); + limit.at(0) = 10; + limit.shape_status(luci::ShapeStatus::VALID); + + delta.dtype(loco::DataType::S32); + delta.size(1); + delta.at(0) = 0; + delta.shape_status(luci::ShapeStatus::VALID); + + range.start(&start); + range.limit(&limit); + range.delta(&delta); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_ANY_THROW(shape_inf_rule.infer(&range, shape)); +}