From ad203b2597a7154b160e2e8771da194d33d085cb Mon Sep 17 00:00:00 2001 From: grumpyp Date: Wed, 2 Oct 2024 12:53:56 +0200 Subject: [PATCH 1/2] populate SUPPORTED_COMMANDS cli --- tests/test_cli.py | 23 +++++++++++++++++++++++ trl/commands/cli.py | 5 ++--- trl/commands/cli_utils.py | 17 +++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index b1b9b8aa67..be5b2ab321 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,6 +14,10 @@ import subprocess import sys import unittest +from pathlib import Path +import os +import glob +from trl.commands.cli_utils import populate_supported_commands @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") @@ -43,3 +47,22 @@ def test_dpo_cli(): def test_env_cli(): output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True) assert "- Python version: " in output.stdout + + +def test_populate_supported_commands(): + commands = populate_supported_commands() + + # Check for specific commands + assert 'sft' in commands, "SFT command not found" + assert 'dpo' in commands, "DPO command not found" + + # Check that all commands are strings and don't have .py extension + for cmd in commands: + assert isinstance(cmd, str), f"Command {cmd} is not a string" + assert not cmd.endswith('.py'), f"Command {cmd} should not have .py extension" + + # Check that the number of commands matches the number of .py files in the scripts directory + trl_dir = Path(__file__).resolve().parent.parent + scripts_path = os.path.join(trl_dir, 'examples', 'scripts', '*.py') + py_files = glob.glob(scripts_path) + assert len(commands) == len(py_files), f"Number of commands ({len(commands)}) doesn't match number of .py files ({len(py_files)})" \ No newline at end of file diff --git a/trl/commands/cli.py b/trl/commands/cli.py index 3a9f8f83a3..4bf9b82853 100644 --- a/trl/commands/cli.py +++ b/trl/commands/cli.py @@ -31,11 +31,10 @@ is_liger_kernel_available, is_llmblender_available, ) -from .cli_utils import get_git_commit_hash +from .cli_utils import get_git_commit_hash, populate_supported_commands -SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto", "env"] - +SUPPORTED_COMMANDS = populate_supported_commands() def print_env(): accelerate_config = accelerate_config_str = "not found" diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index 44918af961..ac3dd038bf 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -17,8 +17,11 @@ import inspect import logging import os +import glob +import functools import subprocess import sys +from pathlib import Path from argparse import Namespace from dataclasses import dataclass, field @@ -305,3 +308,17 @@ def get_git_commit_hash(package_name): return None except Exception as e: return f"Error: {str(e)}" + + +@functools.cache +def populate_supported_commands(): + # Path to the script examples directory + trl_dir = Path(__file__).resolve().parent.parent.parent + scripts_path = os.path.join(trl_dir, 'examples', 'scripts', '*.py') + # find all the scripts in the examples directory + trainer_files = glob.glob(scripts_path) + + # Extract command names without the .py extension + commands = [os.path.basename(f).replace('.py', '') for f in trainer_files] + + return commands \ No newline at end of file From c21134ac02731aad93800413dc2d3e742dc38a80 Mon Sep 17 00:00:00 2001 From: grumpyp Date: Wed, 2 Oct 2024 13:42:58 +0200 Subject: [PATCH 2/2] delete caching --- trl/commands/cli_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index ac3dd038bf..7751e75b3c 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -18,7 +18,6 @@ import logging import os import glob -import functools import subprocess import sys from pathlib import Path @@ -310,7 +309,6 @@ def get_git_commit_hash(package_name): return f"Error: {str(e)}" -@functools.cache def populate_supported_commands(): # Path to the script examples directory trl_dir = Path(__file__).resolve().parent.parent.parent