Skip to content

Commit

Permalink
Use output consistently throughout main (#760)
Browse files Browse the repository at this point in the history
Solves a KeyError issue with an empty output dict in main, especially
after ARC cannot find a TS for a reaction.
  • Loading branch information
calvinp0 authored Aug 11, 2024
2 parents df19d8f + 8cb54d1 commit a3f79be
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
14 changes: 8 additions & 6 deletions arc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,11 @@ def execute(self) -> dict:
ts_adapters=self.ts_adapters,
report_e_elect=self.report_e_elect,
skip_nmd=self.skip_nmd,
output=self.output,
)

save_yaml_file(path=os.path.join(self.project_directory, 'output', 'status.yml'), content=self.scheduler.output)
self.output = self.scheduler.output
save_yaml_file(path=os.path.join(self.project_directory, 'output', 'status.yml'), content=self.output)

if not self.keep_checks:
delete_check_files(self.project_directory)
Expand All @@ -650,7 +652,7 @@ def execute(self) -> dict:
project_directory=self.project_directory,
species_dict=self.scheduler.species_dict,
reactions=self.scheduler.rxn_list,
output_dict=self.scheduler.output,
output_dict=self.output,
bac_type=self.bac_type,
freq_scale_factor=self.freq_scale_factor,
compute_thermo=self.compute_thermo,
Expand Down Expand Up @@ -708,7 +710,7 @@ def save_project_info_file(self):
txt += '\nConsidered the following species and TSs:\n'
for species in self.species:
descriptor = 'TS' if species.is_ts else 'Species'
failed = '' if self.scheduler.output[species.label]['convergence'] else ' (Failed!)'
failed = '' if self.output[species.label]['convergence'] else ' (Failed!)'
txt += f'{descriptor} {species.label}{failed} (run time: {species.run_time})\n'
if self.reactions:
for rxn in self.reactions:
Expand All @@ -730,14 +732,14 @@ def save_project_info_file(self):
if not species.is_ts:
spc_dict = dict()
spc_dict['label'] = species.label
spc_dict['success'] = self.scheduler.output[species.label]['convergence']
spc_dict['success'] = self.output[species.label]['convergence']
spc_dict['smiles'] = species.mol.copy(deep=True).to_smiles() if species.mol is not None else None
spc_dict['adj'] = species.mol.copy(deep=True).to_adjacency_list() if species.mol is not None else None
content['species'].append(spc_dict)
for reaction in self.reactions:
rxn_dict = dict()
rxn_dict['label'] = reaction.label
rxn_dict['success'] = self.scheduler.output[reaction.ts_species.label]['convergence']
rxn_dict['success'] = self.output[reaction.ts_species.label]['convergence']
content['reactions'].append(rxn_dict)
save_yaml_file(path=path, content=content)

Expand All @@ -750,7 +752,7 @@ def summary(self) -> dict:
"""
status_dict = {}
logger.info(f'\n\n\nAll jobs terminated. Summary for project {self.project}:\n')
for label, output in self.scheduler.output.items():
for label, output in self.output.items():
if output['convergence']:
status_dict[label] = True
logger.info(f'Species {label} converged successfully\n')
Expand Down
4 changes: 3 additions & 1 deletion arc/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class Scheduler(object):
ts_adapters (list, optional): Entries represent different TS adapters.
report_e_elect (bool, optional): Whether to report electronic energy. Default is ``False``.
skip_nmd (bool, optional): Whether to skip normal mode displacement check. Default is ``False``.
output (dict, optional): Output dictionary with status per job type and final QM file paths for all species.
Attributes:
project (str): The project's name. Used for naming the working directory.
Expand Down Expand Up @@ -257,6 +258,7 @@ def __init__(self,
ts_adapters: List[str] = None,
report_e_elect: Optional[bool] = False,
skip_nmd: Optional[bool] = False,
output: Optional[dict] = None,
) -> None:

self.project = project
Expand Down Expand Up @@ -287,7 +289,7 @@ def __init__(self,
self.freq_scale_factor = freq_scale_factor
self.ts_adapters = ts_adapters if ts_adapters is not None else default_ts_adapters
self.ts_adapters = [ts_adapter.lower() for ts_adapter in self.ts_adapters]
self.output = dict()
self.output = output or dict()
self.output_multi_spc = dict()
self.report_e_elect = report_e_elect
self.skip_nmd = skip_nmd
Expand Down

0 comments on commit a3f79be

Please sign in to comment.