Skip to content

Commit

Permalink
In case a generally contracted shell can be split into independent bl…
Browse files Browse the repository at this point in the history
…ocks, do so.
  • Loading branch information
susilehtola committed Feb 26, 2021
1 parent 32721c8 commit 7bb517b
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
3 changes: 3 additions & 0 deletions basis_set_exchange/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def get_basis(name,
basis_dict = manip.optimize_general(basis_dict, False)
needs_pruning = True

# Split any blocked contractions
basis_dict = manip.split_blocked_contractions(basis_dict, False)

# uncontract_segmented implies uncontract_general
if uncontract_segmented:
basis_dict = manip.uncontract_segmented(basis_dict, False)
Expand Down
96 changes: 96 additions & 0 deletions basis_set_exchange/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,99 @@ def truhlar_calendarize(basis, month, use_copy=True):

basis = prune_basis(basis, False)
return basis


def split_blocked_contractions(basis, use_copy=True):
'''Checks if the contraction coefficients in a general contraction can be made block diagonal and thereby split into two or more distinct shells
Parameters
----------
basis : dict
Basis set dictionary to work with
use_copy: bool
If True, the input basis set is not modified.
'''

if use_copy:
basis = copy.deepcopy(basis)

for eldata in basis['elements'].values():

if 'electron_shells' not in eldata:
continue

orig_shells = eldata['electron_shells']
new_shells = []
for sh in orig_shells:
coefficients = sh['coefficients']
ncontr = len(coefficients)
nam = len(sh['angular_momentum'])
# Skip sp shells and shells with only one general contraction
if nam > 1 or ncontr == 1:
new_shells.append(sh)
continue

exponents = sh['exponents']
nprim = len(exponents)

# Figure out which contractions share primitives between them
shared_primitives = [[False for _ in range(ncontr)] for _ in range(ncontr)]
for icontr in range(ncontr):
for jcontr in range(icontr + 1):
for iprim in range(nprim):
if float(coefficients[iprim][icontr]) != 0.0 and float(coefficients[iprim][jcontr]) != 0.0:
shared_primitives[icontr][jcontr] = True
shared_primitives[jcontr][icontr] = True
break

# Which contractions have been processed
contraction_processed = [False for _ in range(ncontr)]

# Indices of the contractions that are coupled
blocks = []
for icontr in range(ncontr):
if not contraction_processed[icontr]:
# List of contracted functions in the block
block = [icontr]
contraction_processed[icontr] = True

# Form the list
for jcontr in range(icontr + 1, ncontr):
# No need to analyze functions that have already been processed
if contraction_processed[jcontr]:
continue
if shared_primitives[icontr][jcontr]:
block.append(jcontr)
contraction_processed[jcontr] = True
blocks.append(block)

# Do we need to do anything?
if len(blocks) == 1:
# All functions are in a single block; we keep the shell as it is
new_shells.append(sh)
continue

# Create new shells
for block in blocks:
# Identify the used primitives
used_primitives = []
for iprim in range(nprim):
for icontr in block:
if float(coefficients[iprim][icontr]) != 0.0:
if iprim not in used_primitives:
used_primitives.append(iprim)
continue

# Form submatrices
reduced_exponents = [exponents[p] for p in used_primitives]
reduced_coefficients = [[coefficients[b][p] for p in used_primitives] for b in block]
redsh = sh.copy()
redsh['exponents'] = reduced_exponents
redsh['coefficients'] = reduced_coefficients
new_shells.append(redsh)

# Replace the shells in the basis
eldata['electron_shells'] = new_shells

return basis
4 changes: 4 additions & 0 deletions basis_set_exchange/readers/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..skel import create_skel
from ..validator import validate_data
from ..compose import _whole_basis_types
from ..manip import split_blocked_contractions
from .turbomole import read_turbomole
from .g94 import read_g94
from .nwchem import read_nwchem
Expand Down Expand Up @@ -109,6 +110,9 @@ def read_formatted_basis_str(basis_str, basis_fmt, validate=False, as_component=
if validate:
validate_data(bs_type, data)

# Split any blocked contractions
data = split_blocked_contractions(data, False)

return data


Expand Down

0 comments on commit 7bb517b

Please sign in to comment.