diff --git a/lanedet/utils/config.py b/lanedet/utils/config.py index 8158f76..9bf7cff 100755 --- a/lanedet/utils/config.py +++ b/lanedet/utils/config.py @@ -1,8 +1,10 @@ # Copyright (c) Open-MMLab. All rights reserved. import ast import os.path as osp +import platform import shutil import sys +import re import tempfile from argparse import Action, ArgumentParser from collections import abc @@ -12,8 +14,12 @@ from yapf.yapflib.yapf_api import FormatCode + + + BASE_KEY = '_base_' DELETE_KEY = '_delete_' +DEPRECATION_KEY = '_deprecation_' RESERVED_KEYS = ['filename', 'text', 'pretty_text'] def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): @@ -93,16 +99,70 @@ def _validate_py_syntax(filename): f'file {filename}') @staticmethod - def _file2dict(filename): + def _substitute_predefined_vars(filename, temp_config_name): + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname) + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + for key, value in support_templates.items(): + regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' + value = value.replace('\\', '/') + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, 'w') as tmp_config_file: + tmp_config_file.write(config_file) + + @staticmethod + def _pre_substitute_base_vars(filename, temp_config_name): + """Substitute base variable placehoders to string, so that parsing + would work.""" + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + base_var_dict = {} + regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' + base_vars = set(re.findall(regexp, config_file)) + for base_var in base_vars: + randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' + base_var_dict[randstr] = base_var + regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' + config_file = re.sub(regexp, f'"{randstr}"', config_file) + with open(temp_config_name, 'w') as tmp_config_file: + tmp_config_file.write(config_file) + return base_var_dict + + @staticmethod + def _file2dict(filename, use_predefined_variables=True): filename = osp.abspath(osp.expanduser(filename)) check_file_exist(filename) - if filename.endswith('.py'): - with tempfile.TemporaryDirectory() as temp_config_dir: - temp_config_file = tempfile.NamedTemporaryFile( - dir=temp_config_dir, suffix='.py') - temp_config_name = osp.basename(temp_config_file.name) - shutil.copyfile(filename, - osp.join(temp_config_dir, temp_config_name)) + fileExtname = osp.splitext(filename)[1] + if fileExtname not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile( + dir=temp_config_dir, suffix=fileExtname) + if platform.system() == 'Windows': + temp_config_file.close() + temp_config_name = osp.basename(temp_config_file.name) + # Substitute predefined variables + if use_predefined_variables: + Config._substitute_predefined_vars(filename, + temp_config_file.name) + else: + shutil.copyfile(filename, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars( + temp_config_file.name, temp_config_file.name) + + if filename.endswith('.py'): temp_module_name = osp.splitext(temp_config_name)[0] sys.path.insert(0, temp_config_dir) Config._validate_py_syntax(filename) @@ -115,16 +175,28 @@ def _file2dict(filename): } # delete imported module del sys.modules[temp_module_name] - # close temp file - temp_config_file.close() - elif filename.endswith(('.yml', '.yaml', '.json')): - import mmcv - cfg_dict = mmcv.load(filename) - else: - raise IOError('Only py/yml/yaml/json type are supported now!') - - cfg_text = '' - with open(filename, 'r') as f: + elif filename.endswith(('.yml', '.yaml', '.json')): + import mmcv + cfg_dict = mmcv.load(temp_config_file.name) + # close temp file + temp_config_file.close() + + # check deprecation information + if DEPRECATION_KEY in cfg_dict: + deprecation_info = cfg_dict.pop(DEPRECATION_KEY) + warning_msg = f'The config file {filename} will be deprecated ' \ + 'in the future.' + if 'expected' in deprecation_info: + warning_msg += f' Please use {deprecation_info["expected"]} ' \ + 'instead.' + if 'reference' in deprecation_info: + warning_msg += ' More information can be found at ' \ + f'{deprecation_info["reference"]}' + warnings.warn(warning_msg) + + cfg_text = filename + '\n' + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows cfg_text += f.read() if BASE_KEY in cfg_dict: @@ -142,10 +214,16 @@ def _file2dict(filename): base_cfg_dict = dict() for c in cfg_dict_list: - if len(base_cfg_dict.keys() & c.keys()) > 0: - raise KeyError('Duplicate key is not allowed among bases') + duplicate_keys = base_cfg_dict.keys() & c.keys() + if len(duplicate_keys) > 0: + raise KeyError('Duplicate key is not allowed among bases. ' + f'Duplicate keys: {duplicate_keys}') base_cfg_dict.update(c) + # Subtitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, + base_cfg_dict) + base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) cfg_dict = base_cfg_dict