Skip to content

Commit fe27e5d

Browse files
committed
Improve CLI
1 parent 7df1097 commit fe27e5d

File tree

9 files changed

+270
-325
lines changed

9 files changed

+270
-325
lines changed

src/fairseq2/recipes/__init__.py

+4-31
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,12 @@ def main() -> None:
6161

6262

6363
def _run() -> int:
64-
from fairseq2 import __version__
64+
from fairseq2 import __version__, setup_fairseq2
6565

6666
setup_basic_logging()
6767

68+
setup_fairseq2()
69+
6870
cli = Cli(
6971
name="fairseq2",
7072
origin_module="fairseq2",
@@ -86,33 +88,4 @@ def _setup_cli(cli: Cli) -> None:
8688
_setup_wav2vec2_asr_cli(cli)
8789
_setup_hg_cli(cli)
8890

89-
run_extensions("setup_fairseq2_cli", cli)
90-
91-
_setup_legacy_extensions(cli)
92-
93-
94-
def _setup_legacy_extensions(cli: Cli) -> None:
95-
should_trace = "FAIRSEQ2_EXTENSION_TRACE" in os.environ
96-
97-
for entry_point in entry_points(group="fairseq2.cli"):
98-
try:
99-
extension = entry_point.load()
100-
101-
extension(cli)
102-
except TypeError:
103-
if should_trace:
104-
raise ExtensionError(
105-
entry_point.value, f"The '{entry_point.value}' entry point is not a valid CLI extension function." # fmt: skip
106-
) from None
107-
108-
log.warning("The '{}' entry point is not a valid CLI extension function. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip
109-
except Exception as ex:
110-
if should_trace:
111-
raise ExtensionError(
112-
entry_point.value, f"The '{entry_point.value}' CLI extension function has failed. See the nested exception for details." # fmt: skip
113-
) from ex
114-
115-
log.warning("The '{}' CLI extension function has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip
116-
117-
if should_trace:
118-
log.info("The `{}` CLI extension function run successfully.", entry_point.value) # fmt: skip
91+
run_extensions("fairseq2.cli", cli)

src/fairseq2/recipes/assets.py

+69-92
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,19 @@
66

77
from __future__ import annotations
88

9-
import sys
109
from argparse import ArgumentParser, Namespace
1110
from collections import defaultdict
12-
from typing import Any, cast, final
11+
from typing import final
1312

1413
from rich.console import Console
1514
from rich.pretty import pretty_repr
1615
from typing_extensions import override
1716

18-
from fairseq2.assets import (
19-
AssetCard,
20-
AssetNotFoundError,
21-
AssetStore,
22-
default_asset_store,
23-
)
24-
from fairseq2.datasets import is_dataset_card
17+
from fairseq2.assets import AssetCard, AssetCardNotFoundError, AssetStore
18+
from fairseq2.context import get_runtime_context
2519
from fairseq2.logging import get_log_writer
26-
from fairseq2.models import is_model_card
2720
from fairseq2.recipes.cli import Cli, CliCommandHandler
2821
from fairseq2.recipes.console import get_console
29-
from fairseq2.setup import setup_fairseq2
3022

3123
log = get_log_writer(__name__)
3224

@@ -38,31 +30,21 @@ def _setup_asset_cli(cli: Cli) -> None:
3830

3931
group.add_command(
4032
"list",
41-
ListAssetsCommand(),
33+
ListAssetsHandler(),
4234
help="list assets",
4335
)
4436

4537
group.add_command(
4638
"show",
47-
ShowAssetCommand(),
39+
ShowAssetHandler(),
4840
help="show asset",
4941
)
5042

5143

5244
@final
53-
class ListAssetsCommand(CliCommandHandler):
45+
class ListAssetsHandler(CliCommandHandler):
5446
"""Lists assets available in the current Python environment."""
5547

56-
_asset_store: AssetStore
57-
58-
def __init__(self, asset_store: AssetStore | None = None) -> None:
59-
"""
60-
:param asset_store:
61-
The asset store from which to retrieve the asset cards. If ``None``,
62-
the default asset store will be used.
63-
"""
64-
self._asset_store = asset_store or default_asset_store
65-
6648
@override
6749
def init_parser(self, parser: ArgumentParser) -> None:
6850
parser.add_argument(
@@ -73,89 +55,102 @@ def init_parser(self, parser: ArgumentParser) -> None:
7355
)
7456

7557
@override
76-
def run(self, args: Namespace) -> int:
77-
setup_fairseq2()
78-
79-
usr_assets = self._retrieve_assets(args, user=True)
80-
glb_assets = self._retrieve_assets(args, user=False)
58+
def run(self, parser: ArgumentParser, args: Namespace) -> int:
59+
context = get_runtime_context()
8160

8261
console = get_console()
8362

8463
console.print("[green bold]user:")
8564

86-
self._dump_assets(console, usr_assets)
65+
assets = self._retrieve_assets(context.asset_store, args.type, user=True)
66+
67+
self._dump_assets(console, assets)
8768

8869
console.print("[green bold]global:")
8970

90-
self._dump_assets(console, glb_assets)
71+
assets = self._retrieve_assets(context.asset_store, args.type, user=False)
72+
73+
self._dump_assets(console, assets)
9174

9275
return 0
9376

77+
@classmethod
9478
def _retrieve_assets(
95-
self, args: Namespace, user: bool
79+
cls, asset_store: AssetStore, asset_type: str, user: bool
9680
) -> list[tuple[str, list[str]]]:
9781
assets: dict[str, list[str]] = defaultdict(list)
9882

99-
names = self._asset_store.retrieve_names(scope="user" if user else "global")
83+
asset_names = asset_store.retrieve_names(scope="user" if user else "global")
10084

101-
for name in names:
85+
for asset_name in asset_names:
10286
try:
103-
card = self._asset_store.retrieve_card(
104-
name, scope="all" if user else "global"
87+
card = asset_store.retrieve_card(
88+
asset_name, scope="all" if user else "global"
10589
)
106-
except AssetNotFoundError:
107-
log.warning("The asset '{}' has an invalid card. Skipping.", name)
90+
except AssetCardNotFoundError:
91+
log.warning("The '{}' asset card is not valid. Skipping.", asset_name)
10892

10993
continue
11094

111-
if name[-1] == "@":
112-
name = name[:-1]
95+
if asset_name[-1] == "@":
96+
asset_name = asset_name[:-1]
11397

114-
try:
115-
source = cast(str, card.metadata["__source__"])
116-
except KeyError:
98+
source = card.metadata.get("__source__", "unknown source")
99+
if not isinstance(source, str):
117100
source = "unknown source"
118101

119-
types = []
102+
asset_types = []
120103

121-
if args.type == "all" or args.type == "model":
122-
if is_model_card(card):
123-
types.append("model")
104+
if asset_type == "all" or asset_type == "model":
105+
if cls._is_model_card(card):
106+
asset_types.append("model")
124107

125-
if args.type == "all" or args.type == "dataset":
126-
if is_dataset_card(card):
127-
types.append("dataset")
108+
if asset_type == "all" or asset_type == "dataset":
109+
if cls._is_dataset_card(card):
110+
asset_types.append("dataset")
128111

129-
if args.type == "all" or args.type == "tokenizer":
130-
if self._is_tokenizer_card(card):
131-
types.append("tokenizer")
112+
if asset_type == "all" or asset_type == "tokenizer":
113+
if cls._is_tokenizer_card(card):
114+
asset_types.append("tokenizer")
132115

133-
if args.type == "all" and not types:
134-
types.append("other")
116+
if asset_type == "all" and not asset_types:
117+
asset_types.append("other")
135118

136-
if not types:
119+
if not asset_types:
137120
continue
138121

139122
source_assets = assets[source]
140123

141-
for t in types:
142-
source_assets.append(f"{t}:{name}")
124+
for t in asset_types:
125+
source_assets.append(f"{t}:{asset_name}")
143126

144-
return [(source, names) for source, names in assets.items()]
127+
output = []
128+
129+
for source, asset_names in assets.items():
130+
asset_names.sort()
131+
132+
output.append((source, asset_names))
133+
134+
output.sort(key=lambda e: e[0]) # sort by source
135+
136+
return output
137+
138+
@staticmethod
139+
def _is_model_card(card: AssetCard) -> bool:
140+
return card.field("model_family").exists()
145141

146142
@staticmethod
147143
def _is_tokenizer_card(card: AssetCard) -> bool:
148144
return card.field("tokenizer_family").exists()
149145

150-
def _dump_assets(
151-
self, console: Console, assets: list[tuple[str, list[str]]]
152-
) -> None:
153-
if assets:
154-
assets.sort(key=lambda a: a[0]) # sort by source.
146+
@staticmethod
147+
def _is_dataset_card(card: AssetCard) -> bool:
148+
return card.field("dataset_family").exists()
155149

150+
@staticmethod
151+
def _dump_assets(console: Console, assets: list[tuple[str, list[str]]]) -> None:
152+
if assets:
156153
for source, names in assets:
157-
names.sort(key=lambda n: n[0]) # sort by name.
158-
159154
console.print(f" [blue bold]{source}")
160155

161156
for idx, name in enumerate(names):
@@ -167,21 +162,9 @@ def _dump_assets(
167162
console.print()
168163

169164

170-
class ShowAssetCommand(CliCommandHandler):
165+
class ShowAssetHandler(CliCommandHandler):
171166
"""Shows the metadata of an asset."""
172167

173-
_asset_store: AssetStore
174-
175-
def __init__(self, asset_store: AssetStore | None = None) -> None:
176-
"""
177-
:param asset_store:
178-
The asset store from which to retrieve the asset cards. If ``None``,
179-
the default asset store will be used.
180-
"""
181-
setup_fairseq2()
182-
183-
self._asset_store = asset_store or default_asset_store
184-
185168
@override
186169
def init_parser(self, parser: ArgumentParser) -> None:
187170
parser.add_argument(
@@ -202,15 +185,12 @@ def init_parser(self, parser: ArgumentParser) -> None:
202185
parser.add_argument("name", help="name of the asset")
203186

204187
@override
205-
def run(self, args: Namespace) -> int:
206-
try:
207-
card: AssetCard | None = self._asset_store.retrieve_card(
208-
args.name, envs=args.envs, scope=args.scope
209-
)
210-
except AssetNotFoundError:
211-
log.error("An asset with the name '{}' cannot be found.", args.name)
188+
def run(self, parser: ArgumentParser, args: Namespace) -> int:
189+
context = get_runtime_context()
212190

213-
sys.exit(1)
191+
card: AssetCard | None = context.asset_store.retrieve_card(
192+
args.name, envs=args.envs, scope=args.scope
193+
)
214194

215195
while card is not None:
216196
self._print_metadata(dict(card.metadata))
@@ -219,17 +199,14 @@ def run(self, args: Namespace) -> int:
219199

220200
return 0
221201

222-
def _print_metadata(self, metadata: dict[str, Any]) -> None:
202+
def _print_metadata(self, metadata: dict[str, object]) -> None:
223203
console = get_console()
224204

225205
name = metadata.pop("name")
226206

227207
console.print(f"[green bold]{name}")
228208

229-
try:
230-
source = metadata.pop("__source__")
231-
except KeyError:
232-
source = "unknown"
209+
source = metadata.pop("__source__", "unknown")
233210

234211
items = list(metadata.items())
235212

0 commit comments

Comments
 (0)