Skip to content

Commit

Permalink
PyMJCF recursive include tags relative to base model
Browse files Browse the repository at this point in the history
  • Loading branch information
guyazran committed Aug 3, 2023
1 parent d6f9cb4 commit 111b1d2
Showing 1 changed file with 90 additions and 22 deletions.
112 changes: 90 additions & 22 deletions dm_control/mjcf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@


def from_xml_string(xml_string, escape_separators=False,
model_dir='', resolve_references=True, assets=None):
model_dir='', resolve_references=True, assets=None,
base_model_dir=None):
"""Parses an XML string into an MJCF object model.
Args:
Expand All @@ -41,6 +42,9 @@ def from_xml_string(xml_string, escape_separators=False,
assets: (optional) A dictionary of pre-loaded assets, of the form
`{filename: bytestring}`. If present, PyMJCF will search for assets in
this dictionary before attempting to load them from the filesystem.
base_model_dir: (optional) Path to the directory containing the base model.
This is used to prefix the paths of <include> elements' file attributes
to support nested includes as in the MuJoCo compiler.
Returns:
An `mjcf.RootElement`.
Expand All @@ -49,11 +53,12 @@ def from_xml_string(xml_string, escape_separators=False,
return _parse(xml_root, escape_separators,
model_dir=model_dir,
resolve_references=resolve_references,
assets=assets)
assets=assets, base_model_dir=base_model_dir)


def from_file(file_handle, escape_separators=False,
model_dir='', resolve_references=True, assets=None):
model_dir='', resolve_references=True, assets=None,
base_model_dir=None):
"""Parses an XML file into an MJCF object model.
Args:
Expand All @@ -68,6 +73,9 @@ def from_file(file_handle, escape_separators=False,
assets: (optional) A dictionary of pre-loaded assets, of the form
`{filename: bytestring}`. If present, PyMJCF will search for assets in
this dictionary before attempting to load them from the filesystem.
base_model_dir: (optional) Path to the directory containing the base model.
This is used to prefix the paths of <include> elements' file attributes
to support nested includes as in the MuJoCo compiler.
Returns:
An `mjcf.RootElement`.
Expand All @@ -76,11 +84,11 @@ def from_file(file_handle, escape_separators=False,
return _parse(xml_root, escape_separators,
model_dir=model_dir,
resolve_references=resolve_references,
assets=assets)
assets=assets, base_model_dir=base_model_dir)


def from_path(path, escape_separators=False, resolve_references=True,
assets=None):
assets=None, base_model_dir=None):
"""Parses an XML file into an MJCF object model.
Args:
Expand All @@ -94,6 +102,9 @@ def from_path(path, escape_separators=False, resolve_references=True,
assets: (optional) A dictionary of pre-loaded assets, of the form
`{filename: bytestring}`. If present, PyMJCF will search for assets in
this dictionary before attempting to load them from the filesystem.
base_model_dir: (optional) Path to the directory containing the base model.
This is used to prefix the paths of <include> elements' file attributes
to support nested includes as in the MuJoCo compiler.
Returns:
An `mjcf.RootElement`.
Expand All @@ -103,11 +114,12 @@ def from_path(path, escape_separators=False, resolve_references=True,
xml_root = etree.fromstring(contents)
return _parse(xml_root, escape_separators,
model_dir=model_dir, resolve_references=resolve_references,
assets=assets)
assets=assets, base_model_dir=base_model_dir)


def _parse(xml_root, escape_separators=False,
model_dir='', resolve_references=True, assets=None):
model_dir='', resolve_references=True, assets=None,
base_model_dir=None):
"""Parses a complete MJCF model from an XML.
Args:
Expand All @@ -122,6 +134,9 @@ def _parse(xml_root, escape_separators=False,
assets: (optional) A dictionary of pre-loaded assets, of the form
`{filename: bytestring}`. If present, PyMJCF will search for assets in
this dictionary before attempting to load them from the filesystem.
base_model_dir: (optional) Path to the directory containing the base model.
This is used to prefix the paths of <include> elements' file attributes
to support nested includes as in the MuJoCo compiler.
Returns:
An `mjcf.RootElement`.
Expand All @@ -140,20 +155,9 @@ def _parse(xml_root, escape_separators=False,
# Recursively parse any included XML files.
to_include = []
for include_tag in xml_root.findall('include'):
try:
# First look for the path to the included XML file in the assets dict.
path_or_xml_string = assets[include_tag.attrib['file']]
parsing_func = from_xml_string
except KeyError:
# If it's not present in the assets dict then attempt to load the XML
# from the filesystem.
path_or_xml_string = os.path.join(model_dir, include_tag.attrib['file'])
parsing_func = from_path
included_mjcf = parsing_func(
path_or_xml_string,
escape_separators=escape_separators,
resolve_references=resolve_references,
assets=assets)
included_mjcf = _parse_include(include_tag, escape_separators, model_dir,
resolve_references, assets, base_model_dir)

to_include.append(included_mjcf)
# We must remove <include/> tags before parsing the main XML file, since
# these are a schema violation.
Expand All @@ -165,7 +169,7 @@ def _parse(xml_root, escape_separators=False,
except KeyError:
model = None
mjcf_root = element.RootElement(
model=model, model_dir=model_dir, assets=assets)
model=model, model_dir=base_model_dir or model_dir, assets=assets)
_parse_children(xml_root, mjcf_root, escape_separators)

# Merge in the included XML files.
Expand All @@ -180,6 +184,70 @@ def _parse(xml_root, escape_separators=False,
return mjcf_root


def _parse_include(include_tag, escape_separators, model_dir, resolve_references, assets, base_model_dir):
"""
Parses an included XML file.
Args:
include_tag: An `etree.Element` object with tag 'include'.
escape_separators: (optional) A boolean, whether to replace '/' characters
in element identifiers. If `False`, any '/' present in the XML causes
a ValueError to be raised.
model_dir: (optional) Path to the directory containing the model XML file.
This is used to prefix the paths of all asset files.
resolve_references: (optional) A boolean indicating whether the parser
should attempt to resolve reference attributes to a corresponding element.
assets: (optional) A dictionary of pre-loaded assets, of the form
`{filename: bytestring}`. If present, PyMJCF will search for assets in
this dictionary before attempting to load them from the filesystem.
base_model_dir: (optional) Path to the directory containing the base model.
This is used to prefix the paths of <include> elements' file attributes
to support nested includes as in the MuJoCo compiler.
Returns:
An `mjcf.RootElement`.
Raises:
FileNotFoundError: If the included the inner paths of the included XML could
not be resolved.
"""

base_dirs = [model_dir] # always look in the current model dir first
if base_model_dir is not None:
base_dirs.append(base_model_dir) # then look in the base model dir if provided

not_found_exception = None # a container for the final exception if some file references are not resolved

# try to parse the included XML file from each of the base dirs
for working_dir in base_dirs:

# setup new parsing kwargs dict with current base model dir
parsing_func_kwargs = dict(
escape_separators=escape_separators,
resolve_references=resolve_references,
assets=assets,
base_model_dir=working_dir
)

try:
# First look for the path to the included XML file in the assets dict.
path_or_xml_string = assets[include_tag.attrib['file']]
parsing_func = from_xml_string
parsing_func_kwargs.update(dict(model_dir=working_dir)) # requires explicit model dir
except KeyError:
# If it's not present in the assets dict then attempt to load the XML
# from the filesystem.
path_or_xml_string = os.path.join(working_dir, include_tag.attrib['file'])
parsing_func = from_path
try:
# if successfully parsed the included XML file, stop searching
return parsing_func(path_or_xml_string, **parsing_func_kwargs)
except FileNotFoundError as e:
# base model dir did not resolve the inner include paths
not_found_exception = e

raise FileNotFoundError('Could not find an appropriate base path for include tag') from not_found_exception

def _parse_children(xml_element, mjcf_element, escape_separators=False):
"""Parses all children of a given XML element into an MJCF element.
Expand Down

0 comments on commit 111b1d2

Please sign in to comment.