-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from CodeVisionaries/refactor_evaluations
Refactoring the evaluation implementation to make things more extensible
- Loading branch information
Showing
4 changed files
with
134 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,155 +1,131 @@ | ||
from lark import Lark | ||
from typing import Callable | ||
|
||
from functools import reduce | ||
from lark import Lark, Token | ||
|
||
from .ebnf_grammar import grammar | ||
from .tree_utils import ( | ||
is_rule, | ||
is_terminal, | ||
get_name, | ||
get_children, | ||
get_first_child, | ||
get_value, | ||
) | ||
|
||
|
||
def eval_arith_expr(node, env): | ||
child = get_children(node)[0] | ||
child_name = get_name(child) | ||
assert child_name == "sum" | ||
return eval_sum(child, env) | ||
|
||
|
||
def eval_sum(node, env): | ||
# we know there is only one child because | ||
# of formal grammar definition | ||
child = get_children(node)[0] | ||
child_name = get_name(child) | ||
if child_name == "product": | ||
return eval_product(child, env) | ||
elif child_name == "addition": | ||
return eval_addition(child, env) | ||
elif child_name == "subtraction": | ||
return eval_subtraction(child, env) | ||
|
||
|
||
def eval_product(node, env): | ||
child = get_children(node)[0] | ||
child_name = get_name(child) | ||
if child_name == "atom": | ||
return eval_atom(child, env) | ||
elif child_name == "multiplication": | ||
return eval_multiplication(child, env) | ||
elif child_name == "division": | ||
return eval_division(child, env) | ||
|
||
|
||
def eval_addition(node, env): | ||
child1 = get_children(node)[0] | ||
child2 = get_children(node)[1] | ||
assert get_name(child1) == "sum" | ||
assert get_name(child2) == "product" | ||
res1 = eval_sum(child1, env) | ||
res2 = eval_product(child2, env) | ||
return res1 + res2 | ||
|
||
|
||
def eval_subtraction(node, env): | ||
child1 = get_children(node)[0] | ||
child2 = get_children(node)[1] | ||
assert get_name(child1) == "sum" | ||
assert get_name(child2) == "product" | ||
res1 = eval_sum(child1, env) | ||
res2 = eval_product(child2, env) | ||
return res1 - res2 | ||
|
||
|
||
def eval_multiplication(node, env): | ||
child1 = get_children(node)[0] | ||
child2 = get_children(node)[1] | ||
assert get_name(child1) == "product" | ||
assert get_name(child2) == "atom" | ||
res1 = eval_product(child1, env) | ||
res2 = eval_atom(child2, env) | ||
return res1 * res2 | ||
|
||
|
||
def eval_division(node, env): | ||
child1 = get_children(node)[0] | ||
child2 = get_children(node)[1] | ||
assert get_name(child1) == "product" | ||
assert get_name(child2) == "atom" | ||
res1 = eval_product(child1, env) | ||
res2 = eval_atom(child2, env) | ||
return res1 / res2 | ||
|
||
|
||
def eval_atom(node, env): | ||
child = get_children(node)[0] | ||
child_name = get_name(child) | ||
if child_name == "INT": | ||
return int(get_value(child)) | ||
elif child_name == "SIGNED_FLOAT": | ||
return float(get_value(child)) | ||
elif child_name == "variable": | ||
return eval_variable(child, env) | ||
elif child_name == "neg_atom": | ||
return eval_neg_atom(child, env) | ||
elif child_name == "bracketed_arith_expr": | ||
return eval_bracketed_arith_expr(child, env) | ||
|
||
|
||
def eval_neg_atom(node, env): | ||
# the "-" character appearing in the production rule is | ||
# filtered out by lark by default because it is a constant | ||
# character. Therefore, it doesn't appear among the child nodes | ||
child = get_children(node)[0] | ||
assert get_name(child) == "atom" | ||
return (-eval_atom(child, env)) | ||
|
||
|
||
def eval_bracketed_arith_expr(node, env): | ||
# same here, the constant characters "(" and ")" | ||
# are filtered out and don't appear as child nodes | ||
child = get_children(node)[0] | ||
assert get_name(child) == "arith_expr" | ||
return eval_arith_expr(child, env) | ||
|
||
def eval_line(node, env): | ||
# this is the content of a single line of input | ||
child = get_children(node)[0] | ||
child_name = get_name(child) | ||
if child_name == "arith_expr": | ||
return eval_arith_expr(child, env) | ||
elif child_name == "assignment": | ||
return eval_assignment(child, env) | ||
|
||
def eval_multi_line_block(node, env): | ||
# this can be either an arithmetic expression or | ||
# composed lines | ||
children = get_children(node) | ||
for child in children: | ||
child_name = get_name(child) | ||
assert child_name == "line" | ||
res = eval_line(child, env) | ||
return res | ||
|
||
|
||
def eval_assignment(node, env): | ||
# assign result of an expression to a variable | ||
child1, child2 = get_children(node)[0:2] | ||
assert get_name(child1) == "VARNAME" | ||
assert get_name(child2) == "arith_expr" | ||
|
||
varname = get_value(child1) | ||
env[varname] = eval_arith_expr(child2, env) | ||
return env[varname] | ||
|
||
|
||
def eval_variable(node, env): | ||
children = get_children(node) | ||
assert get_name(children[0]) == "VARNAME" | ||
varname = get_value(children[0]) | ||
value = env[varname] | ||
if len(children) > 1: | ||
for ch in children[1:]: | ||
assert get_name(ch) == "INDEX" | ||
idx = int(get_value(ch)) | ||
value = value[idx] | ||
return value | ||
def instantiate_eval_tree(lark_node): | ||
node_name = get_name(lark_node) | ||
# tunnel thorugh dummy nodes, e.g. sum, product | ||
while node_name not in INV_NODE_MAP: | ||
if isinstance(lark_node, Token): | ||
raise AttributeError( | ||
f"`{lark_node.type}` is a terminal node" | ||
"and doesn't have a node class associated with it." | ||
) | ||
if len(lark_node.children) != 1: | ||
raise IndexError( | ||
"Nodes without associated node class " | ||
"must have exactly one child. However, " | ||
f"the node `{get_name(lark_node)}` has " | ||
f"{len(lark_node.children)} children." | ||
) | ||
lark_node = lark_node.children[0] | ||
node_name = get_name(lark_node) | ||
|
||
return INV_NODE_MAP[node_name](lark_node) | ||
|
||
|
||
class RootNode: | ||
def __init__(self, lark_node): | ||
self._children = [ | ||
instantiate_eval_tree(n) for n in lark_node.children | ||
] | ||
|
||
def __call__(self, env): | ||
results = [n(env) for n in self._children] | ||
return results[-1] | ||
|
||
|
||
class AssignNode: | ||
def __init__(self, lark_node): | ||
self._varname = get_value(lark_node.children[0]) | ||
self._expr = instantiate_eval_tree(lark_node.children[1]) | ||
|
||
def __call__(self, env): | ||
rhs = self._expr(env) | ||
env[self._varname] = rhs | ||
# TODO: Should the full assignment evaluate to the expression on rhs? | ||
# At the moment necessary to satisy some test assumptions. | ||
return rhs | ||
|
||
|
||
class VariableNode: | ||
def __init__(self, lark_node): | ||
self._varname = get_value(lark_node.children[0]) | ||
self._index_nodes = [ | ||
instantiate_eval_tree(n) for n in lark_node.children[1:] | ||
] | ||
|
||
def __call__(self, env): | ||
index_values = [n(env) for n in self._index_nodes] | ||
return reduce( | ||
lambda lst, idx: lst[idx], index_values, env[self._varname] | ||
) | ||
|
||
|
||
class NumberNode: | ||
def __init__(self, lark_node): | ||
node_name = get_name(lark_node) | ||
self._value = { | ||
"SIGNED_FLOAT": float, "INT": int, "INDEX": int, | ||
}[node_name](get_value(lark_node)) | ||
|
||
def __call__(self, env): | ||
return self._value | ||
|
||
|
||
class MappedOperatorNode: | ||
def __init__(self, lark_node, op_map): | ||
node_name = get_name(lark_node) | ||
self._children = [ | ||
instantiate_eval_tree(n) for n in lark_node.children | ||
] | ||
self._func = op_map[node_name] | ||
|
||
def __call__(self, env): | ||
results = [n(env) for n in self._children] | ||
return self._func(results) | ||
|
||
|
||
class UnaryOperatorNode(MappedOperatorNode): | ||
def __init__(self, lark_node): | ||
super().__init__( | ||
lark_node, | ||
op_map={"neg_atom": lambda x: -x[0]} | ||
) | ||
|
||
|
||
class BinaryOperatorNode(MappedOperatorNode): | ||
def __init__(self, lark_node): | ||
super().__init__( | ||
lark_node, | ||
op_map = { | ||
"addition": lambda x: sum(x), | ||
"subtraction": lambda x: x[0] - x[1], | ||
"multiplication": lambda x: x[0] * x[1], | ||
"division": lambda x: x[0] / x[1], | ||
} | ||
) | ||
|
||
|
||
NODE_MAP = { | ||
RootNode: ("multi_line_block",), | ||
AssignNode: ("assignment",), | ||
UnaryOperatorNode: ("neg_atom",), | ||
BinaryOperatorNode: ("addition", "subtraction", "multiplication", "division"), | ||
VariableNode: ("variable", "varname"), | ||
NumberNode: ("INT", "SIGNED_INT", "FLOAT", "SIGNED_FLOAT", "INDEX"), | ||
} | ||
|
||
INV_NODE_MAP = {k: v for v in NODE_MAP for k in NODE_MAP[v]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters