Skip to content

Commit

Permalink
core: add ability to parse MLIR IR that drops dialect names from oper…
Browse files Browse the repository at this point in the history
…ations (#1840)

Only parsing, no printing. This unlocks us being able to parse MLIR IR
directly output from the `mlir-opt` command, without generic syntax or
"local scope". I needed this to unlock our experiments for the linalg to
frep project, since we use MLIR as the input, and our version of MLIR
had different names for some properties, which messed up the generic
syntax. I think that the main xdsl project would benefit from these
changes also.

In my understanding, this is more permissive than MLIR itself, that has
specific rules for which ops can drop names and which ones cannot. I'm
not sure whether we want to be as strict as them on this.
  • Loading branch information
superlopuh authored Dec 6, 2023
1 parent add5935 commit bc1b0b0
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 1 deletion.
13 changes: 13 additions & 0 deletions tests/filecheck/mlir-conversion/with-mlir/scope.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt %s | xdsl-opt | filecheck %s

module {
func.func public @my_func() {
return
}
}

// CHECK: builtin.module {
// CHECK-NEXT: func.func public @my_func() {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
20 changes: 20 additions & 0 deletions tests/filecheck/parser-printer/scope.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP

module {
func.func public @my_func() {
return
}
}

// CHECK: builtin.module {
// CHECK-NEXT: func.func public @my_func() {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK-GENERIC: "builtin.module"() ({
// CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "my_func", "function_type" = () -> (), "sym_visibility" = "public"}> ({
// CHECK-GENERIC-NEXT: "func.return"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
7 changes: 7 additions & 0 deletions tests/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,10 @@ def test_get_attr_or_prop():
assert a.get_attr_or_prop("prop") == StringAttr("prop")
assert a.get_attr_or_prop("attr_and_prop") == StringAttr("prop")
assert a.get_attr_or_prop("none") is None


def test_dialect_name():
class MyOperation(Operation):
name = "dialect.op"

assert MyOperation.dialect_name() == "dialect"
4 changes: 4 additions & 0 deletions xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,10 @@ def emit_error(
diagnostic.add_message(self, message)
diagnostic.raise_exception(message, self, exception_type, underlying_error)

@classmethod
def dialect_name(cls) -> str:
return cls.name.split(".")[0]

def __eq__(self, other: object) -> bool:
return self is other

Expand Down
6 changes: 5 additions & 1 deletion xdsl/parser/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ class ParserState:

lexer: Lexer
current_token: Token
dialect_stack: list[str]

def __init__(self, lexer: Lexer):
def __init__(self, lexer: Lexer, dialect_stack: list[str] | None = None):
if dialect_stack is None:
dialect_stack = ["builtin"]
self.lexer = lexer
self.current_token = lexer.lex()
self.dialect_stack = dialect_stack


_AnyInvT = TypeVar("_AnyInvT")
Expand Down
11 changes: 11 additions & 0 deletions xdsl/parser/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,14 +632,20 @@ def parse_operation(self) -> Operation:
if (op_name := self._parse_optional_token(Token.Kind.BARE_IDENT)) is not None:
# Custom operation format
op_type = self._get_op_by_name(op_name.text)
dialect_name = op_type.dialect_name()
self._parser_state.dialect_stack.append(dialect_name)
op = op_type.parse(self)
self._parser_state.dialect_stack.pop()
else:
# Generic operation format
op_name = self.expect(
self.parse_optional_str_literal, "operation name expected"
)
op_type = self._get_op_by_name(op_name)
dialect_name = op_type.dialect_name()
self._parser_state.dialect_stack.append(dialect_name)
op = self._parse_generic_operation(op_type)
self._parser_state.dialect_stack.pop()

n_bound_results = sum(r[1] for r in bound_results)
if (n_bound_results != 0) and (len(op.results) != n_bound_results):
Expand Down Expand Up @@ -670,6 +676,11 @@ def _get_op_by_name(self, name: str) -> type[Operation]:
if op_type is not None:
return op_type

for dialect_name in reversed(self._parser_state.dialect_stack):
op_type = self.ctx.get_optional_op(f"{dialect_name}.{name}")
if op_type is not None:
return op_type

self.raise_error(f"unregistered operation {name}!")

def _parse_op_result(self) -> tuple[Span, int]:
Expand Down

0 comments on commit bc1b0b0

Please sign in to comment.