Skip to content

Commit

Permalink
mlir: Add casted views.
Browse files Browse the repository at this point in the history
  • Loading branch information
xlauko committed Jul 1, 2024
1 parent a765d86 commit fbb640f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
14 changes: 14 additions & 0 deletions mlir/include/gap/mlir/views.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,24 @@
#endif
#endif

#include <gap/mlir/functors.hpp>

namespace gap::mlir::views
{
auto regions(operation op) -> decltype(std::views::all(op->getRegions()));
auto blocks(operation op) -> decltype(std::views::join(regions(op)));
auto operations(operation op) -> decltype(std::views::join(blocks(op)));

template< typename T >
auto isa = std::views::filter(::gap::mlir::isa< T >);

template< typename T >
auto cast = std::views::transform(::gap::mlir::cast< T >);

template< typename T >
auto dyn_cast = std::views::transform(::gap::mlir::dyn_cast< T >);

template< typename T >
auto filter_cast = isa< T > | cast< T >;

} // namespace gap::mlir::views
12 changes: 12 additions & 0 deletions test/mlir/views.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace gap::test
};

using namespace gap::mlir::views;

TEST_SUITE("views") {

TEST_CASE_FIXTURE(MLIRTestFixture, "operation views") {
Expand All @@ -57,6 +58,17 @@ namespace gap::test
CHECK(std::ranges::distance(operations(mod->getOperation())) == 1);
}

TEST_CASE_FIXTURE(MLIRTestFixture, "cast views") {
auto rets = operations(fn) | isa< ::mlir::func::ReturnOp >;
CHECK(std::ranges::distance(rets) == 3);

auto rets_casted = operations(fn) | filter_cast< ::mlir::func::ReturnOp >;
CHECK(std::ranges::distance(rets_casted) == 3);

auto fns = operations(mod->getOperation()) | isa< ::mlir::func::FuncOp >;
CHECK(std::ranges::distance(fns) == 1);
}

} // TEST_SUITE("views")

} // namespace gap::test

0 comments on commit fbb640f

Please sign in to comment.