Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
fb5d3d7
new_post_v0
QG-phy Dec 22, 2025
937ca2b
new post v0.2
QG-phy Dec 23, 2025
b2db559
feat: add support for built-in polynomial checkpoints
QG-phy Dec 23, 2025
7651d63
feat: add GUI detection and refactor calculator interface
QG-phy Dec 23, 2025
a184318
feat: Implement unified post-processing for band structure and DOS, i…
QG-phy Dec 23, 2025
c5916eb
feat: add get_hk method to HamiltonianCalculator interface
QG-phy Dec 23, 2025
8eacb12
feat: add eigenstates calculation and DOS analysis capabilities
QG-phy Dec 23, 2025
4191840
refactor: improve DOS calculation interface and configuration
QG-phy Dec 23, 2025
debd3b0
refactor(test): replace mock tests with integration tests using real …
QG-phy Dec 23, 2025
5807758
feat: add unified post-processing tutorial notebook
QG-phy Dec 23, 2025
5b21708
feat: ignore band structure and DOS plot files in .gitignore
QG-phy Dec 23, 2025
937a9c3
fix: ensure dtype consistency in tensor product operations
QG-phy Dec 23, 2025
08121a7
feat: add band_data initialization and simplify tensor type check
QG-phy Dec 23, 2025
6b26328
feat: add PythTB export tutorial notebook
QG-phy Dec 23, 2025
e5953c5
feat: implement Fermi level calculation and integration
QG-phy Dec 24, 2025
9548d77
feat: add export functionality for TBSystem to third-party formats
QG-phy Dec 24, 2025
6d893be
feat: support dict input for export interfaces
QG-phy Dec 24, 2025
7827695
feat: add PythTB-Wannier postprocessing example and rename existing n…
QG-phy Dec 24, 2025
77d8d19
feat: add Wannier90 integration and export functionality
QG-phy Dec 24, 2025
964109c
test: update export tests to match implementation changes
QG-phy Dec 24, 2025
e8de838
refactor: simplify HR2HK onsite block construction and SOC handling
QG-phy Dec 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
test*.ipynb
*_bands.png
*_dos.png
examples/**/*centres.xyz
examples/**/*.win
**/processed*/*
Expand Down
6 changes: 5 additions & 1 deletion dptb/nn/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from dptb.utils.tools import j_must_have, j_loader
import copy

import os
log = logging.getLogger(__name__)

def build_model(
Expand Down Expand Up @@ -43,6 +43,10 @@ def build_model(

# load the model_options and common_options from checkpoint if not provided
if not from_scratch:
if checkpoint in ['poly2', 'poly4']:
modelname = f'base_{checkpoint}.pth'
checkpoint = os.path.join(os.path.dirname(__file__), 'dftb', modelname)

if checkpoint.split(".")[-1] == "json":
ckptconfig = j_loader(checkpoint)
else:
Expand Down
83 changes: 32 additions & 51 deletions dptb/nn/hr2hk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@


class HR2HK(torch.nn.Module):
# this is actually a general FFT from real space hamiltonian/overlap to kspace hamiltonian/overlap
# the more correct name should be HSR2HSK. But to keep consistent with previous naming convention, we still use HR2HK here.
def __init__(
self,
basis: Dict[str, Union[str, list]]=None,
Expand Down Expand Up @@ -45,9 +47,6 @@ def __init__(
self.node_field = node_field
self.out_field = out_field




def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:

# construct bond wise hamiltonian block from obital pair wise node/edge features
Expand All @@ -67,15 +66,12 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
soc = data.get(AtomicDataDict.NODE_SOC_SWITCH_KEY, False)
if isinstance(soc, torch.Tensor):
soc = soc.all()
if soc:
# if self.overlap:
# print("Overlap for SOC is realized by kronecker product.")

if soc:
# this soc only support sktb.
orbpair_soc = data[AtomicDataDict.NODE_SOC_KEY]
soc_upup_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device)
soc_updn_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device)


ist = 0
for i,iorb in enumerate(self.idp.full_basis):
jst = 0
Expand All @@ -92,45 +88,48 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:

if i <= j:
bondwise_hopping[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_hopping[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)


# constructing onsite blocks
if self.overlap:
# if iorb == jorb:
# onsite_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = factor * torch.eye(2*li+1, dtype=self.dtype, device=self.device).reshape(1, 2*li+1, 2*lj+1).repeat(onsite_block.shape[0], 1, 1)
if i <= j:
onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)

if soc and i == j:
soc_updn_tmp = orbpair_soc[:, self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1))
# j==i -> 2*lj+1 == 2*li+1
soc_upup_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, :2*lj+1]
soc_updn_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, 2*lj+1:]
else:
if i <= j:
onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)

if soc and i==j:
if soc and i==j and not self.overlap:
# For now, The SOC part is only added to Hamiltonian, not overlap matrix.
# For now, The SOC only has onsite contribution.
soc_updn_tmp = orbpair_soc[:,self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1))
soc_upup_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,:2*lj+1]
soc_updn_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,2*lj+1:]

# constructing onsite blocks
#if self.overlap:
# # if iorb == jorb:
# # onsite_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = factor * torch.eye(2*li+1, dtype=self.dtype, device=self.device).reshape(1, 2*li+1, 2*lj+1).repeat(onsite_block.shape[0], 1, 1)
# if i <= j:
# onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
# if soc and i == j:
# soc_updn_tmp = orbpair_soc[:, self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1))
# # j==i -> 2*lj+1 == 2*li+1
# soc_upup_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, :2*lj+1]
# soc_updn_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, 2*lj+1:]
#else:
# if i <= j:
# onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
#
# if soc and i==j:
# soc_updn_tmp = orbpair_soc[:,self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1))
# soc_upup_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,:2*lj+1]
# soc_updn_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,2*lj+1:]

jst += 2*lj+1
ist += 2*li+1
self.onsite_block = onsite_block
self.bondwise_hopping = bondwise_hopping
if soc:
# 先保存已有的
if soc and not self.overlap:
# store for later use
# for now, soc only contribute to Hamiltonain, thus for overlap not store soc parts.
self.soc_upup_block = soc_upup_block
self.soc_updn_block = soc_updn_block

# R2K procedure can be done for all kpoint at once.
all_norb = self.idp.atom_norb[data[AtomicDataDict.ATOM_TYPE_KEY]].sum()
block = torch.zeros(kpoints.shape[0], all_norb, all_norb, dtype=self.ctype, device=self.device)
# block = torch.complex(block, torch.zeros_like(block))
# if data[AtomicDataDict.NODE_SOC_SWITCH_KEY].all():
# block_uu = torch.zeros(data[AtomicDataDict.KPOINT_KEY].shape[0], all_norb, all_norb, dtype=self.ctype, device=self.device)
# block_ud = torch.zeros(data[AtomicDataDict.KPOINT_KEY].shape[0], all_norb, all_norb, dtype=self.ctype, device=self.device)
atom_id_to_indices = {}
ist = 0
for i, oblock in enumerate(onsite_block):
Expand All @@ -139,21 +138,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
block[:,ist:ist+masked_oblock.shape[0],ist:ist+masked_oblock.shape[1]] = masked_oblock.squeeze(0)
atom_id_to_indices[i] = slice(ist, ist+masked_oblock.shape[0])
ist += masked_oblock.shape[0]

# if data[AtomicDataDict.NODE_SOC_SWITCH_KEY].all():
# ist = 0
# for i, soc_block in enumerate(soc_upup_block):
# mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]]
# masked_soc_block = soc_block[mask][:,mask]
# block_uu[:,ist:ist+masked_soc_block.shape[0],ist:ist+masked_soc_block.shape[1]] = masked_soc_block.squeeze(0)
# ist += masked_soc_block.shape[0]
# ist = 0
# for i, soc_block in enumerate(soc_updn_block):
# mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]]
# masked_soc_block = soc_block[mask][:,mask]
# block_ud[:,ist:ist+masked_soc_block.shape[0],ist:ist+masked_soc_block.shape[1]] = masked_soc_block.squeeze(0)
# ist += masked_soc_block.shape[0]


for i, hblock in enumerate(bondwise_hopping):
iatom = data[AtomicDataDict.EDGE_INDEX_KEY][0][i]
jatom = data[AtomicDataDict.EDGE_INDEX_KEY][1][i]
Expand Down Expand Up @@ -182,10 +167,6 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
data[self.out_field] = S_soc
else:
HK_SOC = torch.zeros(kpoints.shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device)
#HK_SOC[:,:all_norb,:all_norb] = block + block_uu
#HK_SOC[:,:all_norb,all_norb:] = block_ud
#HK_SOC[:,all_norb:,:all_norb] = block_ud.conj()
#HK_SOC[:,all_norb:,all_norb:] = block + block_uu.conj()
ist = 0
assert len(soc_upup_block) == len(soc_updn_block)
for i in range(len(soc_upup_block)):
Expand Down
4 changes: 2 additions & 2 deletions dptb/nn/tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def batch_wigner_D(l_max, alpha, beta, gamma, _Jd):
D_total = sum(dims)

# Construct block-diagonal J matrix
J_full_small = torch.zeros(D_total, D_total, device=device)
J_full_small = torch.zeros(D_total, D_total, device=device, dtype=alpha.dtype)
for l in range(l_max + 1):
start = offsets[l]
J_full_small[start:start+2*l+1, start:start+2*l+1] = _Jd[l]
J_full_small[start:start+2*l+1, start:start+2*l+1] = _Jd[l].to(dtype=alpha.dtype)

J_full = J_full_small.unsqueeze(0).expand(N, -1, -1)
angle_stack = torch.cat([alpha, beta, gamma], dim=0)
Expand Down
61 changes: 60 additions & 1 deletion dptb/postprocess/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import Union, Optional
from copy import deepcopy
from ase.io import read

import sys
from dptb.data import AtomicData, AtomicDataDict, block_to_feature
from dptb.utils.argcheck import get_cutoffs_from_model_options
import matplotlib.pyplot as plt

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,3 +111,61 @@ def load_data_for_model(
# Actually, ElecStruCal.get_data does NOT run self.model(data). It runs self.model.idp(data).
# self.get_eigs runs self.model(data).
return data_obj

def is_gui_available():
"""
Detect if GUI display is available for matplotlib.

Returns:
bool: True if GUI is available, False otherwise
"""
try:
# Check if we're in a Jupyter notebook environment
if 'ipykernel' in sys.modules or 'IPython' in sys.modules:
# In Jupyter, we can typically show plots
return True

# Check DISPLAY environment variable (Unix-like systems)
if sys.platform.startswith('linux') or sys.platform.startswith('darwin'):
display = os.environ.get('DISPLAY')
if display is None:
return False
Comment on lines +129 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

macOS does not require DISPLAY for GUI applications.

On macOS (darwin), GUI applications use Quartz/Cocoa, not X11, so DISPLAY is typically unset even when a GUI is available. This check will incorrectly return False on most macOS systems with a working display.

🔎 Proposed fix
         # Check DISPLAY environment variable (Unix-like systems)
-        if sys.platform.startswith('linux') or sys.platform.startswith('darwin'):
+        if sys.platform.startswith('linux'):
             display = os.environ.get('DISPLAY')
             if display is None:
                 return False
🤖 Prompt for AI Agents
In dptb/postprocess/common.py around lines 129-132, the current check treats
both linux and darwin the same and returns False when DISPLAY is unset, but
macOS (darwin) does not require DISPLAY; change the logic so only Linux checks
DISPLAY: remove darwin from the platform check (or explicitly check for linux
only), keep the existing behavior for Linux (return False when DISPLAY is None),
and ensure darwin falls through to the normal GUI-available path instead of
returning False.


# Try to get the current matplotlib backend
backend = plt.get_backend().lower()

# Non-interactive backends
non_gui_backends = ['agg', 'pdf', 'ps', 'svg', 'cairo', 'gdk', 'template']
if any(non_gui in backend for non_gui in non_gui_backends):
return False

# Try to create a test figure to see if it works
# This is a more robust check
try:
import matplotlib
# Save current backend
current_backend = matplotlib.get_backend()

# Try to use a GUI backend if not already
if 'agg' in backend.lower():
# Try common GUI backends
for test_backend in ['TkAgg', 'Qt5Agg', 'Qt4Agg', 'WXAgg']:
try:
matplotlib.use(test_backend, force=True)
test_fig = plt.figure()
plt.close(test_fig)
matplotlib.use(current_backend, force=True)
return True
except:
continue
return False
else:
# Current backend seems to be GUI-based
return True

except Exception:
return False

except Exception:
# If any error occurs, assume no GUI is available
return False
30 changes: 16 additions & 14 deletions dptb/postprocess/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,19 @@ def __init__(
self.model = model
self.model.eval()

def _get_data_and_blocks(self, data: Union[AtomicData, ase.Atoms, str], AtomicData_options: dict = {}, e_fermi: float = 0.0):
def _get_data_and_blocks(self, data: Union[AtomicData, ase.Atoms, dict, str], AtomicData_options: dict = {}, e_fermi: float = 0.0):
# Check for overlap
if getattr(self.model, "overlap", False):
raise ValueError("Export to Wannier90 format does not support models with non-orthogonal bases (overlap). Please use an orthogonal model.")

# Use centralized data loading
data = load_data_for_model(
data=data,
model=self.model,
device=self.device,
AtomicData_options=AtomicData_options
)
if not isinstance(data,dict):
data = load_data_for_model(
data=data,
model=self.model,
device=self.device,
AtomicData_options=AtomicData_options
)
self.positions = data[AtomicDataDict.POSITIONS_KEY].numpy()
self.scaled_pos = data[AtomicDataDict.POSITIONS_KEY] @ data[AtomicDataDict.CELL_KEY].inverse()
self.scaled_pos = self.scaled_pos.numpy()
Expand Down Expand Up @@ -280,20 +281,21 @@ def __init__(
log.exception("PythTB not installed. Run `pip install pythtb`")
raise

def get_model(self, data: Union[AtomicData, ase.Atoms, str], AtomicData_options: dict = {}, e_fermi: float = 0.0):
def get_model(self, data: Union[AtomicData, ase.Atoms, dict, str], AtomicData_options: dict = {}, e_fermi: float = 0.0):
from pythtb import tb_model

# Check for overlap
if getattr(self.model, "overlap", False):
raise ValueError("Export to PythTB does not support models with non-orthogonal bases (overlap). Please use an orthogonal model.")

# Use centralized data loading
data = load_data_for_model(
data=data,
model=self.model,
device=self.device,
AtomicData_options=AtomicData_options
)
if not isinstance(data, dict):
data = load_data_for_model(
data=data,
model=self.model,
device=self.device,
AtomicData_options=AtomicData_options
)

# Run model forward
data = self.model(data)
Expand Down
4 changes: 4 additions & 0 deletions dptb/postprocess/unified/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .system import TBSystem
from .calculator import HamiltonianCalculator, DeePTBAdapter

__all__ = ["TBSystem", "HamiltonianCalculator", "DeePTBAdapter"]
Loading