diff --git a/src/larktools/ebnf_grammar.py b/src/larktools/ebnf_grammar.py index 8432ceb..7227a60 100644 --- a/src/larktools/ebnf_grammar.py +++ b/src/larktools/ebnf_grammar.py @@ -17,7 +17,9 @@ assign_var: VARNAME "=" arith_expr + variable: VARNAME ("[" INDEX "]")* VARNAME: LETTER (LETTER | DIGIT)* + INDEX: INT // Adopted from the calculator example at // https://lark-parser.readthedocs.io/en/stable/examples/calc.html @@ -33,7 +35,7 @@ multiplication: product "*" atom division: product "/" atom - atom: INT | VARNAME | neg_atom | bracketed_arith_expr + atom: INT | variable | neg_atom | bracketed_arith_expr neg_atom: "-" atom bracketed_arith_expr: "(" arith_expr ")" diff --git a/src/larktools/evaluation.py b/src/larktools/evaluation.py index d480b68..eceec1c 100644 --- a/src/larktools/evaluation.py +++ b/src/larktools/evaluation.py @@ -85,9 +85,8 @@ def eval_atom(node, env): child_name = get_name(child) if child_name == "INT": return int(get_value(child)) - elif child_name == "VARNAME": - varname = get_value(child) - return env[varname] + elif child_name == "variable": + return eval_variable(child, env) elif child_name == "neg_atom": return eval_neg_atom(child, env) elif child_name == "bracketed_arith_expr": @@ -109,3 +108,16 @@ def eval_bracketed_arith_expr(node, env): child = get_children(node)[0] assert get_name(child) == "arith_expr" return eval_arith_expr(child, env) + + +def eval_variable(node, env): + children = get_children(node) + assert get_name(children[0]) == "VARNAME" + varname = get_value(children[0]) + value = env[varname] + if len(children) > 1: + for ch in children[1:]: + assert get_name(ch) == "INDEX" + idx = int(get_value(ch)) + value = value[idx] + return value diff --git a/tests/test_arithmetic.py b/tests/test_arithmetic.py index 8ceda89..14bf12d 100644 --- a/tests/test_arithmetic.py +++ b/tests/test_arithmetic.py @@ -12,15 +12,15 @@ def __init__(self): self.parser = Lark(grammar, parser="lalr", start="arith_expr") self.parse = self.parser.parse - def parse_and_eval(self, expression: str, env: Optional[Union[None, dict]] = None) -> Union[int, float]: + def parse_and_eval(self, expression: str, env: Optional[dict] = None) -> Union[int, float]: tree = self.parse(expression) res = eval_arith_expr(tree, {} if env is None else env) return res -def _parse_and_assert(expression: str, expected: Union[int, float]) -> None: +def _parse_and_assert(expression: str, expected: Union[int, float], env: Optional[dict] = None) -> None: parser = ArithParser() - res = parser.parse_and_eval(expression) + res = parser.parse_and_eval(expression, env) assert expected == res def test_integer_addition(): @@ -38,3 +38,21 @@ def test_float_addition(): _parse_and_assert("3.00000001 + 5.2", 8.20000001) _parse_and_assert("5e3 + 1.23E-2", 5000.00123) + +def test_simple_variable_evaluation(): + _parse_and_assert("x", 10, {"x": 10}) + + +def test_array_variable_evaluation(): + array = [i for i in range(5)] + _parse_and_assert("x[2]", 2, {"x": array}) + + +def test_nested_array_evaluation(): + nested_array = [1, [2, 3], 4] + _parse_and_assert("x[1][0]", 2, {"x": nested_array}) + + +def test_nested_array_in_arith_expr(): + nested_array = [1, [2, 3], 4] + _parse_and_assert("(x[1][0] + 2)/2", 2, {"x": nested_array})