diff --git a/src/pyft/statements.py b/src/pyft/statements.py index f89fce3..7f1ad86 100644 --- a/src/pyft/statements.py +++ b/src/pyft/statements.py @@ -250,7 +250,7 @@ def decode(directive): kind = directive[17:].lstrip(' ').split('(')[0].strip() return table, kind - def updateStmt(stmt, table, kind, extraindent, parent, scopePath): + def updateStmt(stmt, table, kind, extraindent, parent, scope): """ Updates the statement given the table dictionnary '(:, :)' is replaced by '(JI, JK)' if table.keys() is ['JI', 'JK'] @@ -259,7 +259,7 @@ def updateStmt(stmt, table, kind, extraindent, parent, scopePath): :param kind: kind of mnh directives: 'array' or 'where' or None if transformation is not governed by mnh directive - :param scopePath: current scope path + :param scope: current scope """ def addExtra(node, extra): @@ -297,14 +297,14 @@ def addExtra(node, extra): # We loop on named-E nodes (and not directly on array-R nodes to prevent using # the costly getParent) for namedE in stmt.findall('.//{*}R-LT/..'): - self.arrayR2parensR(namedE, table, scopePath) # Replace slices by variable + scope.arrayR2parensR(namedE, table) # Replace slices by variable for cnt in stmt.findall('.//{*}cnt'): addExtra(cnt, extraindent) # Add indentation after continuation characters elif tag(stmt) == 'if-stmt': logging.warning("An if statement is inside a code section " + "transformed in DO loop in {f}".format(f=self.getFileName())) # Update the statement contained in the action node - updateStmt(stmt.find('./{*}action-stmt')[0], table, kind, 0, stmt, scopePath) + updateStmt(stmt.find('./{*}action-stmt')[0], table, kind, 0, stmt, scope) elif tag(stmt) == 'if-construct': logging.warning("An if construct is inside a code section " + "transformed in DO loop in {f}".format(f=self.getFileName())) @@ -313,7 +313,7 @@ def addExtra(node, extra): for child in ifBlock: # Loop over each statement inside the block if tag(child) not in ('if-then-stmt', 'else-if-stmt', 'else-stmt', 'end-if-stmt'): - updateStmt(child, table, kind, extraindent, ifBlock, scopePath) + updateStmt(child, table, kind, extraindent, ifBlock, scope) else: # Update indentation because the loop is here and not in recur addExtra(child, extraindent) @@ -326,11 +326,11 @@ def addExtra(node, extra): stmt.text = 'IF (' + stmt.text.split('(', 1)[1] # Update the action part updateStmt(stmt.find('./{*}action-stmt')[0], table, kind, - extraindent, stmt, scopePath) + extraindent, stmt, scope) mask = stmt.find('./{*}mask-E') mask.tag = f'{{{NAMESPACE}}}condition-E' # rename the condition tag for namedE in mask.findall('.//{*}R-LT/..'): - self.arrayR2parensR(namedE, table, scopePath) # Replace slices by variable + scope.arrayR2parensR(namedE, table) # Replace slices by variable for cnt in stmt.findall('.//{*}cnt'): addExtra(cnt, extraindent) # Add indentation after continuation characters elif tag(stmt) == 'where-construct': @@ -375,12 +375,12 @@ def addExtra(node, extra): mask.tail += ' THEN' for namedE in mask.findall('.//{*}R-LT/..'): # Replace slices by variable in the condition - self.arrayR2parensR(namedE, table, scopePath) + scope.arrayR2parensR(namedE, table) for cnt in child.findall('.//{*}cnt'): # Add indentation spaces after continuation characters addExtra(cnt, extraindent) else: - updateStmt(child, table, kind, extraindent, whereBlock, scopePath) + updateStmt(child, table, kind, extraindent, whereBlock, scope) else: raise PYFTError('Unexpected tag found in mnh_expand ' + 'directives: {t}'.format(t=tag(stmt))) @@ -400,11 +400,10 @@ def closeLoop(loopdesc): toremove = [] # list of nodes to remove newVarList = [] # list of new variables - def recur(elem, scopePath): + def recur(elem, scope): inMnh = False # are we in a DO loop created by a mnh directive inEverywhere = False # are we in a created DO loop (except if done with mnh directive) tailSave = {} # Save tail before transformation (to retrieve original indentation) - scopePath = self.getScopePath(elem) if self.isScopeNode(elem) else scopePath for ie, sElem in enumerate(list(elem)): # we loop on elements in the natural order if tag(sElem) == 'C' and sElem.text.lstrip(' ').startswith('!$mnh_expand') and \ useMnhExpand: @@ -461,7 +460,7 @@ def recur(elem, scopePath): toremove.append((elem, sElem)) # we remove it from its old place inner.insert(-1, sElem) # Insert first in the DO loop # then update, providing new parent in argument - updateStmt(sElem, table, kind, extraindent, inner, scopePath) + updateStmt(sElem, table, kind, extraindent, inner, scope) elif everywhere and tag(sElem) in ('a-stmt', 'if-stmt', 'where-stmt', 'where-construct'): @@ -499,7 +498,7 @@ def recur(elem, scopePath): if arr is not None: # In this case we transform the if statement into an if-construct self.changeIfStatementsInIfConstructs(singleItem=sElem, parent=elem) - recur(sElem, scopePath) # to transform the content of the if + recur(sElem, scope) # to transform the content of the if arr = None # to do nothing more on this node elif tag(sElem) == 'where-stmt': arr = sElem.find('./{*}mask-E//{*}named-E/{*}R-LT/{*}array-R/../..') @@ -530,8 +529,7 @@ def recur(elem, scopePath): else: # Guess a variable name if arr is not None: - newtable, varNew = self.getScopeNode(scopePath).findArrayBounds( - arr, loopVar, newVarList) + newtable, varNew = scope.findArrayBounds(arr, loopVar, newVarList) for var in varNew: var['new'] = True if var not in newVarList: @@ -567,7 +565,7 @@ def recur(elem, scopePath): toremove.append((elem, sElem)) # we remove it from its old place inner.insert(-1, sElem) # Insert first in the DO loop # then update, providing new parent in argument - updateStmt(sElem, table, kind, extraindent, inner, scopePath) + updateStmt(sElem, table, kind, extraindent, inner, scope) if not reuseLoop: # Prevent from reusing this DO loop inEverywhere = closeLoop(inEverywhere) @@ -576,10 +574,11 @@ def recur(elem, scopePath): inEverywhere = closeLoop(inEverywhere) # close loop if needed if len(sElem) >= 1: # Iteration - recur(sElem, scopePath) + recur(sElem, scope) inEverywhere = closeLoop(inEverywhere) - recur(self, self.getScopePath(self)) + for scope in self.getScopes(excludeContains=True): + recur(scope, scope) # First, element insertion by reverse order (in order to keep the insertion index correct) for elem, outer, ie in toinsert[::-1]: elem.insert(ie, outer) @@ -749,7 +748,7 @@ def setPRESENTby(node, var, val): for namedE in callStmt.findall('./{*}arg-spec/{*}arg/{*}named-E'): # Replace slices by indexes if any if namedE.find('./{*}R-LT'): - self.arrayR2parensR(namedE, table, mainScope.path) + mainScope.arrayR2parensR(namedE, table) # Deep copy the object to possibly modify the original one multiple times node = copy.deepcopy(subContained) diff --git a/src/pyft/variables.py b/src/pyft/variables.py index e444c66..f67bdfb 100644 --- a/src/pyft/variables.py +++ b/src/pyft/variables.py @@ -1032,7 +1032,7 @@ def findIndexArrayBounds(self, arr, index, loopVar): return None @debugDecor - def arrayR2parensR(self, namedE, table, scopePath): + def arrayR2parensR(self, namedE, table): """ Transform a array-R into a parens-R node by replacing slices by variables In 'A(:)', the ':' is in a array-R node whereas in 'A(JL)', 'JL' is in a parens-R node. @@ -1040,7 +1040,6 @@ def arrayR2parensR(self, namedE, table, scopePath): :param namedE: a named-E node :param table: dictionnary returned by the decode function :param varList: None or a VarList object in which varaibles are searched for - :param scopePath: scope path in which nameE is """ # Before A(:): # A @@ -1063,7 +1062,6 @@ def arrayR2parensR(self, namedE, table, scopePath): # # - scope = self.getScopeNode(scopePath, excludeContains=True) nodeRLT = namedE.find('./{*}R-LT') arrayR = nodeRLT.find('./{*}array-R') # Not always in first position, eg: ICED%XRTMIN(:) if arrayR is not None: @@ -1099,14 +1097,14 @@ def arrayR2parensR(self, namedE, table, scopePath): # A(2:15) or A(:15) or A(2:) if lower is None: # lower bound not defined, getting lower declared bound for this array - lower = scope.varList.findVar(n2name(namedE.find('{*}N')), - array=True)['as'][ivar][0] + lower = self.varList.findVar(n2name(namedE.find('{*}N')), + array=True)['as'][ivar][0] if lower is None: lower = '1' # default fortran lower bound elif upper is None: # upper bound not defined, getting upper declared bound for this array - upper = scope.varList.findVar(n2name(namedE.find('{*}N')), - array=True)['as'][ivar][1] + upper = self.varList.findVar(n2name(namedE.find('{*}N')), + array=True)['as'][ivar][1] # If the DO loop starts from JI=I1 and goes to JI=I2; and array # bounds are J1:J2 # We compute J1-I1+JI and J2-I2+JI and they should be the same