6
6
7
7
from __future__ import annotations
8
8
9
- import sys
10
9
from argparse import ArgumentParser , Namespace
11
10
from collections import defaultdict
12
- from typing import Any , cast , final
11
+ from typing import final
13
12
14
13
from rich .console import Console
15
14
from rich .pretty import pretty_repr
16
15
from typing_extensions import override
17
16
18
- from fairseq2 .assets import (
19
- AssetCard ,
20
- AssetNotFoundError ,
21
- AssetStore ,
22
- default_asset_store ,
23
- )
24
- from fairseq2 .datasets import is_dataset_card
17
+ from fairseq2 .assets import AssetCard , AssetCardNotFoundError , AssetStore
18
+ from fairseq2 .context import get_runtime_context
25
19
from fairseq2 .logging import get_log_writer
26
- from fairseq2 .models import is_model_card
27
20
from fairseq2 .recipes .cli import Cli , CliCommandHandler
28
21
from fairseq2 .recipes .console import get_console
29
- from fairseq2 .setup import setup_fairseq2
30
22
31
23
log = get_log_writer (__name__ )
32
24
@@ -38,31 +30,21 @@ def _setup_asset_cli(cli: Cli) -> None:
38
30
39
31
group .add_command (
40
32
"list" ,
41
- ListAssetsCommand (),
33
+ ListAssetsHandler (),
42
34
help = "list assets" ,
43
35
)
44
36
45
37
group .add_command (
46
38
"show" ,
47
- ShowAssetCommand (),
39
+ ShowAssetHandler (),
48
40
help = "show asset" ,
49
41
)
50
42
51
43
52
44
@final
53
- class ListAssetsCommand (CliCommandHandler ):
45
+ class ListAssetsHandler (CliCommandHandler ):
54
46
"""Lists assets available in the current Python environment."""
55
47
56
- _asset_store : AssetStore
57
-
58
- def __init__ (self , asset_store : AssetStore | None = None ) -> None :
59
- """
60
- :param asset_store:
61
- The asset store from which to retrieve the asset cards. If ``None``,
62
- the default asset store will be used.
63
- """
64
- self ._asset_store = asset_store or default_asset_store
65
-
66
48
@override
67
49
def init_parser (self , parser : ArgumentParser ) -> None :
68
50
parser .add_argument (
@@ -73,89 +55,102 @@ def init_parser(self, parser: ArgumentParser) -> None:
73
55
)
74
56
75
57
@override
76
- def run (self , args : Namespace ) -> int :
77
- setup_fairseq2 ()
78
-
79
- usr_assets = self ._retrieve_assets (args , user = True )
80
- glb_assets = self ._retrieve_assets (args , user = False )
58
+ def run (self , parser : ArgumentParser , args : Namespace ) -> int :
59
+ context = get_runtime_context ()
81
60
82
61
console = get_console ()
83
62
84
63
console .print ("[green bold]user:" )
85
64
86
- self ._dump_assets (console , usr_assets )
65
+ assets = self ._retrieve_assets (context .asset_store , args .type , user = True )
66
+
67
+ self ._dump_assets (console , assets )
87
68
88
69
console .print ("[green bold]global:" )
89
70
90
- self ._dump_assets (console , glb_assets )
71
+ assets = self ._retrieve_assets (context .asset_store , args .type , user = False )
72
+
73
+ self ._dump_assets (console , assets )
91
74
92
75
return 0
93
76
77
+ @classmethod
94
78
def _retrieve_assets (
95
- self , args : Namespace , user : bool
79
+ cls , asset_store : AssetStore , asset_type : str , user : bool
96
80
) -> list [tuple [str , list [str ]]]:
97
81
assets : dict [str , list [str ]] = defaultdict (list )
98
82
99
- names = self . _asset_store .retrieve_names (scope = "user" if user else "global" )
83
+ asset_names = asset_store .retrieve_names (scope = "user" if user else "global" )
100
84
101
- for name in names :
85
+ for asset_name in asset_names :
102
86
try :
103
- card = self . _asset_store .retrieve_card (
104
- name , scope = "all" if user else "global"
87
+ card = asset_store .retrieve_card (
88
+ asset_name , scope = "all" if user else "global"
105
89
)
106
- except AssetNotFoundError :
107
- log .warning ("The asset '{}' has an invalid card . Skipping." , name )
90
+ except AssetCardNotFoundError :
91
+ log .warning ("The '{}' asset card is not valid . Skipping." , asset_name )
108
92
109
93
continue
110
94
111
- if name [- 1 ] == "@" :
112
- name = name [:- 1 ]
95
+ if asset_name [- 1 ] == "@" :
96
+ asset_name = asset_name [:- 1 ]
113
97
114
- try :
115
- source = cast (str , card .metadata ["__source__" ])
116
- except KeyError :
98
+ source = card .metadata .get ("__source__" , "unknown source" )
99
+ if not isinstance (source , str ):
117
100
source = "unknown source"
118
101
119
- types = []
102
+ asset_types = []
120
103
121
- if args . type == "all" or args . type == "model" :
122
- if is_model_card (card ):
123
- types .append ("model" )
104
+ if asset_type == "all" or asset_type == "model" :
105
+ if cls . _is_model_card (card ):
106
+ asset_types .append ("model" )
124
107
125
- if args . type == "all" or args . type == "dataset" :
126
- if is_dataset_card (card ):
127
- types .append ("dataset" )
108
+ if asset_type == "all" or asset_type == "dataset" :
109
+ if cls . _is_dataset_card (card ):
110
+ asset_types .append ("dataset" )
128
111
129
- if args . type == "all" or args . type == "tokenizer" :
130
- if self ._is_tokenizer_card (card ):
131
- types .append ("tokenizer" )
112
+ if asset_type == "all" or asset_type == "tokenizer" :
113
+ if cls ._is_tokenizer_card (card ):
114
+ asset_types .append ("tokenizer" )
132
115
133
- if args . type == "all" and not types :
134
- types .append ("other" )
116
+ if asset_type == "all" and not asset_types :
117
+ asset_types .append ("other" )
135
118
136
- if not types :
119
+ if not asset_types :
137
120
continue
138
121
139
122
source_assets = assets [source ]
140
123
141
- for t in types :
142
- source_assets .append (f"{ t } :{ name } " )
124
+ for t in asset_types :
125
+ source_assets .append (f"{ t } :{ asset_name } " )
143
126
144
- return [(source , names ) for source , names in assets .items ()]
127
+ output = []
128
+
129
+ for source , asset_names in assets .items ():
130
+ asset_names .sort ()
131
+
132
+ output .append ((source , asset_names ))
133
+
134
+ output .sort (key = lambda e : e [0 ]) # sort by source
135
+
136
+ return output
137
+
138
+ @staticmethod
139
+ def _is_model_card (card : AssetCard ) -> bool :
140
+ return card .field ("model_family" ).exists ()
145
141
146
142
@staticmethod
147
143
def _is_tokenizer_card (card : AssetCard ) -> bool :
148
144
return card .field ("tokenizer_family" ).exists ()
149
145
150
- def _dump_assets (
151
- self , console : Console , assets : list [tuple [str , list [str ]]]
152
- ) -> None :
153
- if assets :
154
- assets .sort (key = lambda a : a [0 ]) # sort by source.
146
+ @staticmethod
147
+ def _is_dataset_card (card : AssetCard ) -> bool :
148
+ return card .field ("dataset_family" ).exists ()
155
149
150
+ @staticmethod
151
+ def _dump_assets (console : Console , assets : list [tuple [str , list [str ]]]) -> None :
152
+ if assets :
156
153
for source , names in assets :
157
- names .sort (key = lambda n : n [0 ]) # sort by name.
158
-
159
154
console .print (f" [blue bold]{ source } " )
160
155
161
156
for idx , name in enumerate (names ):
@@ -167,21 +162,9 @@ def _dump_assets(
167
162
console .print ()
168
163
169
164
170
- class ShowAssetCommand (CliCommandHandler ):
165
+ class ShowAssetHandler (CliCommandHandler ):
171
166
"""Shows the metadata of an asset."""
172
167
173
- _asset_store : AssetStore
174
-
175
- def __init__ (self , asset_store : AssetStore | None = None ) -> None :
176
- """
177
- :param asset_store:
178
- The asset store from which to retrieve the asset cards. If ``None``,
179
- the default asset store will be used.
180
- """
181
- setup_fairseq2 ()
182
-
183
- self ._asset_store = asset_store or default_asset_store
184
-
185
168
@override
186
169
def init_parser (self , parser : ArgumentParser ) -> None :
187
170
parser .add_argument (
@@ -202,15 +185,12 @@ def init_parser(self, parser: ArgumentParser) -> None:
202
185
parser .add_argument ("name" , help = "name of the asset" )
203
186
204
187
@override
205
- def run (self , args : Namespace ) -> int :
206
- try :
207
- card : AssetCard | None = self ._asset_store .retrieve_card (
208
- args .name , envs = args .envs , scope = args .scope
209
- )
210
- except AssetNotFoundError :
211
- log .error ("An asset with the name '{}' cannot be found." , args .name )
188
+ def run (self , parser : ArgumentParser , args : Namespace ) -> int :
189
+ context = get_runtime_context ()
212
190
213
- sys .exit (1 )
191
+ card : AssetCard | None = context .asset_store .retrieve_card (
192
+ args .name , envs = args .envs , scope = args .scope
193
+ )
214
194
215
195
while card is not None :
216
196
self ._print_metadata (dict (card .metadata ))
@@ -219,17 +199,14 @@ def run(self, args: Namespace) -> int:
219
199
220
200
return 0
221
201
222
- def _print_metadata (self , metadata : dict [str , Any ]) -> None :
202
+ def _print_metadata (self , metadata : dict [str , object ]) -> None :
223
203
console = get_console ()
224
204
225
205
name = metadata .pop ("name" )
226
206
227
207
console .print (f"[green bold]{ name } " )
228
208
229
- try :
230
- source = metadata .pop ("__source__" )
231
- except KeyError :
232
- source = "unknown"
209
+ source = metadata .pop ("__source__" , "unknown" )
233
210
234
211
items = list (metadata .items ())
235
212
0 commit comments