From d47bf05034e9e4699321bee6fb36764e2e7cf115 Mon Sep 17 00:00:00 2001 From: Yihong Yu Date: Wed, 25 Sep 2024 14:42:33 -0700 Subject: [PATCH] add special case for tables --- core/rule_generator.py | 17 +++++++++++++---- tests/test_rule_generator.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/core/rule_generator.py b/core/rule_generator.py index d8c3f1e..d27163f 100644 --- a/core/rule_generator.py +++ b/core/rule_generator.py @@ -839,17 +839,17 @@ def tables(pattern_json: str, rewrite_json: str) -> list: rewriteSet = defaultdict(list) for table in patternTables: - if type(table['value']) is str and type(table['name']) is str: + if type(table['value']) is str and type(table['name']) is str and table['name'] not in patternSet[table['value']]: patternSet[table['value']].append(table['name']) for table in rewriteTables: - if type(table['value']) is str and type(table['name']) is str: - rewriteSet[table['value']].append(table['name']) + if type(table['value']) is str and type(table['name']) is str and table['name'] not in rewriteSet[table['value']]: + rewriteSet[table['value']].append(table['name']) superSet = [] for patternValue, patternNames in patternSet.items(): rewriteNames = rewriteSet.get(patternValue, []) - # special case: + # special case 1: # if the patternTable ONLY has {'value': 'employee', 'name': 'employee'} # and the rewriteTable ONLY has {'value': 'employee', 'name': 'e1'}, # we replace 'employee' with 'e1' as table alias @@ -861,6 +861,15 @@ def tables(pattern_json: str, rewrite_json: str) -> list: # if len(patternNames) == 1 and len(rewriteNames) == 1 and patternNames[0] == patternValue: patternNames = rewriteNames + # special case 2: + # if the patternTable ONLY has {'value': 'employee', 'name': 'employee'} + # and the rewriteTable has multiple alias to the same table, e.g., {'value': 'employee', 'name': 'e1'}, {'value': 'employee', 'name': 'e2'} + # we directly append 'e1' and 'e2' as table alias to the superSet as IN MOST CASES patternTable's name should be included in rewriteTable's alias name + # the purpose is for the next step when we replace tables with variables + # + elif len(patternNames) == 1 and len(rewriteNames) != 0 and patternNames[0] == patternValue: + superSet += [{'value': patternValue, 'name': name} for name in rewriteNames] + continue else: patternNames += [name for name in rewriteNames if name not in patternNames] superSet += [{'value': patternValue, 'name': name} for name in patternNames] diff --git a/tests/test_rule_generator.py b/tests/test_rule_generator.py index 3a203cb..010ef35 100644 --- a/tests/test_rule_generator.py +++ b/tests/test_rule_generator.py @@ -2162,7 +2162,7 @@ def test_generate_general_rule_16(): q0_rule, q1_rule = unify_variable_names(rule['pattern'], rule['rewrite']) assert q0_rule== "SELECT , , , , , FROM WHERE IN (SELECT FROM WHERE = AND = ) ORDER BY , " - assert q1_rule == "SELECT ., ., ., ., ., . FROM JOIN ON . = . WHERE . = AND . = ORDER BY ., ." + assert q1_rule == "SELECT ., ., ., ., ., . FROM JOIN ON . = . WHERE . = AND . = ORDER BY ., ." def test_generate_general_rule_17():