Skip to content

Commit

Permalink
documentation: (Toy) use generic parser in Toy
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Dec 8, 2024
1 parent 189b337 commit 2e83bc4
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 81 deletions.
2 changes: 1 addition & 1 deletion docs/Toy/toy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ToyAcceleratorInstructionFunctions,
)
from .frontend.ir_gen import IRGen
from .frontend.parser import Parser as ToyParser
from .frontend.parser import ToyParser as ToyParser
from .interpreter import Interpreter, ToyFunctions

parser = argparse.ArgumentParser(description="Process Toy file")
Expand Down
4 changes: 2 additions & 2 deletions docs/Toy/toy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from .dialects import toy
from .frontend.ir_gen import IRGen
from .frontend.parser import Parser
from .frontend.parser import ToyParser
from .rewrites.inline_toy import InlineToyPass
from .rewrites.lower_toy_affine import LowerToAffinePass
from .rewrites.shape_inference import ShapeInferencePass
Expand All @@ -58,7 +58,7 @@ def context() -> MLContext:

def parse_toy(program: str, ctx: MLContext | None = None) -> ModuleOp:
mlir_gen = IRGen()
module_ast = Parser(Path("in_memory"), program).parseModule()
module_ast = ToyParser(Path("in_memory"), program).parseModule()
module_op = mlir_gen.ir_gen_module(module_ast)
return module_op

Expand Down
18 changes: 1 addition & 17 deletions docs/Toy/toy/frontend/lexer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import re
from enum import Enum, auto
from pathlib import Path
from string import hexdigits
from typing import TypeAlias, cast

from xdsl.utils.exceptions import ParseError
from xdsl.utils.lexer import Input, Lexer, Position, Span, Token
from xdsl.utils.lexer import Lexer, Position, Span, Token


class ToyTokenKind(Enum):
Expand Down Expand Up @@ -179,18 +178,3 @@ def _lex_number(self, start_pos: Position) -> ToyToken:
if match is not None:
return self._form_token(ToyTokenKind.NUMBER, start_pos)
return self._form_token(ToyTokenKind.NUMBER, start_pos)


def tokenize(file: Path, program: str | None = None):
if program is None:
with open(file) as f:
program = f.read()

toy_lexer = ToyLexer(Input(program, str(file)))

tokens = [toy_lexer.lex()]

while tokens[-1].kind != ToyTokenKind.EOF:
tokens.append(toy_lexer.lex())

return tokens
79 changes: 26 additions & 53 deletions docs/Toy/toy/frontend/parser.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from pathlib import Path
from typing import NoReturn, cast
from typing import cast

from xdsl.utils.exceptions import ParseError
from xdsl.parser import GenericParser, ParserState
from xdsl.utils.lexer import Input

from .lexer import (
ToyToken,
ToyTokenKind,
tokenize,
)
from .lexer import ToyLexer, ToyToken, ToyTokenKind
from .location import loc
from .toy_ast import (
BinaryExprAST,
Expand All @@ -26,21 +23,13 @@
)


class Parser:
file: Path
program: str
tokens: list[ToyToken]
pos: int

class ToyParser(GenericParser[ToyTokenKind]):
def __init__(self, file: Path, program: str):
self.file = file
self.program = program
self.tokens = tokenize(file, program)
self.pos = 0
super().__init__(ParserState(ToyLexer(Input(program, str(file)))))

def getToken(self):
"""Returns current token in parser"""
return self.tokens[self.pos]
return self._current_token

def getTokenPrecedence(self) -> int:
"""Returns precedence if the current token is a binary operation, -1 otherwise"""
Expand All @@ -49,15 +38,15 @@ def getTokenPrecedence(self) -> int:
"+": 20,
"*": 40,
}
op = self.getToken().text
op = self._current_token.text

return PRECEDENCE.get(op, -1)

def peek(self, pattern: str | ToyTokenKind) -> ToyToken | None:
"""
Returns token matching pattern or None
"""
token = self.getToken()
token = self._current_token

if isinstance(pattern, str):
if token.text == pattern:
Expand All @@ -74,30 +63,24 @@ def check(self, pattern: str | ToyTokenKind) -> bool:
return self.peek(pattern) is not None

def pop(self) -> ToyToken:
self.pos += 1
return self.tokens[self.pos - 1]
return self._consume_token()

def pop_pattern(self, pattern: str) -> ToyToken:
"""
Verifies that the current token fits the pattern,
raises ParseError otherwise
"""
token = self.peek(pattern)
if token is None:
self.parseError(f"'{pattern}'")
self.pos += 1
token = self._consume_token()
if token.text != pattern:
self.raise_error(f"Expected '{pattern}'", token.span.start, token.span.end)
return token

def pop_token(self, tokenType: ToyTokenKind) -> ToyToken:
"""
Verifies that the current token is of expected type,
raises ParseError otherwise
"""
token = self.peek(tokenType)
if token is None:
self.parseError(tokenType)
self.pos += 1
return token
return self._consume_token(tokenType)

def parseModule(self):
"""
Expand Down Expand Up @@ -157,7 +140,7 @@ def parseTensorLiteralExpr(self) -> LiteralExprAST | NumberExprAST:
values.append(self.parseTensorLiteralExpr())
else:
if not self.check(ToyTokenKind.NUMBER):
self.parseError("<num> or [", "in literal expression")
self.raise_error("Expected <num> or [ in literal expression")
values.append(self.parseNumberExpr())

# End of this list on ']'
Expand All @@ -177,16 +160,16 @@ def parseTensorLiteralExpr(self) -> LiteralExprAST | NumberExprAST:
if any(type(val) is LiteralExprAST for val in values):
allTensors = all(type(val) is LiteralExprAST for val in values)
if not allTensors:
self.parseError(
"uniform well-nested dimensions", "inside literal expression"
self.raise_error(
"Expected uniform well-nested dimensions inside literal expression"
)

tensor_values = cast(list[LiteralExprAST], values)
first = tensor_values[0].dims
allEqual = all(val.dims == first for val in tensor_values)
if not allEqual:
self.parseError(
"uniform well-nested dimensions", "inside literal expression"
self.raise_error(
"Expected uniform well-nested dimensions inside literal expression"
)

dims += first
Expand Down Expand Up @@ -224,7 +207,7 @@ def parseIdentifierExpr(self):
if name.text == "print":
# It can be a builtin call to print
if len(args) != 1:
self.parseError("<single arg>", "as argument to print()")
self.raise_error("Expected <single arg> as argument to print()")

return PrintExprAST(loc(name), args[0])

Expand All @@ -238,7 +221,7 @@ def parsePrimary(self) -> ExprAST | None:
::= parenexpr
::= tensorliteral
"""
current = self.tokens[self.pos]
current = self._current_token
if current.kind == ToyTokenKind.IDENTIFIER:
return self.parseIdentifierExpr()
elif current.kind == ToyTokenKind.NUMBER:
Expand All @@ -252,7 +235,7 @@ def parsePrimary(self) -> ExprAST | None:
elif current.text == "}":
return None
else:
self.parseError("expression or one of `;`, `}`")
self.raise_error("Expected expression or one of `;`, `}`")

def parsePrimaryNotNone(self) -> ExprAST:
"""
Expand All @@ -262,7 +245,7 @@ def parsePrimaryNotNone(self) -> ExprAST:
::= parenexpr
::= tensorliteral
"""
current = self.tokens[self.pos]
current = self._current_token
if current.kind == ToyTokenKind.IDENTIFIER:
return self.parseIdentifierExpr()
elif current.kind == ToyTokenKind.NUMBER:
Expand All @@ -272,7 +255,7 @@ def parsePrimaryNotNone(self) -> ExprAST:
elif current.text == "[":
return self.parseTensorLiteralExpr()
else:
self.parseError("expression")
self.raise_error("Expected expression")

def parseBinOpRHS(self, exprPrec: int, lhs: ExprAST) -> ExprAST:
"""
Expand All @@ -297,7 +280,7 @@ def parseBinOpRHS(self, exprPrec: int, lhs: ExprAST) -> ExprAST:
rhs = self.parsePrimary()

if rhs is None:
self.parseError("expression", "to complete binary operator")
self.raise_error("Expected expression to complete binary operator")

# If BinOp binds less tightly with rhs than the operator after rhs, let
# the pending operator take rhs as its lhs.
Expand All @@ -323,8 +306,7 @@ def parseType(self):

while token := self.pop_token(ToyTokenKind.NUMBER):
shape.append(int(token.span.text))
if self.check(">"):
self.pop()
if self.parse_optional_characters(">"):
break
self.pop_pattern(",")

Expand Down Expand Up @@ -420,12 +402,3 @@ def parseDefinition(self):
proto = self.parsePrototype()
block = self.parseBlock()
return FunctionAST(proto.loc, proto, block)

def parseError(self, expected: str | ToyTokenKind, context: str = "") -> NoReturn:
"""
Helper function to signal errors while parsing, it takes an argument
indicating the expected token and another argument giving more context.
Location is retrieved from the lexer to enrich the error message.
"""
token = self.getToken()
raise ParseError(token.span, context)
6 changes: 3 additions & 3 deletions docs/Toy/toy/tests/test_ir_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from ..dialects import toy
from ..frontend.ir_gen import IRGen
from ..frontend.parser import Parser
from ..frontend.parser import ToyParser


def test_convert_ast():
ast_toy = Path("docs/Toy/examples/ast.toy")

with open(ast_toy) as f:
parser = Parser(ast_toy, f.read())
parser = ToyParser(ast_toy, f.read())

module_ast = parser.parseModule()

Expand Down Expand Up @@ -63,7 +63,7 @@ def test_convert_scalar():
scalar_toy = Path("docs/Toy/examples/scalar.toy")

with open(scalar_toy) as f:
parser = Parser(scalar_toy, f.read())
parser = ToyParser(scalar_toy, f.read())

module_ast = parser.parseModule()

Expand Down
12 changes: 7 additions & 5 deletions docs/Toy/toy/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import pytest

from xdsl.utils.exceptions import ParseError

from ..frontend.location import Location
from ..frontend.parser import ParseError, Parser
from ..frontend.parser import ToyParser
from ..frontend.toy_ast import (
BinaryExprAST,
CallExprAST,
Expand All @@ -24,7 +26,7 @@ def test_parse_ast():
ast_toy = Path("docs/Toy/examples/ast.toy")

with open(ast_toy) as f:
parser = Parser(ast_toy, f.read())
parser = ToyParser(ast_toy, f.read())

parsed_module_ast = parser.parseModule()

Expand Down Expand Up @@ -179,16 +181,16 @@ def loc(line: int, col: int) -> Location:

def test_parse_error():
program = "def("
parser = Parser(Path(), program)
with pytest.raises(ParseError):
parser = ToyParser(Path(), program)
with pytest.raises(ParseError, match="Expected expression"):
parser.parseIdentifierExpr()


def test_parse_scalar():
ast_toy = Path("docs/Toy/examples/scalar.toy")

with open(ast_toy) as f:
parser = Parser(ast_toy, f.read())
parser = ToyParser(ast_toy, f.read())

parsed_module_ast = parser.parseModule()

Expand Down

0 comments on commit 2e83bc4

Please sign in to comment.