diff --git a/mlir/include/gap/mlir/views.hpp b/mlir/include/gap/mlir/views.hpp index 14fbd7b..9dc05b6 100644 --- a/mlir/include/gap/mlir/views.hpp +++ b/mlir/include/gap/mlir/views.hpp @@ -19,10 +19,24 @@ #endif #endif +#include + namespace gap::mlir::views { auto regions(operation op) -> decltype(std::views::all(op->getRegions())); auto blocks(operation op) -> decltype(std::views::join(regions(op))); auto operations(operation op) -> decltype(std::views::join(blocks(op))); + template< typename T > + auto isa = std::views::filter(::gap::mlir::isa< T >); + + template< typename T > + auto cast = std::views::transform(::gap::mlir::cast< T >); + + template< typename T > + auto dyn_cast = std::views::transform(::gap::mlir::dyn_cast< T >); + + template< typename T > + auto filter_cast = isa< T > | cast< T >; + } // namespace gap::mlir::views \ No newline at end of file diff --git a/test/mlir/views.cpp b/test/mlir/views.cpp index 69609b7..43870b6 100644 --- a/test/mlir/views.cpp +++ b/test/mlir/views.cpp @@ -45,6 +45,7 @@ namespace gap::test }; using namespace gap::mlir::views; + TEST_SUITE("views") { TEST_CASE_FIXTURE(MLIRTestFixture, "operation views") { @@ -57,6 +58,17 @@ namespace gap::test CHECK(std::ranges::distance(operations(mod->getOperation())) == 1); } + TEST_CASE_FIXTURE(MLIRTestFixture, "cast views") { + auto rets = operations(fn) | isa< ::mlir::func::ReturnOp >; + CHECK(std::ranges::distance(rets) == 3); + + auto rets_casted = operations(fn) | filter_cast< ::mlir::func::ReturnOp >; + CHECK(std::ranges::distance(rets_casted) == 3); + + auto fns = operations(mod->getOperation()) | isa< ::mlir::func::FuncOp >; + CHECK(std::ranges::distance(fns) == 1); + } + } // TEST_SUITE("views") } // namespace gap::test