Skip to content

Commit

Permalink
Fix numerical instability in complex-valued TDHF diagonalization (pys…
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm authored Dec 17, 2024
1 parent cc2d99a commit cb55121
Showing 1 changed file with 69 additions and 42 deletions.
111 changes: 69 additions & 42 deletions gpu4pyscf/tdscf/_lr_eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def eig(aop, x0, precond, tol_residual=1e-5, nroots=1, x0sym=None, pick=None,
fresh_start = True
for icyc in range(max_cycle):
if fresh_start:
vlast = None
conv_last = conv = np.zeros(nroots, dtype=bool)
xs = np.zeros((0, x0_size))
ax = np.zeros((0, x0_size))
row1 = 0
Expand Down Expand Up @@ -400,7 +402,7 @@ def eig(aop, x0, precond, tol_residual=1e-5, nroots=1, x0sym=None, pick=None,

w, e, elast = w[:space_inc], w[:nroots], e
v = v[:,:space_inc]
if not fresh_start:
if vlast is not None:
elast, conv_last = _sort_elast(elast, conv, vlast, v[:,:nroots], log)
vlast = v[:,:nroots]

Expand Down Expand Up @@ -447,22 +449,30 @@ def eig(aop, x0, precond, tol_residual=1e-5, nroots=1, x0sym=None, pick=None,
xt[:,:half_size] -= c.T.dot(xs[:,half_size:].conj())
xt[:,half_size:] -= c.T.dot(xs[:,:half_size].conj())

if x0sym is None:
xt = _symmetric_orth(xt)
else:
xt_orth = []
xt_orth_ir = []
for ir in set(xt_ir):
idx = np.where(xt_ir == ir)[0]
xt_sub = _symmetric_orth(xt[idx])
xt_orth.append(xt_sub)
xt_orth_ir.append([ir] * len(xt_sub))
if xt_orth:
xt = np.vstack(xt_orth)
xs_ir = np.hstack([xs_ir, *xt_orth_ir])
# Remove quasi linearly dependent bases, as they cause more numerical
# errors in _symmetric_orth
xt_norm = np.linalg.norm(xt, axis=1)
xt_to_keep = (dx_norm > tol_residual) & (xt_norm > max(lindep**.5, tol_residual))
xt = xt[xt_to_keep]
if len(xt) > 0:
xt /= xt_norm[xt_to_keep, None]
if x0sym is None:
xt = _symmetric_orth(xt)
else:
xt = []
xt_orth = xt_orth_ir = xt_sub = None
xt_ir = xt_ir[xt_to_keep]
xt_orth = []
xt_orth_ir = []
for ir in set(xt_ir):
idx = np.where(xt_ir == ir)[0]
xt_sub = _symmetric_orth(xt[idx])
xt_orth.append(xt_sub)
xt_orth_ir.append([ir] * len(xt_sub))
if xt_orth:
xt = np.vstack(xt_orth)
xs_ir = np.hstack([xs_ir, *xt_orth_ir])
else:
xt = []
xt_orth = xt_orth_ir = xt_sub = None

if len(xt) == 0:
log.debug(f'Linear dependency in trial subspace. |r| for each state {dx_norm}')
Expand Down Expand Up @@ -527,7 +537,7 @@ def real_eig(aop, x0, precond, tol_residual=1e-5, nroots=1, x0sym=None, pick=Non
Eigenvectors.
'''

#assert pick is None
assert pick is None
assert callable(precond)

if isinstance(verbose, logger.Logger):
Expand Down Expand Up @@ -789,13 +799,19 @@ def _qr(xs, lindep=1e-14):
return xs[:nv], idx

def _symmetric_orth(xt, lindep=1e-6):
xt = np.asarray(xt)
if xt.dtype == np.float64:
return _symmetric_orth_real(xt, lindep)
else:
return _symmetric_orth_cmplx(xt, lindep)

def _symmetric_orth_real(xt, lindep=1e-6):
'''
Symmetric orthogonalization for xt = {[X, Y]},
and its dual basis vectors {[Y, X]}
'''
xt = np.asarray(xt)
x0_size = xt.shape[1]
s11 = xt.conj().dot(xt.T)
s11 = xt.dot(xt.T)
s21 = _conj_dot(xt, xt)
# Symmetric orthogonalize s, where
# s = [[s11, s21.conj().T],
Expand All @@ -813,15 +829,9 @@ def _symmetric_orth(xt, lindep=1e-6):
n = csc.shape[0]
for i in range(n):
_s21 = csc[i:,i:]
if _s21.dtype == np.float64:
# s21 is symmetric for real vectors
w, u = np.linalg.eigh(_s21)
mask = 1 - abs(w) > lindep
else:
# svd(s[:n,n:]) => svd(_s21.conj().T) => u, w
w2, u = np.linalg.eigh(_s21.conj().T.dot(_s21))
mask = 1 - w2**.5 > lindep
w = np.einsum('pi,pi->i', u.conj(), _s21.dot(u))
# s21 is symmetric for real vectors
w, u = np.linalg.eigh(_s21)
mask = 1 - abs(w) > lindep
if np.any(mask):
c = c[:,i:]
break
Expand All @@ -836,22 +846,16 @@ def _symmetric_orth(xt, lindep=1e-6):
e, c = np.linalg.eigh(c_orth.T.dot(s11).dot(c_orth))
c *= e**-.5
c_orth = c_orth.dot(c)
if s21.dtype == np.float64:
csc = c_orth.T.dot(s21).dot(c_orth)
w, u = np.linalg.eigh(csc)
c_orth = c_orth.dot(u)
else:
sc = s21.dot(c_orth)
w2, u = np.linalg.eigh(sc.conj().T.dot(sc))
c_orth = c_orth.dot(u)
w = np.einsum('pi,pi->i', c_orth.conj(), sc.dot(u))
csc = c_orth.T.dot(s21).dot(c_orth)
w, u = np.linalg.eigh(csc)
c_orth = c_orth.dot(u)

# Symmetric diagonalize
# [1 w] => c = [a b]
# [w 1] [b a]
# [1 w.conj()] => c = [a b]
# [w 1 ] [b a]
# where
# a = ((1+w)**-.5 + (1-w)**-.5)/2
# b = ((1+w)**-.5 - (1-w)**-.5)/2
# b = (phase*(1+w)**-.5 - phase*(1-w)**-.5)/2
a1 = (1 + w)**-.5
a2 = (1 - w)**-.5
a = (a1 + a2) / 2
Expand All @@ -860,8 +864,31 @@ def _symmetric_orth(xt, lindep=1e-6):
m = xt.shape[1] // 2
x_orth = (c_orth * a).T.dot(xt)
# Contribution from the conjugated basis
x_orth[:,:m] += (c_orth * b).T.dot(xt[:,m:].conj())
x_orth[:,m:] += (c_orth * b).T.dot(xt[:,:m].conj())
x_orth[:,:m] += (c_orth * b).T.dot(xt[:,m:])
x_orth[:,m:] += (c_orth * b).T.dot(xt[:,:m])
return x_orth

def _symmetric_orth_cmplx(xt, lindep=1e-6):
n, m = xt.shape
if n == 0:
raise RuntimeError('Linear dependency in trial bases')
m = m // 2
# The conjugated basis np.hstack([xt[:,m:], xt[:,:m]]).conj()
s11 = xt.conj().dot(xt.T)
s21 = _conj_dot(xt, xt)
s = np.block([[s11, s21.conj().T],
[s21, s11.conj() ]])
e, c = scipy.linalg.eigh(s)
if e[0] < lindep:
if n == 1:
return xt
return _symmetric_orth_cmplx(xt[:-1], lindep)

c_orth = (c * e**-.5).dot(c[:n].conj().T)
x_orth = c_orth[:n].T.dot(xt)
# Contribution from the conjugated basis
x_orth[:,:m] += c_orth[n:].T.dot(xt[:,m:].conj())
x_orth[:,m:] += c_orth[n:].T.dot(xt[:,:m].conj())
return x_orth

def _sym_dot(V, U1, m0, m1):
Expand Down

0 comments on commit cb55121

Please sign in to comment.