Skip to content

Commit d99be06

Browse files
committed
update: generate classes
1 parent cd3890a commit d99be06

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

src/py/mat3ra/code/entity.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ def __init__(self, code: ValidationErrorCode, details: Optional[ErrorDetails] =
3333

3434

3535
class InMemoryEntityPydantic(BaseModel):
36-
model_config = {"arbitrary_types_allowed": True}
36+
model_config = {"arbitrary_types_allowed": True, "extra": "allow"}
37+
38+
# Factory helper field mapping field names to class names
39+
_class_factory: Dict = {}
3740

3841
@classmethod
3942
def create(cls: Type[T], config: Dict[str, Any]) -> T:
@@ -61,6 +64,18 @@ def clean(cls: Type[T], config: Dict[str, Any]) -> Dict[str, Any]:
6164
validated_model = cls.model_validate(config)
6265
return validated_model.model_dump()
6366

67+
def model_post_init(self, __context: Any) -> None:
68+
for field_name, field_value in self.__dict__.items():
69+
if isinstance(field_value, BaseModel):
70+
class_reference = self._class_factory.get(field_name)
71+
if not class_reference:
72+
continue
73+
else:
74+
instance_field_name = field_name + "_instance"
75+
config = field_value.model_dump() # convert from BaseModel to dict
76+
class_instance = class_reference(**config)
77+
setattr(self, instance_field_name, class_instance)
78+
6479
def get_cls_name(self) -> str:
6580
return self.__class__.__name__
6681

tests/py/unit/test_entity.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,26 @@
66
REFERENCE_OBJECT_VALID = {"key1": "value1", "key2": 1}
77
REFERENCE_OBJECT_INVALID = {"key1": "value1", "key2": "value2"}
88
REFERENCE_OBJECT_VALID_JSON = json.dumps(REFERENCE_OBJECT_VALID)
9+
REFERENCE_OBJECT_NESTED_VALID = {"nested_key1": {**REFERENCE_OBJECT_VALID}}
910

1011

1112
class ExampleSchema(BaseModel):
1213
key1: str
1314
key2: int
1415

1516

17+
class ExampleNestedSchema(BaseModel):
18+
nested_key1: ExampleSchema
19+
20+
1621
class ExampleClass(ExampleSchema, InMemoryEntityPydantic):
1722
pass
1823

1924

25+
class ExampleNestedClass(ExampleNestedSchema, InMemoryEntityPydantic):
26+
_class_factory = {"nested_key1": ExampleClass}
27+
28+
2029
example_class_instance_valid = ExampleClass(**REFERENCE_OBJECT_VALID)
2130

2231

@@ -27,6 +36,16 @@ def test_create():
2736
assert in_memory_entity.key2 == 1
2837

2938

39+
def test_create_nested():
40+
# Test creating an instance with nested valid data
41+
in_memory_entity = ExampleNestedClass.create(REFERENCE_OBJECT_NESTED_VALID)
42+
assert isinstance(in_memory_entity, ExampleNestedClass)
43+
assert isinstance(in_memory_entity.nested_key1, ExampleSchema)
44+
assert in_memory_entity.nested_key1.key1 == "value1"
45+
assert in_memory_entity.nested_key1.key2 == 1
46+
assert isinstance(in_memory_entity.nested_key1_instance, ExampleClass)
47+
48+
3049
def test_validate():
3150
# Test valid case
3251
in_memory_entity = ExampleClass.create(REFERENCE_OBJECT_VALID)

0 commit comments

Comments
 (0)