diff --git a/t3/main.py b/t3/main.py index fd2a141f..757976e0 100755 --- a/t3/main.py +++ b/t3/main.py @@ -45,7 +45,12 @@ from arc.species.species import ARCSpecies, check_label from arc.species.converter import check_xyz_dict -from t3.common import PROJECTS_BASE_PATH, VALID_CHARS, delete_root_rmg_log, get_species_by_label, time_lapse +from t3.common import (DATA_BASE_PATH, + PROJECTS_BASE_PATH, + VALID_CHARS, + delete_root_rmg_log, + get_species_by_label, + time_lapse) from t3.logger import Logger from t3.runners.rmg_runner import rmg_runner from t3.schema import InputBase @@ -177,6 +182,7 @@ def __init__(self, self.project_directory = self.schema['project_directory'] self.t3 = self.schema['t3'] self.rmg = self.schema['rmg'] + self.rmg['database'] = auto_complete_rmg_libraries(database=self.rmg['database']) self.qm = self.schema['qm'] self.verbose = self.schema['verbose'] @@ -1538,3 +1544,34 @@ def get_species_with_qm_label(species: Species, rmg_species=qm_species, ) return qm_species + + +def auto_complete_rmg_libraries(database: dict) -> dict: + """ + Update the RMG libraries using auto-completion. + + Args: + database (dict): The RMG libraries dictionary. + + Returns: + dict: The updated RMG libraries dictionary. + """ + database['thermo_libraries'] = database['thermo_libraries'] or list() + database['kinetics_libraries'] = database['kinetics_libraries'] or list() + if database['chemistry_sets'] is not None: + libraries_dict = read_yaml_file(path=os.path.join(DATA_BASE_PATH, 'libraries.yml')) + low_credence = database['use_low_credence_libraries'] + for chemistry_set in database['chemistry_sets']: + if chemistry_set not in libraries_dict: + raise ValueError(f"Chemistry set '{chemistry_set}' not found in the libraries.yml file.") + for key, libraries in zip(['thermo', 'kinetics'], [database['thermo_libraries'], database['kinetics_libraries']]): + if key in libraries_dict[chemistry_set]: + for entry in libraries_dict[chemistry_set][key]: + if isinstance(entry, str) and entry not in libraries: + libraries.append(entry) + elif isinstance(entry, dict): + if entry['credence'] == 'low' and low_credence or entry['credence'] != 'low' and entry['name'] not in libraries: + libraries.append(entry['name']) + del database['chemistry_sets'] + del database['use_low_credence_libraries'] + return database