Skip to content

Commit

Permalink
Merge pull request #10 from CodeVisionaries/refactor_evaluations
Browse files Browse the repository at this point in the history
Refactoring the evaluation implementation to make things more extensible
  • Loading branch information
gschnabel authored Oct 29, 2024
2 parents ada8fee + 905b967 commit 439f0cc
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 151 deletions.
266 changes: 121 additions & 145 deletions src/larktools/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,155 +1,131 @@
from lark import Lark
from typing import Callable

from functools import reduce
from lark import Lark, Token

from .ebnf_grammar import grammar
from .tree_utils import (
is_rule,
is_terminal,
get_name,
get_children,
get_first_child,
get_value,
)


def eval_arith_expr(node, env):
child = get_children(node)[0]
child_name = get_name(child)
assert child_name == "sum"
return eval_sum(child, env)


def eval_sum(node, env):
# we know there is only one child because
# of formal grammar definition
child = get_children(node)[0]
child_name = get_name(child)
if child_name == "product":
return eval_product(child, env)
elif child_name == "addition":
return eval_addition(child, env)
elif child_name == "subtraction":
return eval_subtraction(child, env)


def eval_product(node, env):
child = get_children(node)[0]
child_name = get_name(child)
if child_name == "atom":
return eval_atom(child, env)
elif child_name == "multiplication":
return eval_multiplication(child, env)
elif child_name == "division":
return eval_division(child, env)


def eval_addition(node, env):
child1 = get_children(node)[0]
child2 = get_children(node)[1]
assert get_name(child1) == "sum"
assert get_name(child2) == "product"
res1 = eval_sum(child1, env)
res2 = eval_product(child2, env)
return res1 + res2


def eval_subtraction(node, env):
child1 = get_children(node)[0]
child2 = get_children(node)[1]
assert get_name(child1) == "sum"
assert get_name(child2) == "product"
res1 = eval_sum(child1, env)
res2 = eval_product(child2, env)
return res1 - res2


def eval_multiplication(node, env):
child1 = get_children(node)[0]
child2 = get_children(node)[1]
assert get_name(child1) == "product"
assert get_name(child2) == "atom"
res1 = eval_product(child1, env)
res2 = eval_atom(child2, env)
return res1 * res2


def eval_division(node, env):
child1 = get_children(node)[0]
child2 = get_children(node)[1]
assert get_name(child1) == "product"
assert get_name(child2) == "atom"
res1 = eval_product(child1, env)
res2 = eval_atom(child2, env)
return res1 / res2


def eval_atom(node, env):
child = get_children(node)[0]
child_name = get_name(child)
if child_name == "INT":
return int(get_value(child))
elif child_name == "SIGNED_FLOAT":
return float(get_value(child))
elif child_name == "variable":
return eval_variable(child, env)
elif child_name == "neg_atom":
return eval_neg_atom(child, env)
elif child_name == "bracketed_arith_expr":
return eval_bracketed_arith_expr(child, env)


def eval_neg_atom(node, env):
# the "-" character appearing in the production rule is
# filtered out by lark by default because it is a constant
# character. Therefore, it doesn't appear among the child nodes
child = get_children(node)[0]
assert get_name(child) == "atom"
return (-eval_atom(child, env))


def eval_bracketed_arith_expr(node, env):
# same here, the constant characters "(" and ")"
# are filtered out and don't appear as child nodes
child = get_children(node)[0]
assert get_name(child) == "arith_expr"
return eval_arith_expr(child, env)

def eval_line(node, env):
# this is the content of a single line of input
child = get_children(node)[0]
child_name = get_name(child)
if child_name == "arith_expr":
return eval_arith_expr(child, env)
elif child_name == "assignment":
return eval_assignment(child, env)

def eval_multi_line_block(node, env):
# this can be either an arithmetic expression or
# composed lines
children = get_children(node)
for child in children:
child_name = get_name(child)
assert child_name == "line"
res = eval_line(child, env)
return res


def eval_assignment(node, env):
# assign result of an expression to a variable
child1, child2 = get_children(node)[0:2]
assert get_name(child1) == "VARNAME"
assert get_name(child2) == "arith_expr"

varname = get_value(child1)
env[varname] = eval_arith_expr(child2, env)
return env[varname]


def eval_variable(node, env):
children = get_children(node)
assert get_name(children[0]) == "VARNAME"
varname = get_value(children[0])
value = env[varname]
if len(children) > 1:
for ch in children[1:]:
assert get_name(ch) == "INDEX"
idx = int(get_value(ch))
value = value[idx]
return value
def instantiate_eval_tree(lark_node):
node_name = get_name(lark_node)
# tunnel thorugh dummy nodes, e.g. sum, product
while node_name not in INV_NODE_MAP:
if isinstance(lark_node, Token):
raise AttributeError(
f"`{lark_node.type}` is a terminal node"
"and doesn't have a node class associated with it."
)
if len(lark_node.children) != 1:
raise IndexError(
"Nodes without associated node class "
"must have exactly one child. However, "
f"the node `{get_name(lark_node)}` has "
f"{len(lark_node.children)} children."
)
lark_node = lark_node.children[0]
node_name = get_name(lark_node)

return INV_NODE_MAP[node_name](lark_node)


class RootNode:
def __init__(self, lark_node):
self._children = [
instantiate_eval_tree(n) for n in lark_node.children
]

def __call__(self, env):
results = [n(env) for n in self._children]
return results[-1]


class AssignNode:
def __init__(self, lark_node):
self._varname = get_value(lark_node.children[0])
self._expr = instantiate_eval_tree(lark_node.children[1])

def __call__(self, env):
rhs = self._expr(env)
env[self._varname] = rhs
# TODO: Should the full assignment evaluate to the expression on rhs?
# At the moment necessary to satisy some test assumptions.
return rhs


class VariableNode:
def __init__(self, lark_node):
self._varname = get_value(lark_node.children[0])
self._index_nodes = [
instantiate_eval_tree(n) for n in lark_node.children[1:]
]

def __call__(self, env):
index_values = [n(env) for n in self._index_nodes]
return reduce(
lambda lst, idx: lst[idx], index_values, env[self._varname]
)


class NumberNode:
def __init__(self, lark_node):
node_name = get_name(lark_node)
self._value = {
"SIGNED_FLOAT": float, "INT": int, "INDEX": int,
}[node_name](get_value(lark_node))

def __call__(self, env):
return self._value


class MappedOperatorNode:
def __init__(self, lark_node, op_map):
node_name = get_name(lark_node)
self._children = [
instantiate_eval_tree(n) for n in lark_node.children
]
self._func = op_map[node_name]

def __call__(self, env):
results = [n(env) for n in self._children]
return self._func(results)


class UnaryOperatorNode(MappedOperatorNode):
def __init__(self, lark_node):
super().__init__(
lark_node,
op_map={"neg_atom": lambda x: -x[0]}
)


class BinaryOperatorNode(MappedOperatorNode):
def __init__(self, lark_node):
super().__init__(
lark_node,
op_map = {
"addition": lambda x: sum(x),
"subtraction": lambda x: x[0] - x[1],
"multiplication": lambda x: x[0] * x[1],
"division": lambda x: x[0] / x[1],
}
)


NODE_MAP = {
RootNode: ("multi_line_block",),
AssignNode: ("assignment",),
UnaryOperatorNode: ("neg_atom",),
BinaryOperatorNode: ("addition", "subtraction", "multiplication", "division"),
VariableNode: ("variable", "varname"),
NumberNode: ("INT", "SIGNED_INT", "FLOAT", "SIGNED_FLOAT", "INDEX"),
}

INV_NODE_MAP = {k: v for v in NODE_MAP for k in NODE_MAP[v]}
5 changes: 5 additions & 0 deletions src/larktools/tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def get_children(node):
"""Get a list of children from a rule node."""
return node.children

def get_first_child(node):
"""Get the first child from a rule node."""
children = get_children(node)
assert len(children) > 0
return children[0]

def get_value(node):
"""Get the value associated with a terminal node."""
Expand Down
9 changes: 5 additions & 4 deletions tests/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from lark import Lark

from larktools.ebnf_grammar import grammar
from larktools.evaluation import eval_arith_expr
from larktools.evaluation import instantiate_eval_tree


class ArithParser:
def __init__(self):
self.parser = Lark(grammar, parser="lalr", start="arith_expr")
self.parse = self.parser.parse

def parse_and_eval(self, expression: str, env: Optional[dict] = None) -> Union[int, float]:
tree = self.parse(expression)
res = eval_arith_expr(tree, {} if env is None else env)
tree = self.parser.parse(expression)
eval_tree = instantiate_eval_tree(tree)
res = eval_tree({} if env is None else env)
return res


Expand All @@ -30,6 +30,7 @@ def _parse_and_assert_collection(tests: list[str, Union[int, float]]) -> None:
def test_integer_addition():
_parse_and_assert("3 + 5", 8)
_parse_and_assert("5 + 3", 8)
_parse_and_assert("5 + 3 + 1", 9)
_parse_and_assert("9999999999999999 + 555555555555555", 10555555555555554)

def test_integer_addition_neg():
Expand Down
5 changes: 3 additions & 2 deletions tests/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lark import Lark

from larktools.ebnf_grammar import grammar
from larktools.evaluation import eval_multi_line_block
from larktools.evaluation import instantiate_eval_tree


class SyntaxParser:
Expand All @@ -14,7 +14,8 @@ def __init__(self):

def parse_and_eval(self, expression: str, env: Optional[Union[None, dict]] = None) -> Union[int, float]:
tree = self.parse(expression)
res = eval_multi_line_block(tree, {} if env is None else env)
eval_tree = instantiate_eval_tree(tree)
res = eval_tree({} if env is None else env)
return res


Expand Down

0 comments on commit 439f0cc

Please sign in to comment.