@@ -332,25 +332,60 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact
332332
333333func (l * ImportLauncher ) handleHuggingFaceImport (ctx context.Context , artifactURI string , artifact * pb.Artifact ) error {
334334 parts := strings .TrimPrefix (artifactURI , "huggingface://" )
335- pathParts := strings .Split (parts , "/" )
336335
337- if len (pathParts ) < 2 {
338- return fmt .Errorf ("invalid HuggingFace URI format: %q, expected huggingface://repo_id/revision" , artifactURI )
336+ if parts == artifactURI {
337+ return fmt .Errorf ("invalid artifact URI: %q\n " +
338+ "For HuggingFace Hub models and datasets, use the 'huggingface://' URI scheme.\n " +
339+ "Examples:\n " +
340+ " huggingface://gpt2\n " +
341+ " huggingface://meta-llama/Llama-2-7b\n " +
342+ " huggingface://wikitext?repo_type=dataset" ,
343+ artifactURI )
344+ }
345+
346+ var queryStr string
347+ if idx := strings .Index (parts , "?" ); idx != - 1 {
348+ queryStr = parts [idx + 1 :]
349+ parts = parts [:idx ]
339350 }
340351
352+ pathParts := strings .Split (parts , "/" )
353+ if len (pathParts ) < 1 {
354+ return fmt .Errorf ("invalid HuggingFace URI format: %q, expected huggingface://repo_id[/revision]" , artifactURI )
355+ }
356+
357+ repoID := strings .Join (pathParts , "/" )
341358 revision := "main"
342- if len (pathParts ) > 2 {
343- revision = pathParts [len (pathParts )- 1 ]
344- repoID := strings .Join (pathParts [:len (pathParts )- 1 ], "/" )
345- pathParts = []string {repoID , revision }
346- } else {
347- pathParts = []string {parts , revision }
359+
360+ if len (pathParts ) >= 2 && pathParts [len (pathParts )- 1 ] != "" {
361+ lastPart := pathParts [len (pathParts )- 1 ]
362+ if ! strings .Contains (lastPart , "." ) {
363+ revision = lastPart
364+ repoID = strings .Join (pathParts [:len (pathParts )- 1 ], "/" )
365+ }
348366 }
349367
350- repoID := pathParts [0 ]
351- revision = pathParts [1 ]
368+ repoType := "model"
369+ allowPatterns := ""
370+ ignorePatterns := ""
371+ if queryStr != "" {
372+ params := strings .Split (queryStr , "&" )
373+ for _ , param := range params {
374+ kv := strings .Split (param , "=" )
375+ if len (kv ) == 2 {
376+ switch kv [0 ] {
377+ case "repo_type" :
378+ repoType = kv [1 ]
379+ case "allow_patterns" :
380+ allowPatterns = kv [1 ]
381+ case "ignore_patterns" :
382+ ignorePatterns = kv [1 ]
383+ }
384+ }
385+ }
386+ }
352387
353- glog .Infof ("Downloading HuggingFace model repo_id=%q revision=%q" , repoID , revision )
388+ glog .Infof ("Downloading HuggingFace repo_id=%q revision=%q repo_type=%q allow_patterns=%q ignore_patterns=%q " , repoID , revision , repoType , allowPatterns , ignorePatterns )
354389
355390 storeSessionInfo := objectstore.SessionInfo {
356391 Provider : "huggingface" ,
@@ -367,17 +402,29 @@ func (l *ImportLauncher) handleHuggingFaceImport(ctx context.Context, artifactUR
367402
368403 artifact .CustomProperties ["hf_repo_id" ] = metadata .StringValue (repoID )
369404 artifact .CustomProperties ["hf_revision" ] = metadata .StringValue (revision )
405+ artifact .CustomProperties ["hf_repo_type" ] = metadata .StringValue (repoType )
406+ if allowPatterns != "" {
407+ artifact .CustomProperties ["hf_allow_patterns" ] = metadata .StringValue (allowPatterns )
408+ }
409+ if ignorePatterns != "" {
410+ artifact .CustomProperties ["hf_ignore_patterns" ] = metadata .StringValue (ignorePatterns )
411+ }
370412
371413 if l .importer .GetDownloadToWorkspace () {
372- if err := l .downloadHuggingFaceModel (ctx , repoID , revision , artifactURI ); err != nil {
414+ if err := l .downloadHuggingFaceModel (ctx , repoID , revision , repoType , allowPatterns , ignorePatterns ); err != nil {
373415 return fmt .Errorf ("failed to download HuggingFace model: %w" , err )
374416 }
375417 }
376418
377419 return nil
378420}
379421
380- func (l * ImportLauncher ) downloadHuggingFaceModel (ctx context.Context , repoID , revision , artifactURI string ) error {
422+ func (l * ImportLauncher ) downloadHuggingFaceModel (ctx context.Context , repoID , revision , repoType , allowPatterns , ignorePatterns string ) error {
423+ artifactURI := fmt .Sprintf ("huggingface://%s" , repoID )
424+ if revision != "main" {
425+ artifactURI = fmt .Sprintf ("huggingface://%s/%s" , repoID , revision )
426+ }
427+
381428 bucketConfig , err := l .resolveBucketConfigForURI (ctx , artifactURI )
382429 if err != nil {
383430 return err
@@ -393,31 +440,83 @@ func (l *ImportLauncher) downloadHuggingFaceModel(ctx context.Context, repoID, r
393440 hfToken , _ = objectstore .GetHuggingFaceTokenFromSessionInfo (ctx , l .k8sClient , l .launcherV2Options .Namespace , bucketConfig .SessionInfo )
394441 }
395442
396- glog .Infof ("Downloading HuggingFace model %q (revision=%q) to workspace path %q" , repoID , revision , localPath )
397-
398443 if err := os .MkdirAll (localPath , 0755 ); err != nil {
399444 return fmt .Errorf ("failed to create directory %q: %w" , localPath , err )
400445 }
401446
402- cmd := exec .CommandContext (ctx , "python3" , "-c" , fmt .Sprintf (`
403- from huggingface_hub import snapshot_download
447+ isSpecificFile := strings .Contains (repoID , "/" ) && strings .Contains (repoID [strings .LastIndex (repoID , "/" )+ 1 :], "." )
448+
449+ if isSpecificFile {
450+ lastSlash := strings .LastIndex (repoID , "/" )
451+ filename := repoID [lastSlash + 1 :]
452+ actualRepoID := repoID [:lastSlash ]
453+
454+ glog .Infof ("Downloading specific file from HuggingFace repo_id=%q revision=%q filename=%q repo_type=%q to %q" , actualRepoID , revision , filename , repoType , localPath )
455+
456+ pythonScript := fmt .Sprintf (`
457+ from huggingface_hub import hf_hub_download
458+ import inspect
404459import os
405460repo_id = %q
406461revision = %q
407- cache_dir = %q
462+ filename = %q
463+ repo_type = %q
464+ local_dir = %q
408465token = os.environ.get('HF_TOKEN', None)
409- snapshot_download(repo_id=repo_id, revision=revision, cache_dir=cache_dir, token=token, local_dir=%q)
410- ` , repoID , revision , localPath , localPath ))
411-
412- if hfToken != "" {
413- cmd .Env = append (os .Environ (), fmt .Sprintf ("HF_TOKEN=%s" , hfToken ))
414- }
466+ sig = inspect.signature(hf_hub_download)
467+ kwargs = {}
468+ for param in ['repo_id', 'filename', 'revision', 'repo_type', 'local_dir', 'token']:
469+ if param in sig.parameters:
470+ kwargs[param] = locals()[param]
471+ hf_hub_download(**kwargs)
472+ ` , actualRepoID , revision , filename , repoType , localPath )
473+
474+ cmd := exec .CommandContext (ctx , "python3" , "-c" , pythonScript )
475+ if hfToken != "" {
476+ cmd .Env = append (os .Environ (), fmt .Sprintf ("HF_TOKEN=%s" , hfToken ))
477+ }
478+ if output , err := cmd .CombinedOutput (); err != nil {
479+ return fmt .Errorf ("failed to download HuggingFace file: %w, output: %s" , err , string (output ))
480+ }
481+ glog .Infof ("Successfully downloaded HuggingFace file to %q" , localPath )
482+ } else {
483+ glog .Infof ("Downloading HuggingFace repo_id=%q revision=%q repo_type=%q to %q" , repoID , revision , repoType , localPath )
415484
416- if output , err := cmd .CombinedOutput (); err != nil {
417- return fmt .Errorf ("failed to download HuggingFace model: %w, output: %s" , err , string (output ))
485+ pythonScript := fmt .Sprintf (`
486+ from huggingface_hub import snapshot_download
487+ import inspect
488+ import os
489+ repo_id = %q
490+ revision = %q
491+ repo_type = %q
492+ local_dir = %q
493+ allow_patterns = %q
494+ ignore_patterns = %q
495+ token = os.environ.get('HF_TOKEN', None)
496+ # Build kwargs dynamically based on what parameters the function accepts
497+ # This ensures compatibility with current and future versions of huggingface-hub
498+ sig = inspect.signature(snapshot_download)
499+ kwargs = {}
500+ base_params = ['repo_id', 'revision', 'repo_type', 'local_dir', 'token']
501+ optional_params = {'allow_patterns': allow_patterns, 'ignore_patterns': ignore_patterns}
502+ for param in base_params:
503+ if param in sig.parameters:
504+ kwargs[param] = locals()[param]
505+ for param, value in optional_params.items():
506+ if value and param in sig.parameters:
507+ kwargs[param] = value
508+ snapshot_download(**kwargs)
509+ ` , repoID , revision , repoType , localPath , allowPatterns , ignorePatterns )
510+
511+ cmd := exec .CommandContext (ctx , "python3" , "-c" , pythonScript )
512+ if hfToken != "" {
513+ cmd .Env = append (os .Environ (), fmt .Sprintf ("HF_TOKEN=%s" , hfToken ))
514+ }
515+ if output , err := cmd .CombinedOutput (); err != nil {
516+ return fmt .Errorf ("failed to download HuggingFace model: %w, output: %s" , err , string (output ))
517+ }
518+ glog .Infof ("Successfully downloaded HuggingFace repo to %q" , localPath )
418519 }
419-
420- glog .Infof ("Successfully downloaded HuggingFace model to %q" , localPath )
421520 return nil
422521}
423522
0 commit comments