Skip to content

Commit de67ea7

Browse files
committed
chore: lints
1 parent c870975 commit de67ea7

File tree

5 files changed

+72
-85
lines changed

5 files changed

+72
-85
lines changed
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,2 @@
11
from .analyzer import LatticeMaterialAnalyzer
2-
from .helpers import (
3-
get_lattice_type,
4-
get_material_with_conventional_lattice,
5-
get_material_with_primitive_lattice,
6-
)
2+
from .helpers import get_lattice_type, get_material_with_conventional_lattice, get_material_with_primitive_lattice

src/py/mat3ra/made/tools/analyze/lattice/analyzer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,16 @@ def detect_lattice_type(self, tolerance=0.1, angle_tolerance=5) -> LatticeTypeEn
3535
# First try vector-based detection for more accurate primitive cell identification
3636
try:
3737
vector_based_type = detect_lattice_type_from_vectors(
38-
self.material.lattice.vector_arrays,
39-
tolerance=tolerance,
40-
angle_tolerance=angle_tolerance
38+
self.material.lattice.vector_arrays, tolerance=tolerance, angle_tolerance=angle_tolerance
4139
)
42-
40+
4341
# If vector-based detection gives a specific result (not TRI), use it
4442
if vector_based_type != LatticeTypeEnum.TRI:
4543
return vector_based_type
4644
except Exception:
4745
# Fall back to pymatgen if vector-based detection fails
4846
pass
49-
47+
5048
# Fallback to pymatgen spacegroup analyzer
5149
lattice_type_str = PymatgenSpacegroupAnalyzer(
5250
to_pymatgen(self.material),

src/py/mat3ra/made/tools/analyze/lattice/utils.py

Lines changed: 65 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,62 +13,60 @@
1313

1414

1515
def detect_lattice_type_from_vectors(
16-
vectors: List[List[float]],
17-
tolerance: float = 0.1,
18-
angle_tolerance: float = 5.0
16+
vectors: List[List[float]], tolerance: float = 0.1, angle_tolerance: float = 5.0
1917
) -> LatticeTypeEnum:
2018
"""
2119
Detect lattice type from lattice vectors using pattern matching.
22-
20+
2321
Based on the primitive cell definitions from Setyawan & Curtarolo (2010).
24-
22+
2523
Args:
2624
vectors: 3x3 array of lattice vectors
2725
tolerance: Tolerance for length comparisons (Angstroms)
2826
angle_tolerance: Tolerance for angle comparisons (degrees)
29-
27+
3028
Returns:
3129
Detected lattice type
3230
"""
3331
if len(vectors) != 3 or any(len(v) != 3 for v in vectors):
3432
raise ValueError("Expected 3x3 array of lattice vectors")
35-
33+
3634
# Convert to numpy for easier calculations
3735
v = np.array(vectors)
38-
36+
3937
# Calculate lengths
4038
a = np.linalg.norm(v[0])
4139
b = np.linalg.norm(v[1])
4240
c = np.linalg.norm(v[2])
43-
41+
4442
# Calculate angles between vectors (in degrees)
4543
def angle_between_vectors(v1, v2):
4644
cos_angle = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
4745
cos_angle = np.clip(cos_angle, -1.0, 1.0) # Handle numerical errors
4846
return np.degrees(np.arccos(cos_angle))
49-
47+
5048
alpha = angle_between_vectors(v[1], v[2]) # angle between b and c
51-
beta = angle_between_vectors(v[0], v[2]) # angle between a and c
49+
beta = angle_between_vectors(v[0], v[2]) # angle between a and c
5250
gamma = angle_between_vectors(v[0], v[1]) # angle between a and b
53-
51+
5452
# Helper functions for comparisons
5553
def is_equal(val1, val2, tol=tolerance):
5654
return abs(val1 - val2) <= tol
57-
55+
5856
def is_angle_equal(angle1, angle2, tol=angle_tolerance):
5957
return abs(angle1 - angle2) <= tol
60-
58+
6159
def is_right_angle(angle, tol=angle_tolerance):
6260
return is_angle_equal(angle, 90.0, tol)
63-
61+
6462
def is_120_angle(angle, tol=angle_tolerance):
6563
return is_angle_equal(angle, 120.0, tol)
66-
64+
6765
def is_60_angle(angle, tol=angle_tolerance):
6866
return is_angle_equal(angle, 60.0, tol)
69-
67+
7068
# Check for specific lattice patterns
71-
69+
7270
# 1. Cubic patterns (CUB, FCC, BCC)
7371
if is_equal(a, b) and is_equal(b, c):
7472
# All lengths equal
@@ -84,21 +82,21 @@ def is_60_angle(angle, tol=angle_tolerance):
8482
return LatticeTypeEnum.BCC
8583
else:
8684
return LatticeTypeEnum.RHL
87-
85+
8886
# 2. Tetragonal patterns (TET, BCT)
8987
if is_equal(a, b) and not is_equal(a, c):
9088
if is_right_angle(alpha) and is_right_angle(beta) and is_right_angle(gamma):
9189
return LatticeTypeEnum.TET
9290
elif _is_bct_pattern(v, tolerance):
9391
return LatticeTypeEnum.BCT
94-
92+
9593
# 3. Hexagonal pattern
9694
if is_equal(a, b) and not is_equal(a, c):
9795
if is_right_angle(alpha) and is_right_angle(beta) and is_120_angle(gamma):
9896
return LatticeTypeEnum.HEX
9997
elif _is_hex_pattern(v, tolerance):
10098
return LatticeTypeEnum.HEX
101-
99+
102100
# 4. Orthorhombic patterns (ORC, ORCF, ORCI, ORCC)
103101
if not is_equal(a, b) and not is_equal(b, c) and not is_equal(a, c):
104102
if is_right_angle(alpha) and is_right_angle(beta) and is_right_angle(gamma):
@@ -109,18 +107,18 @@ def is_60_angle(angle, tol=angle_tolerance):
109107
return LatticeTypeEnum.ORCI
110108
elif _is_orcc_pattern(v, tolerance):
111109
return LatticeTypeEnum.ORCC
112-
110+
113111
# 5. Monoclinic patterns (MCL, MCLC)
114112
if is_right_angle(alpha) and is_right_angle(gamma) and not is_right_angle(beta):
115113
if _is_mclc_pattern(v, tolerance):
116114
return LatticeTypeEnum.MCLC
117115
else:
118116
return LatticeTypeEnum.MCL
119-
117+
120118
# 6. Rhombohedral (already checked above in cubic section)
121119
if is_equal(a, b) and is_equal(b, c) and is_angle_equal(alpha, beta) and is_angle_equal(beta, gamma):
122120
return LatticeTypeEnum.RHL
123-
121+
124122
# Default to triclinic if no pattern matches
125123
return LatticeTypeEnum.TRI
126124

@@ -130,14 +128,10 @@ def _is_bcc_pattern(vectors: np.ndarray, tolerance: float) -> bool:
130128
# BCC primitive: [-a/2, a/2, a/2], [a/2, -a/2, a/2], [a/2, a/2, -a/2]
131129
v = vectors
132130
a_est = np.linalg.norm(v[0])
133-
131+
134132
# Expected BCC pattern (normalized)
135-
expected = np.array([
136-
[-0.5, 0.5, 0.5],
137-
[0.5, -0.5, 0.5],
138-
[0.5, 0.5, -0.5]
139-
]) * a_est
140-
133+
expected = np.array([[-0.5, 0.5, 0.5], [0.5, -0.5, 0.5], [0.5, 0.5, -0.5]]) * a_est
134+
141135
# Check if current vectors match expected pattern (allowing for rotations)
142136
return _vectors_match_pattern(v, expected, tolerance)
143137

@@ -149,64 +143,61 @@ def _is_bct_pattern(vectors: np.ndarray, tolerance: float) -> bool:
149143
lengths = [np.linalg.norm(vec) for vec in v]
150144
a_est = min(lengths)
151145
c_est = max(lengths)
152-
153-
expected = np.array([
154-
[-a_est/2, a_est/2, c_est/2],
155-
[a_est/2, -a_est/2, c_est/2],
156-
[a_est/2, a_est/2, -c_est/2]
157-
])
158-
146+
147+
expected = np.array(
148+
[[-a_est / 2, a_est / 2, c_est / 2], [a_est / 2, -a_est / 2, c_est / 2], [a_est / 2, a_est / 2, -c_est / 2]]
149+
)
150+
159151
return _vectors_match_pattern(v, expected, tolerance)
160152

161153

162154
def _is_hex_pattern(vectors: np.ndarray, tolerance: float) -> bool:
163155
"""Check if vectors match HEX primitive cell pattern."""
164156
# HEX primitive: [a/2, -a*sqrt(3)/2, 0], [a/2, a*sqrt(3)/2, 0], [0, 0, c]
165157
v = vectors
166-
158+
167159
# Find the vector along z-axis (should be [0, 0, c])
168160
z_vector_idx = None
169161
for i, vec in enumerate(v):
170162
if abs(vec[0]) < tolerance and abs(vec[1]) < tolerance and abs(vec[2]) > tolerance:
171163
z_vector_idx = i
172164
break
173-
165+
174166
if z_vector_idx is None:
175167
return False
176-
177-
c_est = abs(v[z_vector_idx][2])
178-
168+
179169
# Check other two vectors
180170
other_indices = [i for i in range(3) if i != z_vector_idx]
181171
v1, v2 = v[other_indices[0]], v[other_indices[1]]
182-
172+
183173
# They should have equal lengths and specific pattern
184174
if not abs(np.linalg.norm(v1) - np.linalg.norm(v2)) < tolerance:
185175
return False
186-
176+
187177
a_est = np.linalg.norm(v1)
188-
178+
189179
# Check if they match hexagonal pattern
190-
expected_1 = np.array([a_est/2, -a_est*math.sqrt(3)/2, 0])
191-
expected_2 = np.array([a_est/2, a_est*math.sqrt(3)/2, 0])
192-
193-
return (_vector_matches(v1, expected_1, tolerance) and _vector_matches(v2, expected_2, tolerance)) or \
194-
(_vector_matches(v1, expected_2, tolerance) and _vector_matches(v2, expected_1, tolerance))
180+
expected_1 = np.array([a_est / 2, -a_est * math.sqrt(3) / 2, 0])
181+
expected_2 = np.array([a_est / 2, a_est * math.sqrt(3) / 2, 0])
182+
183+
return (_vector_matches(v1, expected_1, tolerance) and _vector_matches(v2, expected_2, tolerance)) or (
184+
_vector_matches(v1, expected_2, tolerance) and _vector_matches(v2, expected_1, tolerance)
185+
)
195186

196187

197188
def _is_orcf_pattern(vectors: np.ndarray, tolerance: float) -> bool:
198189
"""Check if vectors match ORCF primitive cell pattern."""
199190
# ORCF primitive: [0, b/2, c/2], [a/2, 0, c/2], [a/2, b/2, 0]
200191
v = vectors
201-
192+
202193
# Each vector should have one zero component
203194
zero_components = []
204195
for vec in v:
205196
zero_count = sum(1 for x in vec if abs(x) < tolerance)
206197
if zero_count != 1:
207198
return False
208199
zero_components.append([i for i, x in enumerate(vec) if abs(x) < tolerance][0])
209-
200+
210201
# Should have one zero in each dimension
211202
return set(zero_components) == {0, 1, 2}
212203

@@ -216,18 +207,18 @@ def _is_orci_pattern(vectors: np.ndarray, tolerance: float) -> bool:
216207
# ORCI primitive: [-a/2, b/2, c/2], [a/2, -b/2, c/2], [a/2, b/2, -c/2]
217208
# Similar to BCC but with different lengths
218209
v = vectors
219-
210+
220211
# Check if it has the body-centered pattern with different lengths
221212
lengths = [np.linalg.norm(vec) for vec in v]
222213
if len(set(np.round(lengths, 3))) == 1: # All equal lengths -> not ORCI
223214
return False
224-
215+
225216
# Check sign pattern
226217
sign_patterns = []
227218
for vec in v:
228219
signs = [1 if x > tolerance else (-1 if x < -tolerance else 0) for x in vec]
229220
sign_patterns.append(signs)
230-
221+
231222
# Should have body-centered sign pattern
232223
expected_patterns = [[-1, 1, 1], [1, -1, 1], [1, 1, -1]]
233224
return _sign_patterns_match(sign_patterns, expected_patterns)
@@ -237,38 +228,42 @@ def _is_orcc_pattern(vectors: np.ndarray, tolerance: float) -> bool:
237228
"""Check if vectors match ORCC primitive cell pattern."""
238229
# ORCC primitive: [a/2, b/2, 0], [-a/2, b/2, 0], [0, 0, c]
239230
v = vectors
240-
231+
241232
# One vector should be along z-axis
242233
z_vector_count = sum(1 for vec in v if abs(vec[0]) < tolerance and abs(vec[1]) < tolerance)
243234
if z_vector_count != 1:
244235
return False
245-
236+
246237
# Other two should be in xy-plane with specific pattern
247238
xy_vectors = [vec for vec in v if not (abs(vec[0]) < tolerance and abs(vec[1]) < tolerance)]
248239
if len(xy_vectors) != 2:
249240
return False
250-
241+
251242
# Check if they have opposite x-components and same y-components
252243
v1, v2 = xy_vectors
253-
return (abs(v1[0] + v2[0]) < tolerance and abs(v1[1] - v2[1]) < tolerance and
254-
abs(v1[2]) < tolerance and abs(v2[2]) < tolerance)
244+
return (
245+
abs(v1[0] + v2[0]) < tolerance
246+
and abs(v1[1] - v2[1]) < tolerance
247+
and abs(v1[2]) < tolerance
248+
and abs(v2[2]) < tolerance
249+
)
255250

256251

257252
def _is_mclc_pattern(vectors: np.ndarray, tolerance: float) -> bool:
258253
"""Check if vectors match MCLC primitive cell pattern."""
259254
# MCLC primitive: [a/2, b/2, 0], [-a/2, b/2, 0], [0, c*cos(alpha), c*sin(alpha)]
260255
v = vectors
261-
256+
262257
# Similar to ORCC but with one vector at an angle
263258
xy_vectors = []
264259
angled_vectors = []
265-
260+
266261
for vec in v:
267262
if abs(vec[2]) < tolerance:
268263
xy_vectors.append(vec)
269264
else:
270265
angled_vectors.append(vec)
271-
266+
272267
return len(xy_vectors) == 2 and len(angled_vectors) == 1
273268

274269

@@ -277,23 +272,23 @@ def _vectors_match_pattern(v1: np.ndarray, v2: np.ndarray, tolerance: float) ->
277272
# Simple check: compare sorted lengths and angles
278273
lengths1 = sorted([np.linalg.norm(vec) for vec in v1])
279274
lengths2 = sorted([np.linalg.norm(vec) for vec in v2])
280-
275+
281276
if not all(abs(l1 - l2) < tolerance for l1, l2 in zip(lengths1, lengths2)):
282277
return False
283-
278+
284279
# Check angles between vectors
285280
def get_angles(vectors):
286281
angles = []
287282
for i in range(3):
288-
for j in range(i+1, 3):
283+
for j in range(i + 1, 3):
289284
cos_angle = np.dot(vectors[i], vectors[j]) / (np.linalg.norm(vectors[i]) * np.linalg.norm(vectors[j]))
290285
cos_angle = np.clip(cos_angle, -1.0, 1.0)
291286
angles.append(np.degrees(np.arccos(cos_angle)))
292287
return sorted(angles)
293-
288+
294289
angles1 = get_angles(v1)
295290
angles2 = get_angles(v2)
296-
291+
297292
return all(abs(a1 - a2) < tolerance * 10 for a1, a2 in zip(angles1, angles2)) # More lenient for angles
298293

299294

@@ -305,7 +300,7 @@ def _vector_matches(v1: np.ndarray, v2: np.ndarray, tolerance: float) -> bool:
305300
def _sign_patterns_match(patterns1: List[List[int]], patterns2: List[List[int]]) -> bool:
306301
"""Check if sign patterns match (allowing permutations)."""
307302
from itertools import permutations
308-
303+
309304
for perm in permutations(patterns2):
310305
if patterns1 == list(perm):
311306
return True

tests/py/unit/test_tools_analyze.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
from unit.fixtures.nanoribbon.nanoribbon import GRAPHENE_ZIGZAG_NANORIBBON
2323
from unit.utils import OSPlatform, get_platform_specific_value
2424

25-
from .fixtures.bulk import BULK_Si_CONVENTIONAL, BULK_Si_PRIMITIVE, BULK_Si_PRIMITIVIZED
25+
from .fixtures.bulk import BULK_Si_CONVENTIONAL, BULK_Si_PRIMITIVE
2626
from .fixtures.interface.zsl import GRAPHENE_NICKEL_INTERFACE
2727
from .fixtures.slab import SI_CONVENTIONAL_SLAB_001
28-
from .utils import assert_two_entities_deep_almost_equal
2928

3029
COMPARISON_PRECISION = 1e-4
3130

tests/py/unit/test_tools_analyze_lattice.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import pytest
22
from mat3ra.made.material import Material
3-
from mat3ra.made.tools.analyze.lattice import get_lattice_type, LatticeMaterialAnalyzer
3+
from mat3ra.made.tools.analyze.lattice import LatticeMaterialAnalyzer, get_lattice_type
44

5-
from .fixtures.bulk import BULK_Si_PRIMITIVE, BULK_Si_CONVENTIONAL, BULK_Si_PRIMITIVIZED, BULK_Hf2O_MCL, BULK_GRAPHITE
5+
from .fixtures.bulk import BULK_GRAPHITE, BULK_Hf2O_MCL, BULK_Si_CONVENTIONAL, BULK_Si_PRIMITIVE, BULK_Si_PRIMITIVIZED
66
from .fixtures.interface.gr_ni_111_top_hcp import GRAPHENE_NICKEL_INTERFACE_TOP_HCP
77
from .utils import assert_two_entities_deep_almost_equal
88

@@ -23,7 +23,6 @@ def test_lattice_material_analyzer(
2323
assert_two_entities_deep_almost_equal(primitive_cell_generated, expected_primitive_material_config)
2424

2525

26-
2726
@pytest.mark.parametrize(
2827
"material, expected_lattice_type",
2928
[

0 commit comments

Comments
 (0)