Skip to content

Commit

Permalink
Merge pull request #35 from hsorby/main
Browse files Browse the repository at this point in the history
Improvements for logging and dealing with scaffold markers
  • Loading branch information
hsorby authored Dec 3, 2024
2 parents 68740f7 + 352e13e commit b723acd
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 59 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def readfile(filename, split=False):
requires = [
# minimal requirements listing
"cmlibs.maths >= 0.3",
"cmlibs.utils >= 0.6",
"cmlibs.utils >= 0.10",
"cmlibs.zinc >= 4.0"
]
readme.extend(['', 'License', '=======', '', '::', ''])
Expand Down
130 changes: 75 additions & 55 deletions src/scaffoldfitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,26 @@
findOrCreateFieldFiniteElement, findOrCreateFieldStoredMeshLocation, getUniqueFieldName, orphanFieldByName
from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range, findNodeWithName, getMaximumNodeIdentifier
from cmlibs.utils.zinc.general import ChangeManager
from cmlibs.utils.zinc.region import write_to_buffer, read_from_buffer
from cmlibs.zinc.context import Context
from cmlibs.zinc.element import Elementbasis, Elementfieldtemplate
from cmlibs.zinc.field import Field, FieldFindMeshLocation, FieldGroup
from cmlibs.zinc.result import RESULT_OK, RESULT_WARNING_PART_DONE

from scaffoldfitter.fitterexceptions import FitterModelCoordinateField
from scaffoldfitter.fitterstep import FitterStep
from scaffoldfitter.fitterstepconfig import FitterStepConfig
from scaffoldfitter.fitterstepfit import FitterStepFit


def _next_available_identifier(node_set, candidate):
node = node_set.findNodeByIdentifier(candidate)
while node.isValid():
candidate += 1
node = node_set.findNodeByIdentifier(candidate)
return candidate


class Fitter:

def __init__(self, zincModelFileName: str, zincDataFileName: str):
Expand Down Expand Up @@ -452,47 +463,34 @@ def _loadData(self):
if nodes.getSize() > 0:
datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
if datapoints.getSize() > 0:
maximumDatapointIdentifier = max(0, getMaximumNodeIdentifier(datapoints))
maximumNodeIdentifier = max(0, getMaximumNodeIdentifier(nodes))
# this assumes identifiers are in low ranges and can be improved if there is a problem:
identifierOffset = 100000
while (maximumDatapointIdentifier > identifierOffset) or (maximumNodeIdentifier > identifierOffset):
assert identifierOffset < 1000000000, "Invalid node and datapoint identifier ranges"
identifierOffset *= 10
while True:
# logic relies on datapoints being in identifier order
datapoint = datapoints.createNodeiterator().next()
identifier = datapoint.getIdentifier()
if identifier >= identifierOffset:
break
result = datapoint.setIdentifier(identifier + identifierOffset)
assert result == RESULT_OK, "Failed to offset datapoint identifier"
datapoint_iterator = datapoints.createNodeiterator()
datapoint = datapoint_iterator.next()
latest = 1
datapoint_new_identifier_map = {}
while datapoint.isValid():
identifier = _next_available_identifier(nodes, latest)
datapoint_new_identifier_map[identifier] = datapoint
latest = identifier + 1
datapoint = datapoint_iterator.next()

for new_identifier, datapoint in datapoint_new_identifier_map.items():
datapoint.setIdentifier(new_identifier)

# transfer nodes as datapoints to self._region
sir = self._rawDataRegion.createStreaminformationRegion()
srm = sir.createStreamresourceMemory()
sir.setResourceDomainTypes(srm, Field.DOMAIN_TYPE_NODES)
self._rawDataRegion.write(sir)
result, buffer = srm.getBuffer()
assert result == RESULT_OK, "Failed to write nodes"
buffer = write_to_buffer(self._rawDataRegion, resource_domain_type=Field.DOMAIN_TYPE_NODES)
assert buffer is not None, "Failed to write nodes"
buffer = buffer.replace(bytes("!#nodeset nodes", "utf-8"), bytes("!#nodeset datapoints", "utf-8"))
sir = self._region.createStreaminformationRegion()
sir.createStreamresourceMemoryBuffer(buffer)
result = self._region.read(sir)
result = read_from_buffer(self._region, buffer)
if result != RESULT_OK:
self.printLog()
print("Node to datapoints log:")
self.print_log()
raise AssertionError("Failed to load nodes as datapoints")
# transfer datapoints to self._region
sir = self._rawDataRegion.createStreaminformationRegion()
srm = sir.createStreamresourceMemory()
sir.setResourceDomainTypes(srm, Field.DOMAIN_TYPE_DATAPOINTS)
self._rawDataRegion.write(sir)
result, buffer = srm.getBuffer()
assert result == RESULT_OK, "Failed to write datapoints"
sir = self._region.createStreaminformationRegion()
sir.createStreamresourceMemoryBuffer(buffer)
result = self._region.read(sir)
buffer = write_to_buffer(self._rawDataRegion, resource_domain_type=Field.DOMAIN_TYPE_DATAPOINTS)
assert buffer is not None, "Failed to write datapoints"
result = read_from_buffer(self._region, buffer)
if result != RESULT_OK:
self.printLog()
self.print_log()
raise AssertionError("Failed to load datapoints, result " + str(result))
self._discoverDataCoordinatesField()
self._discoverMarkerGroup()
Expand Down Expand Up @@ -1027,6 +1025,27 @@ def setModelCoordinatesField(self, modelCoordinatesField: Field):
def setModelCoordinatesFieldByName(self, modelCoordinatesFieldName):
self.setModelCoordinatesField(self._fieldmodule.findFieldByName(modelCoordinatesFieldName))

def _find_first_coordinate_type_field(self):
field = None

mesh = self.getHighestDimensionMesh()
element = mesh.createElementiterator().next()
if element.isValid():
fieldcache = self._fieldmodule.createFieldcache()
fieldcache.setElement(element)
fielditer = self._fieldmodule.createFielditerator()
field = fielditer.next()
while field.isValid():
if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and \
(field.castFiniteElement().isValid()):
if field.isDefinedAtLocation(fieldcache):
break
field = fielditer.next()
else:
field = None

return field

def _discoverModelCoordinatesField(self):
"""
Choose default modelCoordinates field.
Expand All @@ -1036,24 +1055,14 @@ def _discoverModelCoordinatesField(self):
field = None
if self._modelCoordinatesFieldName:
field = self._fieldmodule.findFieldByName(self._modelCoordinatesFieldName)
else:
mesh = self.getHighestDimensionMesh()
element = mesh.createElementiterator().next()
if element.isValid():
fieldcache = self._fieldmodule.createFieldcache()
fieldcache.setElement(element)
fielditer = self._fieldmodule.createFielditerator()
field = fielditer.next()
while field.isValid():
if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and \
(field.castFiniteElement().isValid()):
if field.isDefinedAtLocation(fieldcache):
break
field = fielditer.next()
else:
field = None
if field:

if field is None or not field.isValid():
field = self._find_first_coordinate_type_field()

if field and field.isValid():
self.setModelCoordinatesField(field)
else:
raise FitterModelCoordinateField("No coordinate field found for model.")

def getModelFitGroup(self):
return self._modelFitGroup
Expand Down Expand Up @@ -1510,11 +1519,22 @@ def getHighestDimensionMesh(self):
return mesh
return None

def printLog(self):
def _log_message_type_to_text(self, message_type):
# 'MESSAGE_TYPE_ERROR', 'MESSAGE_TYPE_INFORMATION', 'MESSAGE_TYPE_INVALID', 'MESSAGE_TYPE_WARNING'
if self._logger.MESSAGE_TYPE_ERROR == message_type:
return "Error"
if self._logger.MESSAGE_TYPE_INFORMATION == message_type:
return "Information"
if self._logger.MESSAGE_TYPE_WARNING == message_type:
return "Warning"

return "Invalid"

def print_log(self):
loggerMessageCount = self._logger.getNumberOfMessages()
if loggerMessageCount > 0:
for i in range(1, loggerMessageCount + 1):
print(self._logger.getMessageTypeAtIndex(i), self._logger.getMessageTextAtIndex(i))
print(f"[Message {i}] {self._log_message_type_to_text(self._logger.getMessageTypeAtIndex(i))}: {self._logger.getMessageTextAtIndex(i)}")
self._logger.removeAllMessages()

def getDiagnosticLevel(self):
Expand Down Expand Up @@ -1551,7 +1571,7 @@ def writeModel(self, modelFileName):
if self._modelFitGroup:
sir.setResourceGroupName(srf, self._modelFitGroup.getName())
result = self._region.write(sir)
# self.printLog()
# self.print_log()

# restore original name
self._modelCoordinatesField.setName(self._modelCoordinatesFieldName)
Expand Down
3 changes: 3 additions & 0 deletions src/scaffoldfitter/fitterexceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

class FitterModelCoordinateField(Exception):
pass
6 changes: 3 additions & 3 deletions src/scaffoldfitter/fitterstepfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def run(self, modelFileNameStem=None):
fieldcache, flattenGroupObjective.getNumberOfComponents())
print(" Flatten group objective", objectiveFormat.format(objective))
if self.getDiagnosticLevel() > 1:
self._fitter.printLog()
self._fitter.print_log()

if self._updateReferenceState:
self._fitter.updateModelReferenceCoordinates()
Expand Down Expand Up @@ -449,7 +449,7 @@ def createDeformationPenaltyObjectiveField(self, deformActiveMeshGroup, strainAc
# convert to local fibre directions, with possible dimension reduction for 2D, 1D
fibreAxes = fieldmodule.createFieldFibreAxes(fibreField, modelReferenceCoordinates)
if not fibreAxes.isValid():
self.getFitter().printLog()
self.getFitter().print_log()
if dimension == 3:
fibreAxesT = fieldmodule.createFieldTranspose(3, fibreAxes)
elif dimension == 2:
Expand Down Expand Up @@ -506,7 +506,7 @@ def createDeformationPenaltyObjectiveField(self, deformActiveMeshGroup, strainAc
deformationTerm = \
(deformationTerm + wtSqDeformationGradient2) if deformationTerm else wtSqDeformationGradient2
if not deformationTerm.isValid():
self.getFitter().printLog()
self.getFitter().print_log()
raise AssertionError("Scaffoldfitter: Failed to get deformation term")

deformationPenaltyObjective = fieldmodule.createFieldMeshIntegral(
Expand Down

0 comments on commit b723acd

Please sign in to comment.