Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
mmatera committed Jan 6, 2025
1 parent bc6f88b commit 850f6f3
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 81 deletions.
60 changes: 33 additions & 27 deletions mathics/core/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
Support for Set and SetDelayed, and other assignment-like builtins
"""

from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, cast

from mathics.core.atoms import Atom
from mathics.core.attributes import A_PROTECTED
from mathics.core.builtin import Builtin
from mathics.core.definitions import Definitions
from mathics.core.definitions import Definition, Definitions
from mathics.core.element import BaseElement
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
Expand Down Expand Up @@ -40,7 +40,7 @@ def build_rulopc(optval: BaseElement) -> Rule:
)


def get_symbol_list(expr: BaseElement, error_callback: Callable) -> Optional[List[str]]:
def get_symbol_list(expr: Expression, error_callback: Callable) -> Optional[List[str]]:
"""
If ``expr`` is of the form ``List[Symbol___]`` returns a list
with the names of the symbols as elements.
Expand Down Expand Up @@ -105,19 +105,24 @@ def get_symbol_values(
if not name:
evaluation.message(func_name, "sym", symbol, 1)
return None
if position in ("default",):
definition = evaluation.definitions.get_definition(name)
else:
definition = evaluation.definitions.get_user_definition(name)
definitions = evaluation.definitions
definition = (
definitions.get_definition(name)
if position in ("default",)
else definitions.get_user_definition(name)
)
if definition is None:
return ListExpression()

elements = []
for rule in definition.get_values_list(position):
for rule in cast(Definition, definition).get_values_list(position):
if isinstance(rule, Rule):
pattern = rule.pattern
if pattern.has_form("HoldPattern", 1):
pattern = pattern.expr
expr_pattern = pattern.expr
else:
pattern = Expression(SymbolHoldPattern, pattern.expr)
elements.append(Expression(SymbolRuleDelayed, pattern, rule.replace))
expr_pattern = Expression(SymbolHoldPattern, pattern.expr)
elements.append(Expression(SymbolRuleDelayed, expr_pattern, rule.replace))
return ListExpression(*elements)


Expand Down Expand Up @@ -162,8 +167,9 @@ def repl_pattern_by_symbol(expr: BaseElement) -> BaseElement:
if len(elements) == 0:
return expr

headname = expr.get_head_name()
if headname == "System`Pattern":
head = expr.get_head()
head_name = head.get_name()
if head_name == "System`Pattern":
return elements[0]

changed = False
Expand All @@ -174,8 +180,7 @@ def repl_pattern_by_symbol(expr: BaseElement) -> BaseElement:
changed = True
new_elements.append(element)
if changed:
# TODO check this: headname is a str not a Symbol.
return Expression(headname, *new_elements)
return Expression(head, *new_elements)
return expr


Expand Down Expand Up @@ -223,32 +228,33 @@ def rejected_because_protected(

def unroll_conditions(lhs: BaseElement) -> Tuple[BaseElement, Optional[Expression]]:
"""
If lhs is a nested `Condition` expression,
If `element` is a nested `Condition` expression,
gather all the conditions in a single one, and returns a tuple
with the lhs stripped from the conditions and the shallow condition.
with the `element` stripped from the conditions and the shallow condition.
If there is not any condition, returns the lhs and None
"""
if isinstance(lhs, Symbol):
return lhs, None

name, lhs_elements = lhs.get_head_name(), lhs.get_elements()
expr: Expression = cast(Expression, lhs)
name, lhs_elements = expr.get_head_name(), expr.get_elements()
conditions = []
# This handle the case of many successive conditions:
# f[x_]/; cond1 /; cond2 ... -> f[x_]/; And[cond1, cond2, ...]
while name == "System`Condition" and len(lhs.elements) == 2:
while name == "System`Condition" and len(lhs_elements) == 2:
conditions.append(lhs_elements[1])
lhs = lhs_elements[0]
if isinstance(lhs, Atom):
if isinstance(expr, Atom):
break
name, lhs_elements = lhs.get_head_name(), lhs.elements
expr = cast(Expression, lhs)
name, lhs_elements = expr.get_head_name(), expr.elements
if len(conditions) == 0:
return lhs, None
if len(conditions) > 1:
condition = Expression(SymbolAnd, *conditions)
else:
condition = conditions[0]

condition: BaseElement = (
Expression(SymbolAnd, *conditions) if len(conditions) > 1 else conditions[0]
)
condition = Expression(SymbolCondition, lhs, condition)
# lhs._format_cache = None
return lhs, condition


Expand All @@ -262,7 +268,7 @@ def unroll_patterns(
if isinstance(lhs, Atom):
return lhs, rhs
name = lhs.get_head_name()
lhs_elements = lhs.elements
lhs_elements = cast(Expression, lhs).elements
if name == "System`Pattern":
lhs = lhs_elements[1]
rulerepl = (lhs_elements[0], repl_pattern_by_symbol(lhs))
Expand Down
Loading

0 comments on commit 850f6f3

Please sign in to comment.