Skip to content

Commit

Permalink
Check origin labels not origin ids during parsing tests. Remove check…
Browse files Browse the repository at this point in the history
…s for number of expected entries. Refactor atom_labels field only contains non-i_seq labels.
  • Loading branch information
cschlick committed Oct 11, 2024
1 parent 8442282 commit 5203ade
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 191 deletions.
82 changes: 32 additions & 50 deletions mmtbx/geometry_restraints/geo_file_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Entry:
"""
name = None # Name or label for the class

def __init__(self,lines,origin_id=0):
def __init__(self,lines,origin_id=0,origin_label="covalent"):
"""
An entry is initialized first, then data is added with entry.lines.append()
Expand All @@ -41,8 +41,8 @@ def __init__(self,lines,origin_id=0):
self.i_seqs = [] # list of integer i_seqs (if possible)
self.atom_labels = [] # list of string atom label from .geo
self._numerical = None # a dict of numerical geo data
self.labels_are_i_seqs = None # boolean, atom labels are i_seqs
self.origin_id = origin_id
self.origin_label = origin_label

# Initialize result data structures
self._proxy = None
Expand All @@ -51,8 +51,17 @@ def __init__(self,lines,origin_id=0):
self._prepare()

# Check if labels are i_seqs (integers)
self.labels_are_i_seqs, self.i_seqs = self._check_labels_are_i_seqs(self.atom_labels)
labels_are_i_seqs, self.i_seqs = self._check_labels_are_i_seqs(self.atom_labels)
if labels_are_i_seqs:
self.atom_labels = []

@property
def labels_are_available(self):
return len(self.atom_labels)>0
@property
def i_seqs_are_available(self):
return len(self.i_seqs)>0

def _prepare(self):
# Parse Atom labels
values = []
Expand Down Expand Up @@ -84,10 +93,6 @@ def _check_labels_are_i_seqs(self,atom_labels):
i_seqs = []
return check, i_seqs

@property
def has_i_seqs(self):
return len(self.i_seqs)>0

@property
def record(self):
"""
Expand Down Expand Up @@ -126,7 +131,7 @@ def proxy(self):
"""
Only create a proxy object if necessary, and if so only do it once
"""
if not self._proxy and self.has_i_seqs:
if not self._proxy and self.i_seqs_are_available:
self._proxy = self.to_proxy()
return self._proxy

Expand Down Expand Up @@ -290,7 +295,10 @@ def __init__(self,*args,**kwargs):
self.atom_labels_i = []
self.atom_labels_j = []
super().__init__(*args,**kwargs)
self.labels_are_i_seqs, self.i_seqs = self._check_labels_are_i_seqs(self.atom_labels)
labels_are_i_seqs, self.i_seqs = self._check_labels_are_i_seqs(self.atom_labels)
if labels_are_i_seqs:
self.atom_labels_i = []
self.atom_labels_j = []
self.i_seqs, self.j_seqs = self.i_seqs # unpack tuple


Expand Down Expand Up @@ -446,26 +454,22 @@ class GeoParser:
"""


def __init__(self,geo_lines,model=None,entry_config=None,strict_counts=True):
def __init__(self,geo_lines,model=None,entry_config=None):
"""
Initialize with a list of Entry subclasses
Params:
geo_lines (list): List of line strings from .geo file
model (mmtbx.model.manager): Optional model file, a source of atom_label: i_seq matching.
entry_config (dict): Configuration dict
strict_counts (bool): Whether or not to enforce # of entries consistent with .geo header
"""

# Set initial arguments
self.entry_config = (entry_config if entry_config else entry_config_default)
self.model = model
self.strict_counts = strict_counts

# Initialize parsing variables
self.lines = geo_lines + ["\n"]
self.expected_counts = defaultdict(int)
self.actual_counts = defaultdict(int)

# Initialize result variables
self.entries = defaultdict(list) # Entry instances
Expand All @@ -474,8 +478,6 @@ def __init__(self,geo_lines,model=None,entry_config=None,strict_counts=True):

# Parse the file
self._parse()
if self.strict_counts:
self._validate_counts()

# If model present, add i_seqs
if self.model:
Expand All @@ -489,11 +491,15 @@ def proxies(self):
def has_proxies(self):
return self.proxies is not None

@property
def i_seqs_are_available(self):
return all([entry.i_seqs_are_available for entry in self.entries_list])

def _fill_labels_from_model(self, model):
"""
Add i_seq attributes for each atom on each entry
"""
if not self.labels_are_i_seqs:
if not self.i_seqs_are_available:
# make i_seq:id_str mapping
map_idstr_to_iseq = {
atom.id_str():atom.i_seq for atom in model.get_atoms()}
Expand Down Expand Up @@ -537,9 +543,11 @@ def _end_entry(self, entry_start_line_number, i, entries_info):
lines_for_entry = self.lines[entry_start_line_number:i]
new_entry = entries_info.entry_class(
lines=lines_for_entry,
origin_id=entries_info.origin_id)
origin_id=entries_info.origin_id,
origin_label=entries_info.origin_label,
)
# add new_entry to somewhere
self.entries[entries_info.entry_class.name].append(new_entry)
self.entries[entries_info.origin_header].append(new_entry)
return i

def _parse(self):
Expand All @@ -553,16 +561,16 @@ def _parse(self):
result = origin_ids.get_origin_label_and_internal(l)
if result:
# if recognized as header, unpack result, store in entries_info
origin_id, header, _, num = result
origin_id, header, label, num = result
entry_class = self.entry_config[header]["entry_class"]
entry_trigger = self.entry_config[header]["entry_trigger"]
entries_info = group_args(
entry_class = entry_class,
origin_id = origin_id,
origin_label = label,
origin_header = header,
entry_trigger = entry_trigger,
)
info_key = str((entry_class.name,origin_id))
self.expected_counts[info_key] = num
entry_start_line_number = -1

elif l.startswith("Sorted by"):
Expand All @@ -578,29 +586,6 @@ def _parse(self):
# entry continues, do nothing
pass

def _validate_counts(self):
"""
Verify numbers match .geo headers, fail assertion otherwise.
If failed, will print expected and actual counts.
"""

# Populate actual counts
for header_name,entries in self.entries.items():
for entry in entries:
info_key = str((header_name, entry.origin_id))
self.actual_counts[info_key]+=1

if not self.expected_counts == self.actual_counts:
print("Validate counts error:")
print()
print("Expected counts:")
print(json.dumps(self.expected_counts,indent=2))
print()
print("Actual counts:")
print(json.dumps(self.actual_counts,indent=2))
assert False, "Expected and actual counts for (restraint header name, origin_id) pairs do not match."

@property
def records(self):
"""
Expand All @@ -616,23 +601,20 @@ def records(self):
self._records = record_dict
return self._records

@property
def labels_are_i_seqs(self):
return all([entry.labels_are_i_seqs for entry in self.entries_list])

def build_proxies(self):
"""
Convert the entry objects to cctbx proxy objects.
Collect into a dict of lists of proxies.
"""
if not self.model and not self.labels_are_i_seqs:
if not self.model and not self.i_seqs_are_available:
raise Sorry("Cannot build proxies without instantiating with a model.")
self._proxies = defaultdict(list)
for entries in self.entries.values():
if len(entries)>0:
entry_class = entries[0].__class__
if hasattr(entry_class,"to_proxy") and not hasattr(entry_class,"ignore"):
self._proxies[entry_class.name] = [entry.proxy for entry in entries]
self._proxies[entry_class.__name__] = [entry.proxy for entry in entries]
return self.proxies


Expand Down
Loading

0 comments on commit 5203ade

Please sign in to comment.