Skip to content

Commit

Permalink
Merge pull request #154 from ReactionMechanismGenerator/label_fix
Browse files Browse the repository at this point in the history
Improve parsing parameters from RMG SA csv files
  • Loading branch information
alongd authored Jul 8, 2024
2 parents 0a07454 + b525380 commit 56a236c
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 22 deletions.
46 changes: 46 additions & 0 deletions t3/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,49 @@ def get_chem_to_rmg_rxn_index_map(chem_annotated_path: str) -> Dict[int, int]:
splits = line.split()
rxn_map[int(splits[4].split('#')[1].split(';')[0])] = int(splits[-1].split('#')[1])
return rxn_map


def get_observable_label_from_header(header: str) -> str:
"""
Get the observable label from a header in an RMG SA csv file.
Args:
header (str): The header from an RMG SA csv file.
Returns:
str: The observable label.
"""
return header.split('[')[1].split(']')[0]


def get_parameter_from_header(header: str) -> Optional[str]:
"""
Get the parameter label from a header in an RMG SA csv file.
parameter extraction examples:
for species get 'C2H4(8)' from `dln[ethane(1)]/dG[C2H4(8)]`
for reaction, get k8 from `dln[ethane(1)]/dln[k8]: H(6)+ethane(1)=H2(12)+C2H5(5)`
Args:
header (str): The header from an RMG SA csv file.
Returns:
Optional[str]: The parameter label.
"""
start_pos = header.find('/dG[')
if start_pos == -1:
start_pos = header.find('/dln[')
if start_pos == -1:
return None
start_pos += len('/dG[') if '/dG[' in header else len('/dln[')
bracket_count = 1
text = []
for i in range(start_pos, len(header)):
if header[i] == '[':
bracket_count += 1
elif header[i] == ']':
bracket_count -= 1
if bracket_count == 0:
break
text.append(header[i])
return ''.join(text)

16 changes: 8 additions & 8 deletions t3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,16 +750,16 @@ def determine_species_based_on_sa(self) -> List[int]:
return species_keys

sa_dict_max = {'kinetics': dict(), 'thermo': dict()}
for key in ['kinetics', 'thermo']:
for observable_label in self.sa_dict[key].keys():
if observable_label not in sa_dict_max[key]:
sa_dict_max[key][observable_label] = list()
for parameter in self.sa_dict[key][observable_label].keys():
for sa_dict_key in ['kinetics', 'thermo']:
for observable_label in self.sa_dict[sa_dict_key]:
if observable_label not in sa_dict_max[sa_dict_key]:
sa_dict_max[sa_dict_key][observable_label] = list()
for parameter in self.sa_dict[sa_dict_key][observable_label]:
entry = dict()
entry['parameter'] = parameter # rxn number (int) or spc label (str)
entry['max_sa'] = max(self.sa_dict[key][observable_label][parameter].max(),
abs(self.sa_dict[key][observable_label][parameter].min()))
sa_dict_max[key][observable_label].append(entry)
entry['max_sa'] = max(self.sa_dict[sa_dict_key][observable_label][parameter].max(),
abs(self.sa_dict[sa_dict_key][observable_label][parameter].min()))
sa_dict_max[sa_dict_key][observable_label].append(entry)

for observable_label, sa_list in sa_dict_max['kinetics'].items():
sa_list_sorted = sorted(sa_list, key=lambda item: item['max_sa'], reverse=True)
Expand Down
13 changes: 5 additions & 8 deletions t3/simulate/cantera_constantHP.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rmgpy.tools.canteramodel import generate_cantera_conditions
from rmgpy.tools.data import GenericData

from t3.common import get_observable_label_from_header, get_parameter_from_header
from t3.logger import Logger
from t3.simulate.adapter import SimulateAdapter
from t3.simulate.factory import register_simulate_adapter
Expand Down Expand Up @@ -400,23 +401,19 @@ def get_sa_coefficients(self):

# extract kinetic SA
for rxn in reaction_sensitivity_data:
# for kinetics, get `ethane(1)` from `dln[ethane(1)]/dln[k8]: H(6)+ethane(1)=H2(12)+C2H5(5)`
observable_label = rxn.label.split('[')[1].split(']')[0]
observable_label = get_observable_label_from_header(rxn)
if observable_label not in sa_dict['kinetics']:
sa_dict['kinetics'][observable_label] = dict()
# for kinetics, get k8 from `dln[ethane(1)]/dln[k8]: H(6)+ethane(1)=H2(12)+C2H5(5)` then only extract 8
parameter = rxn.label.split('[')[2].split(']')[0]
parameter = get_parameter_from_header(rxn)
parameter = int(parameter[1:])
sa_dict['kinetics'][observable_label][parameter] = rxn.data

# extract thermo SA
for spc in thermodynamic_sensitivity_data:
# for thermo get 'C2H4(8)' from `dln[ethane(1)]/dH[C2H4(8)]`
observable_label = spc.label.split('[')[1].split(']')[0]
observable_label = get_observable_label_from_header(spc)
if observable_label not in sa_dict['thermo']:
sa_dict['thermo'][observable_label] = dict()
# for thermo get 'C2H4(8)' from `dln[ethane(1)]/dH[C2H4(8)]`
parameter = spc.label.split('[')[2].split(']')[0]
parameter = get_parameter_from_header(spc)
sa_dict['thermo'][observable_label][parameter] = spc.data

return sa_dict
Expand Down
11 changes: 5 additions & 6 deletions t3/simulate/rmg_constantTP.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import itertools
import os
import pandas as pd
import re
import shutil
from typing import List, Optional

Expand All @@ -18,7 +19,8 @@
from rmgpy.tools.loader import load_rmg_py_job
from rmgpy.tools.plot import plot_sensitivity

from t3.common import get_chem_to_rmg_rxn_index_map, get_species_by_label, get_values_within_range, time_lapse
from t3.common import get_chem_to_rmg_rxn_index_map, get_species_by_label, get_values_within_range, \
get_observable_label_from_header, get_parameter_from_header, time_lapse
from t3.logger import Logger
from t3.simulate.adapter import SimulateAdapter
from t3.simulate.factory import register_simulate_adapter
Expand Down Expand Up @@ -256,17 +258,14 @@ def get_sa_coefficients(self) -> Optional[dict]:
elif '/dG[' in header:
sa_type = 'thermo'
if sa_type is not None:
observable_label = header.split('[')[1].split(']')[0]
observable_label = get_observable_label_from_header(header)
observable = get_species_by_label(observable_label, self.rmg_model.reaction_model.core.species)
if observable is None:
self.logger.error(f'Could not identify observable species for label: {observable_label}')
observable_label = observable.to_chemkin()
if observable_label not in sa_dict[sa_type].keys():
sa_dict[sa_type][observable_label] = dict()
# parameter extraction examples:
# for species get 'C2H4(8)' from `dln[ethane(1)]/dG[C2H4(8)]`
# for reaction, get 8 from `dln[ethane(1)]/dln[k8]: H(6)+ethane(1)=H2(12)+C2H5(5)`
parameter = header.split('[')[2].split(']')[0]
parameter = get_parameter_from_header(header)
if sa_type == 'kinetics':
parameter = parameter[1:]
parameter = chem_to_rmg_rxn_index_map[int(parameter)] \
Expand Down
42 changes: 42 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,45 @@ def test_get_chem_to_rmg_rxn_index_map():
39: 36, 40: 37, 41: 38, 42: 39, 43: 40, 44: 41, 45: 41, 46: 42, 47: 43, 48: 44, 49: 45, 50: 46,
51: 47, 52: 48, 53: 49, 54: 50, 55: 51, 56: 52, 57: 53, 58: 54, 59: 55, 60: 56, 61: 57, 62: 58,
63: 59, 64: 60, 65: 61}


def test_get_observable_label_from_header():
"""
Test that the `get_observable_label_from_header` function correctly parses the header of the RMG simulation csv file
to obtain the observable labels.
"""
label = common.get_observable_label_from_header('dln[H(15)]/dln[k2]: O2(14)+H(15)(+M)<=>HO2(17)(+M)')
assert label == 'H(15)'

label = common.get_observable_label_from_header('dln[C2H4(12)]/dln[k16]: C8H17(11)<=>C8H17(5)')
assert label == 'C2H4(12)'

label = common.get_observable_label_from_header('dln[H(15)]/dG[CC(C)CO[O](237)]')
assert label == 'H(15)'

label = common.get_observable_label_from_header('dln[ethane(1)]/dln[k8]: H(6)+ethane(1)=H2(12)+C2H5(5)')
assert label == 'ethane(1)'


def test_get_parameter_from_header():
"""
Test that the `get_parameter_from_header` function correctly parses the header of the RMG simulation csv file
to obtain the parameter labels.
"""
label = common.get_parameter_from_header('dln[H(15)]/dln[k2]: O2(14)+H(15)(+M)<=>HO2(17)(+M)')
assert label == 'k2'

label = common.get_parameter_from_header('dln[C2H4(12)]/dln[k16]: C8H17(11)<=>C8H17(5)')
assert label == 'k16'

label = common.get_parameter_from_header('dln[H(15)]/dG[t-C4H9(60)]')
assert label == 't-C4H9(60)'

label = common.get_parameter_from_header('dln[H(15)]/dG[CC(C)CO[O](237)]')
assert label == 'CC(C)CO[O](237)'

label = common.get_parameter_from_header('dln[ethane(1)]/dG[C2H4(8)]')
assert label == 'C2H4(8)'

label = common.get_parameter_from_header('dln[ethane(1)]/dln[k8]: H(6)+ethane(1)=H2(12)+C2H5(5)')
assert label == 'k8'

0 comments on commit 56a236c

Please sign in to comment.