Skip to content

Commit

Permalink
feat: allow recursive schemas (#104)
Browse files Browse the repository at this point in the history
* feat: support recursive schemas
  • Loading branch information
simosho authored Mar 10, 2024
1 parent 46684c6 commit 4d8a20e
Show file tree
Hide file tree
Showing 11 changed files with 265 additions and 54 deletions.
130 changes: 88 additions & 42 deletions jsf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
),
initial_state: Dict[str, Any] = MappingProxyType({}),
allow_none_optionals: confloat(ge=0.0, le=1.0) = 0.5,
max_recursive_depth: int = 10,
):
"""Initializes the JSF generator with the provided schema and
configuration options.
Expand All @@ -62,16 +63,19 @@ def __init__(
context (Dict[str, Any], optional): A dictionary that provides additional utilities for handling the schema, such as a faker for generating fake data, a random number generator, and datetime utilities. It also includes an internal dictionary for handling List, Union, and Tuple types. Defaults to a dictionary with "faker", "random", "datetime", and "__internal__" keys.
initial_state (Dict[str, Any], optional): A dictionary that represents the initial state of the parser. If you wish to extend the state so it can be accesses by your schema you can add any references in here. Defaults to an empty dictionary.
allow_none_optionals (confloat, optional): A parameter that determines the probability of optional fields being set to None. Defaults to 0.5.
max_recursive_depth (int, optional): A parameter that determines the maximum depth when generating a recursive schema. Defaults to 10.
"""
self.root_schema = schema
self.definitions = {}
self.base_state = {
"__counter__": count(start=1),
"__all_json_paths__": [],
"__depth__": 0,
**initial_state,
}
self.base_context = context
self.allow_none_optionals = allow_none_optionals
self.max_recursive_depth = max_recursive_depth

self.root = None
self._parse(schema)
Expand All @@ -89,6 +93,7 @@ def from_json(
),
initial_state: Dict[str, Any] = MappingProxyType({}),
allow_none_optionals: confloat(ge=0.0, le=1.0) = 0.5,
max_recursive_depth: int = 10,
) -> "JSF":
"""Initializes the JSF generator with the provided schema at the given
path and configuration options.
Expand All @@ -98,9 +103,12 @@ def from_json(
context (Dict[str, Any], optional): A dictionary that provides additional utilities for handling the schema, such as a faker for generating fake data, a random number generator, and datetime utilities. It also includes an internal dictionary for handling List, Union, and Tuple types. Defaults to a dictionary with "faker", "random", "datetime", and "__internal__" keys.
initial_state (Dict[str, Any], optional): A dictionary that represents the initial state of the parser. If you wish to extend the state so it can be accesses by your schema you can add any references in here. Defaults to an empty dictionary.
allow_none_optionals (confloat, optional): A parameter that determines the probability of optional fields being set to None. Defaults to 0.5.
max_recursive_depth (int, optional): A parameter that determines the maximum depth when generating a recursive schema. Defaults to 10.
"""
with open(path) as f:
return JSF(json.load(f), context, initial_state, allow_none_optionals)
return JSF(
json.load(f), context, initial_state, allow_none_optionals, max_recursive_depth
)

def __parse_primitive(self, name: str, path: str, schema: Dict[str, Any]) -> PrimitiveTypes:
item_type, is_nullable = self.__is_field_nullable(schema)
Expand All @@ -111,62 +119,79 @@ def __parse_primitive(self, name: str, path: str, schema: Dict[str, Any]) -> Pri
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)

def __parse_object(self, name: str, path: str, schema: Dict[str, Any]) -> Object:
def __parse_object(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> Object:
_, is_nullable = self.__is_field_nullable(schema)
model = Object.from_dict(
{
"name": name,
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)
root = model if root is None else root
props = []
for _name, definition in schema.get("properties", {}).items():
props.append(self.__parse_definition(_name, path=f"{path}/{_name}", schema=definition))
props.append(
self.__parse_definition(_name, path=f"{path}/{_name}", schema=definition, root=root)
)
model.properties = props
pattern_props = []
for _name, definition in schema.get("patternProperties", {}).items():
pattern_props.append(
self.__parse_definition(_name, path=f"{path}/{_name}", schema=definition)
self.__parse_definition(_name, path=f"{path}/{_name}", schema=definition, root=root)
)
model.patternProperties = pattern_props

return model

def __parse_array(self, name: str, path: str, schema: Dict[str, Any]) -> Array:
def __parse_array(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> Array:
_, is_nullable = self.__is_field_nullable(schema)
arr = Array.from_dict(
{
"name": name,
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)
arr.items = self.__parse_definition(name, name, schema["items"])
root = arr if root is None else root
arr.items = self.__parse_definition(name, f"{path}/items", schema["items"], root=root)
return arr

def __parse_tuple(self, name: str, path: str, schema: Dict[str, Any]) -> JSFTuple:
def __parse_tuple(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> JSFTuple:
_, is_nullable = self.__is_field_nullable(schema)
arr = JSFTuple.from_dict(
{
"name": name,
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)
root = arr if root is None else root
arr.items = []
for i, item in enumerate(schema["items"]):
arr.items.append(self.__parse_definition(name, path=f"{name}[{i}]", schema=item))
arr.items.append(
self.__parse_definition(name, path=f"{path}/{name}[{i}]", schema=item, root=root)
)
return arr

def __is_field_nullable(self, schema: Dict[str, Any]) -> Tuple[str, bool]:
Expand All @@ -181,40 +206,55 @@ def __is_field_nullable(self, schema: Dict[str, Any]) -> Tuple[str, bool]:
return random.choice(item_type_deep_copy), False
return item_type, False

def __parse_anyOf(self, name: str, path: str, schema: Dict[str, Any]) -> AnyOf:
def __parse_anyOf(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> AnyOf:
model = AnyOf(name=name, path=path, max_recursive_depth=self.max_recursive_depth, **schema)
root = model if root is None else root
schemas = []
for d in schema["anyOf"]:
schemas.append(self.__parse_definition(name, path, d))
return AnyOf(name=name, path=path, schemas=schemas, **schema)
schemas.append(self.__parse_definition(name, path, d, root=root))
model.schemas = schemas
return model

def __parse_allOf(self, name: str, path: str, schema: Dict[str, Any]) -> AllOf:
def __parse_allOf(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> AllOf:
combined_schema = dict(ChainMap(*schema["allOf"]))
return AllOf(
name=name,
path=path,
combined_schema=self.__parse_definition(name, path, combined_schema),
**schema,
)
model = AllOf(name=name, path=path, max_recursive_depth=self.max_recursive_depth, **schema)
root = model if root is None else root
model.combined_schema = self.__parse_definition(name, path, combined_schema, root=root)
return model

def __parse_oneOf(self, name: str, path: str, schema: Dict[str, Any]) -> OneOf:
def __parse_oneOf(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> OneOf:
model = OneOf(name=name, path=path, max_recursive_depth=self.max_recursive_depth, **schema)
root = model if root is None else root
schemas = []
for d in schema["oneOf"]:
schemas.append(self.__parse_definition(name, path, d))
return OneOf(name=name, path=path, schemas=schemas, **schema)
schemas.append(self.__parse_definition(name, path, d, root=root))
model.schemas = schemas
return model

def __parse_named_definition(self, def_name: str) -> AllTypes:
def __parse_named_definition(self, path: str, def_name: str, root) -> AllTypes:
schema = self.root_schema
parsed_definition = None
for def_tag in ("definitions", "$defs"):
for name, definition in schema.get(def_tag, {}).items():
if name == def_name:
parsed_definition = self.__parse_definition(
name, path=f"#/{def_tag}", schema=definition
)
self.definitions[f"#/{def_tag}/{name}"] = parsed_definition
if path.startswith(f"#/{def_tag}/{def_name}"):
root.is_recursive = True
return root
definition = schema.get(def_tag, {}).get(def_name)
if definition is not None:
parsed_definition = self.__parse_definition(
def_name, path=f"{path}/#/{def_tag}/{def_name}", schema=definition, root=root
)
self.definitions[f"#/{def_tag}/{def_name}"] = parsed_definition
return parsed_definition

def __parse_definition(self, name: str, path: str, schema: Dict[str, Any]) -> AllTypes:
def __parse_definition(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> AllTypes:
self.base_state["__all_json_paths__"].append(path)
item_type, is_nullable = self.__is_field_nullable(schema)
if "const" in schema:
Expand All @@ -232,25 +272,26 @@ def __parse_definition(self, name: str, path: str, schema: Dict[str, Any]) -> Al
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)
elif "type" in schema:
if item_type == "object" and "properties" in schema:
return self.__parse_object(name, path, schema)
return self.__parse_object(name, path, schema, root)
elif item_type == "object" and "anyOf" in schema:
return self.__parse_anyOf(name, path, schema)
return self.__parse_anyOf(name, path, schema, root)
elif item_type == "object" and "allOf" in schema:
return self.__parse_allOf(name, path, schema)
return self.__parse_allOf(name, path, schema, root)
elif item_type == "object" and "oneOf" in schema:
return self.__parse_oneOf(name, path, schema)
return self.__parse_oneOf(name, path, schema, root)
elif item_type == "array":
if (schema.get("contains") is not None) or isinstance(schema.get("items"), dict):
return self.__parse_array(name, path, schema)
return self.__parse_array(name, path, schema, root)
if isinstance(schema.get("items"), list) and all(
isinstance(x, dict) for x in schema.get("items", [])
):
return self.__parse_tuple(name, path, schema)
return self.__parse_tuple(name, path, schema, root)
else:
return self.__parse_primitive(name, path, schema)
elif "$ref" in schema:
Expand All @@ -261,28 +302,33 @@ def __parse_definition(self, name: str, path: str, schema: Dict[str, Any]) -> Al
else:
# parse referenced definition
ref_name = frag.split("/")[-1]
cls = self.__parse_named_definition(ref_name)
cls = self.__parse_named_definition(path, ref_name, root)
else:
with s_open(ext, "r") as f:
external_jsf = JSF(json.load(f))
cls = deepcopy(external_jsf.definitions.get(f"#{frag}"))
cls.name = name
cls.path = path
if path != "#" and cls == root:
cls.name = name
elif path != "#":
cls.name = name
cls.path = path
return cls
elif "anyOf" in schema:
return self.__parse_anyOf(name, path, schema)
return self.__parse_anyOf(name, path, schema, root)
elif "allOf" in schema:
return self.__parse_allOf(name, path, schema)
return self.__parse_allOf(name, path, schema, root)
elif "oneOf" in schema:
return self.__parse_oneOf(name, path, schema)
return self.__parse_oneOf(name, path, schema, root)
else:
raise ValueError(f"Cannot parse schema {repr(schema)}") # pragma: no cover

def _parse(self, schema: Dict[str, Any]) -> AllTypes:
for def_tag in ("definitions", "$defs"):
for name, definition in schema.get(def_tag, {}).items():
if f"#/{def_tag}/{name}" not in self.definitions:
item = self.__parse_definition(name, path=f"#/{def_tag}", schema=definition)
item = self.__parse_definition(
name, path=f"#/{def_tag}/{name}", schema=definition
)
self.definitions[f"#/{def_tag}/{name}"] = item

self.root = self.__parse_definition(name="root", path="#", schema=schema)
Expand Down
7 changes: 6 additions & 1 deletion jsf/schema_types/_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def generate(self, context: Dict[str, Any]) -> Optional[List[Tuple]]:
try:
return super().generate(context)
except ProviderNotSetException:
return tuple(item.generate(context) for item in self.items)
depth = context["state"]["__depth__"]
output = []
for item in self.items:
output.append(item.generate(context))
context["state"]["__depth__"] = depth
return tuple(output)

def model(self, context: Dict[str, Any]) -> Tuple[Type, Any]:
_type = eval(
Expand Down
5 changes: 4 additions & 1 deletion jsf/schema_types/anyof.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def generate(self, context: Dict[str, Any]) -> Optional[Any]:
try:
return super().generate(context)
except ProviderNotSetException:
return random.choice(self.schemas).generate(context)
filtered_schemas = []
if context["state"]["__depth__"] > self.max_recursive_depth:
filtered_schemas = [schema for schema in self.schemas if not schema.is_recursive]
return random.choice(filtered_schemas or self.schemas).generate(context)

def model(self, context: Dict[str, Any]) -> None:
pass
11 changes: 7 additions & 4 deletions jsf/schema_types/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,22 @@ def generate(self, context: Dict[str, Any]) -> Optional[List[Any]]:
elif isinstance(self.fixed, int):
self.minItems = self.maxItems = self.fixed

output = [
self.items.generate(context)
for _ in range(random.randint(int(self.minItems), int(self.maxItems)))
]
depth = context["state"]["__depth__"]
output = []
for _ in range(random.randint(int(self.minItems), int(self.maxItems))):
output.append(self.items.generate(context))
context["state"]["__depth__"] = depth
if self.uniqueItems and self.items.type == "object":
output = [dict(s) for s in {frozenset(d.items()) for d in output}]
while len(output) < self.minItems:
output.append(self.items.generate(context))
output = [dict(s) for s in {frozenset(d.items()) for d in output}]
context["state"]["__depth__"] = depth
elif self.uniqueItems:
output = set(output)
while len(output) < self.minItems:
output.add(self.items.generate(context))
context["state"]["__depth__"] = depth
output = list(output)
return output

Expand Down
10 changes: 9 additions & 1 deletion jsf/schema_types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,25 @@ class BaseSchema(BaseModel):
provider: Optional[str] = Field(None, alias="$provider")
set_state: Optional[Dict[str, str]] = Field(None, alias="$state")
is_nullable: bool = False
is_recursive: bool = False
allow_none_optionals: float = Field(0.5, ge=0.0, le=1.0)
max_recursive_depth: int = 10

@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
raise NotImplementedError # pragma: no cover

def generate(self, context: Dict[str, Any]) -> Any:
if self.is_recursive:
context["state"]["__depth__"] += 1

if self.set_state is not None:
context["state"][self.path] = {k: eval(v, context)() for k, v in self.set_state.items()}

if self.is_nullable and random.uniform(0, 1) < self.allow_none_optionals:
if self.is_nullable and (
random.uniform(0, 1) < self.allow_none_optionals
or context["state"]["__depth__"] > self.max_recursive_depth
):
return None
if self.provider is not None:
return eval(self.provider, context)()
Expand Down
Loading

0 comments on commit 4d8a20e

Please sign in to comment.