Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split general contractions into blocks whenever possible #199

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions basis_set_exchange/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def get_basis(name,
basis_dict = manip.optimize_general(basis_dict, False)
needs_pruning = True

# Split any blocked contractions
basis_dict = manip.split_blocked_contractions(basis_dict, False)

# uncontract_segmented implies uncontract_general
if uncontract_segmented:
basis_dict = manip.uncontract_segmented(basis_dict, False)
Expand Down
96 changes: 96 additions & 0 deletions basis_set_exchange/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,99 @@ def truhlar_calendarize(basis, month, use_copy=True):

basis = prune_basis(basis, False)
return basis


def split_blocked_contractions(basis, use_copy=True):
'''Checks if the contraction coefficients in a general contraction can be made block diagonal and thereby split into two or more distinct shells


Parameters
----------
basis : dict
Basis set dictionary to work with
use_copy: bool
If True, the input basis set is not modified.
'''

if use_copy:
basis = copy.deepcopy(basis)

for eldata in basis['elements'].values():

if 'electron_shells' not in eldata:
continue

orig_shells = eldata['electron_shells']
new_shells = []
for sh in orig_shells:
coefficients = sh['coefficients']
ncontr = len(coefficients)
nam = len(sh['angular_momentum'])
# Skip sp shells and shells with only one general contraction
if nam > 1 or ncontr == 1:
new_shells.append(sh)
continue

exponents = sh['exponents']
nprim = len(exponents)

# Figure out which contractions share primitives between them
shared_primitives = [[False for _ in range(ncontr)] for _ in range(ncontr)]
for icontr in range(ncontr):
for jcontr in range(icontr + 1):
for iprim in range(nprim):
if float(coefficients[icontr][iprim]) != 0.0 and float(coefficients[jcontr][iprim]) != 0.0:
shared_primitives[icontr][jcontr] = True
shared_primitives[jcontr][icontr] = True
break

# Which contractions have been processed
contraction_processed = [False for _ in range(ncontr)]

# Indices of the contractions that are coupled
blocks = []
for icontr in range(ncontr):
if not contraction_processed[icontr]:
# List of contracted functions in the block
block = [icontr]
contraction_processed[icontr] = True

# Form the list
for jcontr in range(icontr + 1, ncontr):
# No need to analyze functions that have already been processed
if contraction_processed[jcontr]:
continue
if shared_primitives[icontr][jcontr]:
block.append(jcontr)
contraction_processed[jcontr] = True
blocks.append(block)

# Do we need to do anything?
if len(blocks) == 1:
# All functions are in a single block; we keep the shell as it is
new_shells.append(sh)
continue

# Create new shells
for block in blocks:
# Identify the used primitives
used_primitives = []
for iprim in range(nprim):
for icontr in block:
if float(coefficients[icontr][iprim]) != 0.0:
if iprim not in used_primitives:
used_primitives.append(iprim)
continue

# Form submatrices
reduced_exponents = [exponents[p] for p in used_primitives]
reduced_coefficients = [[coefficients[b][p] for p in used_primitives] for b in block]
redsh = sh.copy()
redsh['exponents'] = reduced_exponents
redsh['coefficients'] = reduced_coefficients
new_shells.append(redsh)

# Replace the shells in the basis
eldata['electron_shells'] = new_shells

return basis
4 changes: 4 additions & 0 deletions basis_set_exchange/readers/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..skel import create_skel
from ..validator import validate_data
from ..compose import _whole_basis_types
from ..manip import split_blocked_contractions
from .turbomole import read_turbomole
from .g94 import read_g94
from .nwchem import read_nwchem
Expand Down Expand Up @@ -109,6 +110,9 @@ def read_formatted_basis_str(basis_str, basis_fmt, validate=False, as_component=
if validate:
validate_data(bs_type, data)

# Split any blocked contractions
data = split_blocked_contractions(data, False)

return data


Expand Down
6 changes: 4 additions & 2 deletions basis_set_exchange/tests/test_curate_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def test_curate_roundtrip(tmp_path, basis, fmt):
uncontract_spdf = 0

bse_formatted = api.get_basis(basis, fmt=fmt)
bse_dict = api.get_basis(basis, uncontract_general=uncontract_general, make_general=make_general)
bse_dict = manip.uncontract_spdf(bse_dict, uncontract_spdf)
bse_dict = api.get_basis(basis,
uncontract_general=uncontract_general,
make_general=make_general,
uncontract_spdf=uncontract_spdf)

outfile_path = os.path.join(tmp_path, 'roundtrip.txt')
with open(outfile_path, 'w', encoding='utf-8') as outfile:
Expand Down