13
13
14
14
15
15
def detect_lattice_type_from_vectors (
16
- vectors : List [List [float ]],
17
- tolerance : float = 0.1 ,
18
- angle_tolerance : float = 5.0
16
+ vectors : List [List [float ]], tolerance : float = 0.1 , angle_tolerance : float = 5.0
19
17
) -> LatticeTypeEnum :
20
18
"""
21
19
Detect lattice type from lattice vectors using pattern matching.
22
-
20
+
23
21
Based on the primitive cell definitions from Setyawan & Curtarolo (2010).
24
-
22
+
25
23
Args:
26
24
vectors: 3x3 array of lattice vectors
27
25
tolerance: Tolerance for length comparisons (Angstroms)
28
26
angle_tolerance: Tolerance for angle comparisons (degrees)
29
-
27
+
30
28
Returns:
31
29
Detected lattice type
32
30
"""
33
31
if len (vectors ) != 3 or any (len (v ) != 3 for v in vectors ):
34
32
raise ValueError ("Expected 3x3 array of lattice vectors" )
35
-
33
+
36
34
# Convert to numpy for easier calculations
37
35
v = np .array (vectors )
38
-
36
+
39
37
# Calculate lengths
40
38
a = np .linalg .norm (v [0 ])
41
39
b = np .linalg .norm (v [1 ])
42
40
c = np .linalg .norm (v [2 ])
43
-
41
+
44
42
# Calculate angles between vectors (in degrees)
45
43
def angle_between_vectors (v1 , v2 ):
46
44
cos_angle = np .dot (v1 , v2 ) / (np .linalg .norm (v1 ) * np .linalg .norm (v2 ))
47
45
cos_angle = np .clip (cos_angle , - 1.0 , 1.0 ) # Handle numerical errors
48
46
return np .degrees (np .arccos (cos_angle ))
49
-
47
+
50
48
alpha = angle_between_vectors (v [1 ], v [2 ]) # angle between b and c
51
- beta = angle_between_vectors (v [0 ], v [2 ]) # angle between a and c
49
+ beta = angle_between_vectors (v [0 ], v [2 ]) # angle between a and c
52
50
gamma = angle_between_vectors (v [0 ], v [1 ]) # angle between a and b
53
-
51
+
54
52
# Helper functions for comparisons
55
53
def is_equal (val1 , val2 , tol = tolerance ):
56
54
return abs (val1 - val2 ) <= tol
57
-
55
+
58
56
def is_angle_equal (angle1 , angle2 , tol = angle_tolerance ):
59
57
return abs (angle1 - angle2 ) <= tol
60
-
58
+
61
59
def is_right_angle (angle , tol = angle_tolerance ):
62
60
return is_angle_equal (angle , 90.0 , tol )
63
-
61
+
64
62
def is_120_angle (angle , tol = angle_tolerance ):
65
63
return is_angle_equal (angle , 120.0 , tol )
66
-
64
+
67
65
def is_60_angle (angle , tol = angle_tolerance ):
68
66
return is_angle_equal (angle , 60.0 , tol )
69
-
67
+
70
68
# Check for specific lattice patterns
71
-
69
+
72
70
# 1. Cubic patterns (CUB, FCC, BCC)
73
71
if is_equal (a , b ) and is_equal (b , c ):
74
72
# All lengths equal
@@ -84,21 +82,21 @@ def is_60_angle(angle, tol=angle_tolerance):
84
82
return LatticeTypeEnum .BCC
85
83
else :
86
84
return LatticeTypeEnum .RHL
87
-
85
+
88
86
# 2. Tetragonal patterns (TET, BCT)
89
87
if is_equal (a , b ) and not is_equal (a , c ):
90
88
if is_right_angle (alpha ) and is_right_angle (beta ) and is_right_angle (gamma ):
91
89
return LatticeTypeEnum .TET
92
90
elif _is_bct_pattern (v , tolerance ):
93
91
return LatticeTypeEnum .BCT
94
-
92
+
95
93
# 3. Hexagonal pattern
96
94
if is_equal (a , b ) and not is_equal (a , c ):
97
95
if is_right_angle (alpha ) and is_right_angle (beta ) and is_120_angle (gamma ):
98
96
return LatticeTypeEnum .HEX
99
97
elif _is_hex_pattern (v , tolerance ):
100
98
return LatticeTypeEnum .HEX
101
-
99
+
102
100
# 4. Orthorhombic patterns (ORC, ORCF, ORCI, ORCC)
103
101
if not is_equal (a , b ) and not is_equal (b , c ) and not is_equal (a , c ):
104
102
if is_right_angle (alpha ) and is_right_angle (beta ) and is_right_angle (gamma ):
@@ -109,18 +107,18 @@ def is_60_angle(angle, tol=angle_tolerance):
109
107
return LatticeTypeEnum .ORCI
110
108
elif _is_orcc_pattern (v , tolerance ):
111
109
return LatticeTypeEnum .ORCC
112
-
110
+
113
111
# 5. Monoclinic patterns (MCL, MCLC)
114
112
if is_right_angle (alpha ) and is_right_angle (gamma ) and not is_right_angle (beta ):
115
113
if _is_mclc_pattern (v , tolerance ):
116
114
return LatticeTypeEnum .MCLC
117
115
else :
118
116
return LatticeTypeEnum .MCL
119
-
117
+
120
118
# 6. Rhombohedral (already checked above in cubic section)
121
119
if is_equal (a , b ) and is_equal (b , c ) and is_angle_equal (alpha , beta ) and is_angle_equal (beta , gamma ):
122
120
return LatticeTypeEnum .RHL
123
-
121
+
124
122
# Default to triclinic if no pattern matches
125
123
return LatticeTypeEnum .TRI
126
124
@@ -130,14 +128,10 @@ def _is_bcc_pattern(vectors: np.ndarray, tolerance: float) -> bool:
130
128
# BCC primitive: [-a/2, a/2, a/2], [a/2, -a/2, a/2], [a/2, a/2, -a/2]
131
129
v = vectors
132
130
a_est = np .linalg .norm (v [0 ])
133
-
131
+
134
132
# Expected BCC pattern (normalized)
135
- expected = np .array ([
136
- [- 0.5 , 0.5 , 0.5 ],
137
- [0.5 , - 0.5 , 0.5 ],
138
- [0.5 , 0.5 , - 0.5 ]
139
- ]) * a_est
140
-
133
+ expected = np .array ([[- 0.5 , 0.5 , 0.5 ], [0.5 , - 0.5 , 0.5 ], [0.5 , 0.5 , - 0.5 ]]) * a_est
134
+
141
135
# Check if current vectors match expected pattern (allowing for rotations)
142
136
return _vectors_match_pattern (v , expected , tolerance )
143
137
@@ -149,64 +143,61 @@ def _is_bct_pattern(vectors: np.ndarray, tolerance: float) -> bool:
149
143
lengths = [np .linalg .norm (vec ) for vec in v ]
150
144
a_est = min (lengths )
151
145
c_est = max (lengths )
152
-
153
- expected = np .array ([
154
- [- a_est / 2 , a_est / 2 , c_est / 2 ],
155
- [a_est / 2 , - a_est / 2 , c_est / 2 ],
156
- [a_est / 2 , a_est / 2 , - c_est / 2 ]
157
- ])
158
-
146
+
147
+ expected = np .array (
148
+ [[- a_est / 2 , a_est / 2 , c_est / 2 ], [a_est / 2 , - a_est / 2 , c_est / 2 ], [a_est / 2 , a_est / 2 , - c_est / 2 ]]
149
+ )
150
+
159
151
return _vectors_match_pattern (v , expected , tolerance )
160
152
161
153
162
154
def _is_hex_pattern (vectors : np .ndarray , tolerance : float ) -> bool :
163
155
"""Check if vectors match HEX primitive cell pattern."""
164
156
# HEX primitive: [a/2, -a*sqrt(3)/2, 0], [a/2, a*sqrt(3)/2, 0], [0, 0, c]
165
157
v = vectors
166
-
158
+
167
159
# Find the vector along z-axis (should be [0, 0, c])
168
160
z_vector_idx = None
169
161
for i , vec in enumerate (v ):
170
162
if abs (vec [0 ]) < tolerance and abs (vec [1 ]) < tolerance and abs (vec [2 ]) > tolerance :
171
163
z_vector_idx = i
172
164
break
173
-
165
+
174
166
if z_vector_idx is None :
175
167
return False
176
-
177
- c_est = abs (v [z_vector_idx ][2 ])
178
-
168
+
179
169
# Check other two vectors
180
170
other_indices = [i for i in range (3 ) if i != z_vector_idx ]
181
171
v1 , v2 = v [other_indices [0 ]], v [other_indices [1 ]]
182
-
172
+
183
173
# They should have equal lengths and specific pattern
184
174
if not abs (np .linalg .norm (v1 ) - np .linalg .norm (v2 )) < tolerance :
185
175
return False
186
-
176
+
187
177
a_est = np .linalg .norm (v1 )
188
-
178
+
189
179
# Check if they match hexagonal pattern
190
- expected_1 = np .array ([a_est / 2 , - a_est * math .sqrt (3 )/ 2 , 0 ])
191
- expected_2 = np .array ([a_est / 2 , a_est * math .sqrt (3 )/ 2 , 0 ])
192
-
193
- return (_vector_matches (v1 , expected_1 , tolerance ) and _vector_matches (v2 , expected_2 , tolerance )) or \
194
- (_vector_matches (v1 , expected_2 , tolerance ) and _vector_matches (v2 , expected_1 , tolerance ))
180
+ expected_1 = np .array ([a_est / 2 , - a_est * math .sqrt (3 ) / 2 , 0 ])
181
+ expected_2 = np .array ([a_est / 2 , a_est * math .sqrt (3 ) / 2 , 0 ])
182
+
183
+ return (_vector_matches (v1 , expected_1 , tolerance ) and _vector_matches (v2 , expected_2 , tolerance )) or (
184
+ _vector_matches (v1 , expected_2 , tolerance ) and _vector_matches (v2 , expected_1 , tolerance )
185
+ )
195
186
196
187
197
188
def _is_orcf_pattern (vectors : np .ndarray , tolerance : float ) -> bool :
198
189
"""Check if vectors match ORCF primitive cell pattern."""
199
190
# ORCF primitive: [0, b/2, c/2], [a/2, 0, c/2], [a/2, b/2, 0]
200
191
v = vectors
201
-
192
+
202
193
# Each vector should have one zero component
203
194
zero_components = []
204
195
for vec in v :
205
196
zero_count = sum (1 for x in vec if abs (x ) < tolerance )
206
197
if zero_count != 1 :
207
198
return False
208
199
zero_components .append ([i for i , x in enumerate (vec ) if abs (x ) < tolerance ][0 ])
209
-
200
+
210
201
# Should have one zero in each dimension
211
202
return set (zero_components ) == {0 , 1 , 2 }
212
203
@@ -216,18 +207,18 @@ def _is_orci_pattern(vectors: np.ndarray, tolerance: float) -> bool:
216
207
# ORCI primitive: [-a/2, b/2, c/2], [a/2, -b/2, c/2], [a/2, b/2, -c/2]
217
208
# Similar to BCC but with different lengths
218
209
v = vectors
219
-
210
+
220
211
# Check if it has the body-centered pattern with different lengths
221
212
lengths = [np .linalg .norm (vec ) for vec in v ]
222
213
if len (set (np .round (lengths , 3 ))) == 1 : # All equal lengths -> not ORCI
223
214
return False
224
-
215
+
225
216
# Check sign pattern
226
217
sign_patterns = []
227
218
for vec in v :
228
219
signs = [1 if x > tolerance else (- 1 if x < - tolerance else 0 ) for x in vec ]
229
220
sign_patterns .append (signs )
230
-
221
+
231
222
# Should have body-centered sign pattern
232
223
expected_patterns = [[- 1 , 1 , 1 ], [1 , - 1 , 1 ], [1 , 1 , - 1 ]]
233
224
return _sign_patterns_match (sign_patterns , expected_patterns )
@@ -237,38 +228,42 @@ def _is_orcc_pattern(vectors: np.ndarray, tolerance: float) -> bool:
237
228
"""Check if vectors match ORCC primitive cell pattern."""
238
229
# ORCC primitive: [a/2, b/2, 0], [-a/2, b/2, 0], [0, 0, c]
239
230
v = vectors
240
-
231
+
241
232
# One vector should be along z-axis
242
233
z_vector_count = sum (1 for vec in v if abs (vec [0 ]) < tolerance and abs (vec [1 ]) < tolerance )
243
234
if z_vector_count != 1 :
244
235
return False
245
-
236
+
246
237
# Other two should be in xy-plane with specific pattern
247
238
xy_vectors = [vec for vec in v if not (abs (vec [0 ]) < tolerance and abs (vec [1 ]) < tolerance )]
248
239
if len (xy_vectors ) != 2 :
249
240
return False
250
-
241
+
251
242
# Check if they have opposite x-components and same y-components
252
243
v1 , v2 = xy_vectors
253
- return (abs (v1 [0 ] + v2 [0 ]) < tolerance and abs (v1 [1 ] - v2 [1 ]) < tolerance and
254
- abs (v1 [2 ]) < tolerance and abs (v2 [2 ]) < tolerance )
244
+ return (
245
+ abs (v1 [0 ] + v2 [0 ]) < tolerance
246
+ and abs (v1 [1 ] - v2 [1 ]) < tolerance
247
+ and abs (v1 [2 ]) < tolerance
248
+ and abs (v2 [2 ]) < tolerance
249
+ )
255
250
256
251
257
252
def _is_mclc_pattern (vectors : np .ndarray , tolerance : float ) -> bool :
258
253
"""Check if vectors match MCLC primitive cell pattern."""
259
254
# MCLC primitive: [a/2, b/2, 0], [-a/2, b/2, 0], [0, c*cos(alpha), c*sin(alpha)]
260
255
v = vectors
261
-
256
+
262
257
# Similar to ORCC but with one vector at an angle
263
258
xy_vectors = []
264
259
angled_vectors = []
265
-
260
+
266
261
for vec in v :
267
262
if abs (vec [2 ]) < tolerance :
268
263
xy_vectors .append (vec )
269
264
else :
270
265
angled_vectors .append (vec )
271
-
266
+
272
267
return len (xy_vectors ) == 2 and len (angled_vectors ) == 1
273
268
274
269
@@ -277,23 +272,23 @@ def _vectors_match_pattern(v1: np.ndarray, v2: np.ndarray, tolerance: float) ->
277
272
# Simple check: compare sorted lengths and angles
278
273
lengths1 = sorted ([np .linalg .norm (vec ) for vec in v1 ])
279
274
lengths2 = sorted ([np .linalg .norm (vec ) for vec in v2 ])
280
-
275
+
281
276
if not all (abs (l1 - l2 ) < tolerance for l1 , l2 in zip (lengths1 , lengths2 )):
282
277
return False
283
-
278
+
284
279
# Check angles between vectors
285
280
def get_angles (vectors ):
286
281
angles = []
287
282
for i in range (3 ):
288
- for j in range (i + 1 , 3 ):
283
+ for j in range (i + 1 , 3 ):
289
284
cos_angle = np .dot (vectors [i ], vectors [j ]) / (np .linalg .norm (vectors [i ]) * np .linalg .norm (vectors [j ]))
290
285
cos_angle = np .clip (cos_angle , - 1.0 , 1.0 )
291
286
angles .append (np .degrees (np .arccos (cos_angle )))
292
287
return sorted (angles )
293
-
288
+
294
289
angles1 = get_angles (v1 )
295
290
angles2 = get_angles (v2 )
296
-
291
+
297
292
return all (abs (a1 - a2 ) < tolerance * 10 for a1 , a2 in zip (angles1 , angles2 )) # More lenient for angles
298
293
299
294
@@ -305,7 +300,7 @@ def _vector_matches(v1: np.ndarray, v2: np.ndarray, tolerance: float) -> bool:
305
300
def _sign_patterns_match (patterns1 : List [List [int ]], patterns2 : List [List [int ]]) -> bool :
306
301
"""Check if sign patterns match (allowing permutations)."""
307
302
from itertools import permutations
308
-
303
+
309
304
for perm in permutations (patterns2 ):
310
305
if patterns1 == list (perm ):
311
306
return True
0 commit comments