diff --git a/bench/parser/bench_lexer.py b/bench/parser/bench_lexer.py index ffa1c258c8..16f6cd36c9 100644 --- a/bench/parser/bench_lexer.py +++ b/bench/parser/bench_lexer.py @@ -11,15 +11,15 @@ from collections.abc import Iterable from xdsl.utils.lexer import Input -from xdsl.utils.mlir_lexer import Kind, Lexer +from xdsl.utils.mlir_lexer import MLIRLexer, MLIRTokenKind def lex_file(file: Input): """ Lex the given file """ - lexer = Lexer(file) - while lexer.lex().kind is not Kind.EOF: + lexer = MLIRLexer(file) + while lexer.lex().kind is not MLIRTokenKind.EOF: pass diff --git a/tests/test_lexer.py b/tests/test_lexer.py index 716cfa23a5..5f728ae5cd 100644 --- a/tests/test_lexer.py +++ b/tests/test_lexer.py @@ -2,18 +2,18 @@ from xdsl.utils.exceptions import ParseError from xdsl.utils.lexer import Input -from xdsl.utils.mlir_lexer import Kind, Lexer, Token +from xdsl.utils.mlir_lexer import MLIRLexer, MLIRToken, MLIRTokenKind -def get_token(input: str) -> Token: +def get_token(input: str) -> MLIRToken: file = Input(input, "") - lexer = Lexer(file) + lexer = MLIRLexer(file) token = lexer.lex() return token def assert_single_token( - input: str, expected_kind: Kind, expected_text: str | None = None + input: str, expected_kind: MLIRTokenKind, expected_text: str | None = None ): if expected_text is None: expected_text = input @@ -26,7 +26,7 @@ def assert_single_token( def assert_token_fail(input: str): file = Input(input, "") - lexer = Lexer(file) + lexer = MLIRLexer(file) with pytest.raises(ParseError): lexer.lex() @@ -34,29 +34,29 @@ def assert_token_fail(input: str): @pytest.mark.parametrize( "text,kind", [ - ("->", Kind.ARROW), - (":", Kind.COLON), - (",", Kind.COMMA), - ("...", Kind.ELLIPSIS), - ("=", Kind.EQUAL), - (">", Kind.GREATER), - ("{", Kind.L_BRACE), - ("(", Kind.L_PAREN), - ("[", Kind.L_SQUARE), - ("<", Kind.LESS), - ("-", Kind.MINUS), - ("+", Kind.PLUS), - ("?", Kind.QUESTION), - ("}", Kind.R_BRACE), - (")", Kind.R_PAREN), - ("]", Kind.R_SQUARE), - ("*", Kind.STAR), - ("|", Kind.VERTICAL_BAR), - ("{-#", Kind.FILE_METADATA_BEGIN), - ("#-}", Kind.FILE_METADATA_END), + ("->", MLIRTokenKind.ARROW), + (":", MLIRTokenKind.COLON), + (",", MLIRTokenKind.COMMA), + ("...", MLIRTokenKind.ELLIPSIS), + ("=", MLIRTokenKind.EQUAL), + (">", MLIRTokenKind.GREATER), + ("{", MLIRTokenKind.L_BRACE), + ("(", MLIRTokenKind.L_PAREN), + ("[", MLIRTokenKind.L_SQUARE), + ("<", MLIRTokenKind.LESS), + ("-", MLIRTokenKind.MINUS), + ("+", MLIRTokenKind.PLUS), + ("?", MLIRTokenKind.QUESTION), + ("}", MLIRTokenKind.R_BRACE), + (")", MLIRTokenKind.R_PAREN), + ("]", MLIRTokenKind.R_SQUARE), + ("*", MLIRTokenKind.STAR), + ("|", MLIRTokenKind.VERTICAL_BAR), + ("{-#", MLIRTokenKind.FILE_METADATA_BEGIN), + ("#-}", MLIRTokenKind.FILE_METADATA_END), ], ) -def test_punctuation(text: str, kind: Kind): +def test_punctuation(text: str, kind: MLIRTokenKind): assert_single_token(text, kind) @@ -69,7 +69,7 @@ def test_punctuation_fail(text: str): "text", ['""', '"@"', '"foo"', '"\\""', '"\\n"', '"\\\\"', '"\\t"'] ) def test_str_literal(text: str): - assert_single_token(text, Kind.STRING_LIT) + assert_single_token(text, MLIRTokenKind.STRING_LIT) @pytest.mark.parametrize("text", ['"', '"\\"', '"\\a"', '"\n"', '"\v"', '"\f"']) @@ -82,7 +82,7 @@ def test_str_literal_fail(text: str): ) def test_bare_ident(text: str): """bare-id ::= (letter|[_]) (letter|digit|[_$.])*""" - assert_single_token(text, Kind.BARE_IDENT) + assert_single_token(text, MLIRTokenKind.BARE_IDENT) @pytest.mark.parametrize( @@ -109,7 +109,7 @@ def test_bare_ident(text: str): ) def test_at_ident(text: str): """at-ident ::= `@` (bare-id | string-literal)""" - assert_single_token(text, Kind.AT_IDENT) + assert_single_token(text, MLIRTokenKind.AT_IDENT) @pytest.mark.parametrize( @@ -129,10 +129,10 @@ def test_prefixed_ident(text: str): """percent-ident ::= `%` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)""" """caret-ident ::= `^` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)""" """exclamation-ident ::= `!` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)""" - assert_single_token("#" + text, Kind.HASH_IDENT) - assert_single_token("%" + text, Kind.PERCENT_IDENT) - assert_single_token("^" + text, Kind.CARET_IDENT) - assert_single_token("!" + text, Kind.EXCLAMATION_IDENT) + assert_single_token("#" + text, MLIRTokenKind.HASH_IDENT) + assert_single_token("%" + text, MLIRTokenKind.PERCENT_IDENT) + assert_single_token("^" + text, MLIRTokenKind.CARET_IDENT) + assert_single_token("!" + text, MLIRTokenKind.EXCLAMATION_IDENT) @pytest.mark.parametrize("text", ["+", '""', "#", "%", "^", "!", "\n", ""]) @@ -155,46 +155,46 @@ def test_prefixed_ident_fail(text: str): ) def test_prefixed_ident_split(text: str, expected: str): """Check that the prefixed identifier is split at the right character.""" - assert_single_token("#" + text, Kind.HASH_IDENT, "#" + expected) - assert_single_token("%" + text, Kind.PERCENT_IDENT, "%" + expected) - assert_single_token("^" + text, Kind.CARET_IDENT, "^" + expected) - assert_single_token("!" + text, Kind.EXCLAMATION_IDENT, "!" + expected) + assert_single_token("#" + text, MLIRTokenKind.HASH_IDENT, "#" + expected) + assert_single_token("%" + text, MLIRTokenKind.PERCENT_IDENT, "%" + expected) + assert_single_token("^" + text, MLIRTokenKind.CARET_IDENT, "^" + expected) + assert_single_token("!" + text, MLIRTokenKind.EXCLAMATION_IDENT, "!" + expected) @pytest.mark.parametrize("text", ["0", "01", "123456789", "99", "0x1234", "0xabcdef"]) def test_integer_literal(text: str): - assert_single_token(text, Kind.INTEGER_LIT) + assert_single_token(text, MLIRTokenKind.INTEGER_LIT) @pytest.mark.parametrize( "text,expected", [("0a", "0"), ("0xg", "0"), ("0xfg", "0xf"), ("0xf.", "0xf")] ) def test_integer_literal_split(text: str, expected: str): - assert_single_token(text, Kind.INTEGER_LIT, expected) + assert_single_token(text, MLIRTokenKind.INTEGER_LIT, expected) @pytest.mark.parametrize( "text", ["0.", "1.", "0.2", "38.1243", "92.54e43", "92.5E43", "43.3e-54", "32.E+25"] ) def test_float_literal(text: str): - assert_single_token(text, Kind.FLOAT_LIT) + assert_single_token(text, MLIRTokenKind.FLOAT_LIT) @pytest.mark.parametrize( "text,expected", [("3.9e", "3.9"), ("4.5e+", "4.5"), ("5.8e-", "5.8")] ) def test_float_literal_split(text: str, expected: str): - assert_single_token(text, Kind.FLOAT_LIT, expected) + assert_single_token(text, MLIRTokenKind.FLOAT_LIT, expected) @pytest.mark.parametrize("text", ["0", " 0", " 0", "\n0", "\t0", "// Comment\n0"]) def test_whitespace_skip(text: str): - assert_single_token(text, Kind.INTEGER_LIT, "0") + assert_single_token(text, MLIRTokenKind.INTEGER_LIT, "0") @pytest.mark.parametrize("text", ["", " ", "\n\n", "// Comment\n"]) def test_eof(text: str): - assert_single_token(text, Kind.EOF, "") + assert_single_token(text, MLIRTokenKind.EOF, "") @pytest.mark.parametrize( @@ -209,7 +209,7 @@ def test_eof(text: str): ) def test_token_get_int_value(text: str, expected: int): token = get_token(text) - assert token.kind == Kind.INTEGER_LIT + assert token.kind == MLIRTokenKind.INTEGER_LIT assert token.kind.get_int_value(token.span) == expected @@ -228,7 +228,7 @@ def test_token_get_int_value(text: str, expected: int): ) def test_token_get_float_value(text: str, expected: float): token = get_token(text) - assert token.kind == Kind.FLOAT_LIT + assert token.kind == MLIRTokenKind.FLOAT_LIT assert token.kind.get_float_value(token.span) == expected @@ -246,5 +246,5 @@ def test_token_get_float_value(text: str, expected: float): ) def test_token_get_string_literal_value(text: str, expected: float): token = get_token(text) - assert token.kind == Kind.STRING_LIT + assert token.kind == MLIRTokenKind.STRING_LIT assert token.kind.get_string_literal_value(token.span) == expected diff --git a/tests/test_parser.py b/tests/test_parser.py index 13bbab6247..6f6933ab3f 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -30,7 +30,7 @@ from xdsl.parser import Parser from xdsl.printer import Printer from xdsl.utils.exceptions import ParseError, VerifyException -from xdsl.utils.mlir_lexer import Kind, PunctuationSpelling +from xdsl.utils.mlir_lexer import MLIRTokenKind, PunctuationSpelling from xdsl.utils.str_enum import StrEnum # pyright: reportPrivateUsage=false @@ -584,51 +584,54 @@ def test_parse_comma_separated_list_error_delimiters( @pytest.mark.parametrize( - "punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().values()) + "punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().values()) ) -def test_is_punctuation_true(punctuation: Kind): +def test_is_punctuation_true(punctuation: MLIRTokenKind): assert punctuation.is_punctuation() -@pytest.mark.parametrize("punctuation", [Kind.BARE_IDENT, Kind.EOF, Kind.INTEGER_LIT]) -def test_is_punctuation_false(punctuation: Kind): +@pytest.mark.parametrize( + "punctuation", + [MLIRTokenKind.BARE_IDENT, MLIRTokenKind.EOF, MLIRTokenKind.INTEGER_LIT], +) +def test_is_punctuation_false(punctuation: MLIRTokenKind): assert not punctuation.is_punctuation() @pytest.mark.parametrize( - "punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().values()) + "punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().values()) ) -def test_is_spelling_of_punctuation_true(punctuation: Kind): +def test_is_spelling_of_punctuation_true(punctuation: MLIRTokenKind): value = cast(PunctuationSpelling, punctuation.value) - assert Kind.is_spelling_of_punctuation(value) + assert MLIRTokenKind.is_spelling_of_punctuation(value) @pytest.mark.parametrize("punctuation", [">-", "o", "4", "$", "_", "@"]) def test_is_spelling_of_punctuation_false(punctuation: str): - assert not Kind.is_spelling_of_punctuation(punctuation) + assert not MLIRTokenKind.is_spelling_of_punctuation(punctuation) @pytest.mark.parametrize( - "punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().values()) + "punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().values()) ) -def test_get_punctuation_kind(punctuation: Kind): +def test_get_punctuation_kind(punctuation: MLIRTokenKind): value = cast(PunctuationSpelling, punctuation.value) assert punctuation.get_punctuation_kind_from_spelling(value) == punctuation @pytest.mark.parametrize( - "punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys()) + "punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys()) ) def test_parse_punctuation(punctuation: PunctuationSpelling): parser = Parser(MLContext(), punctuation) res = parser.parse_punctuation(punctuation) assert res == punctuation - assert parser._parse_token(Kind.EOF, "").kind == Kind.EOF + assert parser._parse_token(MLIRTokenKind.EOF, "").kind == MLIRTokenKind.EOF @pytest.mark.parametrize( - "punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys()) + "punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys()) ) def test_parse_punctuation_fail(punctuation: PunctuationSpelling): parser = Parser(MLContext(), "e +") @@ -639,17 +642,17 @@ def test_parse_punctuation_fail(punctuation: PunctuationSpelling): @pytest.mark.parametrize( - "punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys()) + "punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys()) ) def test_parse_optional_punctuation(punctuation: PunctuationSpelling): parser = Parser(MLContext(), punctuation) res = parser.parse_optional_punctuation(punctuation) assert res == punctuation - assert parser._parse_token(Kind.EOF, "").kind == Kind.EOF + assert parser._parse_token(MLIRTokenKind.EOF, "").kind == MLIRTokenKind.EOF @pytest.mark.parametrize( - "punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys()) + "punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys()) ) def test_parse_optional_punctuation_fail(punctuation: PunctuationSpelling): parser = Parser(MLContext(), "e +") diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index 19c2007a10..6ad8c61c38 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -32,7 +32,7 @@ from xdsl.traits import IsTerminator, NoTerminator, OpTrait, OpTraitInvT from xdsl.utils.exceptions import VerifyException -from xdsl.utils.mlir_lexer import Lexer +from xdsl.utils.mlir_lexer import MLIRLexer from xdsl.utils.str_enum import StrEnum # Used for cyclic dependencies in type hints @@ -404,7 +404,7 @@ def _check_enum_constraints( raise TypeError("Only direct inheritance from EnumAttribute is allowed.") for v in enum_type: - if Lexer.bare_identifier_suffix_regex.fullmatch(v) is None: + if MLIRLexer.bare_identifier_suffix_regex.fullmatch(v) is None: raise ValueError( "All StrEnum values of an EnumAttribute must be parsable as an identifer." ) diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 105070c058..d6161c10a3 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -80,11 +80,11 @@ ) from xdsl.parser import BaseParser, ParserState from xdsl.utils.lexer import Input -from xdsl.utils.mlir_lexer import Kind, Lexer, Token +from xdsl.utils.mlir_lexer import MLIRLexer, MLIRToken, MLIRTokenKind @dataclass -class FormatLexer(Lexer): +class FormatLexer(MLIRLexer): """ A lexer for the declarative assembly format. The differences with the MLIR lexer are the following: @@ -92,7 +92,7 @@ class FormatLexer(Lexer): * Bare identifiers may also may contain `-`. """ - def lex(self) -> Token: + def lex(self) -> MLIRToken: """Lex a token from the input, and returns it.""" # First, skip whitespaces self._consume_whitespace() @@ -102,13 +102,13 @@ def lex(self) -> Token: # Handle end of file if current_char is None: - return self._form_token(Kind.EOF, start_pos) + return self._form_token(MLIRTokenKind.EOF, start_pos) # We parse '`', `\\` and '$' as a BARE_IDENT. # This is a hack to reuse the MLIR lexer. if current_char in ("`", "$", "\\", "^"): self._consume_chars() - return self._form_token(Kind.BARE_IDENT, start_pos) + return self._form_token(MLIRTokenKind.BARE_IDENT, start_pos) return super().lex() # Authorize `-` in bare identifier @@ -168,7 +168,7 @@ def parse_format(self) -> FormatProgram: unambiguous and refer to all elements exactly once. """ elements: list[FormatDirective] = [] - while self._current_token.kind != Kind.EOF: + while self._current_token.kind != MLIRTokenKind.EOF: elements.append(self.parse_format_directive()) self.add_reserved_attrs_to_directive(elements) @@ -717,7 +717,7 @@ def parse_keyword_or_punctuation(self) -> FormatDirective: if self._current_token.kind.is_punctuation(): punctuation = self._consume_token().text self.parse_characters("`") - assert Kind.is_spelling_of_punctuation(punctuation) + assert MLIRTokenKind.is_spelling_of_punctuation(punctuation) return PunctuationDirective(punctuation) # Identifier case diff --git a/xdsl/parser/affine_parser.py b/xdsl/parser/affine_parser.py index edf4b506a6..3c630aef64 100644 --- a/xdsl/parser/affine_parser.py +++ b/xdsl/parser/affine_parser.py @@ -9,7 +9,7 @@ ) from xdsl.parser.base_parser import BaseParser, ParserState from xdsl.utils.exceptions import ParseError -from xdsl.utils.mlir_lexer import Kind, Token +from xdsl.utils.mlir_lexer import MLIRToken, MLIRTokenKind class AffineParser(BaseParser): @@ -33,12 +33,12 @@ def _parse_primary(self, dims: list[str], syms: list[str]) -> AffineExpr: | `-` primary """ # Handle parentheses - if self._parse_optional_token(Kind.L_PAREN): + if self._parse_optional_token(MLIRTokenKind.L_PAREN): expr = self._parse_affine_expr(dims, syms) - self._parse_token(Kind.R_PAREN, "Expected closing parenthesis") + self._parse_token(MLIRTokenKind.R_PAREN, "Expected closing parenthesis") return expr # Handle bare id - if bare_id := self._parse_optional_token(Kind.BARE_IDENT): + if bare_id := self._parse_optional_token(MLIRTokenKind.BARE_IDENT): if bare_id.text in dims: return AffineExpr.dimension(dims.index(bare_id.text)) elif bare_id.text in syms: @@ -48,10 +48,10 @@ def _parse_primary(self, dims: list[str], syms: list[str]) -> AffineExpr: bare_id.span, f"Identifier not in space {bare_id.text}" ) # Handle integer literal - if int_lit := self._parse_optional_token(Kind.INTEGER_LIT): + if int_lit := self._parse_optional_token(MLIRTokenKind.INTEGER_LIT): return AffineExpr.constant(int_lit.kind.get_int_value(int_lit.span)) # Handle negative primary - if self._parse_optional_token(Kind.MINUS): + if self._parse_optional_token(MLIRTokenKind.MINUS): return -self._parse_primary(dims, syms) raise ParseError(self._current_token.span, "Expected primary expression") @@ -60,7 +60,7 @@ def _get_token_precedence(self) -> int: return self._BINOP_PRECEDENCE.get(self._current_token.text, -1) def _create_binop_expr( - self, lhs: AffineExpr, rhs: AffineExpr, binop: Token + self, lhs: AffineExpr, rhs: AffineExpr, binop: MLIRToken ) -> AffineExpr: match binop.text: case "+": @@ -175,12 +175,14 @@ def _parse_affine_space(self) -> tuple[list[str], list[str]]: """ def parse_id() -> str: - return self._parse_token(Kind.BARE_IDENT, "Expected identifier").text + return self._parse_token( + MLIRTokenKind.BARE_IDENT, "Expected identifier" + ).text # Parse dimensions dims = self.parse_comma_separated_list(self.Delimiter.PAREN, parse_id) # Parse optional symbols - if self._current_token.kind != Kind.L_SQUARE: + if self._current_token.kind != MLIRTokenKind.L_SQUARE: syms = [] else: syms = self.parse_comma_separated_list(self.Delimiter.SQUARE, parse_id) @@ -195,7 +197,7 @@ def parse_affine_map(self) -> AffineMap: # Parse affine space dims, syms = self._parse_affine_space() # Parse : delimiter - self._parse_token(Kind.ARROW, "Expected `->`") + self._parse_token(MLIRTokenKind.ARROW, "Expected `->`") # Parse list of affine expressions exprs = self._parse_multi_affine_expr(dims, syms) # Create map and return. @@ -209,7 +211,7 @@ def parse_affine_set(self) -> AffineSet: # Parse affine space dims, syms = self._parse_affine_space() # Parse : delimiter - self._parse_token(Kind.COLON, "Expected `:`") + self._parse_token(MLIRTokenKind.COLON, "Expected `:`") # Parse list of affine expressions constraints = self._parse_multi_affine_constaint(dims, syms) # Create map and return. diff --git a/xdsl/parser/attribute_parser.py b/xdsl/parser/attribute_parser.py index 98a390a1b5..2fa5c35b48 100644 --- a/xdsl/parser/attribute_parser.py +++ b/xdsl/parser/attribute_parser.py @@ -71,7 +71,7 @@ from xdsl.utils.exceptions import ParseError, VerifyException from xdsl.utils.isattr import isattr from xdsl.utils.lexer import Position, Span -from xdsl.utils.mlir_lexer import Kind, StringLiteral +from xdsl.utils.mlir_lexer import MLIRTokenKind, StringLiteral @dataclass @@ -113,7 +113,9 @@ def parse_optional_type(self) -> Attribute | None: | `{` dialect-attribute-contents+ `}` | [^[]<>(){}\0]+ """ - if (token := self._parse_optional_token(Kind.EXCLAMATION_IDENT)) is not None: + if ( + token := self._parse_optional_token(MLIRTokenKind.EXCLAMATION_IDENT) + ) is not None: return self._parse_extended_type_or_attribute(token.text[1:], True) return self._parse_optional_builtin_type() @@ -149,7 +151,7 @@ def parse_optional_attribute(self) -> Attribute | None: | `{` dialect-attribute-contents+ `}` | [^[]<>(){}\0]+ """ - if (token := self._parse_optional_token(Kind.HASH_IDENT)) is not None: + if (token := self._parse_optional_token(MLIRTokenKind.HASH_IDENT)) is not None: return self._parse_extended_type_or_attribute(token.text[1:], False) return self._parse_optional_builtin_attr() @@ -177,7 +179,7 @@ def _parse_attribute_entry(self) -> tuple[str, Attribute]: attribute_entry := (bare-id | string-literal) `=` attribute attribute := dialect-attribute | builtin-attribute """ - if (name := self._parse_optional_token(Kind.BARE_IDENT)) is not None: + if (name := self._parse_optional_token(MLIRTokenKind.BARE_IDENT)) is not None: name = name.span.text else: name = self.parse_optional_str_literal() @@ -222,7 +224,9 @@ def _parse_dialect_type_or_attribute_body( self.parse_punctuation("<") attr_name += ( "." - + self._parse_token(Kind.BARE_IDENT, "Expected attribute name.").text + + self._parse_token( + MLIRTokenKind.BARE_IDENT, "Expected attribute name." + ).text ) attr_def = self.ctx.get_optional_attr( attr_name, @@ -285,7 +289,7 @@ def _parse_extended_type_or_attribute( # dialect parser. if not is_pretty_name: attr_name_token = self._parse_token( - Kind.BARE_IDENT, "Expected attribute name." + MLIRTokenKind.BARE_IDENT, "Expected attribute name." ) starting_opaque_pos = attr_name_token.span.end @@ -311,18 +315,18 @@ def _parse_unregistered_attr_body(self, start_pos: Position | None) -> str: start_pos = self.pos end_pos: Position = start_pos - symbols_stack: list[Kind] = [] + symbols_stack: list[MLIRTokenKind] = [] parentheses = { - Kind.GREATER: Kind.LESS, - Kind.R_PAREN: Kind.L_PAREN, - Kind.R_SQUARE: Kind.L_SQUARE, - Kind.R_BRACE: Kind.L_BRACE, + MLIRTokenKind.GREATER: MLIRTokenKind.LESS, + MLIRTokenKind.R_PAREN: MLIRTokenKind.L_PAREN, + MLIRTokenKind.R_SQUARE: MLIRTokenKind.L_SQUARE, + MLIRTokenKind.R_BRACE: MLIRTokenKind.L_BRACE, } parentheses_names = { - Kind.GREATER: "`>`", - Kind.R_PAREN: "`)`", - Kind.R_SQUARE: "`]`", - Kind.R_BRACE: "`}`", + MLIRTokenKind.GREATER: "`>`", + MLIRTokenKind.R_PAREN: "`)`", + MLIRTokenKind.R_SQUARE: "`]`", + MLIRTokenKind.R_BRACE: "`}`", } while True: # Opening a new parenthesis @@ -339,7 +343,7 @@ def _parse_unregistered_attr_body(self, start_pos: Position | None) -> str: # If we don't have any open parenthesis, either we end the parsing if # the parenthesis is a `>`, or we raise an error. if len(symbols_stack) == 0: - if token.kind == Kind.GREATER: + if token.kind == MLIRTokenKind.GREATER: end_pos = self.pos break self.raise_error( @@ -361,7 +365,7 @@ def _parse_unregistered_attr_body(self, start_pos: Position | None) -> str: continue # Checking for unexpected EOF - if self._parse_optional_token(Kind.EOF) is not None: + if self._parse_optional_token(MLIRTokenKind.EOF) is not None: self.raise_error( "Unexpected end of file before closing of attribute body!" ) @@ -380,7 +384,7 @@ def _parse_optional_builtin_parametrized_type(self) -> ParametrizedAttribute | N builtin-name ::= vector | memref | tensor | complex | tuple args ::= """ - if self._current_token.kind != Kind.BARE_IDENT: + if self._current_token.kind != MLIRTokenKind.BARE_IDENT: return None name = self._current_token.text @@ -401,7 +405,7 @@ def unimplemented() -> NoReturn: if name not in builtin_parsers: return None - self._consume_token(Kind.BARE_IDENT) + self._consume_token(MLIRTokenKind.BARE_IDENT) self.parse_punctuation("<", " after builtin name") # Get the parser for the type, falling back to the unimplemented warning @@ -418,8 +422,8 @@ def parse_shape_dimension(self, allow_dynamic: bool = True) -> int: Optionally allows to not parse `?` as -1. """ if self._current_token.kind not in ( - Kind.INTEGER_LIT, - Kind.QUESTION, + MLIRTokenKind.INTEGER_LIT, + MLIRTokenKind.QUESTION, ): if allow_dynamic: self.raise_error( @@ -438,7 +442,7 @@ def parse_shape_dimension(self, allow_dynamic: bool = True) -> int: # If the integer literal starts with `0x`, this is decomposed into # `0` and `x`. - int_token = self._consume_token(Kind.INTEGER_LIT) + int_token = self._consume_token(MLIRTokenKind.INTEGER_LIT) if int_token.text[:2] == "0x": self._resume_from(int_token.span.start + 1) return 0 @@ -451,7 +455,7 @@ def parse_shape_delimiter(self) -> None: characters, it will split the token. For instance, 'x1' will be split into 'x' and '1'. """ - if self._current_token.kind != Kind.BARE_IDENT: + if self._current_token.kind != MLIRTokenKind.BARE_IDENT: self.raise_error( "Expected 'x' in shape delimiter, got " f"{self._current_token.kind.name}" @@ -473,7 +477,10 @@ def parse_ranked_shape(self) -> tuple[list[int], Attribute]: each dimension is also required to be non-negative. """ dims: list[int] = [] - while self._current_token.kind in (Kind.INTEGER_LIT, Kind.QUESTION): + while self._current_token.kind in ( + MLIRTokenKind.INTEGER_LIT, + MLIRTokenKind.QUESTION, + ): dim = self.parse_shape_dimension() dims.append(dim) self.parse_shape_delimiter() @@ -543,7 +550,7 @@ def _parse_vector_attrs(self) -> AnyVectorType: dims: list[int] = [] num_scalable_dims = 0 # First, parse the static dimensions - while self._current_token.kind == Kind.INTEGER_LIT: + while self._current_token.kind == MLIRTokenKind.INTEGER_LIT: dims.append(self.parse_shape_dimension(allow_dynamic=False)) self.parse_shape_delimiter() @@ -619,7 +626,7 @@ def _parse_optional_builtin_attr(self) -> Attribute | None: def _parse_int_or_question(self, context_msg: str = "") -> int | Literal["?"]: """Parse either an integer literal, or a '?'.""" - if self._parse_optional_token(Kind.QUESTION) is not None: + if self._parse_optional_token(MLIRTokenKind.QUESTION) is not None: return "?" if (v := self.parse_optional_integer(allow_boolean=False)) is not None: return v @@ -632,7 +639,7 @@ def _parse_strided_layout_attr(self) -> Attribute: (`,` `offset` `:` integer-literal)? `>` """ # Parse stride list - self._parse_token(Kind.LESS, "Expected `<` after `strided`") + self._parse_token(MLIRTokenKind.LESS, "Expected `<` after `strided`") strides = self.parse_comma_separated_list( self.Delimiter.SQUARE, lambda: self._parse_int_or_question(" in stride list"), @@ -645,17 +652,19 @@ def _parse_strided_layout_attr(self) -> Attribute: strides = [None if stride == "?" else stride for stride in strides] # Case without offset - if self._parse_optional_token(Kind.GREATER) is not None: + if self._parse_optional_token(MLIRTokenKind.GREATER) is not None: return StridedLayoutAttr(strides) # Parse the optional offset self._parse_token( - Kind.COMMA, "Expected end of strided attribute or ',' for offset." + MLIRTokenKind.COMMA, "Expected end of strided attribute or ',' for offset." ) self.parse_keyword("offset", " after comma") - self._parse_token(Kind.COLON, "Expected ':' after 'offset'") + self._parse_token(MLIRTokenKind.COLON, "Expected ':' after 'offset'") offset = self._parse_int_or_question(" in stride offset") - self._parse_token(Kind.GREATER, "Expected '>' in end of stride attribute") + self._parse_token( + MLIRTokenKind.GREATER, "Expected '>' in end of stride attribute" + ) return StridedLayoutAttr(strides, None if offset == "?" else offset) def parse_optional_unit_attr(self) -> Attribute | None: @@ -663,7 +672,7 @@ def parse_optional_unit_attr(self) -> Attribute | None: Parse a value of `unit` type. unit-attribute ::= `unit` """ - if self._current_token.kind != Kind.BARE_IDENT: + if self._current_token.kind != MLIRTokenKind.BARE_IDENT: return None name = self._current_token.span.text @@ -675,7 +684,7 @@ def parse_optional_unit_attr(self) -> Attribute | None: return None def _parse_optional_builtin_parametrized_attr(self) -> Attribute | None: - if self._current_token.kind != Kind.BARE_IDENT: + if self._current_token.kind != MLIRTokenKind.BARE_IDENT: return None name = self._current_token.span parsers = { @@ -690,7 +699,7 @@ def _parse_optional_builtin_parametrized_attr(self) -> Attribute | None: if name.text not in parsers: return None - self._consume_token(Kind.BARE_IDENT) + self._consume_token(MLIRTokenKind.BARE_IDENT) return parsers[name.text]() def _parse_builtin_dense_attr_hex( @@ -1036,21 +1045,21 @@ def _parse_tensor_literal_element(self) -> _TensorLiteralElement: """ # boolean case if self._current_token.text == "true": - token = self._consume_token(Kind.BARE_IDENT) + token = self._consume_token(MLIRTokenKind.BARE_IDENT) return self._TensorLiteralElement(False, True, token.span) if self._current_token.text == "false": - token = self._consume_token(Kind.BARE_IDENT) + token = self._consume_token(MLIRTokenKind.BARE_IDENT) return self._TensorLiteralElement(False, False, token.span) # checking for negation - minus_token = self._parse_optional_token(Kind.MINUS) + minus_token = self._parse_optional_token(MLIRTokenKind.MINUS) # Integer and float case - if self._current_token.kind == Kind.FLOAT_LIT: - token = self._consume_token(Kind.FLOAT_LIT) + if self._current_token.kind == MLIRTokenKind.FLOAT_LIT: + token = self._consume_token(MLIRTokenKind.FLOAT_LIT) value = token.kind.get_float_value(token.span) - elif self._current_token.kind == Kind.INTEGER_LIT: - token = self._consume_token(Kind.INTEGER_LIT) + elif self._current_token.kind == MLIRTokenKind.INTEGER_LIT: + token = self._consume_token(MLIRTokenKind.INTEGER_LIT) value = token.kind.get_int_value(token.span) else: self.raise_error("Expected either a float, integer, or complex literal") @@ -1118,7 +1127,7 @@ def parse_optional_symbol_name(self) -> StringAttr | None: Parse an @-identifier if present, and return its name (without the '@') in a string attribute. """ - if (token := self._parse_optional_token(Kind.AT_IDENT)) is None: + if (token := self._parse_optional_token(MLIRTokenKind.AT_IDENT)) is None: return None assert len(token.text) > 1, "token should be at least 2 characters long" @@ -1152,11 +1161,11 @@ def _parse_optional_symref_attr(self) -> SymbolRefAttr | None: # Parse nested symbols refs: list[StringAttr] = [] - while self._current_token.kind == Kind.COLON: + while self._current_token.kind == MLIRTokenKind.COLON: # Parse `::`. As in MLIR, this require to backtrack if a single `:` is given. pos = self._current_token.span.start - self._consume_token(Kind.COLON) - if self._parse_optional_token(Kind.COLON) is None: + self._consume_token(MLIRTokenKind.COLON) + if self._parse_optional_token(MLIRTokenKind.COLON) is None: self._resume_from(pos) break @@ -1194,7 +1203,7 @@ def parse_optional_builtin_int_or_float_attr( return None # If no types are given, we take the default ones - if self._current_token.kind != Kind.COLON: + if self._current_token.kind != MLIRTokenKind.COLON: if isinstance(value, float): return FloatAttr(value, Float64Type()) return IntegerAttr(value, i64) @@ -1246,7 +1255,7 @@ def _parse_optional_string_attr(self) -> StringAttr | None: Parse a string attribute, if present. string-attr ::= string-literal """ - token = self._parse_optional_token(Kind.STRING_LIT) + token = self._parse_optional_token(MLIRTokenKind.STRING_LIT) return ( StringAttr(token.kind.get_string_literal_value(token.span)) if token is not None @@ -1282,7 +1291,7 @@ def parse_optional_function_type(self) -> FunctionType | None: function-type ::= type-list `->` (type | type-list) type-list ::= `(` `)` | `(` type (`,` type)* `)` """ - if self._current_token.kind != Kind.L_PAREN: + if self._current_token.kind != MLIRTokenKind.L_PAREN: return None # Parse the arguments @@ -1314,7 +1323,7 @@ def _parse_optional_builtin_dict_attr(self) -> DictionaryAttr | None: `dictionary-attr ::= `{` ( attribute-entry (`,` attribute-entry)* )? `}` `attribute-entry` := (bare-id | string-literal) `=` attribute """ - if self._current_token.kind != Kind.L_BRACE: + if self._current_token.kind != MLIRTokenKind.L_BRACE: return None param = DictionaryAttr.parse_parameter(self) return DictionaryAttr(param) @@ -1339,7 +1348,7 @@ def _parse_optional_integer_or_float_type(self) -> Attribute | None: integer-type ::= (`i` | `si` | `ui`) decimal-literal float-type ::= `f16` | `f32` | `f64` | `f80` | `f128` | `bf16` """ - if self._current_token.kind != Kind.BARE_IDENT: + if self._current_token.kind != MLIRTokenKind.BARE_IDENT: return None name = self._current_token.text diff --git a/xdsl/parser/base_parser.py b/xdsl/parser/base_parser.py index 98fe77017b..bacc28fe03 100644 --- a/xdsl/parser/base_parser.py +++ b/xdsl/parser/base_parser.py @@ -12,11 +12,11 @@ from xdsl.utils.exceptions import ParseError from xdsl.utils.lexer import Position, Span from xdsl.utils.mlir_lexer import ( - Kind, - Lexer, + MLIRLexer, + MLIRToken, + MLIRTokenKind, PunctuationSpelling, StringLiteral, - Token, ) from xdsl.utils.str_enum import StrEnum @@ -29,11 +29,11 @@ class ParserState: share the same position. """ - lexer: Lexer - current_token: Token + lexer: MLIRLexer + current_token: MLIRToken dialect_stack: list[str] - def __init__(self, lexer: Lexer, dialect_stack: list[str] | None = None): + def __init__(self, lexer: MLIRLexer, dialect_stack: list[str] | None = None): if dialect_stack is None: dialect_stack = ["builtin"] self.lexer = lexer @@ -115,12 +115,12 @@ def _resume_from(self, pos: Position | ParserState): self._parser_state = pos @property - def _current_token(self) -> Token: + def _current_token(self) -> MLIRToken: """Get the token that is currently being parsed. Do not consume the token.""" return self._parser_state.current_token @property - def lexer(self) -> Lexer: + def lexer(self) -> MLIRLexer: """The lexer used to parse the current input.""" return self._parser_state.lexer @@ -132,7 +132,7 @@ def pos(self) -> Position: """ return self._current_token.span.start - def _consume_token(self, expected_kind: Kind | None = None) -> Token: + def _consume_token(self, expected_kind: MLIRTokenKind | None = None) -> MLIRToken: """ Advance the lexer to the next token. Additionally check that the current token was of a specific kind, @@ -146,7 +146,7 @@ def _consume_token(self, expected_kind: Kind | None = None) -> Token: self._parser_state.current_token = self.lexer.lex() return consumed_token - def _parse_optional_token(self, expected_kind: Kind) -> Token | None: + def _parse_optional_token(self, expected_kind: MLIRTokenKind) -> MLIRToken | None: """ If the current token is of the expected kind, consume it and return it. Otherwise, return None. @@ -157,7 +157,7 @@ def _parse_optional_token(self, expected_kind: Kind) -> Token | None: return current_token return None - def _parse_token(self, expected_kind: Kind, error_msg: str) -> Token: + def _parse_token(self, expected_kind: MLIRTokenKind, error_msg: str) -> MLIRToken: """ Parse a specific token, and raise an error if it is not present. Returns the token that was parsed. @@ -168,7 +168,9 @@ def _parse_token(self, expected_kind: Kind, error_msg: str) -> Token: self._consume_token(expected_kind) return current_token - def _parse_optional_token_in(self, expected_kinds: Iterable[Kind]) -> Token | None: + def _parse_optional_token_in( + self, expected_kinds: Iterable[MLIRTokenKind] + ) -> MLIRToken | None: """Parse one of the expected tokens if present, and returns it.""" if self._current_token.kind not in expected_kinds: return None @@ -218,7 +220,7 @@ def parse_comma_separated_list( # Parse the list of elements elems = [parse()] - while self._parse_optional_token(Kind.COMMA) is not None: + while self._parse_optional_token(MLIRTokenKind.COMMA) is not None: elems.append(parse()) # Parse the closing bracket, if a delimiter was provided @@ -257,7 +259,7 @@ def parse_optional_comma_separated_list( # Parse the list of elements elems = [parse()] - while self._parse_optional_token(Kind.COMMA) is not None: + while self._parse_optional_token(MLIRTokenKind.COMMA) is not None: elems.append(parse()) # Parse the closing bracket @@ -283,7 +285,7 @@ def parse_optional_undelimited_comma_separated_list( # Parse the remaining elements elems = [first_elem] - while self._parse_optional_token(Kind.COMMA) is not None: + while self._parse_optional_token(MLIRTokenKind.COMMA) is not None: elems.append(parse()) return elems @@ -292,12 +294,12 @@ def parse_optional_boolean(self) -> bool | None: """ Parse a boolean, if present, with the format `true` or `false`. """ - if self._current_token.kind == Kind.BARE_IDENT: + if self._current_token.kind == MLIRTokenKind.BARE_IDENT: if self._current_token.text == "true": - self._consume_token(Kind.BARE_IDENT) + self._consume_token(MLIRTokenKind.BARE_IDENT) return True elif self._current_token.text == "false": - self._consume_token(Kind.BARE_IDENT) + self._consume_token(MLIRTokenKind.BARE_IDENT) return False return None @@ -326,10 +328,10 @@ def parse_optional_integer( # Parse negative numbers if required is_negative = False if allow_negative: - is_negative = self._parse_optional_token(Kind.MINUS) is not None + is_negative = self._parse_optional_token(MLIRTokenKind.MINUS) is not None # Parse the actual number - if (int_token := self._parse_optional_token(Kind.INTEGER_LIT)) is None: + if (int_token := self._parse_optional_token(MLIRTokenKind.INTEGER_LIT)) is None: if is_negative: self.raise_error("Expected integer literal after '-'") return None @@ -367,9 +369,9 @@ def parse_optional_float( """ is_negative = False if allow_negative: - is_negative = self._parse_optional_token(Kind.MINUS) is not None + is_negative = self._parse_optional_token(MLIRTokenKind.MINUS) is not None - if (value := self._parse_optional_token(Kind.FLOAT_LIT)) is not None: + if (value := self._parse_optional_token(MLIRTokenKind.FLOAT_LIT)) is not None: value = value.kind.get_float_value(value.span) return -value if is_negative else value @@ -395,7 +397,7 @@ def parse_optional_number( Can optionally parse 'true' or 'false' into 1 and 0. """ - is_negative = self._parse_optional_token(Kind.MINUS) is not None + is_negative = self._parse_optional_token(MLIRTokenKind.MINUS) is not None if ( value := self.parse_optional_integer( @@ -436,7 +438,7 @@ def parse_optional_str_literal(self) -> str | None: resolved. """ - if (token := self._parse_optional_token(Kind.STRING_LIT)) is None: + if (token := self._parse_optional_token(MLIRTokenKind.STRING_LIT)) is None: return None try: return token.kind.get_string_literal_value(token.span) @@ -463,7 +465,7 @@ def parse_optional_bytes_literal(self) -> bytes | None: resolved. """ - if (token := self._parse_optional_token(Kind.BYTES_LIT)) is None: + if (token := self._parse_optional_token(MLIRTokenKind.BYTES_LIT)) is None: return None return StringLiteral.from_span(token.span).bytes_contents @@ -484,7 +486,7 @@ def parse_optional_identifier(self) -> str | None: Parse an identifier, if present, with syntax: ident ::= (letter|[_]) (letter|digit|[_$.])* """ - if (token := self._parse_optional_token(Kind.BARE_IDENT)) is not None: + if (token := self._parse_optional_token(MLIRTokenKind.BARE_IDENT)) is not None: return token.text return None @@ -523,10 +525,10 @@ def parse_optional_keyword(self, keyword: str) -> str | None: """Parse a specific identifier if it is present""" if ( - self._current_token.kind == Kind.BARE_IDENT + self._current_token.kind == MLIRTokenKind.BARE_IDENT and self._current_token.text == keyword ): - self._consume_token(Kind.BARE_IDENT) + self._consume_token(MLIRTokenKind.BARE_IDENT) return keyword return None @@ -547,10 +549,10 @@ def parse_optional_punctuation( """ # This check is only necessary to catch errors made by users that # are not using pyright. - assert Kind.is_spelling_of_punctuation(punctuation), ( + assert MLIRTokenKind.is_spelling_of_punctuation(punctuation), ( "'parse_optional_punctuation' must be " "called with a valid punctuation" ) - kind = Kind.get_punctuation_kind_from_spelling(punctuation) + kind = MLIRTokenKind.get_punctuation_kind_from_spelling(punctuation) if self._parse_optional_token(kind) is not None: return punctuation return None @@ -564,10 +566,10 @@ def parse_punctuation( """ # This check is only necessary to catch errors made by users that # are not using pyright. - assert Kind.is_spelling_of_punctuation( + assert MLIRTokenKind.is_spelling_of_punctuation( punctuation ), "'parse_punctuation' must be called with a valid punctuation" - kind = Kind.get_punctuation_kind_from_spelling(punctuation) + kind = MLIRTokenKind.get_punctuation_kind_from_spelling(punctuation) self._parse_token(kind, f"Expected '{punctuation}'" + context_msg) return punctuation @@ -586,12 +588,12 @@ def parse_str_enum(self, enum_type: type[_EnumType]) -> _EnumType: def parse_optional_str_enum(self, enum_type: type[_EnumType]) -> _EnumType | None: """Parse a string enum value, if present.""" - if self._current_token.kind != Kind.BARE_IDENT: + if self._current_token.kind != MLIRTokenKind.BARE_IDENT: return None val = self._current_token.text if val not in enum_type.__members__.values(): return None - self._consume_token(Kind.BARE_IDENT) + self._consume_token(MLIRTokenKind.BARE_IDENT) return enum_type(val) diff --git a/xdsl/parser/core.py b/xdsl/parser/core.py index 2cd654de27..dbbbdf96f9 100644 --- a/xdsl/parser/core.py +++ b/xdsl/parser/core.py @@ -20,7 +20,7 @@ from xdsl.parser import AttrParser, ParserState, Position from xdsl.utils.exceptions import MultipleSpansParseError from xdsl.utils.lexer import Input, Span -from xdsl.utils.mlir_lexer import Kind, Lexer +from xdsl.utils.mlir_lexer import MLIRLexer, MLIRTokenKind @dataclass(eq=False) @@ -95,7 +95,7 @@ def __init__( input: str, name: str = "", ) -> None: - super().__init__(ParserState(Lexer(Input(input, name))), ctx) + super().__init__(ParserState(MLIRLexer(Input(input, name))), ctx) self.ssa_values = dict() self.blocks = dict() self.forward_block_references = dict() @@ -118,10 +118,10 @@ def parse_module(self, allow_implicit_module: bool = True) -> ModuleOp: else: parsed_ops: list[Operation] = [] - while self._current_token.kind != Kind.EOF: + while self._current_token.kind != MLIRTokenKind.EOF: if self._current_token.kind in ( - Kind.HASH_IDENT, - Kind.EXCLAMATION_IDENT, + MLIRTokenKind.HASH_IDENT, + MLIRTokenKind.EXCLAMATION_IDENT, ): self._parse_alias_def() continue @@ -159,7 +159,7 @@ def _parse_alias_def(self): """ if ( token := self._parse_optional_token_in( - [Kind.EXCLAMATION_IDENT, Kind.HASH_IDENT] + [MLIRTokenKind.EXCLAMATION_IDENT, MLIRTokenKind.HASH_IDENT] ) ) is None: self.raise_error("expected attribute name") @@ -191,13 +191,13 @@ def _parse_optional_block_arg_list(self, block: Block): value-id-and-type-list ::= value-id-and-type (`,` ssa-id-and-type)* block-arg-list ::= `(` value-id-and-type-list? `)` """ - if self._current_token.kind != Kind.L_PAREN: + if self._current_token.kind != MLIRTokenKind.L_PAREN: return None def parse_argument() -> None: """Parse a single block argument with its type.""" arg_name = self._parse_token( - Kind.PERCENT_IDENT, "block argument expected" + MLIRTokenKind.PERCENT_IDENT, "block argument expected" ).span self.parse_punctuation(":") arg_type = self.parse_attribute() @@ -226,7 +226,9 @@ def _parse_block(self) -> Block: block-id ::= caret-id block-arg-list ::= `(` ssa-id-and-type-list? `)` """ - name_token = self._parse_token(Kind.CARET_IDENT, " in block definition") + name_token = self._parse_token( + MLIRTokenKind.CARET_IDENT, " in block definition" + ) name = name_token.text[1:] if name not in self.blocks: @@ -255,12 +257,12 @@ def parse_optional_unresolved_operand(self) -> UnresolvedOperand | None: Parse an operand with format `%(#)?`, if present. The operand may be forward declared. """ - name_token = self._parse_optional_token(Kind.PERCENT_IDENT) + name_token = self._parse_optional_token(MLIRTokenKind.PERCENT_IDENT) if name_token is None: return None index = 0 - index_token = self._parse_optional_token(Kind.HASH_IDENT) + index_token = self._parse_optional_token(MLIRTokenKind.HASH_IDENT) if index_token is not None: if re.fullmatch(self._decimal_integer_regex, index_token.text[1:]) is None: self.raise_error( @@ -456,7 +458,7 @@ def parse_optional_argument( """ # The argument name - name_token = self._parse_optional_token(Kind.PERCENT_IDENT) + name_token = self._parse_optional_token(MLIRTokenKind.PERCENT_IDENT) if name_token is None: return None @@ -528,7 +530,7 @@ def parse_optional_region( # Check that the entry block has no label. # Since a multi-block region block must have a terminator, there isn't a # possibility of having an empty entry block, and thus parsing the label directly. - if self._current_token.kind == Kind.CARET_IDENT: + if self._current_token.kind == MLIRTokenKind.CARET_IDENT: self.raise_error("invalid block name in region with named arguments") # Set the block arguments in the context @@ -542,8 +544,8 @@ def parse_optional_region( # If no arguments was provided, parse the entry block if present. elif self._current_token.kind not in ( - Kind.CARET_IDENT, - Kind.R_BRACE, + MLIRTokenKind.CARET_IDENT, + MLIRTokenKind.R_BRACE, ): block = Block() self._parse_block_body(block) @@ -604,7 +606,7 @@ def parse_region_list(self) -> list[Region]: Parse a list of regions with format: regions-list ::= `(` region (`,` region)* `)` """ - if self._current_token.kind == Kind.L_PAREN: + if self._current_token.kind == MLIRTokenKind.L_PAREN: return self.parse_comma_separated_list( self.Delimiter.PAREN, self.parse_region, " in operation region list" ) @@ -668,9 +670,9 @@ def parse_optional_operation(self) -> Operation | None: properties ::= `<` dictionary-attribute `>` """ if self._current_token.kind not in ( - Kind.PERCENT_IDENT, - Kind.BARE_IDENT, - Kind.STRING_LIT, + MLIRTokenKind.PERCENT_IDENT, + MLIRTokenKind.BARE_IDENT, + MLIRTokenKind.STRING_LIT, ): return None return self.parse_operation() @@ -695,7 +697,9 @@ def parse_operation(self) -> Operation: op_loc = self._current_token.span bound_results = self._parse_op_result_list() - if (op_name := self._parse_optional_token(Kind.BARE_IDENT)) is not None: + if ( + op_name := self._parse_optional_token(MLIRTokenKind.BARE_IDENT) + ) is not None: # Custom operation format op_type = self._get_op_by_name(op_name.text) dialect_name = op_type.dialect_name() @@ -756,13 +760,13 @@ def _parse_op_result(self) -> tuple[Span, int]: value tuple (by default 1). """ value_token = self._parse_token( - Kind.PERCENT_IDENT, "Expected result SSA value!" + MLIRTokenKind.PERCENT_IDENT, "Expected result SSA value!" ) - if self._parse_optional_token(Kind.COLON) is None: + if self._parse_optional_token(MLIRTokenKind.COLON) is None: return (value_token.span, 1) size_token = self._parse_token( - Kind.INTEGER_LIT, "Expected SSA value tuple size" + MLIRTokenKind.INTEGER_LIT, "Expected SSA value tuple size" ) size = size_token.kind.get_int_value(size_token.span) return (value_token.span, size) @@ -774,7 +778,7 @@ def _parse_op_result_list(self) -> list[tuple[Span, int]]: Each result is a tuple of the span of the SSA value name (including the `%`), and the size of the value tuple (by default 1). """ - if self._current_token.kind == Kind.PERCENT_IDENT: + if self._current_token.kind == MLIRTokenKind.PERCENT_IDENT: res = self.parse_comma_separated_list( self.Delimiter.NONE, self._parse_op_result, " in operation result list" ) @@ -890,7 +894,7 @@ def parse_optional_successor(self) -> Block | None: Parse a successor with format: successor ::= caret-id """ - block_token = self._parse_optional_token(Kind.CARET_IDENT) + block_token = self._parse_optional_token(MLIRTokenKind.CARET_IDENT) if block_token is None: return None name = block_token.text[1:] @@ -912,7 +916,7 @@ def parse_optional_successors(self) -> list[Block] | None: successor-list ::= `[` successor (`,` successor)* `]` successor ::= caret-id """ - if self._current_token.kind != Kind.L_SQUARE: + if self._current_token.kind != MLIRTokenKind.L_SQUARE: return None return self.parse_successors() diff --git a/xdsl/printer.py b/xdsl/printer.py index 92b8d8d250..958c5ccf21 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -74,7 +74,7 @@ convert_f64_to_u64, ) from xdsl.utils.diagnostic import Diagnostic -from xdsl.utils.mlir_lexer import Lexer +from xdsl.utils.mlir_lexer import MLIRLexer indentNumSpaces = 2 @@ -447,7 +447,7 @@ def print_identifier_or_string_literal(self, string: str): Prints the provided string as an identifier if it is one, and as a string literal otherwise. """ - if Lexer.bare_identifier_regex.fullmatch(string) is None: + if MLIRLexer.bare_identifier_regex.fullmatch(string) is None: self.print_string_literal(string) return self.print_string(string) diff --git a/xdsl/utils/lexer.py b/xdsl/utils/lexer.py index c41f683496..5e6a4b3888 100644 --- a/xdsl/utils/lexer.py +++ b/xdsl/utils/lexer.py @@ -149,7 +149,7 @@ def __repr__(self): @dataclass -class GenericToken(Generic[TokenKindT]): +class Token(Generic[TokenKindT]): kind: TokenKindT span: Span @@ -161,7 +161,7 @@ def text(self): @dataclass -class AbstractLexer(Generic[TokenKindT], ABC): +class Lexer(Generic[TokenKindT], ABC): input: Input """Input that is currently being lexed.""" @@ -171,16 +171,14 @@ class AbstractLexer(Generic[TokenKindT], ABC): The position can be out of bounds, in which case the lexer is in EOF state. """ - def _form_token( - self, kind: TokenKindT, start_pos: Position - ) -> GenericToken[TokenKindT]: + def _form_token(self, kind: TokenKindT, start_pos: Position) -> Token[TokenKindT]: """ Return a token with the given kind, and the start position. """ - return GenericToken(kind, Span(start_pos, self.pos, self.input)) + return Token(kind, Span(start_pos, self.pos, self.input)) @abstractmethod - def lex(self) -> GenericToken[TokenKindT]: + def lex(self) -> Token[TokenKindT]: """ Lex a token from the input, and returns it. """ diff --git a/xdsl/utils/mlir_lexer.py b/xdsl/utils/mlir_lexer.py index 8552deb3eb..15561ff50f 100644 --- a/xdsl/utils/mlir_lexer.py +++ b/xdsl/utils/mlir_lexer.py @@ -7,7 +7,7 @@ from typing import Literal, TypeAlias, TypeGuard, cast, overload from xdsl.utils.exceptions import ParseError -from xdsl.utils.lexer import AbstractLexer, GenericToken, Position, Span +from xdsl.utils.lexer import Lexer, Position, Span, Token PunctuationSpelling: TypeAlias = Literal[ "->", @@ -87,7 +87,7 @@ def bytes_contents(self) -> bytes: return bytes(bytes_contents) -class Kind(Enum): +class MLIRTokenKind(Enum): # Markers EOF = object() @@ -134,50 +134,50 @@ class Kind(Enum): FILE_METADATA_END = "#-}" @staticmethod - def get_punctuation_spelling_to_kind_dict() -> dict[str, Kind]: + def get_punctuation_spelling_to_kind_dict() -> dict[str, MLIRTokenKind]: return { - "->": Kind.ARROW, - ":": Kind.COLON, - ",": Kind.COMMA, - "...": Kind.ELLIPSIS, - "=": Kind.EQUAL, - ">": Kind.GREATER, - "{": Kind.L_BRACE, - "(": Kind.L_PAREN, - "[": Kind.L_SQUARE, - "<": Kind.LESS, - "-": Kind.MINUS, - "+": Kind.PLUS, - "?": Kind.QUESTION, - "}": Kind.R_BRACE, - ")": Kind.R_PAREN, - "]": Kind.R_SQUARE, - "*": Kind.STAR, - "|": Kind.VERTICAL_BAR, - "{-#": Kind.FILE_METADATA_BEGIN, - "#-}": Kind.FILE_METADATA_END, + "->": MLIRTokenKind.ARROW, + ":": MLIRTokenKind.COLON, + ",": MLIRTokenKind.COMMA, + "...": MLIRTokenKind.ELLIPSIS, + "=": MLIRTokenKind.EQUAL, + ">": MLIRTokenKind.GREATER, + "{": MLIRTokenKind.L_BRACE, + "(": MLIRTokenKind.L_PAREN, + "[": MLIRTokenKind.L_SQUARE, + "<": MLIRTokenKind.LESS, + "-": MLIRTokenKind.MINUS, + "+": MLIRTokenKind.PLUS, + "?": MLIRTokenKind.QUESTION, + "}": MLIRTokenKind.R_BRACE, + ")": MLIRTokenKind.R_PAREN, + "]": MLIRTokenKind.R_SQUARE, + "*": MLIRTokenKind.STAR, + "|": MLIRTokenKind.VERTICAL_BAR, + "{-#": MLIRTokenKind.FILE_METADATA_BEGIN, + "#-}": MLIRTokenKind.FILE_METADATA_END, } def is_punctuation(self) -> bool: - punctuation_dict = Kind.get_punctuation_spelling_to_kind_dict() + punctuation_dict = MLIRTokenKind.get_punctuation_spelling_to_kind_dict() return self in punctuation_dict.values() @staticmethod def is_spelling_of_punctuation( spelling: str, ) -> TypeGuard[PunctuationSpelling]: - punctuation_dict = Kind.get_punctuation_spelling_to_kind_dict() + punctuation_dict = MLIRTokenKind.get_punctuation_spelling_to_kind_dict() return spelling in punctuation_dict.keys() @staticmethod def get_punctuation_kind_from_spelling( spelling: PunctuationSpelling, - ) -> Kind: - assert Kind.is_spelling_of_punctuation(spelling), ( + ) -> MLIRTokenKind: + assert MLIRTokenKind.is_spelling_of_punctuation(spelling), ( "Kind.get_punctuation_kind_from_spelling: spelling is not a " "valid punctuation spelling!" ) - return Kind.get_punctuation_spelling_to_kind_dict()[spelling] + return MLIRTokenKind.get_punctuation_spelling_to_kind_dict()[spelling] def get_int_value(self, span: Span): """ @@ -185,7 +185,7 @@ def get_int_value(self, span: Span): This function will raise an exception if the token is not an integer literal. """ - if self != Kind.INTEGER_LIT: + if self != MLIRTokenKind.INTEGER_LIT: raise ValueError("Token is not an integer literal!") if span.text[:2] in ["0x", "0X"]: return int(span.text, 16) @@ -197,7 +197,7 @@ def get_float_value(self, span: Span): This function will raise an exception if the token is not a float literal. """ - if self != Kind.FLOAT_LIT: + if self != MLIRTokenKind.FLOAT_LIT: raise ValueError("Token is not a float literal!") return float(span.text) @@ -208,16 +208,16 @@ def get_string_literal_value(self, span: Span) -> str: the string. This function will raise an exception if the token is not a string literal. """ - if self != Kind.STRING_LIT: + if self != MLIRTokenKind.STRING_LIT: raise ValueError("Token is not a string literal!") return StringLiteral.from_span(span).string_contents -Token = GenericToken[Kind] +MLIRToken = Token[MLIRTokenKind] @dataclass -class Lexer(AbstractLexer[Kind]): +class MLIRLexer(Lexer[MLIRTokenKind]): def _is_in_bounds(self, size: Position = 1) -> bool: """ Check if the current position is within the bounds of the input. @@ -265,13 +265,13 @@ def _consume_whitespace(self) -> None: """ self._consume_regex(self._whitespace_regex) - def _form_token(self, kind: Kind, start_pos: Position) -> Token: + def _form_token(self, kind: MLIRTokenKind, start_pos: Position) -> MLIRToken: """ Return a token with the given kind, and the start position. """ - return Token(kind, Span(start_pos, self.pos, self.input)) + return MLIRToken(kind, Span(start_pos, self.pos, self.input)) - def lex(self) -> Token: + def lex(self) -> MLIRToken: """ Lex a token from the input, and returns it. """ @@ -283,7 +283,7 @@ def lex(self) -> Token: # Handle end of file if current_char is None: - return self._form_token(Kind.EOF, start_pos) + return self._form_token(MLIRTokenKind.EOF, start_pos) # bare identifier if current_char.isalpha() or current_char == "_": @@ -291,20 +291,20 @@ def lex(self) -> Token: # single-char punctuation that are not part of a multi-char token single_char_punctuation = { - ":": Kind.COLON, - ",": Kind.COMMA, - "(": Kind.L_PAREN, - ")": Kind.R_PAREN, - "}": Kind.R_BRACE, - "[": Kind.L_SQUARE, - "]": Kind.R_SQUARE, - "<": Kind.LESS, - ">": Kind.GREATER, - "=": Kind.EQUAL, - "+": Kind.PLUS, - "*": Kind.STAR, - "?": Kind.QUESTION, - "|": Kind.VERTICAL_BAR, + ":": MLIRTokenKind.COLON, + ",": MLIRTokenKind.COMMA, + "(": MLIRTokenKind.L_PAREN, + ")": MLIRTokenKind.R_PAREN, + "}": MLIRTokenKind.R_BRACE, + "[": MLIRTokenKind.L_SQUARE, + "]": MLIRTokenKind.R_SQUARE, + "<": MLIRTokenKind.LESS, + ">": MLIRTokenKind.GREATER, + "=": MLIRTokenKind.EQUAL, + "+": MLIRTokenKind.PLUS, + "*": MLIRTokenKind.STAR, + "?": MLIRTokenKind.QUESTION, + "|": MLIRTokenKind.VERTICAL_BAR, } if current_char in single_char_punctuation: return self._form_token(single_char_punctuation[current_char], start_pos) @@ -316,26 +316,26 @@ def lex(self) -> Token: Span(start_pos, start_pos + 1, self.input), "Expected three consecutive '.' for an ellipsis", ) - return self._form_token(Kind.ELLIPSIS, start_pos) + return self._form_token(MLIRTokenKind.ELLIPSIS, start_pos) # '-' and '->' if current_char == "-": if self._peek_chars() == ">": self._consume_chars() - return self._form_token(Kind.ARROW, start_pos) - return self._form_token(Kind.MINUS, start_pos) + return self._form_token(MLIRTokenKind.ARROW, start_pos) + return self._form_token(MLIRTokenKind.MINUS, start_pos) # '{' and '{-#' if current_char == "{": if self._peek_chars(2) == "-#": self._consume_chars(2) - return self._form_token(Kind.FILE_METADATA_BEGIN, start_pos) - return self._form_token(Kind.L_BRACE, start_pos) + return self._form_token(MLIRTokenKind.FILE_METADATA_BEGIN, start_pos) + return self._form_token(MLIRTokenKind.L_BRACE, start_pos) # '#-}' if current_char == "#" and self._peek_chars(2) == "-}": self._consume_chars(2) - return self._form_token(Kind.FILE_METADATA_END, start_pos) + return self._form_token(MLIRTokenKind.FILE_METADATA_END, start_pos) # '@' and at-identifier if current_char == "@": @@ -360,7 +360,7 @@ def lex(self) -> Token: bare_identifier_regex = re.compile(r"[a-zA-Z_]" + IDENTIFIER_SUFFIX) bare_identifier_suffix_regex = re.compile(IDENTIFIER_SUFFIX) - def _lex_bare_identifier(self, start_pos: Position) -> Token: + def _lex_bare_identifier(self, start_pos: Position) -> MLIRToken: """ Lex a bare identifier with the following grammar: `bare-id ::= (letter|[_]) (letter|digit|[_$.])*` @@ -369,9 +369,9 @@ def _lex_bare_identifier(self, start_pos: Position) -> Token: """ self._consume_regex(self.bare_identifier_suffix_regex) - return self._form_token(Kind.BARE_IDENT, start_pos) + return self._form_token(MLIRTokenKind.BARE_IDENT, start_pos) - def _lex_at_ident(self, start_pos: Position) -> Token: + def _lex_at_ident(self, start_pos: Position) -> MLIRToken: """ Lex an at-identifier with the following grammar: `at-id ::= `@` (bare-id | string-literal)` @@ -389,12 +389,12 @@ def _lex_at_ident(self, start_pos: Position) -> Token: # bare identifier case if current_char.isalpha() or current_char == "_": token = self._lex_bare_identifier(start_pos) - return self._form_token(Kind.AT_IDENT, token.span.start) + return self._form_token(MLIRTokenKind.AT_IDENT, token.span.start) # literal string case if current_char == '"': token = self._lex_string_literal(start_pos) - return self._form_token(Kind.AT_IDENT, token.span.start) + return self._form_token(MLIRTokenKind.AT_IDENT, token.span.start) raise ParseError( Span(start_pos, self.pos, self.input), @@ -403,7 +403,7 @@ def _lex_at_ident(self, start_pos: Position) -> Token: _suffix_id = re.compile(r"([0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*)") - def _lex_prefixed_ident(self, start_pos: Position) -> Token: + def _lex_prefixed_ident(self, start_pos: Position) -> MLIRToken: """ Parsed the following prefixed identifiers: ``` @@ -424,16 +424,16 @@ def _lex_prefixed_ident(self, start_pos: Position) -> Token: ), "First prefixed identifier character must have been parsed" first_char = self.input.at(self.pos - 1) if first_char == "#": - kind = Kind.HASH_IDENT + kind = MLIRTokenKind.HASH_IDENT elif first_char == "!": - kind = Kind.EXCLAMATION_IDENT + kind = MLIRTokenKind.EXCLAMATION_IDENT elif first_char == "^": - kind = Kind.CARET_IDENT + kind = MLIRTokenKind.CARET_IDENT else: assert ( first_char == "%" ), "First prefixed identifier character must have been parsed correctly" - kind = Kind.PERCENT_IDENT + kind = MLIRTokenKind.PERCENT_IDENT match = self._consume_regex(self._suffix_id) if match is None: @@ -446,7 +446,7 @@ def _lex_prefixed_ident(self, start_pos: Position) -> Token: _unescaped_characters_regex = re.compile(r'[^"\\\n\v\f]*') - def _lex_string_literal(self, start_pos: Position) -> Token: + def _lex_string_literal(self, start_pos: Position) -> MLIRToken: """ Lex a string literal. The first character `"` is expected to have already been parsed. @@ -460,9 +460,9 @@ def _lex_string_literal(self, start_pos: Position) -> Token: # end of string literal if current_char == '"': if bytes_token: - return self._form_token(Kind.BYTES_LIT, start_pos) + return self._form_token(MLIRTokenKind.BYTES_LIT, start_pos) else: - return self._form_token(Kind.STRING_LIT, start_pos) + return self._form_token(MLIRTokenKind.STRING_LIT, start_pos) # newline character in string literal (not allowed) if current_char in ["\n", "\v", "\f"]: @@ -500,7 +500,7 @@ def _lex_string_literal(self, start_pos: Position) -> Token: _digits_star_regex = re.compile(r"[0-9]*") _fractional_suffix_regex = re.compile(r"\.[0-9]*([eE][+-]?[0-9]+)?") - def _lex_number(self, start_pos: Position) -> Token: + def _lex_number(self, start_pos: Position) -> MLIRToken: """ Lex a number literal, which is either a decimal or an hexadecimal. The first character is expected to have already been parsed. @@ -518,7 +518,7 @@ def _lex_number(self, start_pos: Position) -> Token: ): self._consume_chars(2) self._consume_regex(self._hexdigits_star_regex) - return self._form_token(Kind.INTEGER_LIT, start_pos) + return self._form_token(MLIRTokenKind.INTEGER_LIT, start_pos) # Decimal case self._consume_regex(self._digits_star_regex) @@ -526,5 +526,5 @@ def _lex_number(self, start_pos: Position) -> Token: # Check if we are lexing a floating point match = self._consume_regex(self._fractional_suffix_regex) if match is not None: - return self._form_token(Kind.FLOAT_LIT, start_pos) - return self._form_token(Kind.INTEGER_LIT, start_pos) + return self._form_token(MLIRTokenKind.FLOAT_LIT, start_pos) + return self._form_token(MLIRTokenKind.INTEGER_LIT, start_pos)