diff --git a/launch.py b/launch.py index e16994f..76cea67 100644 --- a/launch.py +++ b/launch.py @@ -94,6 +94,11 @@ def writeRunFile(modelName: str, command: str): def printUsage(): print('Usage: python download-model.py ') + print() + print('Options:') + print(' The name of the model to download') + print(' --run Run the model after download') + print() print('Available models:') for model in MODELS: print(f' {model}') @@ -109,6 +114,7 @@ def printUsage(): if modelName not in MODELS: print(f'Model is not supported: {modelName}') exit(1) + runAfterDownload = sys.argv.count('--run') > 0 model = MODELS[modelName] (modelPath, tokenizerPath) = download(modelName, model) @@ -123,15 +129,16 @@ def printUsage(): print('To run Distributed Llama you need to execute:') print('--- copy start ---') print() - print(command) + print('\033[96m' + command + '\033[0m') print() print('--- copy end -----') runFilePath = writeRunFile(modelName, command) print(f'🌻 Created {runFilePath} script to easy run') - result = input('❓ Do you want to run Distributed Llama? ("Y" if yes): ') - if (result.upper() == 'Y'): + if (not runAfterDownload): + runAfterDownload = input('❓ Do you want to run Distributed Llama? ("Y" if yes): ').lower() == 'Y' + if (runAfterDownload): if (not os.path.isfile('dllama')): os.system('make dllama') os.system(command)