Skip to content

Commit

Permalink
Remove scopePath argument of arrayR2parensR
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastienRietteMTO committed Oct 14, 2024
1 parent 000ca1d commit abe1413
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 26 deletions.
37 changes: 18 additions & 19 deletions src/pyft/statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def decode(directive):
kind = directive[17:].lstrip(' ').split('(')[0].strip()
return table, kind

def updateStmt(stmt, table, kind, extraindent, parent, scopePath):
def updateStmt(stmt, table, kind, extraindent, parent, scope):
"""
Updates the statement given the table dictionnary '(:, :)' is replaced by '(JI, JK)' if
table.keys() is ['JI', 'JK']
Expand All @@ -259,7 +259,7 @@ def updateStmt(stmt, table, kind, extraindent, parent, scopePath):
:param kind: kind of mnh directives: 'array' or 'where'
or None if transformation is not governed by
mnh directive
:param scopePath: current scope path
:param scope: current scope
"""

def addExtra(node, extra):
Expand Down Expand Up @@ -297,14 +297,14 @@ def addExtra(node, extra):
# We loop on named-E nodes (and not directly on array-R nodes to prevent using
# the costly getParent)
for namedE in stmt.findall('.//{*}R-LT/..'):
self.arrayR2parensR(namedE, table, scopePath) # Replace slices by variable
scope.arrayR2parensR(namedE, table) # Replace slices by variable
for cnt in stmt.findall('.//{*}cnt'):
addExtra(cnt, extraindent) # Add indentation after continuation characters
elif tag(stmt) == 'if-stmt':
logging.warning("An if statement is inside a code section " +
"transformed in DO loop in {f}".format(f=self.getFileName()))
# Update the statement contained in the action node
updateStmt(stmt.find('./{*}action-stmt')[0], table, kind, 0, stmt, scopePath)
updateStmt(stmt.find('./{*}action-stmt')[0], table, kind, 0, stmt, scope)
elif tag(stmt) == 'if-construct':
logging.warning("An if construct is inside a code section " +
"transformed in DO loop in {f}".format(f=self.getFileName()))
Expand All @@ -313,7 +313,7 @@ def addExtra(node, extra):
for child in ifBlock: # Loop over each statement inside the block
if tag(child) not in ('if-then-stmt', 'else-if-stmt',
'else-stmt', 'end-if-stmt'):
updateStmt(child, table, kind, extraindent, ifBlock, scopePath)
updateStmt(child, table, kind, extraindent, ifBlock, scope)
else:
# Update indentation because the loop is here and not in recur
addExtra(child, extraindent)
Expand All @@ -326,11 +326,11 @@ def addExtra(node, extra):
stmt.text = 'IF (' + stmt.text.split('(', 1)[1]
# Update the action part
updateStmt(stmt.find('./{*}action-stmt')[0], table, kind,
extraindent, stmt, scopePath)
extraindent, stmt, scope)
mask = stmt.find('./{*}mask-E')
mask.tag = f'{{{NAMESPACE}}}condition-E' # rename the condition tag
for namedE in mask.findall('.//{*}R-LT/..'):
self.arrayR2parensR(namedE, table, scopePath) # Replace slices by variable
scope.arrayR2parensR(namedE, table) # Replace slices by variable
for cnt in stmt.findall('.//{*}cnt'):
addExtra(cnt, extraindent) # Add indentation after continuation characters
elif tag(stmt) == 'where-construct':
Expand Down Expand Up @@ -375,12 +375,12 @@ def addExtra(node, extra):
mask.tail += ' THEN'
for namedE in mask.findall('.//{*}R-LT/..'):
# Replace slices by variable in the condition
self.arrayR2parensR(namedE, table, scopePath)
scope.arrayR2parensR(namedE, table)
for cnt in child.findall('.//{*}cnt'):
# Add indentation spaces after continuation characters
addExtra(cnt, extraindent)
else:
updateStmt(child, table, kind, extraindent, whereBlock, scopePath)
updateStmt(child, table, kind, extraindent, whereBlock, scope)
else:
raise PYFTError('Unexpected tag found in mnh_expand ' +
'directives: {t}'.format(t=tag(stmt)))
Expand All @@ -400,11 +400,10 @@ def closeLoop(loopdesc):
toremove = [] # list of nodes to remove
newVarList = [] # list of new variables

def recur(elem, scopePath):
def recur(elem, scope):
inMnh = False # are we in a DO loop created by a mnh directive
inEverywhere = False # are we in a created DO loop (except if done with mnh directive)
tailSave = {} # Save tail before transformation (to retrieve original indentation)
scopePath = self.getScopePath(elem) if self.isScopeNode(elem) else scopePath
for ie, sElem in enumerate(list(elem)): # we loop on elements in the natural order
if tag(sElem) == 'C' and sElem.text.lstrip(' ').startswith('!$mnh_expand') and \
useMnhExpand:
Expand Down Expand Up @@ -461,7 +460,7 @@ def recur(elem, scopePath):
toremove.append((elem, sElem)) # we remove it from its old place
inner.insert(-1, sElem) # Insert first in the DO loop
# then update, providing new parent in argument
updateStmt(sElem, table, kind, extraindent, inner, scopePath)
updateStmt(sElem, table, kind, extraindent, inner, scope)

elif everywhere and tag(sElem) in ('a-stmt', 'if-stmt', 'where-stmt',
'where-construct'):
Expand Down Expand Up @@ -499,7 +498,7 @@ def recur(elem, scopePath):
if arr is not None:
# In this case we transform the if statement into an if-construct
self.changeIfStatementsInIfConstructs(singleItem=sElem, parent=elem)
recur(sElem, scopePath) # to transform the content of the if
recur(sElem, scope) # to transform the content of the if
arr = None # to do nothing more on this node
elif tag(sElem) == 'where-stmt':
arr = sElem.find('./{*}mask-E//{*}named-E/{*}R-LT/{*}array-R/../..')
Expand Down Expand Up @@ -530,8 +529,7 @@ def recur(elem, scopePath):
else:
# Guess a variable name
if arr is not None:
newtable, varNew = self.getScopeNode(scopePath).findArrayBounds(
arr, loopVar, newVarList)
newtable, varNew = scope.findArrayBounds(arr, loopVar, newVarList)
for var in varNew:
var['new'] = True
if var not in newVarList:
Expand Down Expand Up @@ -567,7 +565,7 @@ def recur(elem, scopePath):
toremove.append((elem, sElem)) # we remove it from its old place
inner.insert(-1, sElem) # Insert first in the DO loop
# then update, providing new parent in argument
updateStmt(sElem, table, kind, extraindent, inner, scopePath)
updateStmt(sElem, table, kind, extraindent, inner, scope)
if not reuseLoop:
# Prevent from reusing this DO loop
inEverywhere = closeLoop(inEverywhere)
Expand All @@ -576,10 +574,11 @@ def recur(elem, scopePath):
inEverywhere = closeLoop(inEverywhere) # close loop if needed
if len(sElem) >= 1:
# Iteration
recur(sElem, scopePath)
recur(sElem, scope)
inEverywhere = closeLoop(inEverywhere)

recur(self, self.getScopePath(self))
for scope in self.getScopes(excludeContains=True):
recur(scope, scope)
# First, element insertion by reverse order (in order to keep the insertion index correct)
for elem, outer, ie in toinsert[::-1]:
elem.insert(ie, outer)
Expand Down Expand Up @@ -749,7 +748,7 @@ def setPRESENTby(node, var, val):
for namedE in callStmt.findall('./{*}arg-spec/{*}arg/{*}named-E'):
# Replace slices by indexes if any
if namedE.find('./{*}R-LT'):
self.arrayR2parensR(namedE, table, mainScope.path)
mainScope.arrayR2parensR(namedE, table)

# Deep copy the object to possibly modify the original one multiple times
node = copy.deepcopy(subContained)
Expand Down
12 changes: 5 additions & 7 deletions src/pyft/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,15 +1032,14 @@ def findIndexArrayBounds(self, arr, index, loopVar):
return None

@debugDecor
def arrayR2parensR(self, namedE, table, scopePath):
def arrayR2parensR(self, namedE, table):
"""
Transform a array-R into a parens-R node by replacing slices by variables
In 'A(:)', the ':' is in a array-R node whereas in 'A(JL)', 'JL' is in a parens-R node.
Both the array-R and the parens-R nodes are inside a R-LT node
:param namedE: a named-E node
:param table: dictionnary returned by the decode function
:param varList: None or a VarList object in which varaibles are searched for
:param scopePath: scope path in which nameE is
"""
# Before A(:): <f:named-E>
# <f:N><f:n>A</f:n></f:N>
Expand All @@ -1063,7 +1062,6 @@ def arrayR2parensR(self, namedE, table, scopePath):
# </f:R-LT>
# </f:named-E>

scope = self.getScopeNode(scopePath, excludeContains=True)
nodeRLT = namedE.find('./{*}R-LT')
arrayR = nodeRLT.find('./{*}array-R') # Not always in first position, eg: ICED%XRTMIN(:)
if arrayR is not None:
Expand Down Expand Up @@ -1099,14 +1097,14 @@ def arrayR2parensR(self, namedE, table, scopePath):
# A(2:15) or A(:15) or A(2:)
if lower is None:
# lower bound not defined, getting lower declared bound for this array
lower = scope.varList.findVar(n2name(namedE.find('{*}N')),
array=True)['as'][ivar][0]
lower = self.varList.findVar(n2name(namedE.find('{*}N')),
array=True)['as'][ivar][0]
if lower is None:
lower = '1' # default fortran lower bound
elif upper is None:
# upper bound not defined, getting upper declared bound for this array
upper = scope.varList.findVar(n2name(namedE.find('{*}N')),
array=True)['as'][ivar][1]
upper = self.varList.findVar(n2name(namedE.find('{*}N')),
array=True)['as'][ivar][1]
# If the DO loop starts from JI=I1 and goes to JI=I2; and array
# bounds are J1:J2
# We compute J1-I1+JI and J2-I2+JI and they should be the same
Expand Down

0 comments on commit abe1413

Please sign in to comment.