diff --git a/basis_set_exchange/api.py b/basis_set_exchange/api.py index c4cfec5c6..8f656a720 100644 --- a/basis_set_exchange/api.py +++ b/basis_set_exchange/api.py @@ -232,6 +232,9 @@ def get_basis(name, basis_dict = manip.optimize_general(basis_dict, False) needs_pruning = True + # Split any blocked contractions + basis_dict = manip.split_blocked_contractions(basis_dict, False) + # uncontract_segmented implies uncontract_general if uncontract_segmented: basis_dict = manip.uncontract_segmented(basis_dict, False) diff --git a/basis_set_exchange/manip.py b/basis_set_exchange/manip.py index 3ce0f9df6..65795a846 100644 --- a/basis_set_exchange/manip.py +++ b/basis_set_exchange/manip.py @@ -763,3 +763,99 @@ def truhlar_calendarize(basis, month, use_copy=True): basis = prune_basis(basis, False) return basis + + +def split_blocked_contractions(basis, use_copy=True): + '''Checks if the contraction coefficients in a general contraction can be made block diagonal and thereby split into two or more distinct shells + + + Parameters + ---------- + basis : dict + Basis set dictionary to work with + use_copy: bool + If True, the input basis set is not modified. + ''' + + if use_copy: + basis = copy.deepcopy(basis) + + for eldata in basis['elements'].values(): + + if 'electron_shells' not in eldata: + continue + + orig_shells = eldata['electron_shells'] + new_shells = [] + for sh in orig_shells: + coefficients = sh['coefficients'] + ncontr = len(coefficients) + nam = len(sh['angular_momentum']) + # Skip sp shells and shells with only one general contraction + if nam > 1 or ncontr == 1: + new_shells.append(sh) + continue + + exponents = sh['exponents'] + nprim = len(exponents) + + # Figure out which contractions share primitives between them + shared_primitives = [[False for _ in range(ncontr)] for _ in range(ncontr)] + for icontr in range(ncontr): + for jcontr in range(icontr + 1): + for iprim in range(nprim): + if float(coefficients[icontr][iprim]) != 0.0 and float(coefficients[jcontr][iprim]) != 0.0: + shared_primitives[icontr][jcontr] = True + shared_primitives[jcontr][icontr] = True + break + + # Which contractions have been processed + contraction_processed = [False for _ in range(ncontr)] + + # Indices of the contractions that are coupled + blocks = [] + for icontr in range(ncontr): + if not contraction_processed[icontr]: + # List of contracted functions in the block + block = [icontr] + contraction_processed[icontr] = True + + # Form the list + for jcontr in range(icontr + 1, ncontr): + # No need to analyze functions that have already been processed + if contraction_processed[jcontr]: + continue + if shared_primitives[icontr][jcontr]: + block.append(jcontr) + contraction_processed[jcontr] = True + blocks.append(block) + + # Do we need to do anything? + if len(blocks) == 1: + # All functions are in a single block; we keep the shell as it is + new_shells.append(sh) + continue + + # Create new shells + for block in blocks: + # Identify the used primitives + used_primitives = [] + for iprim in range(nprim): + for icontr in block: + if float(coefficients[icontr][iprim]) != 0.0: + if iprim not in used_primitives: + used_primitives.append(iprim) + continue + + # Form submatrices + reduced_exponents = [exponents[p] for p in used_primitives] + reduced_coefficients = [[coefficients[b][p] for p in used_primitives] for b in block] + redsh = sh.copy() + redsh['exponents'] = reduced_exponents + redsh['coefficients'] = reduced_coefficients + new_shells.append(redsh) + + # Replace the shells in the basis + eldata['electron_shells'] = new_shells + + return basis diff --git a/basis_set_exchange/readers/read.py b/basis_set_exchange/readers/read.py index e3c40c6a5..4868af2b2 100644 --- a/basis_set_exchange/readers/read.py +++ b/basis_set_exchange/readers/read.py @@ -7,6 +7,7 @@ from ..skel import create_skel from ..validator import validate_data from ..compose import _whole_basis_types +from ..manip import split_blocked_contractions from .turbomole import read_turbomole from .g94 import read_g94 from .nwchem import read_nwchem @@ -109,6 +110,9 @@ def read_formatted_basis_str(basis_str, basis_fmt, validate=False, as_component= if validate: validate_data(bs_type, data) + # Split any blocked contractions + data = split_blocked_contractions(data, False) + return data