Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

populate SUPPORTED_COMMANDS cli #2157

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
import subprocess
import sys
import unittest
from pathlib import Path
import os
import glob
from trl.commands.cli_utils import populate_supported_commands


@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
Expand Down Expand Up @@ -43,3 +47,22 @@ def test_dpo_cli():
def test_env_cli():
output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True)
assert "- Python version: " in output.stdout


def test_populate_supported_commands():
commands = populate_supported_commands()

# Check for specific commands
assert 'sft' in commands, "SFT command not found"
assert 'dpo' in commands, "DPO command not found"

# Check that all commands are strings and don't have .py extension
for cmd in commands:
assert isinstance(cmd, str), f"Command {cmd} is not a string"
assert not cmd.endswith('.py'), f"Command {cmd} should not have .py extension"

# Check that the number of commands matches the number of .py files in the scripts directory
trl_dir = Path(__file__).resolve().parent.parent
scripts_path = os.path.join(trl_dir, 'examples', 'scripts', '*.py')
py_files = glob.glob(scripts_path)
assert len(commands) == len(py_files), f"Number of commands ({len(commands)}) doesn't match number of .py files ({len(py_files)})"
5 changes: 2 additions & 3 deletions trl/commands/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@
is_liger_kernel_available,
is_llmblender_available,
)
from .cli_utils import get_git_commit_hash
from .cli_utils import get_git_commit_hash, populate_supported_commands


SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto", "env"]

SUPPORTED_COMMANDS = populate_supported_commands()

def print_env():
accelerate_config = accelerate_config_str = "not found"
Expand Down
15 changes: 15 additions & 0 deletions trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import inspect
import logging
import os
import glob
import subprocess
import sys
from pathlib import Path
from argparse import Namespace
from dataclasses import dataclass, field

Expand Down Expand Up @@ -305,3 +307,16 @@ def get_git_commit_hash(package_name):
return None
except Exception as e:
return f"Error: {str(e)}"


def populate_supported_commands():
# Path to the script examples directory
trl_dir = Path(__file__).resolve().parent.parent.parent
scripts_path = os.path.join(trl_dir, 'examples', 'scripts', '*.py')
# find all the scripts in the examples directory
trainer_files = glob.glob(scripts_path)

# Extract command names without the .py extension
commands = [os.path.basename(f).replace('.py', '') for f in trainer_files]

return commands