Skip to content

Commit 9ecd381

Browse files
committed
feat: Added more features
1 parent 835d185 commit 9ecd381

File tree

4 files changed

+93
-59
lines changed

4 files changed

+93
-59
lines changed

confly/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.0.4"
1+
__version__ = "0.0.5"
22

33
from confly.confly import Confly

confly/confly.py

Lines changed: 86 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66
import operator
77
from functools import reduce
88
import regex
9+
import math
910

1011

11-
class Confly:
12-
def __init__(self, config: Optional[Union[str, Path, dict]] = None, config_dir: Optional[Union[str, Path]] = None, args: List[str] = None, cli: bool = False):
13-
self.config = config
14-
self.config_dir = config_dir
15-
self.general_op_regex = regex.compile(r"""
12+
GENERAL_OP_REGEX = regex.compile(r"""
1613
\$\{
1714
(?P<op>\w+) # Operation name (add, mul, etc.)
1815
\s*:\s* # Colon with optional spaces
@@ -25,7 +22,9 @@ def __init__(self, config: Optional[Union[str, Path, dict]] = None, config_dir:
2522
)
2623
\}
2724
""", regex.VERBOSE)
28-
self.cfg_regex = regex.compile(r"""
25+
26+
27+
CFG_REGEX = regex.compile(r"""
2928
\$\{
3029
(?P<op>cfg) # Only match 'cfg' literally
3130
\s*:\s* # Colon with optional spaces
@@ -37,7 +36,32 @@ def __init__(self, config: Optional[Union[str, Path, dict]] = None, config_dir:
3736
)
3837
\}
3938
""", regex.VERBOSE)
40-
self.op_regex = None
39+
40+
41+
VAR_REGEX = regex.compile(r"""
42+
\$\{
43+
(?P<op>var) # Only match 'cfg' literally
44+
\s*:\s* # Colon with optional spaces
45+
(?P<arg> # Start capturing argument
46+
(?: # Non-capturing group for content
47+
[^{}]+ # Non-brace content
48+
| \{ (?0) \} # Or nested {...} recursively
49+
)*
50+
)
51+
\}
52+
""", regex.VERBOSE)
53+
54+
55+
OPERATOR_MAPPING = {
56+
"div": operator.truediv,
57+
"sqrt": None
58+
}
59+
60+
61+
class Confly:
62+
def __init__(self, config: Optional[Union[str, Path, dict]] = None, config_dir: Optional[Union[str, Path]] = None, args: List[str] = None, cli: bool = False):
63+
self.config = config
64+
self.config_dir = config_dir
4165

4266
if isinstance(self.config, Path):
4367
self.config = str(self.config)
@@ -48,33 +72,17 @@ def __init__(self, config: Optional[Union[str, Path, dict]] = None, config_dir:
4872
self.config_dir = Path.cwd()
4973

5074
if isinstance(self.config, str):
51-
# arg_configs, arg_parameters = self._parse_args(args, cli)
52-
# self.config = self._update_config(arg_configs)
53-
# self.config = self._interpolate(self.config, self._interpolate_cfg, r'\$\{cfg:\s*([^}]+)\}')
54-
# self.config = self._interpolate(self.config, self._interpolate_env, r'\$\{env:\s*([^}]+)\}')
55-
# self.config = self._update_parameters(arg_parameters)
56-
# self.config = self._interpolate(self.config, self._interpolate_cfg, r'\$\{cfg:\s*([^}]+)\}')
57-
# self.config = self._interpolate(self.config, self._interpolate_env, r'\$\{env:\s*([^}]+)\}')
58-
# self.config = self._interpolate(self.config, self._interpolate_var, r'\$\{var:\s*([^}]+)\}')
59-
# self.config = self._interpolate(self.config, self._interpolate_var, r'\$\{(add|sub|mul|div|sqrt|pow):\s*([^}]+)\}')
60-
# self.config = self._apply_recursively(self._maybe_convert_to_numeric, self.config)
61-
62-
arg_configs, arg_parameters = self._parse_args(args, cli)
75+
arg_configs, overrides = self._parse_args(args, cli)
6376
self.config = self._update_config(arg_configs)
64-
self.op_regex = self.cfg_regex
65-
self.config = self._interpolate(self.config)
66-
self.config = self._update_parameters(arg_parameters)
67-
self.op_regex = self.general_op_regex
68-
self.config = self._interpolate(self.config)
77+
self.config = self._interpolate(self.config, self.config, CFG_REGEX, "", overrides)
78+
self.config = self._update_overrides(overrides)
79+
self.config = self._interpolate(self.config, self.config, GENERAL_OP_REGEX, "", overrides)
6980
self.config = self._apply_recursively(self._maybe_convert_to_numeric, self.config)
7081

7182
for key, value in self.config.items():
7283
setattr(self, key, Confly(value) if isinstance(value, dict) else value)
7384
del self.config
7485
del self.config_dir
75-
del self.op_regex
76-
del self.general_op_regex
77-
del self.cfg_regex
7886

7987

8088
def _parse_args(self, args, cli: bool):
@@ -93,12 +101,14 @@ def _parse_args(self, args, cli: bool):
93101
args = []
94102
if cli:
95103
args.append(sys.argv[1:])
96-
configs, parameters = [], []
104+
configs, parameters = [], {}
97105
for arg in args:
98106
if "=" in arg:
99-
parameters.append(arg)
107+
arg = arg if arg[0] == "." else "." + arg
108+
arg = arg.split("=")
109+
parameters[arg[0]] = arg[1]
100110
elif "--" in arg:
101-
parameters.append(arg[2:] + "=True")
111+
parameters["." + arg[2:]] = True
102112
else:
103113
configs.append(arg)
104114
return configs, parameters
@@ -121,53 +131,61 @@ def _update_config(self, arg_configs: list):
121131
config = "${cfg:" + ",".join(arg_configs) + "}"
122132
return config
123133

124-
def _interpolate(self, obj):
134+
def _interpolate(self, obj, conf, op_regex, current_path, overrides=None):
135+
if overrides is not None and current_path in overrides:
136+
obj = overrides[current_path]
137+
return obj
125138
if isinstance(obj, dict):
126-
return {k: self._interpolate(v) for k, v in obj.items()}
139+
return {k: self._interpolate(v, conf, op_regex, f"{current_path}.{k}", overrides) for k, v in obj.items()}
127140
elif isinstance(obj, list) or isinstance(obj, tuple):
128-
return [self._interpolate(elem) for elem in obj]
129-
elif isinstance(obj, str) and self._is_entire_expression(obj):
130-
expr, op, arg = self._get_expression(obj)
131-
obj = self._interpolate_op(expr, op, arg)
141+
return [self._interpolate(elem, conf, op_regex, current_path, overrides) for elem in obj]
142+
elif isinstance(obj, str) and self._is_entire_expression(obj, op_regex):
143+
expr, op, arg = self._get_expression(obj, op_regex)
144+
obj = self._interpolate_op(expr, op, arg, conf)
145+
obj = self._interpolate(obj, conf, op_regex, current_path, overrides)
146+
if op_regex == CFG_REGEX:
147+
obj = self._interpolate(obj, obj, VAR_REGEX, current_path, overrides)
132148
return obj
133-
elif isinstance(obj, str) and self._contains_expression(obj):
134-
while self._contains_expression(obj):
135-
expr, op, arg = self._get_expression(obj)
136-
interpolated_expr = self._interpolate(expr)
137-
obj = obj.replace(expr, interpolated_expr, 1)
149+
elif isinstance(obj, str) and self._contains_expression(obj, op_regex):
150+
while self._contains_expression(obj, op_regex):
151+
expr, op, arg = self._get_expression(obj, op_regex)
152+
interpolated_expr = self._interpolate(expr, conf, op_regex, current_path, overrides)
153+
obj = obj.replace(expr, str(interpolated_expr), 1)
138154
return obj
139155
else:
140156
return obj
141157

142-
def _is_entire_expression(self, obj: str) -> bool:
143-
return bool(regex.fullmatch(self.op_regex, obj))
158+
def _is_entire_expression(self, obj: str, op_regex) -> bool:
159+
return bool(regex.fullmatch(op_regex, obj))
144160

145-
def _contains_expression(self, obj: str) -> bool:
146-
return bool(regex.search(self.op_regex, obj))
161+
def _contains_expression(self, obj: str, op_regex) -> bool:
162+
return bool(regex.search(op_regex, obj))
147163

148-
def _get_expression(self, obj: str):
149-
for m in self.op_regex.finditer(obj):
164+
def _get_expression(self, obj: str, op_regex):
165+
for m in op_regex.finditer(obj):
150166
expr = m.group(0)
151167
op = m.group("op")
152168
arg = m.group("arg")
153169
break
154170
return expr, op, arg
155171

156-
def _interpolate_op(self, expr, op, arg):
172+
def _interpolate_op(self, expr, op, arg, conf):
157173
if op == "var":
158-
return self._interpolate_var(arg)
174+
return self._interpolate_var(arg, conf)
175+
if op == "gvar":
176+
return self._interpolate_var(arg, self.config)
159177
elif op == "cfg":
160178
return self._interpolate_cfg(arg)
161179
elif op == "env":
162180
return self._interpolate_env(arg)
163-
elif hasattr(operator, op):
181+
elif hasattr(operator, op) or hasattr(math, op) or op in OPERATOR_MAPPING:
164182
return self._interpolate_math(op, arg)
165183
else:
166184
return expr
167185

168-
def _interpolate_var(self, obj):
186+
def _interpolate_var(self, obj, conf):
169187
keys = obj.split(".")
170-
interpolated_variable = self.config
188+
interpolated_variable = conf
171189
for key in keys:
172190
if key not in interpolated_variable:
173191
raise RuntimeError(f"Interpolation failed as {obj} is not defined.")
@@ -193,11 +211,22 @@ def _interpolate_env(self, obj):
193211
def _interpolate_math(self, op, args):
194212
args = [arg.strip() for arg in args.split(",")]
195213
args = self._apply_recursively(self._maybe_convert_to_numeric, args)
196-
op = getattr(operator, op)
197-
result = str(reduce(op, args))
214+
if op == "sqrt" and len(args) == 2:
215+
result = str(math.pow(args[0], 1/args[1]))
216+
elif op in OPERATOR_MAPPING:
217+
op = OPERATOR_MAPPING[op]
218+
result = str(op(*args))
219+
elif hasattr(operator, op):
220+
op = getattr(operator, op)
221+
result = str(reduce(op, args))
222+
elif hasattr(operator, math):
223+
op = getattr(math, op)
224+
result = str(op(*args))
225+
else:
226+
raise RuntimeError(f"Operator ({op}) must be a function of 'operator', 'math' or 'OPERATOR_MAPPING'.")
198227
return result
199228

200-
def _update_parameters(self, arg_parameters: list):
229+
def _update_overrides(self, overrides: list):
201230
"""
202231
Update the configuration with command-line parameter overrides.
203232
@@ -208,8 +237,8 @@ def _update_parameters(self, arg_parameters: list):
208237
Returns:
209238
dict: The updated configuration with parameter overrides applied.
210239
"""
211-
for para in arg_parameters:
212-
key_path, value = para.split("=")
240+
for key_path, value in overrides.items():
241+
key_path = key_path[1:]
213242
keys = key_path.split(".")
214243
sub_config = self.config
215244
for key in keys[:-1]:

example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from confly import Confly
22

3-
config = Confly(config="tests/configs/interpolation_test")
3+
args = ["model.arch=lol", "model.tmp=1"]
4+
config = Confly(config="tests/configs/interpolation_test", args=args)
45

56
print(config)

tests/configs/math_test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
a: ${add:5,3,1}
2+
b: ${sqrt:5,2}
3+
c: ${pow:5,3,1}
4+
d: ${div:5,3}

0 commit comments

Comments
 (0)