diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d5d90d1 --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +.PHONY: clean watch watch-grammar + +grammar: grammar/ev_lex.g4 grammar/ev_parse.g4 + antlr -Dlanguage=Python3 grammar/ev_lex.g4 grammar/ev_parse.g4 \ + -o evlang/antlr -lib grammar -Xexact-output-dir + +watch-grammar: + make watch WATCHMAKE=grammar + +watch: + while true; do \ + make $(WATCHMAKE); \ + fswatch -1 grammar/ev_lex.g4 grammar/ev_parse.g4; \ + done + + +clean: + rm -rf build diff --git a/evlang/__init__.py b/evlang/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evlang/ast.py b/evlang/ast.py new file mode 100644 index 0000000..dd066ff --- /dev/null +++ b/evlang/ast.py @@ -0,0 +1,99 @@ +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +from typing import ( + List, +) + + +@dataclass +class SourceInfo: + startpos: int + endpos: int + + +@dataclass +class Literal: + class Kind(Enum): + STRING = "string" + DECIMAL_INT = "decimal_int" + HEX_INT = "hex_int" + BIN_INT = "bin_int" + DECIMAL_FLOAT = "decimal_float" + IDENTIFIER = "identifier" + + srcinfo: SourceInfo + kind: Kind + text: str + + def is_int(self) -> bool: + return self.kind in ( + Literal.Kind.DECIMAL_INT, + Literal.Kind.HEX_INT, + Literal.Kind.BIN_INT, + ) + + @staticmethod + def String(srcinfo: SourceInfo, text: str) -> Literal: + return Literal(srcinfo, Literal.Kind.STRING, text) + + @staticmethod + def DecimalInt(srcinfo: SourceInfo, text: str) -> Literal: + return Literal(srcinfo, Literal.Kind.DECIMAL_INT, text) + + @staticmethod + def HexInt(srcinfo: SourceInfo, text: str) -> Literal: + return Literal(srcinfo, Literal.Kind.HEX_INT, text) + + @staticmethod + def BinInt(srcinfo: SourceInfo, text: str) -> Literal: + return Literal(srcinfo, Literal.Kind.BIN_INT, text) + + @staticmethod + def DecimalFloat(srcinfo: SourceInfo, text: str) -> Literal: + return Literal(srcinfo, Literal.Kind.DECIMAL_FLOAT, text) + + @staticmethod + def Identifier(srcinfo: SourceInfo, text: str) -> Literal: + return Literal(srcinfo, Literal.Kind.IDENTIFIER, text) + + +@dataclass +class Operator: + srcinfo: SourceInfo + text: str + + +@dataclass +class Unary: + srcinfo: SourceInfo + operator: Operator + operand: Operand + + +Operand = Literal | Unary + + +@dataclass +class Call: + srcinfo: SourceInfo + expression: Expression + arguments: List[Expression] + + +@dataclass +class LabeledStatement: + srcinfo: SourceInfo + label: str + statement: SimpleStatement + + +Expression = Operand | Unary | Call +SimpleStatement = Expression +Statement = SimpleStatement | LabeledStatement + + +@dataclass +class Program: + srcinfo: SourceInfo + statements: List[Statement] diff --git a/evlang/parse.py b/evlang/parse.py new file mode 100644 index 0000000..5467151 --- /dev/null +++ b/evlang/parse.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import antlr4 +from antlr4.tree.Tree import INVALID_INTERVAL + +from evlang.antlr.ev_lex import ev_lex +from evlang.antlr.ev_parse import ev_parse +from evlang.antlr.ev_parseListener import ev_parseListener +from evlang.partial import partial, snake +from evlang.ast import ( + Call, + Expression, + LabeledStatement, + Literal, + Operator, + Program, + SourceInfo, + Statement, + Unary, +) + + +class Listener(ev_parseListener): + def __init__(self): + self.partial_program = partial(Program) + self.stack = [self.partial_program] + + def getsrcinfo(self, ctx): + interval = ctx.getSourceInterval() + if interval is INVALID_INTERVAL: + return None + return SourceInfo(startpos=interval[0], endpos=interval[1]) + + def push(self, args): + self.stack.append(args) + + def call(self, *args, **kwargs): + i = 0 + while self.stack and args or len(self.stack) >= 2 and i == 0: + try: + ret = self.stack[-1](*args, **kwargs) + except StopIteration: + if len(self.stack) == 1: + return + ret = self.stack[-1]() + else: + if ( + i != 0 + or ret == self.stack[-1] + or ret == None + or len(self.stack) == 1 + ): + return + self.stack.pop() + args, kwargs = (ret,), {} + i += 1 + + def enterProgram(self, ctx: ev_parse.ProgramContext): + self.call(self.getsrcinfo(ctx)) + + def enterStatementList(self, _): + self.push(snake(Statement)) + + def exitStatementList(self, _): + self.call() + + def enterLabeledStatement(self, ctx: ev_parse.LabeledStatementContext): + self.push( + partial(LabeledStatement)( + self.getsrcinfo(ctx), + ctx.label().IDENTIFIER().getText(), + ) + ) + + def enterOperatorUnary(self, ctx: ev_parse.OperatorUnaryContext): + self.push( + partial(Unary)( + self.getsrcinfo(ctx), + Operator( + self.getsrcinfo(ctx.OPERATOR()), + ctx.OPERATOR().getText(), + ), + ) + ) + + def enterExpressionList(self, _): + self.push(snake(Expression)) + + def exitExpressionList(self, _): + self.call() + + def enterCallExpression(self, ctx: ev_parse.CallExpressionContext): + self.push(partial(Call)(self.getsrcinfo(ctx))) + + def enterIdentifierOperand(self, ctx: ev_parse.IdentifierOperandContext): + self.call(Literal.Identifier(self.getsrcinfo(ctx), ctx.IDENTIFIER().getText())) + + def enterLiteral(self, ctx: ev_parse.LiteralContext): + self.call( + { + ev_parse.STRING_LIT: Literal.String, + ev_parse.DECIMAL_INT_LIT: Literal.DecimalInt, + ev_parse.HEX_INT_LIT: Literal.HexInt, + ev_parse.BIN_INT_LIT: Literal.BinInt, + ev_parse.DECIMAL_FLOAT_LIT: Literal.DecimalFloat, + }[ctx.children[0].symbol.type]( + self.getsrcinfo(ctx), + ctx.children[0].getText(), + ) + ) + + +def parse(src: str): + input_stream = antlr4.InputStream(src) + lexer = ev_lex(input_stream) + stream = antlr4.CommonTokenStream(lexer) + parser = ev_parse(stream) + tree = parser.program() + listener = Listener() + walker = antlr4.ParseTreeWalker() + walker.walk(listener, tree) + return listener.partial_program() diff --git a/evlang/partial.py b/evlang/partial.py new file mode 100644 index 0000000..13cfa06 --- /dev/null +++ b/evlang/partial.py @@ -0,0 +1,80 @@ +from __future__ import annotations +from typing import Callable, Generic, List, Type, TypeVar, overload +from inspect import Signature + +F = TypeVar("F", bound=Callable) +T = TypeVar("T") + + +class partial(Generic[F]): + def __init__(self, func: F, eager: bool = True, callcount: int | None = None): + self._sig = Signature.from_callable(func) + self._binding = self._sig.bind_partial() + self._emptycall = False + self._kwonly = False + self._func = func + self._eager = eager + self._callcount = callcount + self._calls = 0 + + @overload + def __call__(self: partial[Callable[..., T]]) -> T: + pass + + @overload + def __call__( + self: partial[Callable[..., T]], *args, **kwargs + ) -> T | partial[Callable[..., T]] | None: + pass + + def __call__( + self: partial[Callable[..., T]], *args, **kwargs + ) -> T | partial[Callable[..., T]] | None: + if args and self._kwonly: + raise TypeError("positional argument following keyword argument") + if kwargs: + self._kwonly = True + if not args and not kwargs: + if self._callcount is not None and self._calls < self._callcount: + raise TypeError("empty call before call count reached") + if self._emptycall: + raise TypeError("more than one empty call") + self._emptycall = True + binding = self._sig.bind(*self._binding.args, **self._binding.kwargs) + return self._func(*binding.args, **binding.kwargs) + if self._emptycall: + raise TypeError("cannot apply arguments after empty call") + if self._callcount is not None and self._calls == self._callcount: + raise TypeError("cannot apply arguments after empty call") + self._calls += 1 + partial_binding = self._sig.bind_partial( + *(self._binding.args + args), **({} | self._binding.kwargs | kwargs) + ) + try: + binding = self._sig.bind( + *(self._binding.args + args), **({} | self._binding.kwargs | kwargs) + ) + except TypeError: + self._binding = partial_binding + else: + self._binding = binding + if self._eager: + raise StopIteration() + if self._callcount is not None and self._calls == self._callcount: + raise StopIteration() + return self + + +E = TypeVar("E") + + +def snake(_: Type[E], capacity=None) -> partial[Callable[..., List[E]]]: + def args(*args: E) -> List[E]: + return list(args) + + return partial(args, eager=False, callcount=capacity) + + +s = snake(int) +s1 = s(1) +s1 = s() diff --git a/evlang/partial_test.py b/evlang/partial_test.py new file mode 100644 index 0000000..d9a19ab --- /dev/null +++ b/evlang/partial_test.py @@ -0,0 +1,55 @@ +from pytest import raises, mark +from evlang.partial import snake, partial +from evlang.ast import Literal, SourceInfo + + +@mark.parametrize( + ("stack", "f", "expected"), + [ + ( + partial(Literal), + lambda lit: lit(SourceInfo(0, 0))(Literal.Kind.STRING)("foo"), + Literal(SourceInfo(0, 0), Literal.Kind.STRING, "foo"), + ), + ( + partial(Literal), + lambda lit: lit(text="foo")(srcinfo=SourceInfo(0, 0))( + kind=Literal.Kind.STRING + ), + Literal(SourceInfo(0, 0), Literal.Kind.STRING, "foo"), + ), + ( + partial(Literal), + lambda lit: lit(text="foo")(text="foo")(srcinfo=SourceInfo(0, 0))( + kind=Literal.Kind.STRING + ), + Literal(SourceInfo(0, 0), Literal.Kind.STRING, "foo"), + ), + (snake(int, 3), lambda snek: snek(1)(2)(3), [1, 2, 3]), + ], +) +def test_stacks(stack, f, expected): + with raises(StopIteration): + f(stack) + assert expected == stack() + + +@mark.parametrize( + ("stack", "f", "ecls"), + [ + ( + partial(Literal), + lambda lit: lit(SourceInfo(0, 0))(kind=Literal.Kind.STRING)("foo"), + TypeError, + ), + ( + partial(Literal), + lambda lit: lit(), + TypeError, + ), + (snake(int, capacity=3), lambda snek: snek(1)(), TypeError), + ], +) +def test_stacks_raises(stack, f, ecls): + with raises(ecls): + f(stack) diff --git a/evlang/test_parse.py b/evlang/test_parse.py new file mode 100644 index 0000000..2ee0022 --- /dev/null +++ b/evlang/test_parse.py @@ -0,0 +1,149 @@ +from evlang.ast import ( + Call, + LabeledStatement, + Literal, + Operator, + Program, + SourceInfo, + Unary, +) +from evlang.parse import parse +import orjson +from pytest import mark, fixture +from pytest_mock import MockerFixture + +SI = SourceInfo(0, 0) + + +@fixture +def mock_sourceinfo(mocker: MockerFixture): + mocker.patch("evlang.parse.Listener.getsrcinfo", return_value=SI) + return mocker + + +@mark.parametrize( + ("src", "expected"), + [ + ( + "hello: world()", + Program( + SI, + [ + LabeledStatement( + SI, + "hello", + Call( + SI, + Literal.Identifier(SI, "world"), + [], + ), + ) + ], + ), + ), + ( + """ + world(bar(), foo())() + """, + Program( + SI, + [ + Call( + SI, + Call( + SI, + Literal.Identifier(SI, "world"), + [ + Call(SI, Literal.Identifier(SI, "bar"), []), + Call(SI, Literal.Identifier(SI, "foo"), []), + ], + ), + [], + ), + ], + ), + ), + ( + """ + 0x100001 + 0b0001 + -2.999e-100 + +1234 + -0xFFF + """, + Program( + SI, + [ + Literal.HexInt(SI, "0x100001"), + Literal.BinInt(SI, "0b0001"), + Unary( + SI, Operator(SI, "-"), Literal.DecimalFloat(SI, "2.999e-100") + ), + Unary(SI, Operator(SI, "+"), Literal.DecimalInt(SI, "1234")), + Unary(SI, Operator(SI, "-"), Literal.HexInt(SI, "0xFFF")), + ], + ), + ), + ( + """ + ev_z80r0101_obj00: + _TALKMSG('dp_scenario2%00-some-locale-label') + _END() + ev_z80r0101_obj01: + _GET_POKETCH() + _ADD_POKEMON(181, 50, 0, @247) + """, + Program( + SI, + [ + LabeledStatement( + SI, + "ev_z80r0101_obj00", + Call( + SI, + Literal.Identifier( + SI, + "_TALKMSG", + ), + [Literal.String(SI, "'dp_scenario2%00-some-locale-label'")], + ), + ), + Call( + SI, + Literal.Identifier(SI, "_END"), + [], + ), + LabeledStatement( + SI, + "ev_z80r0101_obj01", + Call( + SI, + Literal.Identifier( + SI, + "_GET_POKETCH", + ), + [], + ), + ), + Call( + SI, + Literal.Identifier(SI, "_ADD_POKEMON"), + [ + Literal.DecimalInt(SI, "181"), + Literal.DecimalInt(SI, "50"), + Literal.DecimalInt(SI, "0"), + Unary( + SI, + Operator(SI, "@"), + Literal.DecimalInt(SI, "247"), + ), + ], + ), + ], + ), + ), + ], +) +def test_parse(mock_sourceinfo, src: str, expected: Program): + program = parse(src) + assert expected == program diff --git a/grammar/ev_lex.g4 b/grammar/ev_lex.g4 new file mode 100644 index 0000000..cbcd8ee --- /dev/null +++ b/grammar/ev_lex.g4 @@ -0,0 +1,32 @@ +lexer grammar ev_lex; + +channels { + COMMENTS +} + +IDENTIFIER: '_'? [a-zA-Z] [_a-zA-Z0-9]*; + +STRING_LIT: '\u0027' ~'\u0027'* '\u0027'; + +HEX_INT_LIT: '0x' [0-9a-fA-F]+; +BIN_INT_LIT: '0b' [01]+; +DECIMAL_INT_LIT: [0-9]+ DECIMAL_EXPONENT?; +DECIMAL_EXPONENT: [eE] [0-9]+; + +DECIMAL_FLOAT_LIT: [0-9]+ '.' [0-9]* (DECIMAL_FLOAT_EXPONENT)?; +DECIMAL_FLOAT_EXPONENT: [eE] [+-]? [0-9]+; + +OPERATOR: PLUS | MINUS | AT | HASHTAG | DOLLAR | EXCLAM; + +COLON: ':'; +COMMA: ','; +PAREN_LEFT: '('; +PAREN_RIGHT: ')'; +EOL: [\r\n]+; +WS: [ \t] -> skip; +PLUS: '+'; +MINUS: '-'; +AT: '@'; +HASHTAG: '#'; +DOLLAR: '$'; +EXCLAM: '!'; \ No newline at end of file diff --git a/grammar/ev_parse.g4 b/grammar/ev_parse.g4 new file mode 100644 index 0000000..333432b --- /dev/null +++ b/grammar/ev_parse.g4 @@ -0,0 +1,28 @@ +parser grammar ev_parse; + +options { + tokenVocab = ev_lex; +} + +program: EOL* statementList EOL* EOF; +statementList: statement (EOL statement)* |; +statement: + expression # ExpressionStatement + | label expression # LabeledStatement; +expression: + operand # OperandExpression + | unary # UnaryExpression + | expression arguments # CallExpression; +unary: operand # OperandUnary | OPERATOR unary # OperatorUnary; +operand: + IDENTIFIER # IdentifierOperand + | literal # LiteralOperand; +expressionList: expression (COMMA expression)* |; +arguments: PAREN_LEFT expressionList? PAREN_RIGHT; +label: IDENTIFIER COLON EOL?; +literal: + STRING_LIT + | DECIMAL_INT_LIT + | HEX_INT_LIT + | BIN_INT_LIT + | DECIMAL_FLOAT_LIT; \ No newline at end of file