diff --git a/arc/mapping/driver.py b/arc/mapping/driver.py index 428b3ce98e..21d27eb0b8 100644 --- a/arc/mapping/driver.py +++ b/arc/mapping/driver.py @@ -224,7 +224,7 @@ def map_rxn(rxn: 'ARCReaction', """ r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_a_reaction(arc_reaction=rxn) - assign_labels_to_products(rxn) + assign_labels_to_products(rxn=rxn, products=rxn.get_family_products()) reactants, products = copy_species_list_for_mapping(rxn.r_species), copy_species_list_for_mapping(rxn.p_species) label_species_atoms(reactants), label_species_atoms(products) diff --git a/arc/mapping/engine.py b/arc/mapping/engine.py index def3109b0c..4afa348837 100644 --- a/arc/mapping/engine.py +++ b/arc/mapping/engine.py @@ -83,7 +83,7 @@ def pair_reaction_products(reaction: 'ARCReaction', products (List[ARCSpecies]): Species that correspond to the ARCReaction products that require pairing. Returns: - Dict[int, int]: Keys are specie indices in the ARC reaction, values are respective indices in the product list. + Dict[int, int]: Keys are species indices in the ARC reaction, values are respective indices in the product list. """ if reaction.is_isomerization(): return {0: 0} @@ -225,7 +225,7 @@ def pair_reaction_products(reaction: 'ARCReaction', # Returns: # Tuple[Dict[int, Union[List[int], int]], Dict[int, Union[List[int], int]]]: # The first tuple entry refers to reactants, the second to products. -# Keys are specie indices in the ARC reaction, +# Keys are specied indices in the ARC reaction, # values are respective indices in the RMG reaction. # If ``concatenate`` is ``True``, values are lists of integers. Otherwise, values are integers. # """ @@ -1146,33 +1146,34 @@ def get_label_dict(rxn: 'ARCReaction') -> Optional[Dict[str, int]]: return None -def assign_labels_to_products(rxn: 'ARCReaction'): +def assign_labels_to_products(rxn: 'ARCReaction', + products: List[Molecule], + ): """ Add the indices to the reactants and products. Args: rxn ('ARCReaction'): The reaction to be mapped. + products (List[Molecule]): The products generated from the RMG family with the same atom order as the reactants. Returns: Adding labels to the atoms of the reactants and products, to be identified later. """ label_dict = get_label_dict(rxn) - print(f'\n\nlabel_dict: {label_dict}\n\n') atom_index = 0 for r in rxn.r_species: for atom in r.mol.atoms: if atom_index in label_dict.values(): atom.label = key_by_val(label_dict, atom_index) atom_index += 1 - - - - + product_pairs = pair_reaction_products(reaction=rxn, products=products) atom_index = 0 - for product in rxn.p_species: - for atom in product.mol.atoms: - if atom_index in label_dict.values() and (atom.label is str or atom.label is None): - atom.label = key_by_val(label_dict, atom_index) + for product_index in range(len(products)): + rxn_product, fam_product = rxn.p_species[product_index], products[product_pairs[product_index]] + atom_map = map_two_species(spc_1=rxn_product, spc_2=fam_product, map_type='list') + for i, atom in enumerate(fam_product.atoms): + if atom_index in label_dict.values(): + rxn_product.mol.atoms[atom_map[i]].label = key_by_val(label_dict, atom_index) atom_index += 1 @@ -1185,7 +1186,7 @@ def update_xyz(spcs: List[ARCSpecies]) -> List[ARCSpecies]: spcs: the scission products that needs to be updated Returns: - new: A newely generated copies of the ARCSpecies, with updated xyz + list: A newly generated copies of the ARCSpecies, with updated xyz. """ new = list() for spc in spcs: @@ -1224,7 +1225,7 @@ def pairing_reactants_and_products_for_mapping(r_cuts: List[ARCSpecies], p_cuts: A list of the scissored species in the reactants Returns: - a list of paired reactant and products, to be sent to map_two_species. + list: Paired reactant and products, to be sent to map_two_species. """ pairs = [] for reactant_cut in r_cuts: @@ -1238,7 +1239,7 @@ def pairing_reactants_and_products_for_mapping(r_cuts: List[ARCSpecies], def map_pairs(pairs): """ - A function that maps the mached species together + A function that maps the matched species together Args: pairs: A list of the pairs of reactants and species @@ -1246,11 +1247,9 @@ def map_pairs(pairs): Returns: A list of the mapped species """ - maps = list() for pair in pairs: maps.append(map_two_species(pair[0], pair[1])) - return maps @@ -1261,11 +1260,11 @@ def label_species_atoms(spcs): Args: spcs: ARCSpecies object to be labeled. """ - index=0 + index = 0 for spc in spcs: for atom in spc.mol.atoms: atom.label = str(index) - index+=1 + index += 1 def glue_maps(maps, pairs_of_reactant_and_products): @@ -1273,11 +1272,11 @@ def glue_maps(maps, pairs_of_reactant_and_products): a function that joins together the maps from the parts of the reaction. Args: - rxn: ARCReaction that requires atom mapping maps: The list of all maps of the isomorphic cuts. + pairs_of_reactant_and_products: The pairs of the reactants and products. Returns: - an Atom Map of the compleate reaction. + list: An Atom Map of the complete reaction. """ am_dict = dict() for _map, pair in zip(maps, pairs_of_reactant_and_products): @@ -1319,18 +1318,21 @@ def determine_bdes_on_spc_based_on_atom_labels(spc: "ARCSpecies", bde: Tuple[int return False -def cut_species_based_on_atom_indices(species: List["ARCSpecies"], bdes: List[Tuple[int, int]]) -> Optional[List["ARCSpecies"]]: +def cut_species_based_on_atom_indices(species: List["ARCSpecies"], + bdes: List[Tuple[int, int]], + ) -> Optional[List["ARCSpecies"]]: """ A function for scissoring species based on their atom indices. + Args: species (List[ARCSpecies]): The species list that requires scission. bdes (List[Tuple[int, int]]): A list of the atoms between which the bond should be scissored. The atoms are described using the atom labels, and not the actuall atom positions. + Returns: Optional[List["ARCSpecies"]]: The species list input after the scission. """ if not bdes: return species - for bde in bdes: for index, spc in enumerate(species): if determine_bdes_on_spc_based_on_atom_labels(spc, bde): @@ -1351,7 +1353,6 @@ def cut_species_based_on_atom_indices(species: List["ARCSpecies"], bdes: List[Tu except SpeciesError: return None break - return species @@ -1372,6 +1373,7 @@ def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpeci def find_all_bdes(rxn: "ARCReaction", is_reactants: bool, + products: Optional[List["Molecule"]] = None, ) -> List[Tuple[int, int]]: """ A function for finding all the broken(/formed) bonds during a chemical reaction, based on the atom indices. @@ -1379,15 +1381,32 @@ def find_all_bdes(rxn: "ARCReaction", Args: rxn (ARCReaction): The reaction in question. is_reactants (bool): Whether the species list represents reactants or products. + products (List[Molecule], optional): The products generated from the RMG family with the same atom order + as the reactants. If given, the BDE values will be mapped from them + to the reaction products. Returns: List[Tuple[int, int]]: A list of tuples of the form (atom_index1, atom_index2) for each broken bond. Note that these represent the atom indices to be cut, and not final BDEs. """ label_dict = get_label_dict(rxn) + if not is_reactants: + product_pairs = pair_reaction_products(reaction=rxn, products=products) if products is not None else None + print(label_dict) bdes = list() if rxn.family is not None: for action in ReactionFamily(rxn.family).actions: - if action[0].lower() == ("break_bond" if is_reactants else "form_bond"): + print(action) + if (action[0].lower() == "break_bond" and is_reactants + or action[0].lower() == "form_bond" and not is_reactants): + print(f'appending {action[1]} and {action[3]}: {(label_dict[action[1]] + 1, label_dict[action[3]] + 1)}') bdes.append((label_dict[action[1]] + 1, label_dict[action[3]] + 1)) return bdes + + +rxn_product, fam_product = rxn.p_species[product_index], products[product_pairs[product_index]] +atom_map = map_two_species(spc_1=rxn_product, spc_2=fam_product, map_type='list') +for i, atom in enumerate(fam_product.atoms): + if atom_index in label_dict.values(): + rxn_product.mol.atoms[atom_map[i]].label = key_by_val(label_dict, atom_index) + atom_index += 1 \ No newline at end of file diff --git a/arc/mapping/engine_test.py b/arc/mapping/engine_test.py index eb2a1fbb84..f56eb4fd3c 100644 --- a/arc/mapping/engine_test.py +++ b/arc/mapping/engine_test.py @@ -10,7 +10,6 @@ from random import shuffle import itertools -from arc.common import _check_r_n_p_symbols_between_rmg_and_arc_rxns from arc.mapping.engine import * from arc.reaction import ARCReaction @@ -512,16 +511,11 @@ def test_pair_reaction_products(self): def test_assign_labels_to_products(self): """Test assigning labels to products based on the atom map of the reaction""" rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) - assign_labels_to_products(rxn_1_test) - print([atom.label for atom in rxn_1_test.p_species[0].mol.atoms]) - index = 0 - for product in rxn_1_test.p_species: - print(product.label, index) - for atom in product.mol.atoms: - if not isinstance(atom.label, str) or atom.label != "": - print(atom.label, index) - self.assertEqual(self.p_label_dict_rxn_1[atom.label], index) - index += 1 + assign_labels_to_products(rxn_1_test, rxn_1_test.get_family_products()) + self.assertEqual([atom.label for atom in rxn_1_test.r_species[0].mol.atoms], ['*3', '', '']) + self.assertEqual([atom.label for atom in rxn_1_test.r_species[1].mol.atoms], ['*1', '', '*2']) + self.assertEqual([atom.label for atom in rxn_1_test.p_species[0].mol.atoms], ['*3', '', '*1', '']) + self.assertEqual([atom.label for atom in rxn_1_test.p_species[1].mol.atoms], ['', '*2']) def test_inc_vals(self): """Test creating an atom map via map_two_species() and incrementing all values""" @@ -550,19 +544,22 @@ def test_label_species_atoms(self): def test_cut_species_based_on_atom_indices(self): """test the cut_species_for_mapping function""" rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2], - rmg_family_set=['F_Abstraction']) - reactants, products = copy_species_list_for_mapping(rxn_1_test.r_species), copy_species_list_for_mapping(rxn_1_test.p_species) + rmg_family_set=['H_Abstraction']) + reactants = copy_species_list_for_mapping(rxn_1_test.r_species) + products = copy_species_list_for_mapping(rxn_1_test.p_species) label_species_atoms(reactants), label_species_atoms(products) r_bdes, p_bdes = find_all_bdes(rxn_1_test, True), find_all_bdes(rxn_1_test, False) + print(r_bdes, p_bdes) r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes) p_cuts = cut_species_based_on_atom_indices(products, p_bdes) - - self.assertIn("C[CH]C", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) - self.assertIn("[F]", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) - self.assertIn("[CH3]", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) - self.assertIn("C[CH]C", [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts]) - self.assertIn("[F]", [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts]) - self.assertIn("[CH3]", [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts]) + print([a.mol for a in r_cuts], [a.mol for a in p_cuts]) + + self.assertIn('[C]#CF', [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) + self.assertIn('[C]#N', [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) + self.assertIn('[H]', [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) + self.assertIn('[C]#CF', [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts]) + self.assertIn('[C]#N', [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts]) + self.assertIn('[H]', [p_cut.mol.copy(deep=True).smiles for p_cut in p_cuts]) spc = ARCSpecies(label="test", smiles="CNC", bdes=[(1, 2), (2, 3)]) for i, a in enumerate(spc.mol.atoms): diff --git a/arc/reaction/reaction.py b/arc/reaction/reaction.py index d019ba6da7..7d499d3e15 100644 --- a/arc/reaction/reaction.py +++ b/arc/reaction/reaction.py @@ -2,10 +2,9 @@ A module for representing a reaction. """ -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from arkane.common import get_element_mass -from rmgpy.reaction import Reaction from rmgpy.species import Species from arc.common import get_logger @@ -20,6 +19,9 @@ from arc.mapping.driver import map_reaction from arc.species.species import ARCSpecies, check_atom_balance, check_label +if TYPE_CHECKING: + from rmgpy.molecule import Molecule + logger = get_logger() @@ -533,6 +535,22 @@ def determine_family(self, return family, family_own_reverse return None, None + def get_family_products(self) -> Optional[List['Molecule']]: + """ + Determine the RMG reaction family. + Populates the .family, and .family_own_reverse attributes. + + Returns: + Optional[List[Molecule]]: The products of the reaction with the same atom order as the reactants, + generated by the family. Currently only returning the first product list. + """ + product_dicts = get_reaction_family_products(rxn=self, + rmg_family_set=[self.family], + ) + if len(product_dicts): + return product_dicts[0]['products'] + return None + def check_attributes(self): """Check that the Reaction object is defined correctly""" self.set_label_reactants_products()