diff --git a/src/larktools/ebnf_grammar.py b/src/larktools/ebnf_grammar.py index ac90bbc..a46a0ef 100644 --- a/src/larktools/ebnf_grammar.py +++ b/src/larktools/ebnf_grammar.py @@ -13,9 +13,8 @@ // // The top level rule at which the matching/expansion process starts is named "start". - start: assign_var + start: multi_line_block - assign_var: VARNAME "=" multi_line_block VARNAME: LETTER (LETTER | DIGIT)* @@ -25,9 +24,11 @@ // https://lark-parser.readthedocs.io/en/stable/tree_construction.html - line: arith_expr - multi_line_block: (line _NL? | _NL )* + line: arith_expr | assignment + + assignment: VARNAME "=" arith_expr + arith_expr: sum sum: product | addition | subtraction diff --git a/src/larktools/evaluation.py b/src/larktools/evaluation.py index c22c0d8..22b0c41 100644 --- a/src/larktools/evaluation.py +++ b/src/larktools/evaluation.py @@ -116,6 +116,8 @@ def eval_line(node, env): child_name = get_name(child) if child_name == "arith_expr": return eval_arith_expr(child, env) + elif child_name == "assignment": + return eval_assignment(child, env) def eval_multi_line_block(node, env): # this can be either an arithmetic expression or @@ -127,3 +129,15 @@ def eval_multi_line_block(node, env): res = eval_line(child, env) return res +def eval_assignment(node, env): + # assign result of an expression to a variable + child1, child2 = get_children(node)[0:2] + assert get_name(child1) == "VARNAME" + assert get_name(child2) == "arith_expr" + + varname = get_value(child1) + env[varname] = eval_arith_expr(child2, env) + return env[varname] + + + diff --git a/tests/test_syntax.py b/tests/test_syntax.py index b158516..d43753d 100644 --- a/tests/test_syntax.py +++ b/tests/test_syntax.py @@ -18,9 +18,9 @@ def parse_and_eval(self, expression: str, env: Optional[Union[None, dict]] = Non return res -def _parse_and_assert(expression: str, expected: Union[int, float]) -> None: +def _parse_and_assert(expression: str, expected: Union[int, float], env: Optional[Union[None, dict]] = None) -> None: parser = SyntaxParser() - res = parser.parse_and_eval(expression) + res = parser.parse_and_eval(expression, env) assert expected == res def test_multi_line(): @@ -30,3 +30,24 @@ def test_multi_line(): _parse_and_assert("5+5\n3+4\n1+2", 3) _parse_and_assert("\n\n5\n\n3\n8", 8) + +def test_assignment(): + _parse_and_assert("a=5", 5) + _parse_and_assert("z=1+2+3", 6) + _parse_and_assert("y=(1+2+3)", 6) + +def test_assignment_env_variable(): + # check env variables is set + env = {"a":1} + _parse_and_assert("a=3", 3, env=env) + assert env["a"] == 3 + + _parse_and_assert("y = x + 3", 20, env={"x":17, "i":123}) + _parse_and_assert("y = x + i", 20, env={"x":17, "i":3}) + +def test_assign_multiline(): + _parse_and_assert("x=3 \n y=4 \n z=x+y",7) + _parse_and_assert("x=1 \n z = x + y", 3, env={"y":2}) + + +