From 04169436da874d8ce71d9211f70047d29e29a4e9 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 18:47:17 +0100 Subject: [PATCH] symbolic rewrites support linear terms Co-authored-by: Tuomas Laakkonen --- zxlive/custom_rule.py | 64 +++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index a86fd0c9..c3c81abf 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -114,30 +114,62 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) -def get_var(v): +def get_linear(v): if not isinstance(v, Poly): raise ValueError("Not a symbolic parameter") - if len(v.terms) != 1: - raise ValueError("Only single-term symbolic parameters are supported") - if len(v.terms[0][1].vars) != 1: - raise ValueError("Only single-variable symbolic parameters are supported") - return v.terms[0][1].vars[0][0] + if len(v.terms) > 2 or len(v.free_vars()) > 1: + raise ValueError("Only linear symbolic parameters are supported") + if len(v.terms) == 0: + return 1, None, 0 + elif len(v.terms) == 1: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = 0 + else: + const = v.terms[0][0] + return 1, None, const + else: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = v.terms[1][0] + else: + var_term = v.terms[1] + const = v.terms[0][0] + coeff = var_term[0] + var, power = var_term[1].vars[0] + if power != 1: + raise ValueError("Only linear symbolic parameters are supported") + return coeff, var, const + def match_symbolic_parameters(match, left, right): params = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) + + def check_phase_equality(v): + if left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + + def update_params(v, var, coeff, const): + var_value = (right_phase[match[v]] - const) / coeff + if var in params and params[var] != var_value: + raise ValueError("Symbolic parameters do not match") + params[var] = var_value + for v in left.nodes(): if isinstance(left_phase[v], Poly): - if get_var(left_phase[v]) in params: - if params[get_var(left_phase[v])] != right_phase[match[v]]: - raise ValueError("Symbolic parameters do not match") - else: - params[get_var(left_phase[v])] = right_phase[match[v]] - elif left_phase[v] != right_phase[match[v]]: - raise ValueError("Parameters do not match") + coeff, var, const = get_linear(left_phase[v]) + if var is None: + check_phase_equality(v) + continue + update_params(v, var, coeff, const) + else: + check_phase_equality(v) + return params + def filter_matchings_if_symbolic_compatible(matchings, left, right): new_matchings = [] for matching in matchings: @@ -221,9 +253,11 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: return False for vertex in rule.lhs_graph.vertices(): if isinstance(rule.lhs_graph.phase(vertex), Poly): - if len(rule.lhs_graph.phase(vertex).free_vars()) > 1: + try: + get_linear(rule.lhs_graph.phase(vertex)) + except ValueError as e: if show_error: from .dialogs import show_error_msg - show_error_msg("Warning!", "Only one symbolic parameter per vertex is supported on the left-hand side.") + show_error_msg("Warning!", str(e)) return False return True