Skip to content

Commit

Permalink
Cleanup, typos
Browse files Browse the repository at this point in the history
  • Loading branch information
olegsobolev committed Oct 8, 2024
1 parent 8a9c476 commit 6e135cd
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 72 deletions.
5 changes: 3 additions & 2 deletions cctbx/geometry_restraints/linking_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,13 @@ def get_origin_label_and_internal(self, query_header, verbose=False):
Line from a .geo file
Returns:
None if not a header line
If heasder:
If header:
tuple of:
Origin_id :
Bond type : Bond, Bond angle, ...
Subtype : A string related mostly to the cif_link used but also
origin_id
Number of restraints: int - number from the header.
'''
if verbose:
for origin_label, info in self.data.items():
Expand Down
127 changes: 57 additions & 70 deletions mmtbx/geometry_restraints/geo_file_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Entry:
"""
name = None # Name or label for the class

def __init__(self,lines,line_idx=None,origin_id=0,origin_label='covalent'):
def __init__(self,lines,origin_id=0):
"""
An entry is initialized first, then data is added with entry.lines.append()
Expand All @@ -42,7 +42,6 @@ def __init__(self,lines,line_idx=None,origin_id=0,origin_label='covalent'):
self.atom_labels = [] # list of string atom label from .geo
self._numerical = None # a dict of numerical geo data
self.labels_are_i_seqs = None # boolean, atom labels are i_seqs
self.labels_are_id_strs = None # boolean, atom labels are id_strs
self.origin_id = origin_id

# Initialize result data structures
Expand Down Expand Up @@ -72,14 +71,14 @@ def _prepare(self):
# Numerical values
values = self.lines[-1]
numerical_values = values.split()
numerical_values = [self._coerce_type(v) for v in numerical_values]
numerical_values = [coerce_type(v) for v in numerical_values]
self._numerical = dict(zip(numerical_labels,numerical_values))

def _check_labels_are_i_seqs(self,atom_labels):
"""
If all the labels are integers, assume i_seqs
"""
i_seqs = [self._try_int(v) for v in atom_labels]
i_seqs = [try_int(v) for v in atom_labels]
check = all([isinstance(i_seq,int) for i_seq in i_seqs])
if not check:
i_seqs = []
Expand Down Expand Up @@ -131,48 +130,6 @@ def proxy(self):
self._proxy = self.to_proxy()
return self._proxy


def _try_int(self, val):
"""
Try to convert to int
"""
try:
return int(val)
except (ValueError, TypeError):
return None

def _try_float(self, val):
"""
Try to convert to flaot
"""
try:
return float(val)
except (ValueError, TypeError):
return None

def _try_numeric(self,val):
"""
Try to convert input to 1) int, 2) float
Otherwise return None
"""
out = self._try_int(val)
if out:
return out
out = self._try_float(val)
if out:
return out
return None

def _coerce_type(self,val):
"""
If input can be numeric, convert it
Else return unmodified
"""
out = self._try_numeric(val)
if out:
return out
return val # input (probably string)

### Start of Entry subclasses

class NonBondedEntry(Entry):
Expand Down Expand Up @@ -297,7 +254,7 @@ def _prepare(self):
nums_T = list(map(list, zip(*nums))) # transpose
numerical_values = nums_T

numerical_values = [[self._coerce_type(v) for v in vals] for vals in numerical_values]
numerical_values = [[coerce_type(v) for v in vals] for vals in numerical_values]
self._numerical = dict(zip(numerical_labels,numerical_values))

@property
Expand Down Expand Up @@ -393,7 +350,7 @@ def _prepare(self):
num_labels = shlex.split(line0[num_idx:])
num_values = all_parts[0][2:]

num_values = [self._coerce_type(v) for v in num_values]
num_values = [coerce_type(v) for v in num_values]
self._numerical = dict(zip(num_labels,num_values))

def _check_labels_are_i_seqs(self,*args):
Expand All @@ -406,10 +363,6 @@ def _check_labels_are_i_seqs(self,*args):
check = check_i and check_j
return check, (i_seqs, j_seqs)

if not check:
i_seqs = []
return check, i_seqs

@property
def n_atoms(self):
return len(self.atom_labels_i) + len(sel.atom_labels_j)
Expand Down Expand Up @@ -511,10 +464,6 @@ def __init__(self,geo_lines,model=None,entry_config=None,strict_counts=True):

# Initialize parsing variables
self.lines = geo_lines + ["\n"]
self.current_entry= None
self.current_entry_class = None
self.current_origin_id = 0
self.current_origin_label = 'covalent'
self.expected_counts = defaultdict(int)
self.actual_counts = defaultdict(int)

Expand Down Expand Up @@ -546,21 +495,17 @@ def _fill_labels_from_model(self, model):
"""
if not self.labels_are_i_seqs:
# make i_seq:id_str mapping
map_iseq_to_idstr = {
atom.i_seq:atom.id_str()
for atom in model.get_atoms()}

# Reverse
map_idstr_to_iseq = {id_str:i_seq for i_seq,id_str in map_iseq_to_idstr.items()}
map_idstr_to_iseq = {
atom.id_str():atom.i_seq for atom in model.get_atoms()}

for entry in self.entries_list:
# Case where entry was loaded with id_strs, can load i_seqs
entry.labels_are_id_strs = True
labels_are_id_strs = True
for val in entry.atom_labels:
if val not in map_idstr_to_iseq:
entry.labels_are_id_strs = False
labels_are_id_strs = False

if entry.labels_are_id_strs:
if labels_are_id_strs:
entry.i_seqs = [map_idstr_to_iseq[idstr] for idstr in entry.atom_labels]

@property
Expand Down Expand Up @@ -600,15 +545,15 @@ def _end_entry(self, entry_start_line_number, i, entries_info):
def _parse(self):
entries_info=None # entries_info = group_args(entry_class, origin_id, entry_trigger)
entry_start_line_number = -1
for i, l in enumerate(self.lines + ["\n"]):
for i, l in enumerate(self.lines):
if entries_info is None:
# new block starts because we don't know entries_info

# Query linking class with line
result = origin_ids.get_origin_label_and_internal(l)
if result:
# if recognized as header, unpack result, store in entries_info
origin_id, header, restraint_class, num = result
origin_id, header, _, num = result
entry_class = self.entry_config[header]["entry_class"]
entry_trigger = self.entry_config[header]["entry_trigger"]
entries_info = group_args(
Expand All @@ -617,7 +562,7 @@ def _parse(self):
entry_trigger = entry_trigger,
)
info_key = str((entry_class.name,origin_id))
self.expected_counts[info_key]+=num
self.expected_counts[info_key] = num
entry_start_line_number = -1

elif l.startswith("Sorted by"):
Expand All @@ -637,7 +582,7 @@ def _validate_counts(self):
"""
Verify numbers match .geo headers, fail assertion otherwise.
I failed, will print expected and actual counts.
If failed, will print expected and actual counts.
"""

# Populate actual counts
Expand Down Expand Up @@ -668,7 +613,7 @@ def records(self):
for entry_name, entries in self.entries.items():
if len(entries)>0:
record_dict[entry_name] = [entry.record for entry in entries]
self._records = record_dict
self._records = record_dict
return self._records

@property
Expand All @@ -689,3 +634,45 @@ def build_proxies(self):
if hasattr(entry_class,"to_proxy") and not hasattr(entry_class,"ignore"):
self._proxies[entry_class.name] = [entry.proxy for entry in entries]
return self.proxies


def try_int(val):
"""
Try to convert to int
"""
try:
return int(val)
except (ValueError, TypeError):
return None

def try_float(val):
"""
Try to convert to flaot
"""
try:
return float(val)
except (ValueError, TypeError):
return None

def try_numeric(val):
"""
Try to convert input to 1) int, 2) float
Otherwise return None
"""
out = try_int(val)
if out:
return out
out = try_float(val)
if out:
return out
return None

def coerce_type(val):
"""
If input can be numeric, convert it
Else return unmodified
"""
out = try_numeric(val)
if out:
return out
return val # input (probably string)

0 comments on commit 6e135cd

Please sign in to comment.