Skip to content

Commit 6950591

Browse files
committed
chore: simplify init
1 parent 2598779 commit 6950591

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/py/mat3ra/code/vector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66

77

88
class Vector3D(Vector3DSchema):
9-
pass
9+
def __init__(self, root: List[float]):
10+
super().__init__(root=root)
1011

1112
@property
1213
def value(self):
1314
return self.root
1415

1516

1617
class RoundedVector3D(RoundNumericValuesMixin, Vector3D):
18+
def __init__(self, root: List[float]):
19+
super().__init__(root=root)
20+
1721
@model_serializer
1822
def to_dict(self, skip_rounding: bool = False) -> List[float]:
1923
rounded_value = self.round_array_or_number(self.root) if not skip_rounding else self.root

tests/py/unit/test_vector.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
def test_vector_init():
9-
vector = Vector3D(root=VECTOR_FLOAT)
9+
vector = Vector3D(VECTOR_FLOAT)
1010
assert vector.model_dump() == VECTOR_FLOAT
1111

1212

@@ -19,23 +19,28 @@ def test_vector_init_wrong_type():
1919

2020
def test_vector_init_wrong_size():
2121
try:
22-
_ = Vector3D(root=[1, 2])
22+
_ = Vector3D([1, 2])
2323
assert False
2424
except Exception:
2525
assert True
2626

2727

28+
def test_rounded_vector_init():
29+
vector = RoundedVector3D(VECTOR_FLOAT)
30+
assert vector.model_dump() == VECTOR_FLOAT
31+
32+
2833
def test_rounded_vector_serialization():
2934
class_reference = RoundedVector3D
3035
class_reference.__round_precision__ = 4
31-
vector = class_reference(root=VECTOR_FLOAT)
36+
vector = class_reference(VECTOR_FLOAT)
3237
assert vector.model_dump() == VECTOR_FLOAT_ROUNDED_4
3338
assert vector.value_rounded == VECTOR_FLOAT_ROUNDED_4
3439
assert vector.value == VECTOR_FLOAT
3540

3641
class_reference = RoundedVector3D
3742
class_reference.__round_precision__ = 3
38-
vector = class_reference(root=VECTOR_FLOAT)
43+
vector = class_reference(VECTOR_FLOAT)
3944
assert vector.model_dump() == VECTOR_FLOAT_ROUNDED_3
4045
assert vector.value_rounded == VECTOR_FLOAT_ROUNDED_3
4146
assert vector.value == VECTOR_FLOAT

0 commit comments

Comments
 (0)