Skip to content

Commit 4534008

Browse files
committed
feat: add config_tools.py and refactor configs
Signed-off-by: Terry Kong <[email protected]> compare command Signed-off-by: Terry Kong <[email protected]> config changes Signed-off-by: Terry Kong <[email protected]> Revert "config changes" This reverts commit 25b87e2. Signed-off-by: Terry Kong <[email protected]> cleanup Signed-off-by: Terry Kong <[email protected]> vlm example Signed-off-by: Terry Kong <[email protected]> minimize configs Signed-off-by: Terry Kong <[email protected]> Revert "minimize configs" This reverts commit 1375480. Signed-off-by: Terry Kong <[email protected]>
1 parent 2b55598 commit 4534008

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed

tools/config_cli.py

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
#!/usr/bin/env -S uv run -q
2+
"""Utilities for working with YAML configs in this repo.
3+
4+
Subcommands:
5+
- expand: Resolve a config with OmegaConf interpolation and inheritance.
6+
- minimize: Given a base config and a config, remove keys in the config that
7+
are equal to the base, and ensure a defaults entry pointing to the base
8+
exists. The defaults path in the resulting config is written relative to
9+
the base config file.
10+
11+
Both commands support printing to stdout or in-place editing of the config file.
12+
13+
Example:
14+
# Expand a config with a root level "defaults" key to see the full config; print to stdout
15+
uv run tools/config_cli.py expand examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml
16+
17+
# Expand a config with a root level "defaults" key to see the full config; edit the config in place
18+
uv run tools/config_cli.py expand examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml --in-place
19+
20+
# Minimize a config and remove all keys that are present in the base config; print to stdout
21+
# uv run tools/config_cli.py minimize <base_config> <config>
22+
uv run tools/config_cli.py minimize examples/configs/dpo.yaml examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml
23+
24+
# Minimize a config and remove all keys that are present in the base config; edit the config in place
25+
# uv run tools/config_cli.py minimize <base_config> <config>
26+
uv run tools/config_cli.py minimize examples/configs/dpo.yaml examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml --in-place
27+
28+
# Minimize all llm the configs:
29+
for algo in grpo dpo sft; do
30+
base_config=examples/configs/${algo}.yaml
31+
if [[ ${algo} == grpo ]]; then
32+
base_config=examples/configs/grpo_math_1B.yaml
33+
fi
34+
for recipe in examples/configs/recipes/llm/${algo}-*.yaml; do
35+
uv run tools/config_cli.py minimize $base_config $recipe --in-place
36+
done
37+
done
38+
39+
# Minimize vlm configs:
40+
for recipe in examples/configs/recipes/vlm/vlm_grpo-*.yaml; do
41+
uv run tools/config_cli.py minimize examples/configs/vlm_grpo_3B.yaml $recipe --in-place
42+
done
43+
44+
# Compare two configs
45+
uv run tools/config_cli.py compare examples/configs/grpo_math_1B.yaml examples/configs/grpo_math_8B.yaml
46+
47+
# Minimize a config and compare it to not minimzing (should be the same)
48+
uv run tools/config_cli.py minimize examples/configs/dpo.yaml examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml >examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml.minimized
49+
uv run tools/config_cli.py compare \
50+
examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml \
51+
examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml.minimized
52+
"""
53+
54+
import argparse
55+
from pathlib import Path
56+
from typing import Any, Iterable
57+
58+
from omegaconf import DictConfig, OmegaConf
59+
60+
from nemo_rl.utils.config import load_config
61+
62+
63+
def _dict_like(obj: Any) -> bool:
64+
return isinstance(obj, dict)
65+
66+
67+
def _list_like(obj: Any) -> bool:
68+
return isinstance(obj, list)
69+
70+
71+
REMOVE = object()
72+
73+
74+
def _prune_equal(a: Any, b: Any) -> Any:
75+
"""Return a copy of `a` with entries equal to `b` removed.
76+
77+
- If both are dicts: recursively prune and drop keys whose subtree is empty
78+
after pruning or equal.
79+
- If both are lists of same length: recursively prune by index and drop list
80+
if becomes entirely empty or equal.
81+
- Else: if equal, return a sentinel indicating removal; otherwise return `a`.
82+
"""
83+
if _dict_like(a) and _dict_like(b):
84+
out: dict[str, Any] = {}
85+
a_dict: dict[str, Any] = a # type: ignore[assignment]
86+
b_dict: dict[str, Any] = b # type: ignore[assignment]
87+
for key, a_val in a_dict.items():
88+
if key in b_dict:
89+
pruned = _prune_equal(a_val, b_dict[key])
90+
if pruned is REMOVE:
91+
# equal, skip
92+
continue
93+
# keep if subtree has content
94+
if pruned != {} and pruned != []:
95+
out[key] = pruned
96+
else:
97+
out[key] = a_val
98+
return out
99+
100+
if _list_like(a) and _list_like(b) and len(a) == len(b):
101+
# Only remove if entire list equals base; avoid partial list pruning
102+
# to prevent semantic changes in ordered config sections.
103+
if a == b:
104+
return REMOVE
105+
return a
106+
107+
# Base types
108+
if a == b:
109+
return REMOVE
110+
return a
111+
112+
113+
def _ensure_defaults_relative(
114+
child_path: Path, base_path: Path, child_cfg: dict[str, Any]
115+
) -> None:
116+
"""Ensure `defaults:` points to the base, with a path relative to the base config file.
117+
118+
The path we store must be a string such that, when the resulting minimized
119+
config sits at `child_path`, the `defaults` string references the base
120+
config location. The instruction asks that the defaults path in the resulting
121+
config is relative to the base config; we interpret this as "express `base`
122+
relative to the directory of the base file", then make that path relative
123+
to the child config so that hydra resolution works from the child file.
124+
"""
125+
# Compute a relative reference from child dir to base file
126+
import os
127+
128+
rel_from_child_to_base = os.path.relpath(
129+
str(base_path), start=str(child_path.parent)
130+
)
131+
132+
existing = child_cfg.get("defaults")
133+
if existing is None:
134+
child_cfg["defaults"] = str(rel_from_child_to_base)
135+
return
136+
# Normalize various forms: string, single list element, list
137+
if isinstance(existing, str):
138+
existing_list: list[Any] = [existing]
139+
else:
140+
existing_list = list(existing) if isinstance(existing, Iterable) else [existing]
141+
# Put our base at the first position if not present
142+
if str(rel_from_child_to_base) not in [str(x) for x in existing_list]:
143+
existing_list.insert(0, str(rel_from_child_to_base))
144+
# If it's a single element list, collapse to string for this repo's style
145+
if len(existing_list) == 1:
146+
child_cfg["defaults"] = existing_list[0]
147+
else:
148+
child_cfg["defaults"] = existing_list
149+
150+
151+
def expand(args: argparse.Namespace) -> int:
152+
# Merge defaults/inheritance using repo loader; preserve ${...}
153+
cfg = load_config(str(Path(args.config).resolve()))
154+
# Preserve ${...} by not resolving
155+
text = OmegaConf.to_yaml(cfg)
156+
if args.in_place:
157+
Path(args.config).write_text(text)
158+
else:
159+
print(text + ("\n" if not text.endswith("\n") else ""), end="")
160+
161+
162+
def minimize(args: argparse.Namespace) -> int:
163+
child_path = Path(args.config).resolve()
164+
base_path = Path(args.base).resolve()
165+
166+
child_cfg_raw = OmegaConf.load(child_path)
167+
if not isinstance(child_cfg_raw, DictConfig):
168+
raise TypeError(
169+
f"Config at {child_path} must be a mapping (DictConfig), got {type(child_cfg_raw)}"
170+
)
171+
base_cfg_raw = OmegaConf.load(base_path)
172+
if not isinstance(base_cfg_raw, DictConfig):
173+
raise TypeError(
174+
f"Config at {base_path} must be a mapping (DictConfig), got {type(base_cfg_raw)}"
175+
)
176+
177+
# Resolve both before comparison
178+
child_resolved = OmegaConf.to_container(child_cfg_raw)
179+
base_resolved = OmegaConf.to_container(base_cfg_raw)
180+
181+
if not isinstance(child_resolved, dict) or not isinstance(base_resolved, dict):
182+
raise TypeError("Both child and base configs must be mappings after resolution")
183+
184+
pruned = _prune_equal(child_resolved, base_resolved)
185+
186+
# Ensure mapping output
187+
if pruned is None or not isinstance(pruned, dict):
188+
pruned = {} if pruned is None else {"value": pruned}
189+
190+
# Ensure defaults reference base (relative path from child)
191+
_ensure_defaults_relative(child_path, base_path, pruned)
192+
193+
# Ensure `defaults` appears first in the top-level mapping
194+
if "defaults" in pruned:
195+
pruned = {"defaults": pruned["defaults"], **pruned}
196+
197+
# Emit
198+
text = OmegaConf.to_yaml(OmegaConf.create(pruned))
199+
if args.in_place:
200+
Path(args.config).write_text(text)
201+
else:
202+
print(text + ("\n" if not text.endswith("\n") else ""), end="")
203+
204+
205+
def _flatten(d: Any, prefix: str = "") -> dict[str, Any]:
206+
out: dict[str, Any] = {}
207+
if isinstance(d, dict):
208+
for k, v in d.items():
209+
key = f"{prefix}.{k}" if prefix else str(k)
210+
out.update(_flatten(v, key))
211+
elif isinstance(d, list):
212+
for i, v in enumerate(d):
213+
key = f"{prefix}[{i}]"
214+
out.update(_flatten(v, key))
215+
else:
216+
out[prefix] = d
217+
return out
218+
219+
220+
def compare(args: argparse.Namespace) -> int:
221+
left_path = Path(args.left).resolve()
222+
right_path = Path(args.right).resolve()
223+
224+
# Expand via repo loader, then convert to plain dict/list so _flatten works
225+
left = OmegaConf.to_container(load_config(str(left_path))) # type: ignore[assignment]
226+
right = OmegaConf.to_container(load_config(str(right_path))) # type: ignore[assignment]
227+
228+
lf = _flatten(left)
229+
rf = _flatten(right)
230+
231+
left_keys = set(lf.keys())
232+
right_keys = set(rf.keys())
233+
234+
added = sorted(right_keys - left_keys)
235+
removed = sorted(left_keys - right_keys)
236+
common = sorted(left_keys & right_keys)
237+
238+
changed: list[str] = []
239+
for k in common:
240+
if lf[k] != rf[k]:
241+
changed.append(k)
242+
243+
if not added and not removed and not changed:
244+
print("Configs are identical after expansion")
245+
return 0
246+
247+
# Print concise report with explicit left/right context
248+
print("Comparing configs after expansion:")
249+
print(f" Left : {left_path}")
250+
print(f" Right: {right_path}")
251+
252+
if added:
253+
print("\nAdded in Right (missing in Left):")
254+
for k in added:
255+
print(f" {k} = {rf[k]}")
256+
257+
if removed:
258+
print("\nRemoved in Right (only in Left):")
259+
for k in removed:
260+
print(f" {k} = {lf[k]}")
261+
262+
if changed:
263+
print("\nChanged (Left -> Right):")
264+
for k in changed:
265+
print(f" {k}: {lf[k]} -> {rf[k]}")
266+
267+
268+
if __name__ == "__main__":
269+
parser = argparse.ArgumentParser(description="Config tools (expand, minimize)")
270+
sub = parser.add_subparsers(dest="cmd", required=True)
271+
272+
p_expand = sub.add_parser("expand", help="Resolve a config with OmegaConf")
273+
p_expand.add_argument("config", help="Path to config YAML")
274+
p_expand.add_argument(
275+
"--in-place",
276+
action="store_true",
277+
dest="in_place",
278+
help="Edit file in place instead of printing",
279+
)
280+
p_expand.set_defaults(func=expand)
281+
282+
p_min = sub.add_parser(
283+
"minimize",
284+
help="Remove keys equal to base and ensure defaults reference base",
285+
)
286+
p_min.add_argument("base", help="Base config path")
287+
p_min.add_argument("config", help="Child config path")
288+
p_min.add_argument(
289+
"--in-place",
290+
action="store_true",
291+
dest="in_place",
292+
help="Edit file in place instead of printing",
293+
)
294+
p_min.set_defaults(func=minimize)
295+
296+
p_cmp = sub.add_parser(
297+
"compare", help="Compare two configs after expanding their defaults"
298+
)
299+
p_cmp.add_argument("left", help="Left config path")
300+
p_cmp.add_argument("right", help="Right config path")
301+
p_cmp.set_defaults(func=compare)
302+
303+
args = parser.parse_args()
304+
args.func(args)

0 commit comments

Comments
 (0)