Skip to content

Commit

Permalink
core: rename MLIR-specific lexing constructs (#3592)
Browse files Browse the repository at this point in the history
A cleanup PR renaming some things, I did it in two steps to make the
diff smaller on the previous PR.
  • Loading branch information
superlopuh authored Dec 9, 2024
1 parent e84955d commit 09f495a
Show file tree
Hide file tree
Showing 12 changed files with 297 additions and 279 deletions.
6 changes: 3 additions & 3 deletions bench/parser/bench_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
from collections.abc import Iterable

from xdsl.utils.lexer import Input
from xdsl.utils.mlir_lexer import Kind, Lexer
from xdsl.utils.mlir_lexer import MLIRLexer, MLIRTokenKind


def lex_file(file: Input):
"""
Lex the given file
"""
lexer = Lexer(file)
while lexer.lex().kind is not Kind.EOF:
lexer = MLIRLexer(file)
while lexer.lex().kind is not MLIRTokenKind.EOF:
pass


Expand Down
92 changes: 46 additions & 46 deletions tests/test_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

from xdsl.utils.exceptions import ParseError
from xdsl.utils.lexer import Input
from xdsl.utils.mlir_lexer import Kind, Lexer, Token
from xdsl.utils.mlir_lexer import MLIRLexer, MLIRToken, MLIRTokenKind


def get_token(input: str) -> Token:
def get_token(input: str) -> MLIRToken:
file = Input(input, "<unknown>")
lexer = Lexer(file)
lexer = MLIRLexer(file)
token = lexer.lex()
return token


def assert_single_token(
input: str, expected_kind: Kind, expected_text: str | None = None
input: str, expected_kind: MLIRTokenKind, expected_text: str | None = None
):
if expected_text is None:
expected_text = input
Expand All @@ -26,37 +26,37 @@ def assert_single_token(

def assert_token_fail(input: str):
file = Input(input, "<unknown>")
lexer = Lexer(file)
lexer = MLIRLexer(file)
with pytest.raises(ParseError):
lexer.lex()


@pytest.mark.parametrize(
"text,kind",
[
("->", Kind.ARROW),
(":", Kind.COLON),
(",", Kind.COMMA),
("...", Kind.ELLIPSIS),
("=", Kind.EQUAL),
(">", Kind.GREATER),
("{", Kind.L_BRACE),
("(", Kind.L_PAREN),
("[", Kind.L_SQUARE),
("<", Kind.LESS),
("-", Kind.MINUS),
("+", Kind.PLUS),
("?", Kind.QUESTION),
("}", Kind.R_BRACE),
(")", Kind.R_PAREN),
("]", Kind.R_SQUARE),
("*", Kind.STAR),
("|", Kind.VERTICAL_BAR),
("{-#", Kind.FILE_METADATA_BEGIN),
("#-}", Kind.FILE_METADATA_END),
("->", MLIRTokenKind.ARROW),
(":", MLIRTokenKind.COLON),
(",", MLIRTokenKind.COMMA),
("...", MLIRTokenKind.ELLIPSIS),
("=", MLIRTokenKind.EQUAL),
(">", MLIRTokenKind.GREATER),
("{", MLIRTokenKind.L_BRACE),
("(", MLIRTokenKind.L_PAREN),
("[", MLIRTokenKind.L_SQUARE),
("<", MLIRTokenKind.LESS),
("-", MLIRTokenKind.MINUS),
("+", MLIRTokenKind.PLUS),
("?", MLIRTokenKind.QUESTION),
("}", MLIRTokenKind.R_BRACE),
(")", MLIRTokenKind.R_PAREN),
("]", MLIRTokenKind.R_SQUARE),
("*", MLIRTokenKind.STAR),
("|", MLIRTokenKind.VERTICAL_BAR),
("{-#", MLIRTokenKind.FILE_METADATA_BEGIN),
("#-}", MLIRTokenKind.FILE_METADATA_END),
],
)
def test_punctuation(text: str, kind: Kind):
def test_punctuation(text: str, kind: MLIRTokenKind):
assert_single_token(text, kind)


Expand All @@ -69,7 +69,7 @@ def test_punctuation_fail(text: str):
"text", ['""', '"@"', '"foo"', '"\\""', '"\\n"', '"\\\\"', '"\\t"']
)
def test_str_literal(text: str):
assert_single_token(text, Kind.STRING_LIT)
assert_single_token(text, MLIRTokenKind.STRING_LIT)


@pytest.mark.parametrize("text", ['"', '"\\"', '"\\a"', '"\n"', '"\v"', '"\f"'])
Expand All @@ -82,7 +82,7 @@ def test_str_literal_fail(text: str):
)
def test_bare_ident(text: str):
"""bare-id ::= (letter|[_]) (letter|digit|[_$.])*"""
assert_single_token(text, Kind.BARE_IDENT)
assert_single_token(text, MLIRTokenKind.BARE_IDENT)


@pytest.mark.parametrize(
Expand All @@ -109,7 +109,7 @@ def test_bare_ident(text: str):
)
def test_at_ident(text: str):
"""at-ident ::= `@` (bare-id | string-literal)"""
assert_single_token(text, Kind.AT_IDENT)
assert_single_token(text, MLIRTokenKind.AT_IDENT)


@pytest.mark.parametrize(
Expand All @@ -129,10 +129,10 @@ def test_prefixed_ident(text: str):
"""percent-ident ::= `%` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)"""
"""caret-ident ::= `^` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)"""
"""exclamation-ident ::= `!` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)"""
assert_single_token("#" + text, Kind.HASH_IDENT)
assert_single_token("%" + text, Kind.PERCENT_IDENT)
assert_single_token("^" + text, Kind.CARET_IDENT)
assert_single_token("!" + text, Kind.EXCLAMATION_IDENT)
assert_single_token("#" + text, MLIRTokenKind.HASH_IDENT)
assert_single_token("%" + text, MLIRTokenKind.PERCENT_IDENT)
assert_single_token("^" + text, MLIRTokenKind.CARET_IDENT)
assert_single_token("!" + text, MLIRTokenKind.EXCLAMATION_IDENT)


@pytest.mark.parametrize("text", ["+", '""', "#", "%", "^", "!", "\n", ""])
Expand All @@ -155,46 +155,46 @@ def test_prefixed_ident_fail(text: str):
)
def test_prefixed_ident_split(text: str, expected: str):
"""Check that the prefixed identifier is split at the right character."""
assert_single_token("#" + text, Kind.HASH_IDENT, "#" + expected)
assert_single_token("%" + text, Kind.PERCENT_IDENT, "%" + expected)
assert_single_token("^" + text, Kind.CARET_IDENT, "^" + expected)
assert_single_token("!" + text, Kind.EXCLAMATION_IDENT, "!" + expected)
assert_single_token("#" + text, MLIRTokenKind.HASH_IDENT, "#" + expected)
assert_single_token("%" + text, MLIRTokenKind.PERCENT_IDENT, "%" + expected)
assert_single_token("^" + text, MLIRTokenKind.CARET_IDENT, "^" + expected)
assert_single_token("!" + text, MLIRTokenKind.EXCLAMATION_IDENT, "!" + expected)


@pytest.mark.parametrize("text", ["0", "01", "123456789", "99", "0x1234", "0xabcdef"])
def test_integer_literal(text: str):
assert_single_token(text, Kind.INTEGER_LIT)
assert_single_token(text, MLIRTokenKind.INTEGER_LIT)


@pytest.mark.parametrize(
"text,expected", [("0a", "0"), ("0xg", "0"), ("0xfg", "0xf"), ("0xf.", "0xf")]
)
def test_integer_literal_split(text: str, expected: str):
assert_single_token(text, Kind.INTEGER_LIT, expected)
assert_single_token(text, MLIRTokenKind.INTEGER_LIT, expected)


@pytest.mark.parametrize(
"text", ["0.", "1.", "0.2", "38.1243", "92.54e43", "92.5E43", "43.3e-54", "32.E+25"]
)
def test_float_literal(text: str):
assert_single_token(text, Kind.FLOAT_LIT)
assert_single_token(text, MLIRTokenKind.FLOAT_LIT)


@pytest.mark.parametrize(
"text,expected", [("3.9e", "3.9"), ("4.5e+", "4.5"), ("5.8e-", "5.8")]
)
def test_float_literal_split(text: str, expected: str):
assert_single_token(text, Kind.FLOAT_LIT, expected)
assert_single_token(text, MLIRTokenKind.FLOAT_LIT, expected)


@pytest.mark.parametrize("text", ["0", " 0", " 0", "\n0", "\t0", "// Comment\n0"])
def test_whitespace_skip(text: str):
assert_single_token(text, Kind.INTEGER_LIT, "0")
assert_single_token(text, MLIRTokenKind.INTEGER_LIT, "0")


@pytest.mark.parametrize("text", ["", " ", "\n\n", "// Comment\n"])
def test_eof(text: str):
assert_single_token(text, Kind.EOF, "")
assert_single_token(text, MLIRTokenKind.EOF, "")


@pytest.mark.parametrize(
Expand All @@ -209,7 +209,7 @@ def test_eof(text: str):
)
def test_token_get_int_value(text: str, expected: int):
token = get_token(text)
assert token.kind == Kind.INTEGER_LIT
assert token.kind == MLIRTokenKind.INTEGER_LIT
assert token.kind.get_int_value(token.span) == expected


Expand All @@ -228,7 +228,7 @@ def test_token_get_int_value(text: str, expected: int):
)
def test_token_get_float_value(text: str, expected: float):
token = get_token(text)
assert token.kind == Kind.FLOAT_LIT
assert token.kind == MLIRTokenKind.FLOAT_LIT
assert token.kind.get_float_value(token.span) == expected


Expand All @@ -246,5 +246,5 @@ def test_token_get_float_value(text: str, expected: float):
)
def test_token_get_string_literal_value(text: str, expected: float):
token = get_token(text)
assert token.kind == Kind.STRING_LIT
assert token.kind == MLIRTokenKind.STRING_LIT
assert token.kind.get_string_literal_value(token.span) == expected
37 changes: 20 additions & 17 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.utils.exceptions import ParseError, VerifyException
from xdsl.utils.mlir_lexer import Kind, PunctuationSpelling
from xdsl.utils.mlir_lexer import MLIRTokenKind, PunctuationSpelling
from xdsl.utils.str_enum import StrEnum

# pyright: reportPrivateUsage=false
Expand Down Expand Up @@ -584,51 +584,54 @@ def test_parse_comma_separated_list_error_delimiters(


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().values())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().values())
)
def test_is_punctuation_true(punctuation: Kind):
def test_is_punctuation_true(punctuation: MLIRTokenKind):
assert punctuation.is_punctuation()


@pytest.mark.parametrize("punctuation", [Kind.BARE_IDENT, Kind.EOF, Kind.INTEGER_LIT])
def test_is_punctuation_false(punctuation: Kind):
@pytest.mark.parametrize(
"punctuation",
[MLIRTokenKind.BARE_IDENT, MLIRTokenKind.EOF, MLIRTokenKind.INTEGER_LIT],
)
def test_is_punctuation_false(punctuation: MLIRTokenKind):
assert not punctuation.is_punctuation()


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().values())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().values())
)
def test_is_spelling_of_punctuation_true(punctuation: Kind):
def test_is_spelling_of_punctuation_true(punctuation: MLIRTokenKind):
value = cast(PunctuationSpelling, punctuation.value)
assert Kind.is_spelling_of_punctuation(value)
assert MLIRTokenKind.is_spelling_of_punctuation(value)


@pytest.mark.parametrize("punctuation", [">-", "o", "4", "$", "_", "@"])
def test_is_spelling_of_punctuation_false(punctuation: str):
assert not Kind.is_spelling_of_punctuation(punctuation)
assert not MLIRTokenKind.is_spelling_of_punctuation(punctuation)


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().values())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().values())
)
def test_get_punctuation_kind(punctuation: Kind):
def test_get_punctuation_kind(punctuation: MLIRTokenKind):
value = cast(PunctuationSpelling, punctuation.value)
assert punctuation.get_punctuation_kind_from_spelling(value) == punctuation


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_punctuation(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), punctuation)

res = parser.parse_punctuation(punctuation)
assert res == punctuation
assert parser._parse_token(Kind.EOF, "").kind == Kind.EOF
assert parser._parse_token(MLIRTokenKind.EOF, "").kind == MLIRTokenKind.EOF


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_punctuation_fail(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), "e +")
Expand All @@ -639,17 +642,17 @@ def test_parse_punctuation_fail(punctuation: PunctuationSpelling):


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_optional_punctuation(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), punctuation)
res = parser.parse_optional_punctuation(punctuation)
assert res == punctuation
assert parser._parse_token(Kind.EOF, "").kind == Kind.EOF
assert parser._parse_token(MLIRTokenKind.EOF, "").kind == MLIRTokenKind.EOF


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_optional_punctuation_fail(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), "e +")
Expand Down
4 changes: 2 additions & 2 deletions xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from xdsl.traits import IsTerminator, NoTerminator, OpTrait, OpTraitInvT
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.mlir_lexer import Lexer
from xdsl.utils.mlir_lexer import MLIRLexer
from xdsl.utils.str_enum import StrEnum

# Used for cyclic dependencies in type hints
Expand Down Expand Up @@ -404,7 +404,7 @@ def _check_enum_constraints(
raise TypeError("Only direct inheritance from EnumAttribute is allowed.")

for v in enum_type:
if Lexer.bare_identifier_suffix_regex.fullmatch(v) is None:
if MLIRLexer.bare_identifier_suffix_regex.fullmatch(v) is None:
raise ValueError(
"All StrEnum values of an EnumAttribute must be parsable as an identifer."
)
Expand Down
14 changes: 7 additions & 7 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@
)
from xdsl.parser import BaseParser, ParserState
from xdsl.utils.lexer import Input
from xdsl.utils.mlir_lexer import Kind, Lexer, Token
from xdsl.utils.mlir_lexer import MLIRLexer, MLIRToken, MLIRTokenKind


@dataclass
class FormatLexer(Lexer):
class FormatLexer(MLIRLexer):
"""
A lexer for the declarative assembly format.
The differences with the MLIR lexer are the following:
* It can parse '`' or '$' as tokens. The token will have the `BARE_IDENT` kind.
* Bare identifiers may also may contain `-`.
"""

def lex(self) -> Token:
def lex(self) -> MLIRToken:
"""Lex a token from the input, and returns it."""
# First, skip whitespaces
self._consume_whitespace()
Expand All @@ -102,13 +102,13 @@ def lex(self) -> Token:

# Handle end of file
if current_char is None:
return self._form_token(Kind.EOF, start_pos)
return self._form_token(MLIRTokenKind.EOF, start_pos)

# We parse '`', `\\` and '$' as a BARE_IDENT.
# This is a hack to reuse the MLIR lexer.
if current_char in ("`", "$", "\\", "^"):
self._consume_chars()
return self._form_token(Kind.BARE_IDENT, start_pos)
return self._form_token(MLIRTokenKind.BARE_IDENT, start_pos)
return super().lex()

# Authorize `-` in bare identifier
Expand Down Expand Up @@ -168,7 +168,7 @@ def parse_format(self) -> FormatProgram:
unambiguous and refer to all elements exactly once.
"""
elements: list[FormatDirective] = []
while self._current_token.kind != Kind.EOF:
while self._current_token.kind != MLIRTokenKind.EOF:
elements.append(self.parse_format_directive())

self.add_reserved_attrs_to_directive(elements)
Expand Down Expand Up @@ -717,7 +717,7 @@ def parse_keyword_or_punctuation(self) -> FormatDirective:
if self._current_token.kind.is_punctuation():
punctuation = self._consume_token().text
self.parse_characters("`")
assert Kind.is_spelling_of_punctuation(punctuation)
assert MLIRTokenKind.is_spelling_of_punctuation(punctuation)
return PunctuationDirective(punctuation)

# Identifier case
Expand Down
Loading

0 comments on commit 09f495a

Please sign in to comment.