diff --git a/docker/Dockerfile b/docker/Dockerfile index 3eb864ecd6..2b04b1bd9b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -114,7 +114,8 @@ fi # The venv is symlinked to avoid bloating the layer size uv sync --link-mode symlink --locked --no-install-project uv sync --link-mode symlink --locked --extra vllm --no-install-project -uv sync --link-mode symlink --locked --extra sglang --no-install-project +# Skipped: sgl-kernel build broken (DeepGEMM ref 54f99a8 missing upstream) +# uv sync --link-mode symlink --locked --extra sglang --no-install-project uv sync --link-mode symlink --locked --extra mcore --no-install-project uv sync --link-mode symlink --locked --extra automodel --no-install-project uv sync --link-mode symlink --locked --all-groups --no-install-project diff --git a/docs/repro-spec-decode-build.md b/docs/repro-spec-decode-build.md new file mode 100644 index 0000000000..a043bdf7f4 --- /dev/null +++ b/docs/repro-spec-decode-build.md @@ -0,0 +1,264 @@ +# Reproduction Report: NeMo-RL v0.5.0 + vLLM Speculative Decoding with Draft Models + +**Date:** 2026-02-10 +**Author:** Shaunak J +**Host:** viking-prod-228 (8x NVIDIA H20, computelab-sc-01 cluster) + +## Overview + +This report documents how to build a custom NeMo-RL Docker image that includes vLLM speculative decoding with draft models (vLLM PR #24322), and how to run GRPO training with it. + +## Source Components + +| Component | Source | Commit / Version | How Obtained | +|-----------|--------|-------------------|--------------| +| NeMo-RL | `https://github.com/NVIDIA-NeMo/RL.git` | `946862e7` (HEAD of `main` on 2026-02-10) | Already cloned locally | +| vLLM (custom) | `https://github.com/vllm-project/vllm.git` | `4a5299c93ff97c26def537b92562df5ada530fea` | Merge commit of PR #24322 | +| Precompiled vLLM wheel | `https://wheels.vllm.ai` | Same commit as above | Auto-downloaded during build | +| Base Docker image | `nvcr.io/nvidia/cuda-dl-base:25.05-cuda12.9-devel-ubuntu24.04` | Pulled by Dockerfile | From NGC registry | + +## Commit Hash Provenance + +### `946862e7` — NeMo-RL + +The latest commit on the NeMo-RL `main` branch at the time of this build: + +``` +946862e7 docs: add release runs to front page readme for 0.5 (#1879) +``` + +This is the same codebase as the `nvcr.io/nvidia/nemo-rl:v0.5.0` release. + +### `4a5299c93ff97c26def537b92562df5ada530fea` — vLLM + +The merge commit of [vllm-project/vllm#24322](https://github.com/vllm-project/vllm/pull/24322) ("feat: spec decode with draft models"), merged on 2026-01-19. Retrieved via: + +```bash +gh pr view 24322 --repo vllm-project/vllm --json mergeCommit +# => 4a5299c93ff97c26def537b92562df5ada530fea +``` + +The precompiled wheel for this commit exists at: +``` +https://wheels.vllm.ai/4a5299c93ff97c26def537b92562df5ada530fea/vllm-0.14.0rc2.dev156%2Bg4a5299c93-cp38-abi3-manylinux_2_31_x86_64.whl +``` + +This was verified by HTTP HEAD request returning 200. + +### Replaced defaults in `tools/build-custom-vllm.sh` + +The build script had two hardcoded defaults that were written by the NeMo-RL team for their v0.10-era vLLM pin: + +- `GIT_REF` default: `cc99baf14dacc2497d0c5ed84e076ef2c37f6a4d` — the previous vLLM commit pin +- `VLLM_WHEEL_COMMIT` default: `862f2ef893d9751db0a92bd2d4ae0e3d9677872f` — chosen as `git merge-base --fork-point origin/main tags/v0.10.0` + +Both were replaced with the PR #24322 merge commit. The wheel URL template was also updated because vLLM changed their naming convention from `vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl` to `vllm--cp38-abi3-manylinux_2_31_x86_64.whl`. + +## Local Modifications + +### 1. `tools/build-custom-vllm.sh` + +**Line 24** — Changed default `GIT_REF`: +```diff +-GIT_REF=${2:-cc99baf14dacc2497d0c5ed84e076ef2c37f6a4d} ++GIT_REF=${2:-4a5299c93ff97c26def537b92562df5ada530fea} +``` + +**Lines 27-28** — Changed default wheel commit and URL pattern, made env var overridable: +```diff +-VLLM_WHEEL_COMMIT=${3:-862f2ef893d9751db0a92bd2d4ae0e3d9677872f} # use full commit hash from the main branch +-export VLLM_PRECOMPILED_WHEEL_LOCATION="https://wheels.vllm.ai/${VLLM_WHEEL_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" ++VLLM_WHEEL_COMMIT=${3:-4a5299c93ff97c26def537b92562df5ada530fea} # merge commit of vllm PR #24322 (spec decode with draft models) ++export VLLM_PRECOMPILED_WHEEL_LOCATION="${VLLM_PRECOMPILED_WHEEL_LOCATION:-https://wheels.vllm.ai/${VLLM_WHEEL_COMMIT}/vllm-0.14.0rc2.dev156%2Bg4a5299c93-cp38-abi3-manylinux_2_31_x86_64.whl}" +``` + +### 2. `docker/Dockerfile` + +**Line 117** — Commented out sglang extra (broken upstream dependency): +```diff +-uv sync --link-mode symlink --locked --extra sglang --no-install-project ++# Skipped: sgl-kernel build broken (DeepGEMM ref 54f99a8 missing upstream) ++# uv sync --link-mode symlink --locked --extra sglang --no-install-project +``` + +**Reason:** The `sgl-kernel` build fails because it tries to fetch DeepGEMM at commit `54f99a8af537b3c6eb4819b69907ccbe2b600792` which no longer exists in the upstream repo. This only affects the sglang inference backend, not vLLM. + +## Step-by-Step Reproduction + +### Prerequisites + +- Docker with buildx support +- GPU node with NVIDIA drivers +- ~100 GB disk space for the build +- Internet access (to pull base image, clone vLLM, download wheel, download HF models) + +### 1. Clone and prepare the source + +```bash +git clone https://github.com/NVIDIA-NeMo/RL.git +cd RL +git checkout 946862e7 + +# Initialize submodules (required for Megatron-LM, Automodel, etc.) +git submodule update --init --depth 1 +``` + +### 2. Apply patches to build-custom-vllm.sh + +```bash +# Update GIT_REF default to PR #24322 merge commit +sed -i 's|GIT_REF=${2:-cc99baf14dacc2497d0c5ed84e076ef2c37f6a4d}|GIT_REF=${2:-4a5299c93ff97c26def537b92562df5ada530fea}|' \ + tools/build-custom-vllm.sh + +# Update VLLM_WHEEL_COMMIT default +sed -i 's|VLLM_WHEEL_COMMIT=${3:-862f2ef893d9751db0a92bd2d4ae0e3d9677872f}.*|VLLM_WHEEL_COMMIT=${3:-4a5299c93ff97c26def537b92562df5ada530fea} # merge commit of vllm PR #24322|' \ + tools/build-custom-vllm.sh + +# Update wheel URL pattern and make it overridable +sed -i 's|^export VLLM_PRECOMPILED_WHEEL_LOCATION="https://wheels.vllm.ai/${VLLM_WHEEL_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"|export VLLM_PRECOMPILED_WHEEL_LOCATION="${VLLM_PRECOMPILED_WHEEL_LOCATION:-https://wheels.vllm.ai/${VLLM_WHEEL_COMMIT}/vllm-0.14.0rc2.dev156%2Bg4a5299c93-cp38-abi3-manylinux_2_31_x86_64.whl}"|' \ + tools/build-custom-vllm.sh +``` + +### 3. Apply patch to Dockerfile + +```bash +sed -i 's|^uv sync --link-mode symlink --locked --extra sglang --no-install-project$|# Skipped: sgl-kernel build broken (DeepGEMM ref missing upstream)\n# uv sync --link-mode symlink --locked --extra sglang --no-install-project|' \ + docker/Dockerfile +``` + +### 4. Build the Docker image + +```bash +docker buildx build \ + --build-arg BUILD_CUSTOM_VLLM=1 \ + --target release \ + --build-context nemo-rl=. \ + -f docker/Dockerfile \ + --tag nemo-rl:v0.5.0-spec-decode \ + . +``` + +This takes approximately 20-30 minutes. The build will: +1. Pull the CUDA base image +2. Install system dependencies, uv, Python 3.12 +3. Clone vLLM at commit `4a5299c` and build it with the precompiled wheel +4. Sync all dependency extras (vllm, mcore, automodel) except sglang +5. Prefetch Ray worker virtual environments +6. Generate container fingerprint + +### 5. Save the image (optional) + +```bash +docker save nemo-rl:v0.5.0-spec-decode | gzip > nemo-rl-v0.5.0-spec-decode.tar.gz +``` + +The resulting image is approximately 21 GB compressed. + +To load on another machine: +```bash +docker load < nemo-rl-v0.5.0-spec-decode.tar.gz +``` + +### 6. Run the container + +```bash +docker run --gpus all --ipc=host \ + --ulimit memlock=-1 --ulimit stack=67108864 \ + -v /path/to/your/scratch:/path/to/your/scratch \ + -it nemo-rl:v0.5.0-spec-decode bash +``` + +### 7. Create the training config + +Save as `grpo-qwen3-4b-spec-decode-1n8g.yaml`: + +```yaml +defaults: ../../grpo_math_1B.yaml +grpo: + max_num_steps: 200 + num_prompts_per_step: 32 + num_generations_per_prompt: 16 +checkpointing: + checkpoint_dir: results/grpo-qwen3-4b-spec-decode-1n8g +policy: + model_name: Qwen/Qwen3-4B + tokenizer: + name: Qwen/Qwen3-4B + train_global_batch_size: 512 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 2048 + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 1024 + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.85 + max_model_len: 2048 + vllm_kwargs: + speculative_config: + method: draft_model + model: Qwen/Qwen3-0.6B + num_speculative_tokens: 5 + draft_tensor_parallel_size: 1 +data: + max_input_seq_length: 1024 +logger: + log_dir: logs/grpo-qwen3-4b-spec-decode-1n8g + wandb_enabled: false + tensorboard_enabled: true +cluster: + gpus_per_node: 8 +``` + +**Model pairing note:** The target and draft models must have the same vocabulary size. Qwen3 models all use vocab_size=151936 across all sizes, so any Qwen3 pair works. Qwen2.5 models have inconsistent vocab sizes (7B+ use 152064, smaller models use 151936), which prevents cross-size draft pairing. + +### 8. Run GRPO training + +```bash +python examples/run_grpo.py \ + --config /path/to/grpo-qwen3-4b-spec-decode-1n8g.yaml \ + ++logger.log_dir=/path/to/your/scratch/logs/grpo-qwen3-4b-spec-decode \ + ++checkpointing.checkpoint_dir=/path/to/your/scratch/results/grpo-qwen3-4b-spec-decode +``` + +**Important:** Use `python` (not `uv run`) inside the container, as custom vLLM containers use frozen environments. See `docs/guides/use-custom-vllm.md` for details. + +## Verified Results + +Training was verified running on viking-prod-228 (8x H20 GPUs). Key metrics from the first 3 steps: + +| Metric | Step 1 | Step 2 | Step 3 | +|--------|--------|--------|--------| +| Loss | -0.035 | -0.014 | -0.016 | +| Avg Reward | 0.340 | 0.272 | 0.229 | +| Mean Gen Length | 1008 | 1015 | 1001 | +| E2E Tokens/sec (total) | 2488 | 2564 | 2555 | +| Step Time | 228s | 224s | 224s | +| Generation % of step | 51.5% | 46.2% | 46.2% | + +vLLM confirmed speculative decoding active: +``` +Initializing a V1 LLM engine (v0.14.0rc2.dev156+g4a5299c93.d20260210) + speculative_config=SpeculativeConfig(method='draft_model', model='Qwen/Qwen3-0.6B', num_spec_tokens=5) +Loading drafter model... +Starting to load draft model Qwen/Qwen3-0.6B. TP=1, rank=0 +Model loading took 8.67 GiB memory +``` + +## Known Issues + +1. **sglang extra is disabled** — `sgl-kernel` cannot build due to a missing DeepGEMM commit upstream. This does not affect vLLM-based inference or training. +2. **Async scheduling disabled with draft model spec decode** — vLLM logs: "Async scheduling not supported with draft_model-based speculative decoding and will be disabled." This is expected behavior from the vLLM implementation. +3. **`min_p`, `logit_bias`, `min_tokens` unsupported with spec decode** — vLLM warning, not an issue for standard GRPO training. + +## File Locations (on viking-prod-228) + +- Source: `/home/scratch.shaunakj_other/Development/RL/` +- Saved image: `/home/scratch.shaunakj_other/Development/nemo-rl-v0.5.0-spec-decode.tar.gz` +- Training logs: `/tmp/logs/grpo-qwen3-4b-spec-decode/exp_001/` +- Training config: `/home/scratch.shaunakj_other/Development/RL/examples/configs/recipes/llm/grpo-qwen2.5-7b-spec-decode-1n8g.yaml` diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-spec-decode-1n8g.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-spec-decode-1n8g.yaml new file mode 100644 index 0000000000..bb4c62ae1e --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-spec-decode-1n8g.yaml @@ -0,0 +1,46 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + max_num_steps: 200 + num_prompts_per_step: 32 + num_generations_per_prompt: 16 +checkpointing: + checkpoint_dir: results/grpo-qwen2.5-7b-spec-decode-1n8g +policy: + model_name: Qwen/Qwen3-4B + tokenizer: + name: Qwen/Qwen3-4B + train_global_batch_size: 512 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 2048 + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 1024 + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.85 + max_model_len: 2048 + async_engine: true + enable_vllm_metrics_logger: true + vllm_metrics_logger_interval: 0.5 + vllm_kwargs: + speculative_config: + method: draft_model + model: Qwen/Qwen3-0.6B + num_speculative_tokens: 5 + draft_tensor_parallel_size: 1 +data: + max_input_seq_length: 1024 +logger: + log_dir: logs/grpo-qwen2.5-7b-spec-decode-1n8g + wandb_enabled: false + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-qwen2.5-7b-spec-decode-1n8g +cluster: + gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen3-14b-draft0.6b-spec-decode-1n8g.yaml b/examples/configs/recipes/llm/grpo-qwen3-14b-draft0.6b-spec-decode-1n8g.yaml new file mode 100644 index 0000000000..c3fd2ed658 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen3-14b-draft0.6b-spec-decode-1n8g.yaml @@ -0,0 +1,43 @@ +defaults: /opt/nemo-rl/examples/configs/grpo_math_1B.yaml +grpo: + max_num_steps: 200 + num_prompts_per_step: 32 + num_generations_per_prompt: 16 +checkpointing: + checkpoint_dir: results/grpo-qwen3-14b-draft0.6b-spec-decode-1n8g +policy: + model_name: Qwen/Qwen3-14B + tokenizer: + name: Qwen/Qwen3-14B + train_global_batch_size: 512 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 2048 + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 1024 + vllm_cfg: + tensor_parallel_size: 2 + gpu_memory_utilization: 0.80 + max_model_len: 2048 + async_engine: true + enable_vllm_metrics_logger: true + vllm_metrics_logger_interval: 0.5 + vllm_kwargs: + speculative_config: + method: draft_model + model: Qwen/Qwen3-0.6B + num_speculative_tokens: 5 + draft_tensor_parallel_size: 2 +data: + max_input_seq_length: 1024 +logger: + log_dir: logs/grpo-qwen3-14b-draft0.6b-spec-decode-1n8g + wandb_enabled: false + tensorboard_enabled: true +cluster: + gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen3-14b-draft0.6b-spec-decode-tuned-1n8g.yaml b/examples/configs/recipes/llm/grpo-qwen3-14b-draft0.6b-spec-decode-tuned-1n8g.yaml new file mode 100644 index 0000000000..fcd02b924f --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen3-14b-draft0.6b-spec-decode-tuned-1n8g.yaml @@ -0,0 +1,44 @@ +defaults: /opt/nemo-rl/examples/configs/grpo_math_1B.yaml +grpo: + max_num_steps: 200 + num_prompts_per_step: 32 + num_generations_per_prompt: 16 +checkpointing: + checkpoint_dir: results/grpo-qwen3-14b-draft0.6b-spec-decode-tuned-1n8g +policy: + model_name: Qwen/Qwen3-14B + tokenizer: + name: Qwen/Qwen3-14B + train_global_batch_size: 512 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 2048 + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 1024 + temperature: 0.6 + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.85 + max_model_len: 2048 + async_engine: true + enable_vllm_metrics_logger: true + vllm_metrics_logger_interval: 0.5 + vllm_kwargs: + speculative_config: + method: draft_model + model: Qwen/Qwen3-0.6B + num_speculative_tokens: 2 + draft_tensor_parallel_size: 1 +data: + max_input_seq_length: 1024 +logger: + log_dir: logs/grpo-qwen3-14b-draft0.6b-spec-decode-tuned-1n8g + wandb_enabled: false + tensorboard_enabled: true +cluster: + gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen3-14b-no-spec-decode-1n8g.yaml b/examples/configs/recipes/llm/grpo-qwen3-14b-no-spec-decode-1n8g.yaml new file mode 100644 index 0000000000..2602fe17dc --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen3-14b-no-spec-decode-1n8g.yaml @@ -0,0 +1,34 @@ +defaults: /opt/nemo-rl/examples/configs/grpo_math_1B.yaml +grpo: + max_num_steps: 200 + num_prompts_per_step: 32 + num_generations_per_prompt: 16 +checkpointing: + checkpoint_dir: results/grpo-qwen3-14b-no-spec-decode-1n8g +policy: + model_name: Qwen/Qwen3-14B + tokenizer: + name: Qwen/Qwen3-14B + train_global_batch_size: 512 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 2048 + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 1024 + vllm_cfg: + tensor_parallel_size: 2 + gpu_memory_utilization: 0.80 + max_model_len: 2048 +data: + max_input_seq_length: 1024 +logger: + log_dir: logs/grpo-qwen3-14b-no-spec-decode-1n8g + wandb_enabled: false + tensorboard_enabled: true +cluster: + gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen3-14b-spec-decode-1n8g.yaml b/examples/configs/recipes/llm/grpo-qwen3-14b-spec-decode-1n8g.yaml new file mode 100644 index 0000000000..fea094bb39 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen3-14b-spec-decode-1n8g.yaml @@ -0,0 +1,43 @@ +defaults: /opt/nemo-rl/examples/configs/grpo_math_1B.yaml +grpo: + max_num_steps: 200 + num_prompts_per_step: 32 + num_generations_per_prompt: 16 +checkpointing: + checkpoint_dir: results/grpo-qwen3-14b-spec-decode-1n8g +policy: + model_name: Qwen/Qwen3-14B + tokenizer: + name: Qwen/Qwen3-14B + train_global_batch_size: 512 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 2048 + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 1024 + vllm_cfg: + tensor_parallel_size: 2 + gpu_memory_utilization: 0.80 + max_model_len: 2048 + async_engine: true + enable_vllm_metrics_logger: true + vllm_metrics_logger_interval: 0.5 + vllm_kwargs: + speculative_config: + method: draft_model + model: Qwen/Qwen3-1.7B + num_speculative_tokens: 5 + draft_tensor_parallel_size: 2 +data: + max_input_seq_length: 1024 +logger: + log_dir: logs/grpo-qwen3-14b-spec-decode-1n8g + wandb_enabled: false + tensorboard_enabled: true +cluster: + gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen3-32b-spec-decode-lowbatch-1n8g.yaml b/examples/configs/recipes/llm/grpo-qwen3-32b-spec-decode-lowbatch-1n8g.yaml new file mode 100644 index 0000000000..c6b16939b6 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen3-32b-spec-decode-lowbatch-1n8g.yaml @@ -0,0 +1,45 @@ +defaults: /opt/nemo-rl/examples/configs/grpo_math_1B.yaml +grpo: + max_num_steps: 10 + num_prompts_per_step: 2 + num_generations_per_prompt: 4 +checkpointing: + checkpoint_dir: results/grpo-qwen3-32b-spec-decode-lowbatch + enabled: false +policy: + model_name: Qwen/Qwen3-32B + tokenizer: + name: Qwen/Qwen3-32B + train_global_batch_size: 8 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 5120 + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 4096 + temperature: 0.6 + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.85 + max_model_len: 5120 + async_engine: true + enable_vllm_metrics_logger: true + vllm_metrics_logger_interval: 0.5 + vllm_kwargs: + speculative_config: + method: draft_model + model: Qwen/Qwen3-0.6B + num_speculative_tokens: 3 + draft_tensor_parallel_size: 1 +data: + max_input_seq_length: 1024 +logger: + log_dir: logs/grpo-qwen3-32b-spec-decode-lowbatch + wandb_enabled: false + tensorboard_enabled: true +cluster: + gpus_per_node: 8 diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index b0971aa198..9983d9f582 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -39,6 +39,7 @@ ) from nemo_rl.algorithms.utils import ( calculate_baseline_and_std_per_prompt, + compute_spec_decode_token_acceptance_metrics, log_generation_metrics_to_wandb, print_performance_metrics, set_seed, @@ -1842,6 +1843,19 @@ def grpo_train( f" • Mean Generation Length: {metrics_logging_data['mean_gen_tokens_per_sample']:.4f}", flush=True, ) + step_token_acceptance_metrics = ( + compute_spec_decode_token_acceptance_metrics( + metrics.get("generation_logger_metrics", {}) + ) + ) + if "token_acceptance_rate" in step_token_acceptance_metrics: + print( + " • Token Acceptance Rate: " + f"{step_token_acceptance_metrics['token_acceptance_rate']:.4f} " + f"({step_token_acceptance_metrics['accepted_draft_tokens']:.0f}/" + f"{step_token_acceptance_metrics['proposed_draft_tokens']:.0f})", + flush=True, + ) print("\n⏱️ Timing:", flush=True) # Display total time first, separately @@ -2816,6 +2830,18 @@ def async_grpo_train( print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") print(f" • Buffer Size: {buffer_size_current}") print(f" • Avg Trajectory Age: {avg_trajectory_age:.2f} steps") + step_token_acceptance_metrics = ( + compute_spec_decode_token_acceptance_metrics( + metrics.get("generation_logger_metrics", {}) + ) + ) + if "token_acceptance_rate" in step_token_acceptance_metrics: + print( + " • Token Acceptance Rate: " + f"{step_token_acceptance_metrics['token_acceptance_rate']:.4f} " + f"({step_token_acceptance_metrics['accepted_draft_tokens']:.0f}/" + f"{step_token_acceptance_metrics['proposed_draft_tokens']:.0f})" + ) print("\n⏱️ Timing:") total_time = timing_metrics.get("total_step_time", 0) diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index cc99033aba..cb98ce5e3e 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -383,6 +383,61 @@ def maybe_pad_last_batch(batch: dict, dp_size: int, mbs: int) -> dict: return batch +def _counter_delta_over_step(counter_series: list[Any]) -> float: + """Estimate step-local increments from a cumulative counter timeline.""" + parsed_values = [] + for value in counter_series: + try: + parsed_values.append(float(value)) + except (TypeError, ValueError): + continue + + if len(parsed_values) < 2: + return 0.0 + + delta = parsed_values[-1] - parsed_values[0] + # If the counter resets during the step, fallback to the latest counter value. + if delta < 0: + return max(0.0, parsed_values[-1]) + return delta + + +def compute_spec_decode_token_acceptance_metrics( + generation_logger_metrics: dict[str, Any], +) -> dict[str, float]: + """Compute speculative decoding token acceptance metrics from logger timelines.""" + accepted_per_worker = generation_logger_metrics.get("spec_decode_accepted_tokens") + proposed_per_worker = generation_logger_metrics.get("spec_decode_proposed_tokens") + if not isinstance(accepted_per_worker, dict) or not isinstance( + proposed_per_worker, dict + ): + return {} + + accepted_draft_tokens = sum( + _counter_delta_over_step(worker_series) + for worker_series in accepted_per_worker.values() + if isinstance(worker_series, list) + ) + proposed_draft_tokens = sum( + _counter_delta_over_step(worker_series) + for worker_series in proposed_per_worker.values() + if isinstance(worker_series, list) + ) + + if accepted_draft_tokens <= 0 and proposed_draft_tokens <= 0: + return {} + + token_acceptance_metrics: dict[str, float] = { + "accepted_draft_tokens": accepted_draft_tokens, + "proposed_draft_tokens": proposed_draft_tokens, + } + if proposed_draft_tokens > 0: + token_acceptance_metrics["token_acceptance_rate"] = ( + accepted_draft_tokens / proposed_draft_tokens + ) + return token_acceptance_metrics + + def print_performance_metrics( train_results: dict[str, float], metrics: dict[str, Any], @@ -562,6 +617,19 @@ def visualize_per_worker_timeline( "Num Pending Samples", None, ) + token_acceptance_metrics = compute_spec_decode_token_acceptance_metrics( + vllm_logger_metrics + ) + if "token_acceptance_rate" in token_acceptance_metrics: + accepted_draft_tokens = token_acceptance_metrics["accepted_draft_tokens"] + proposed_draft_tokens = token_acceptance_metrics["proposed_draft_tokens"] + token_acceptance_rate = token_acceptance_metrics["token_acceptance_rate"] + print( + " • Spec Decode Token Acceptance Rate: " + f"{token_acceptance_rate:.4f} " + f"({accepted_draft_tokens:.0f}/{proposed_draft_tokens:.0f})" + ) + performance_metrics.update(token_acceptance_metrics) # ===================================================== # Throughputs diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 1366ce28c5..9452f3777c 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -843,6 +843,8 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]: "num_pending_samples": {}, # dp_idx -> list[int] "kv_cache_usage_perc": {}, # dp_idx -> list[float] "generation_tokens": {}, # dp_idx -> list[int] + "spec_decode_accepted_tokens": {}, # dp_idx -> list[int] + "spec_decode_proposed_tokens": {}, # dp_idx -> list[int] } for dp_idx, stats in zip(dp_indices, results): @@ -862,6 +864,16 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]: generation_tokens = stats.get("generation_tokens") if generation_tokens: vllm_logger_metrics["generation_tokens"][dp_idx] = generation_tokens + spec_decode_accepted_tokens = stats.get("spec_decode_accepted_tokens") + if spec_decode_accepted_tokens: + vllm_logger_metrics["spec_decode_accepted_tokens"][dp_idx] = ( + spec_decode_accepted_tokens + ) + spec_decode_proposed_tokens = stats.get("spec_decode_proposed_tokens") + if spec_decode_proposed_tokens: + vllm_logger_metrics["spec_decode_proposed_tokens"][dp_idx] = ( + spec_decode_proposed_tokens + ) return vllm_logger_metrics diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 0e4ea5cdeb..7e613a1c04 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -212,6 +212,35 @@ def _start_vllm_metrics_logger(self) -> None: self.num_pending_samples: list[int] = [] self.kv_cache_usage_perc: list[float] = [] self.generation_tokens: list[int] = [] + self.spec_decode_accepted_tokens: list[int] = [] + self.spec_decode_proposed_tokens: list[int] = [] + self._spec_decode_accepted_counter_name: Optional[str] = None + self._spec_decode_proposed_counter_name: Optional[str] = None + + def _normalize_metric_name(name: str) -> str: + return name.lower().replace("-", "_") + + def _is_spec_decode_token_counter(name: str) -> bool: + normalized_name = _normalize_metric_name(name) + return ( + "token" in normalized_name + and ("spec" in normalized_name or "draft" in normalized_name) + ) + + def _is_accepted_token_counter(name: str) -> bool: + normalized_name = _normalize_metric_name(name) + return _is_spec_decode_token_counter(normalized_name) and ( + "accept" in normalized_name or "accepted" in normalized_name + ) + + def _is_proposed_token_counter(name: str) -> bool: + normalized_name = _normalize_metric_name(name) + return _is_spec_decode_token_counter(normalized_name) and ( + ("proposed" in normalized_name) + or ("proposal" in normalized_name) + or ("speculative" in normalized_name) + or ("draft" in normalized_name) + ) def _logger_loop(): # Delay a little to let engine settle @@ -233,6 +262,22 @@ def _logger_loop(): elif isinstance(m, Counter): if m.name == "vllm:generation_tokens": self.generation_tokens.append(int(m.value)) + elif ( + self._spec_decode_accepted_counter_name is None + and _is_accepted_token_counter(m.name) + ): + self._spec_decode_accepted_counter_name = m.name + elif ( + self._spec_decode_proposed_counter_name is None + and _is_proposed_token_counter(m.name) + and not _is_accepted_token_counter(m.name) + ): + self._spec_decode_proposed_counter_name = m.name + + if m.name == self._spec_decode_accepted_counter_name: + self.spec_decode_accepted_tokens.append(int(m.value)) + if m.name == self._spec_decode_proposed_counter_name: + self.spec_decode_proposed_tokens.append(int(m.value)) except Exception: print( "⚠️[vLLM Metric Logger] Exception in vLLM metrics logger", @@ -261,6 +306,12 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]: "num_pending_samples": copy.deepcopy(self.num_pending_samples), "kv_cache_usage_perc": copy.deepcopy(self.kv_cache_usage_perc), "generation_tokens": copy.deepcopy(self.generation_tokens), + "spec_decode_accepted_tokens": copy.deepcopy( + self.spec_decode_accepted_tokens + ), + "spec_decode_proposed_tokens": copy.deepcopy( + self.spec_decode_proposed_tokens + ), } return metric @@ -273,6 +324,8 @@ def clear_vllm_logger_metrics(self) -> None: self.num_pending_samples = [] self.kv_cache_usage_perc = [] self.generation_tokens = [] + self.spec_decode_accepted_tokens = [] + self.spec_decode_proposed_tokens = [] async def post_init_async(self): self.vllm_device_ids = await self.report_device_id_async() diff --git a/pyproject.toml b/pyproject.toml index 981c413241..b92675c87c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ vllm = [ # sudo apt-get update # sudo apt-get install libibverbs-dev "deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@bfded34800dfec415b71503f8205181de90b2480", - "vllm==0.11.2", + "vllm>=0.12", "num2words>=0.5.14", ] sglang = [ diff --git a/tools/build-custom-vllm.sh b/tools/build-custom-vllm.sh index 72b2cec71b..e6a3ea1709 100644 --- a/tools/build-custom-vllm.sh +++ b/tools/build-custom-vllm.sh @@ -21,11 +21,11 @@ REPO_ROOT="$(realpath "$SCRIPT_DIR/..")" # Parse command line arguments GIT_URL=${1:-https://github.com/vllm-project/vllm.git} -GIT_REF=${2:-cc99baf14dacc2497d0c5ed84e076ef2c37f6a4d} +GIT_REF=${2:-4a5299c93ff97c26def537b92562df5ada530fea} # NOTE: VLLM_USE_PRECOMPILED=1 didn't always seem to work since the wheels were sometimes built against an incompatible torch/cuda combo. # This commit was chosen as one close to the v0.10 release: git merge-base --fork-point origin/main tags/v0.10.0 -VLLM_WHEEL_COMMIT=${3:-862f2ef893d9751db0a92bd2d4ae0e3d9677872f} # use full commit hash from the main branch -export VLLM_PRECOMPILED_WHEEL_LOCATION="https://wheels.vllm.ai/${VLLM_WHEEL_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" +VLLM_WHEEL_COMMIT=${3:-4a5299c93ff97c26def537b92562df5ada530fea} # merge commit of vllm PR #24322 (spec decode with draft models) +export VLLM_PRECOMPILED_WHEEL_LOCATION="${VLLM_PRECOMPILED_WHEEL_LOCATION:-https://wheels.vllm.ai/${VLLM_WHEEL_COMMIT}/vllm-0.14.0rc2.dev156%2Bg4a5299c93-cp38-abi3-manylinux_2_31_x86_64.whl}" BUILD_DIR=$(realpath "$SCRIPT_DIR/../3rdparty/vllm") if [[ -e "$BUILD_DIR" ]]; then