diff --git a/docs/Toy/toy/__main__.py b/docs/Toy/toy/__main__.py index 8ea2e45063..0626dedb44 100644 --- a/docs/Toy/toy/__main__.py +++ b/docs/Toy/toy/__main__.py @@ -22,7 +22,7 @@ ToyAcceleratorInstructionFunctions, ) from .frontend.ir_gen import IRGen -from .frontend.parser import Parser as ToyParser +from .frontend.parser import ToyParser as ToyParser from .interpreter import Interpreter, ToyFunctions parser = argparse.ArgumentParser(description="Process Toy file") diff --git a/docs/Toy/toy/compiler.py b/docs/Toy/toy/compiler.py index 4923ac27a1..4729dee117 100644 --- a/docs/Toy/toy/compiler.py +++ b/docs/Toy/toy/compiler.py @@ -35,7 +35,7 @@ from .dialects import toy from .frontend.ir_gen import IRGen -from .frontend.parser import Parser +from .frontend.parser import ToyParser from .rewrites.inline_toy import InlineToyPass from .rewrites.lower_toy_affine import LowerToAffinePass from .rewrites.shape_inference import ShapeInferencePass @@ -58,7 +58,7 @@ def context() -> MLContext: def parse_toy(program: str, ctx: MLContext | None = None) -> ModuleOp: mlir_gen = IRGen() - module_ast = Parser(Path("in_memory"), program).parseModule() + module_ast = ToyParser(Path("in_memory"), program).parseModule() module_op = mlir_gen.ir_gen_module(module_ast) return module_op diff --git a/docs/Toy/toy/frontend/lexer.py b/docs/Toy/toy/frontend/lexer.py index e33d7bd670..c814e9face 100644 --- a/docs/Toy/toy/frontend/lexer.py +++ b/docs/Toy/toy/frontend/lexer.py @@ -1,11 +1,10 @@ import re from enum import Enum, auto -from pathlib import Path from string import hexdigits from typing import TypeAlias, cast from xdsl.utils.exceptions import ParseError -from xdsl.utils.lexer import Input, Lexer, Position, Span, Token +from xdsl.utils.lexer import Lexer, Position, Span, Token class ToyTokenKind(Enum): @@ -179,18 +178,3 @@ def _lex_number(self, start_pos: Position) -> ToyToken: if match is not None: return self._form_token(ToyTokenKind.NUMBER, start_pos) return self._form_token(ToyTokenKind.NUMBER, start_pos) - - -def tokenize(file: Path, program: str | None = None): - if program is None: - with open(file) as f: - program = f.read() - - toy_lexer = ToyLexer(Input(program, str(file))) - - tokens = [toy_lexer.lex()] - - while tokens[-1].kind != ToyTokenKind.EOF: - tokens.append(toy_lexer.lex()) - - return tokens diff --git a/docs/Toy/toy/frontend/parser.py b/docs/Toy/toy/frontend/parser.py index 983ecda24e..058ac2bced 100644 --- a/docs/Toy/toy/frontend/parser.py +++ b/docs/Toy/toy/frontend/parser.py @@ -1,13 +1,10 @@ from pathlib import Path -from typing import NoReturn, cast +from typing import cast -from xdsl.utils.exceptions import ParseError +from xdsl.parser import GenericParser, ParserState +from xdsl.utils.lexer import Input -from .lexer import ( - ToyToken, - ToyTokenKind, - tokenize, -) +from .lexer import ToyLexer, ToyToken, ToyTokenKind from .location import loc from .toy_ast import ( BinaryExprAST, @@ -26,21 +23,13 @@ ) -class Parser: - file: Path - program: str - tokens: list[ToyToken] - pos: int - +class ToyParser(GenericParser[ToyTokenKind]): def __init__(self, file: Path, program: str): - self.file = file - self.program = program - self.tokens = tokenize(file, program) - self.pos = 0 + super().__init__(ParserState(ToyLexer(Input(program, str(file))))) def getToken(self): """Returns current token in parser""" - return self.tokens[self.pos] + return self._current_token def getTokenPrecedence(self) -> int: """Returns precedence if the current token is a binary operation, -1 otherwise""" @@ -49,7 +38,7 @@ def getTokenPrecedence(self) -> int: "+": 20, "*": 40, } - op = self.getToken().text + op = self._current_token.text return PRECEDENCE.get(op, -1) @@ -57,7 +46,7 @@ def peek(self, pattern: str | ToyTokenKind) -> ToyToken | None: """ Returns token matching pattern or None """ - token = self.getToken() + token = self._current_token if isinstance(pattern, str): if token.text == pattern: @@ -74,18 +63,16 @@ def check(self, pattern: str | ToyTokenKind) -> bool: return self.peek(pattern) is not None def pop(self) -> ToyToken: - self.pos += 1 - return self.tokens[self.pos - 1] + return self._consume_token() def pop_pattern(self, pattern: str) -> ToyToken: """ Verifies that the current token fits the pattern, raises ParseError otherwise """ - token = self.peek(pattern) - if token is None: - self.parseError(f"'{pattern}'") - self.pos += 1 + token = self._consume_token() + if token.text != pattern: + self.raise_error(f"Expected '{pattern}'", token.span.start, token.span.end) return token def pop_token(self, tokenType: ToyTokenKind) -> ToyToken: @@ -93,11 +80,7 @@ def pop_token(self, tokenType: ToyTokenKind) -> ToyToken: Verifies that the current token is of expected type, raises ParseError otherwise """ - token = self.peek(tokenType) - if token is None: - self.parseError(tokenType) - self.pos += 1 - return token + return self._consume_token(tokenType) def parseModule(self): """ @@ -157,7 +140,7 @@ def parseTensorLiteralExpr(self) -> LiteralExprAST | NumberExprAST: values.append(self.parseTensorLiteralExpr()) else: if not self.check(ToyTokenKind.NUMBER): - self.parseError(" or [", "in literal expression") + self.raise_error("Expected or [ in literal expression") values.append(self.parseNumberExpr()) # End of this list on ']' @@ -177,16 +160,16 @@ def parseTensorLiteralExpr(self) -> LiteralExprAST | NumberExprAST: if any(type(val) is LiteralExprAST for val in values): allTensors = all(type(val) is LiteralExprAST for val in values) if not allTensors: - self.parseError( - "uniform well-nested dimensions", "inside literal expression" + self.raise_error( + "Expected uniform well-nested dimensions inside literal expression" ) tensor_values = cast(list[LiteralExprAST], values) first = tensor_values[0].dims allEqual = all(val.dims == first for val in tensor_values) if not allEqual: - self.parseError( - "uniform well-nested dimensions", "inside literal expression" + self.raise_error( + "Expected uniform well-nested dimensions inside literal expression" ) dims += first @@ -224,7 +207,7 @@ def parseIdentifierExpr(self): if name.text == "print": # It can be a builtin call to print if len(args) != 1: - self.parseError("", "as argument to print()") + self.raise_error("Expected as argument to print()") return PrintExprAST(loc(name), args[0]) @@ -238,7 +221,7 @@ def parsePrimary(self) -> ExprAST | None: ::= parenexpr ::= tensorliteral """ - current = self.tokens[self.pos] + current = self._current_token if current.kind == ToyTokenKind.IDENTIFIER: return self.parseIdentifierExpr() elif current.kind == ToyTokenKind.NUMBER: @@ -252,7 +235,7 @@ def parsePrimary(self) -> ExprAST | None: elif current.text == "}": return None else: - self.parseError("expression or one of `;`, `}`") + self.raise_error("Expected expression or one of `;`, `}`") def parsePrimaryNotNone(self) -> ExprAST: """ @@ -262,7 +245,7 @@ def parsePrimaryNotNone(self) -> ExprAST: ::= parenexpr ::= tensorliteral """ - current = self.tokens[self.pos] + current = self._current_token if current.kind == ToyTokenKind.IDENTIFIER: return self.parseIdentifierExpr() elif current.kind == ToyTokenKind.NUMBER: @@ -272,7 +255,7 @@ def parsePrimaryNotNone(self) -> ExprAST: elif current.text == "[": return self.parseTensorLiteralExpr() else: - self.parseError("expression") + self.raise_error("Expected expression") def parseBinOpRHS(self, exprPrec: int, lhs: ExprAST) -> ExprAST: """ @@ -297,7 +280,7 @@ def parseBinOpRHS(self, exprPrec: int, lhs: ExprAST) -> ExprAST: rhs = self.parsePrimary() if rhs is None: - self.parseError("expression", "to complete binary operator") + self.raise_error("Expected expression to complete binary operator") # If BinOp binds less tightly with rhs than the operator after rhs, let # the pending operator take rhs as its lhs. @@ -323,8 +306,7 @@ def parseType(self): while token := self.pop_token(ToyTokenKind.NUMBER): shape.append(int(token.span.text)) - if self.check(">"): - self.pop() + if self.parse_optional_characters(">"): break self.pop_pattern(",") @@ -420,12 +402,3 @@ def parseDefinition(self): proto = self.parsePrototype() block = self.parseBlock() return FunctionAST(proto.loc, proto, block) - - def parseError(self, expected: str | ToyTokenKind, context: str = "") -> NoReturn: - """ - Helper function to signal errors while parsing, it takes an argument - indicating the expected token and another argument giving more context. - Location is retrieved from the lexer to enrich the error message. - """ - token = self.getToken() - raise ParseError(token.span, context) diff --git a/docs/Toy/toy/tests/test_ir_gen.py b/docs/Toy/toy/tests/test_ir_gen.py index 9423bbdf1b..0130e6c53a 100644 --- a/docs/Toy/toy/tests/test_ir_gen.py +++ b/docs/Toy/toy/tests/test_ir_gen.py @@ -6,14 +6,14 @@ from ..dialects import toy from ..frontend.ir_gen import IRGen -from ..frontend.parser import Parser +from ..frontend.parser import ToyParser def test_convert_ast(): ast_toy = Path("docs/Toy/examples/ast.toy") with open(ast_toy) as f: - parser = Parser(ast_toy, f.read()) + parser = ToyParser(ast_toy, f.read()) module_ast = parser.parseModule() @@ -63,7 +63,7 @@ def test_convert_scalar(): scalar_toy = Path("docs/Toy/examples/scalar.toy") with open(scalar_toy) as f: - parser = Parser(scalar_toy, f.read()) + parser = ToyParser(scalar_toy, f.read()) module_ast = parser.parseModule() diff --git a/docs/Toy/toy/tests/test_parser.py b/docs/Toy/toy/tests/test_parser.py index 478fb102ed..599ffd9896 100644 --- a/docs/Toy/toy/tests/test_parser.py +++ b/docs/Toy/toy/tests/test_parser.py @@ -2,8 +2,10 @@ import pytest +from xdsl.utils.exceptions import ParseError + from ..frontend.location import Location -from ..frontend.parser import ParseError, Parser +from ..frontend.parser import ToyParser from ..frontend.toy_ast import ( BinaryExprAST, CallExprAST, @@ -24,7 +26,7 @@ def test_parse_ast(): ast_toy = Path("docs/Toy/examples/ast.toy") with open(ast_toy) as f: - parser = Parser(ast_toy, f.read()) + parser = ToyParser(ast_toy, f.read()) parsed_module_ast = parser.parseModule() @@ -179,8 +181,8 @@ def loc(line: int, col: int) -> Location: def test_parse_error(): program = "def(" - parser = Parser(Path(), program) - with pytest.raises(ParseError): + parser = ToyParser(Path(), program) + with pytest.raises(ParseError, match="Expected expression"): parser.parseIdentifierExpr() @@ -188,7 +190,7 @@ def test_parse_scalar(): ast_toy = Path("docs/Toy/examples/scalar.toy") with open(ast_toy) as f: - parser = Parser(ast_toy, f.read()) + parser = ToyParser(ast_toy, f.read()) parsed_module_ast = parser.parseModule()