Skip to content

Conversation

@govind-ramnarayan
Copy link
Collaborator

@govind-ramnarayan govind-ramnarayan commented Jan 14, 2026

Summary by CodeRabbit

Release Notes

  • New Features

    • Added complete Eagle3 model implementation with speculative decoding support for AutoDeploy.
    • Added model type override capability for flexible model configuration during instantiation.
  • Tests

    • Added comprehensive tests for Eagle3 model weight loading, validation, and torch export.
    • Added Eagle3 model configuration and build pipeline tests.

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Copy link
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good overall. Just needs a clean-up

@govind-ramnarayan govind-ramnarayan changed the title Eagle3 Head in AutoDeploy [AutoDeploy] PyTorch impl for Eagle3 checkpoint Jan 15, 2026
@govind-ramnarayan govind-ramnarayan changed the title [AutoDeploy] PyTorch impl for Eagle3 checkpoint [feat][AutoDeploy] PyTorch impl for Eagle3 checkpoint Jan 15, 2026
@govind-ramnarayan govind-ramnarayan changed the title [feat][AutoDeploy] PyTorch impl for Eagle3 checkpoint [None][feat][AutoDeploy] PyTorch impl for Eagle3 checkpoint Jan 15, 2026
@govind-ramnarayan govind-ramnarayan marked this pull request as ready for review January 15, 2026 00:43
@govind-ramnarayan govind-ramnarayan requested a review from a team as a code owner January 15, 2026 00:43
Signed-off-by: Govind Ramnarayan <[email protected]>
…properly with AutoDeploy

Signed-off-by: Govind Ramnarayan <[email protected]>
…This verifies that the code there is exportable

Signed-off-by: Govind Ramnarayan <[email protected]>
Signed-off-by: Govind Ramnarayan <[email protected]>
…g. Removed debugging output for dtypes as well

Signed-off-by: Govind Ramnarayan <[email protected]>
Signed-off-by: Govind Ramnarayan <[email protected]>
…sting, removed LlamaConfig inheritance

Signed-off-by: Govind Ramnarayan <[email protected]>
Signed-off-by: Govind Ramnarayan <[email protected]>
Signed-off-by: Govind Ramnarayan <[email protected]>
@govind-ramnarayan govind-ramnarayan force-pushed the gramnarayan/export-eagle3 branch from b20bcf7 to 9898953 Compare January 15, 2026 00:45
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

📝 Walkthrough

Walkthrough

This pull request introduces a complete HuggingFace-compatible Eagle3 speculative decoding model for AutoDeploy, including core transformer architecture, configuration classes, AutoDeploy registration hooks, mock testing variants, weight loading analysis utilities, and comprehensive test cases spanning unit and integration tests.

Changes

Cohort / File(s) Summary
Eagle3 Model Implementation
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
Adds complete Eagle3 model architecture: EagleRMSNorm, EagleMLP, Eagle3Config (HuggingFace PretrainedConfig), Eagle3Attention (with custom ops and rotary embeddings), Eagle3DecoderLayer, Eagle3Model (core 3x hidden_size fusion), and Eagle3ModelForCausalLM (HuggingFace wrapper). Includes MockEagle3Config and MockEagle3ModelForCausalLM for testing with random hidden states. Registers models with AutoConfig and AutoModelForCausalLMFactory.
Model Public API
tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
Exports Eagle3ModelForCausalLM in \__all__\\ and imports from modeling_eagle module.
AutoDeploy Config Override
tensorrt_llm/_torch/auto_deploy/models/hf.py
Adds \_override_model_type\\ method to AutoModelForCausalLMFactory that re-instantiates model config with a different config class when model_type override is provided via model_kwargs. Integrated into \_get_model_config\\ flow before recursive config updates.
Model Configuration & Utilities
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
Adds Eagle3 model entry ("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B") to small model configs with tokenizer path resolution logic based on tokenizer_subdir and tokenizer_hub_id.
Unit Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py
test_ad_build_small_single.py adds Eagle3 model test case with flashinfer and torch-compile transforms. test_ad_speculative_decoding.py adds Eagle3 torch export test with config loading, model instantiation, dummy input creation, and torch.export validation.
Integration Tests
tests/integration/defs/examples/test_ad_speculative_decoding.py
Adds \_analyze_weight_loading\\ helper to inspect checkpoint keys and validate state_dict mapping. Adds \test_eagle_model_with_weights\\ to instantiate MockEagle3ModelForCausalLM via factory, analyze weight loading, validate missing/unexpected keys, load weights, run forward pass, and assert output logits shape.

Sequence Diagram

sequenceDiagram
    participant Client as AutoDeploy Client
    participant Factory as AutoModelForCausalLMFactory
    participant Config as Eagle3Config
    participant Model as Eagle3ModelForCausalLM
    participant Attention as Eagle3Attention
    participant MLP as Eagle3MLP
    participant Head as Linear Head

    Client->>Factory: load_or_random_init(model_type_override="mock_eagle3")
    activate Factory
    Factory->>Factory: _override_model_type()
    Factory->>Config: AutoConfig.for_model("MockEagle3Config")
    Config-->>Factory: MockEagle3Config instance
    Factory->>Model: instantiate Eagle3ModelForCausalLM
    activate Model
    Model->>Model: forward(input_ids, target_hidden_states)
    Note over Model: Fuse 3×hidden_size to hidden_size
    Model->>Model: apply rotary embeddings
    
    loop For each decoder layer
        Model->>Attention: Eagle3Attention.forward()
        activate Attention
        Attention->>Attention: compute Q, K, V
        Attention->>Attention: apply rotary positions
        Attention->>Attention: custom attention op
        Attention-->>Model: attention output
        deactivate Attention
        
        Model->>MLP: Eagle3MLP.forward()
        activate MLP
        MLP->>MLP: gate_proj + up_proj
        MLP->>MLP: activation + down_proj
        MLP-->>Model: mlp output
        deactivate MLP
        
        Model->>Model: residual + norm
    end
    
    Model->>Head: final linear projection
    Head-->>Model: logits (vocab_size)
    Model-->>Client: CausalLMOutputWithPast
    deactivate Model
    deactivate Factory
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is largely the template boilerplate with no substantive content about what was changed, why, or test coverage details. Critical sections remain unfilled. Fill in the Description and Test Coverage sections with clear explanations of the changes, rationale, and specific tests added. Document the Eagle3 model implementation, integration, and any relevant test cases.
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: adding a PyTorch implementation for Eagle3 checkpoint in AutoDeploy, following the required format with ticket ID and type.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@govind-ramnarayan govind-ramnarayan changed the title [None][feat][AutoDeploy] PyTorch impl for Eagle3 checkpoint [None][chore][AutoDeploy] PyTorch impl for Eagle3 checkpoint Jan 15, 2026
@govind-ramnarayan govind-ramnarayan changed the title [None][chore][AutoDeploy] PyTorch impl for Eagle3 checkpoint [None][chore] AutoDeploy: PyTorch impl for Eagle3 checkpoint Jan 15, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py`:
- Around line 195-198: The call to self.self_attn(...) in modeling_eagle.py
incorrectly indexes the result with [0] even though Eagle3Attention.forward
returns a torch.Tensor; remove the [0] indexing and use the returned tensor
directly (i.e., assign hidden_states = self.self_attn(...)) and update any
downstream code that expected a tuple to work with the single Tensor return from
Eagle3Attention.forward.
- Around line 131-168: In the forward method of the custom attention module
(forward in modeling_eagle.py) remove the debug print statement
print("hidden_states.shape:", hidden_states.shape) so logs aren't polluted in
production; if inspection is still needed replace it with a proper logger.debug
call or conditional debug flag, but do not leave raw print statements in the
q/k/v projection path before calling torch_attention.
- Around line 415-429: The constructor currently reads config.dtype which
doesn't exist on Eagle3Config; change reading of the dtype in __init__ to use
config.torch_dtype or a fallback (e.g. getattr(config, "torch_dtype", None) or
torch.float32) and store it to self._dtype, and ensure forward still uses
self._dtype when creating the mock target_hidden_states; also consider accepting
an override from kwargs (e.g. kwargs.get("torch_dtype")) so callers can supply
dtype via model_kwargs if needed.
- Around line 255-285: The forward method assigns self.midlayer to an
nn.ModuleList when config.num_hidden_layers > 1 but then calls it as if it were
a single layer; replace the direct call to self.midlayer(...) with an explicit
loop over the layers in self.midlayer (each is an Eagle3DecoderLayer) and
sequentially pass/receive the tensors (hidden_states, embeds/input_embeds,
position_embeds) through each layer, updating hidden_states (and embeds if the
architecture requires) per iteration, then set out to the final layer's output
before continuing to norm/lm_head; ensure you handle both the single-layer case
(Eagle3DecoderLayer) and multi-layer case consistently by using the same
variable names used in forward (hidden_states, input_embeds, position_embeds,
out).
🧹 Nitpick comments (3)
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py (1)

157-163: Consider catching more specific exception types.

The broad Exception catch is flagged by static analysis (BLE001). While torch.export can raise various exception types, consider catching torch.export.ExportError or RuntimeError for more precise error handling. However, since this is a test validating exportability and any failure should result in test failure, the current approach is acceptable.

♻️ Optional: More specific exception handling
     # Attempt torch.export
     try:
         exported_program = torch.export.export(model, args=example_args)
         print("✅ torch.export successful!")
         print("Graph module code preview (first 20 lines):")
         code_lines = exported_program.graph_module.code.split("\n")[:20]
         print("\n".join(code_lines))
-    except Exception as e:
+    except (torch.export.ExportError, RuntimeError, TypeError) as e:
         pytest.fail(f"torch.export failed: {e}")
tests/integration/defs/examples/test_ad_speculative_decoding.py (2)

285-339: Code duplication acknowledged; consider future refactoring.

The TODO comment at line 285-286 correctly notes this replicates logic from hf.py. The inline import re at line 300 works but is unconventional. Consider moving the import to the file's top-level for consistency, though it's minor for a test utility.

The function is well-documented and serves its purpose for validating weight loading behavior. However, as noted in lines 409-410, this could get stale if hf.py changes. Consider exposing this analysis functionality from the factory itself in the future.

♻️ Move import to top-level
 import os
+import re
 from pathlib import Path
 
 import pytest

Then remove the import from inside the function:

 def _analyze_weight_loading(model_path: Path, model: torch.nn.Module):
     ...
-    import re
-
     # 1. Load checkpoint keys

447-468: Hardcoded expected keys may become brittle.

The expected missing/unexpected keys are hardcoded based on the current Eagle3 architecture. If the model architecture changes (e.g., different weight sharing strategy), this test will fail with a potentially confusing assertion error.

Consider adding a comment noting when these expectations should be updated, or make the test more flexible by checking subsets rather than exact equality.

♻️ Alternative: Use subset checks with informative output
     # Verify expected missing and unexpected keys
     # These are the keys we expect based on Eagle3 architecture:
     # - embed_tokens: shared from target model (not in Eagle checkpoint)
     # - t2d: target-to-draft mapping, not used in Eagle3 (uses d2t instead)
+    # NOTE: Update these sets if Eagle3 architecture changes
     expected_missing_keys = {"model.embed_tokens.weight"}
     expected_unexpected_keys = {"model.t2d"}

-    assert missing_keys == expected_missing_keys, (
+    assert expected_missing_keys.issubset(missing_keys), (
+        f"Expected missing keys not found.\n"
+        f"Expected at least: {expected_missing_keys}\n"
+        f"Got: {missing_keys}"
+    )
+    # Warn if there are additional missing keys
+    extra_missing = missing_keys - expected_missing_keys
+    if extra_missing:
+        print(f"⚠️  Additional missing keys (may be expected): {extra_missing}")
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 211c44b and 9898953.

📒 Files selected for processing (7)
  • tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tests/integration/defs/examples/test_ad_speculative_decoding.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g., some_file.py)
Python classes should use PascalCase (e.g., class SomeClass)
Python functions and methods should use snake_case (e.g., def my_awesome_function():)
Python local variables should use snake_case, with prefix k for variable names that start with a number (e.g., k_99th_percentile)
Python global variables should use upper snake_case with prefix G (e.g., G_MY_GLOBAL)
Python constants should use upper snake_case (e.g., MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format """<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic

Files:

  • tests/integration/defs/examples/test_ad_speculative_decoding.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification

Files:

  • tests/integration/defs/examples/test_ad_speculative_decoding.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py
  • tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py
🧠 Learnings (8)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/integration/defs/examples/test_ad_speculative_decoding.py
📚 Learning: 2025-08-29T14:07:45.863Z
Learnt from: EmmaQiaoCh
Repo: NVIDIA/TensorRT-LLM PR: 7370
File: tests/unittest/trt/model_api/test_model_quantization.py:24-27
Timestamp: 2025-08-29T14:07:45.863Z
Learning: In TensorRT-LLM's CI infrastructure, pytest skip markers (pytest.mark.skip) are properly honored even when test files have __main__ blocks that call test functions directly. The testing system correctly skips tests without requiring modifications to the __main__ block execution pattern.

Applied to files:

  • tests/integration/defs/examples/test_ad_speculative_decoding.py
📚 Learning: 2025-08-28T10:25:22.370Z
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 7294
File: tensorrt_llm/_torch/pyexecutor/sampler.py:887-891
Timestamp: 2025-08-28T10:25:22.370Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, the draft_probs and target_probs tensors have shapes [1, steps] not [steps, vocab_size] as might be expected, making the .squeeze(0) operations appropriate for removing the batch dimension of size 1.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
📚 Learning: 2025-10-13T13:55:04.170Z
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 8263
File: examples/models/contrib/sdxl/run_sdxl.py:0-0
Timestamp: 2025-10-13T13:55:04.170Z
Learning: The `diffusers` library (e.g., `DiffusionPipeline`, `StableDiffusionXLPipeline`, `StableDiffusion3Pipeline`) uses the `torch_dtype` parameter in `from_pretrained()` calls, not `dtype`. Only the `transformers` library has migrated to using `dtype`.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
📚 Learning: 2025-08-09T02:04:49.623Z
Learnt from: Fridah-nv
Repo: NVIDIA/TensorRT-LLM PR: 6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.623Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
📚 Learning: 2025-12-19T06:31:54.973Z
Learnt from: nvyocox
Repo: NVIDIA/TensorRT-LLM PR: 10117
File: tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py:336-339
Timestamp: 2025-12-19T06:31:54.973Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py, the cast to torch.float16 for qkv_node before creating the AttentionPlugin is intentional and required because DriveOS LLM expects float16 dtype specifically. This should not be changed to preserve original dtype or made configurable for bfloat16 models in the DriveOS LLM ONNX export path.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.

Applied to files:

  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
🧬 Code graph analysis (1)
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py (1)
tests/integration/defs/triton_server/conftest.py (1)
  • models_root (280-284)
🪛 Ruff (0.14.11)
tests/integration/defs/examples/test_ad_speculative_decoding.py

316-316: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

262-262: Avoid specifying long messages outside the exception class

(TRY003)


306-306: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


313-315: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


324-324: Unused method argument: attention_mask

(ARG002)


326-326: Unused method argument: past_key_values

(ARG002)


327-327: Unused method argument: inputs_embeds

(ARG002)


328-328: Unused method argument: labels

(ARG002)


329-329: Unused method argument: use_cache

(ARG002)


330-330: Unused method argument: output_attentions

(ARG002)


331-331: Unused method argument: output_hidden_states

(ARG002)


332-332: Unused method argument: return_dict

(ARG002)


356-360: Avoid specifying long messages outside the exception class

(TRY003)

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py

163-163: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (17)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (2)

511-518: LGTM! Well-structured Eagle3 small model configuration.

The configuration correctly:

  • Uses mock_eagle3 model_type for standalone testing
  • Separates tokenizer path from model path (needed since Eagle3 shares tokenizer with target Llama model)
  • Provides both local subdir and hub fallback paths for the tokenizer

544-548: LGTM! Clean tokenizer path resolution.

The logic correctly handles separate tokenizer paths for models like Eagle3 that share tokenizers with their target models, reusing the existing _hf_model_dir_or_hub_id helper.

tensorrt_llm/_torch/auto_deploy/models/hf.py (2)

191-216: LGTM! Clean model_type override implementation.

The method correctly:

  • Only acts when there's an actual override needed
  • Uses AutoConfig.for_model() to get the registered config class for the target type
  • Preserves original config values via to_dict()/from_dict() round-trip
  • Logs the override for debugging

The docstring clearly explains the use case (Eagle draft model with llama checkpoint needing Eagle architecture).


228-230: LGTM! Correct integration point.

Calling _override_model_type before _recursive_update_config ensures the config class is swapped first, then all model_kwargs (including model_type) are applied to the new config instance.

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py (4)

46-58: LGTM! Clean RMSNorm implementation.

Simple wrapper around the custom torch_rmsnorm op with standard parameter initialization.


61-74: LGTM! Standard gated MLP implementation.

Follows the typical LLaMA-style gated MLP pattern with configurable activation.


77-90: LGTM! Minimal config registration.

The config inherits attributes from the checkpoint's original config (e.g., LlamaConfig) via the _override_model_type mechanism in hf.py. Registration with AutoConfig is correct.


293-379: LGTM with minor notes.

The implementation correctly:

  • Wraps Eagle3Model in HF-compatible interface
  • Requires target_hidden_states via kwargs with clear error message
  • Uses _checkpoint_conversion_mapping for weight key remapping

The unused method arguments (attention_mask, past_key_values, etc.) are intentional for HF interface compatibility, so the static analysis warnings can be safely ignored.

tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py (1)

1-10: LGTM! Clean public API export.

Correctly exports only Eagle3ModelForCausalLM for production use while keeping MockEagle3ModelForCausalLM internal (accessed via model_type override mechanism for testing).

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py (1)

201-209: LGTM! Test case follows existing patterns.

The new Eagle3 test case correctly:

  • Uses the hub ID matching the config in _model_test_utils.py
  • Applies appropriate transforms (flashinfer attention + torch-compile)
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py (3)

16-27: LGTM!

Imports are correctly organized and appropriate for the test functionality.


31-37: LGTM!

The helper function correctly handles the case when llm_models_root() returns None and validates existence before returning the path.


141-143: The mock hidden states dimension is correct. The Eagle3Model docstring explicitly documents that it expects "concatenated hidden states from target model layers (typically 3 layers, resulting in hidden_size * 3 dimensions)." The test correctly creates mock_hidden_states with shape (batch_size, seq_len, hidden_dim * 3) which matches the model's documented input specification. Note: eagle3_layers_to_capture is a higher-level API configuration used to determine which layers to capture from the target model, not a parameter Eagle3Model itself uses.

tests/integration/defs/examples/test_ad_speculative_decoding.py (4)

16-28: LGTM!

Imports are correctly organized with standard library imports first, followed by third-party and local imports.


306-316: LGTM!

Good practices:

  • Uses weights_only=True for secure loading
  • Properly cleans up memory with del state_dict
  • Handles both safetensors and PyTorch binary formats

The Ruff TRY003 warning about the exception message is a minor style concern that can be safely ignored here.


342-356: Excellent documentation.

The docstring provides a comprehensive explanation of the factory initialization flow, making the test's purpose and validation scope clear.


499-512: LGTM!

The forward pass validation is well-structured:

  • Uses torch.inference_mode() appropriately
  • Handles both draft_vocab_size and fallback to vocab_size
  • Clear shape verification with informative assertion message

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +131 to +168
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
print("hidden_states.shape:", hidden_states.shape)
bsz, q_len, _ = hidden_states.size()
cos, sin = position_embeddings

# Projections
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# Reshape to [Batch, Seq, Heads, Dim]
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)

query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=2
)

attn_output = torch.ops.auto_deploy.torch_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=self.is_causal,
layout="bsnd",
)

attn_output = attn_output.view(bsz, q_len, self.num_attention_heads * self.head_dim)

attn_output = self.o_proj(attn_output)

return attn_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove debug print statement.

Line 136 contains a debug print statement that will pollute logs in production.

🔧 Proposed fix
     def forward(
         self,
         hidden_states: torch.Tensor,
         position_embeddings: tuple[torch.Tensor, torch.Tensor],
     ) -> torch.Tensor:
-        print("hidden_states.shape:", hidden_states.shape)
         bsz, q_len, _ = hidden_states.size()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
print("hidden_states.shape:", hidden_states.shape)
bsz, q_len, _ = hidden_states.size()
cos, sin = position_embeddings
# Projections
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Reshape to [Batch, Seq, Heads, Dim]
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=2
)
attn_output = torch.ops.auto_deploy.torch_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=self.is_causal,
layout="bsnd",
)
attn_output = attn_output.view(bsz, q_len, self.num_attention_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
bsz, q_len, _ = hidden_states.size()
cos, sin = position_embeddings
# Projections
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Reshape to [Batch, Seq, Heads, Dim]
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=2
)
attn_output = torch.ops.auto_deploy.torch_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=self.is_causal,
layout="bsnd",
)
attn_output = attn_output.view(bsz, q_len, self.num_attention_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py` around lines
131 - 168, In the forward method of the custom attention module (forward in
modeling_eagle.py) remove the debug print statement
print("hidden_states.shape:", hidden_states.shape) so logs aren't polluted in
production; if inspection is still needed replace it with a proper logger.debug
call or conditional debug flag, but do not leave raw print statements in the
q/k/v projection path before calling torch_attention.

Comment on lines +195 to +198
hidden_states = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeds,
)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Type mismatch: Eagle3Attention.forward returns Tensor, not tuple.

Eagle3Attention.forward returns torch.Tensor (line 168), but this code indexes [0] on the result, which will raise TypeError: 'Tensor' object is not subscriptable at runtime.

🐛 Proposed fix: remove indexing
         hidden_states = self.self_attn(
             hidden_states=hidden_states,
             position_embeddings=position_embeds,
-        )[0]
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
hidden_states = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeds,
)[0]
hidden_states = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeds,
)
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py` around lines
195 - 198, The call to self.self_attn(...) in modeling_eagle.py incorrectly
indexes the result with [0] even though Eagle3Attention.forward returns a
torch.Tensor; remove the [0] indexing and use the returned tensor directly
(i.e., assign hidden_states = self.self_attn(...)) and update any downstream
code that expected a tuple to work with the single Tensor return from
Eagle3Attention.forward.

Comment on lines +255 to +285
if config.num_hidden_layers == 1:
self.midlayer = Eagle3DecoderLayer(config, layer_idx=0)
elif config.num_hidden_layers > 1:
self.midlayer = nn.ModuleList(
[Eagle3DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
)
else:
raise ValueError(f"Invalid number of hidden layers: {config.num_hidden_layers}")

# Output head
self.norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

def forward(
self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# Fuse hidden states from target model layers
hidden_states = self.fc(hidden_states)
input_embeds = self.embed_tokens(input_ids)

cos, sin = self.rotary_emb(hidden_states, position_ids)
position_embeds = (cos, sin)

out = self.midlayer(
hidden_states=hidden_states,
embeds=input_embeds,
position_embeds=position_embeds,
)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cd / && find . -name "modeling_eagle.py" -type f 2>/dev/null

Repository: NVIDIA/TensorRT-LLM

Length of output: 146


🏁 Script executed:

cat -n ./home/jailuser/git/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py | sed -n '240,300p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 178


🏁 Script executed:

cat -n tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py | sed -n '240,300p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 2623


🏁 Script executed:

cat -n tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py | sed -n '1,50p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 2313


🏁 Script executed:

cat -n tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py | sed -n '100,200p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 4458


🏁 Script executed:

cat -n tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py | sed -n '200,220p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 913


🏁 Script executed:

grep -n "self.midlayer" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 214


🏁 Script executed:

python3 << 'EOF'
import torch.nn as nn

# Verify ModuleList behavior
ml = nn.ModuleList([nn.Linear(10, 10)])

# Try to call it
try:
    result = ml(x=torch.randn(1, 10))
    print("ModuleList is callable")
except TypeError as e:
    print(f"ModuleList is NOT callable: {e}")
EOF

Repository: NVIDIA/TensorRT-LLM

Length of output: 149


🏁 Script executed:

cat -n tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1076


Bug: midlayer forward calls nn.ModuleList directly, which is not callable.

When config.num_hidden_layers > 1, self.midlayer is assigned as nn.ModuleList. However, lines 281-285 attempt to call self.midlayer(...) directly with keyword arguments. Since nn.ModuleList is not callable (it's a container, not a layer), this will raise a TypeError at runtime. The code only works when num_hidden_layers == 1 because then self.midlayer is an Eagle3DecoderLayer which is callable.

🐛 Proposed fix: iterate through layers
-        out = self.midlayer(
-            hidden_states=hidden_states,
-            embeds=input_embeds,
-            position_embeds=position_embeds,
-        )[0]
+        if isinstance(self.midlayer, nn.ModuleList):
+            out = hidden_states
+            for layer in self.midlayer:
+                out, _ = layer(
+                    hidden_states=out,
+                    embeds=input_embeds,
+                    position_embeds=position_embeds,
+                )
+        else:
+            out, _ = self.midlayer(
+                hidden_states=hidden_states,
+                embeds=input_embeds,
+                position_embeds=position_embeds,
+            )

Note: Verify the intended multi-layer behavior—specifically whether embeds should be the same input_embeds for all layers or updated per layer based on the Eagle3 architecture design.

🧰 Tools
🪛 Ruff (0.14.11)

262-262: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py` around lines
255 - 285, The forward method assigns self.midlayer to an nn.ModuleList when
config.num_hidden_layers > 1 but then calls it as if it were a single layer;
replace the direct call to self.midlayer(...) with an explicit loop over the
layers in self.midlayer (each is an Eagle3DecoderLayer) and sequentially
pass/receive the tensors (hidden_states, embeds/input_embeds, position_embeds)
through each layer, updating hidden_states (and embeds if the architecture
requires) per iteration, then set out to the final layer's output before
continuing to norm/lm_head; ensure you handle both the single-layer case
(Eagle3DecoderLayer) and multi-layer case consistently by using the same
variable names used in forward (hidden_states, input_embeds, position_embeds,
out).

Comment on lines +415 to +429
def __init__(self, config):
super().__init__(config)
self._hidden_size = config.hidden_size
self._dtype = config.dtype

def forward(self, input_ids, **kwargs):
# Inject mock hidden states if not provided
if "target_hidden_states" not in kwargs:
batch_size, seq_len = input_ids.shape
kwargs["target_hidden_states"] = torch.randn(
(batch_size, seq_len, self._hidden_size * 3),
dtype=self._dtype,
device=input_ids.device,
)
return super().forward(input_ids, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let me examine the modeling_eagle.py file around the problematic lines
head -n 435 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py | tail -n 55

Repository: NVIDIA/TensorRT-LLM

Length of output: 2090


🏁 Script executed:

# Check if Eagle3Config is defined in this file and what attributes it has
rg -n "class.*Eagle3Config" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py -A 20

Repository: NVIDIA/TensorRT-LLM

Length of output: 3161


🏁 Script executed:

# Look for how config.dtype or torch_dtype is used elsewhere in the codebase
rg -n "config\.dtype|config\.torch_dtype" tensorrt_llm/_torch/auto_deploy/models/custom/ -A 2 -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 562


🏁 Script executed:

# Check if there are any initialization patterns or test files for MockEagle3
find tensorrt_llm -name "*test*eagle*" -o -name "*eagle*test*" 2>/dev/null | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look at the full Eagle3Config definition and related configs
rg -n "Eagle3Config\|PretrainedConfig" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check the imports to understand PretrainedConfig source
rg -n "from.*PretrainedConfig|import.*PretrainedConfig" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 136


🏁 Script executed:

# Look for how Eagle3Config instances are created in the codebase
rg -n "Eagle3Config\(\|MockEagle3Config\(" tensorrt_llm -A 3 -B 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there's any config initialization or kwargs handling
rg -n "model_kwargs\|dtype" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look at the broader context around config usage in the Eagle3ModelForCausalLM class
sed -n '300,330p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1284


🏁 Script executed:

# Search for any torch_dtype usage in the codebase as alternative
rg -n "torch_dtype" tensorrt_llm/_torch/auto_deploy/models/custom/ -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for where MockEagle3ModelForCausalLM is instantiated or referenced
rg -n "MockEagle3\|mock_eagle3" tensorrt_llm -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look for config creation patterns in auto_deploy
rg -n "from_pretrained\|__init__" tensorrt_llm/_torch/auto_deploy/ --type py | grep -i config | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there's any documentation or comments about dtype in the Eagle files
rg -n "dtype\|torch.float" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look for the Eagle3Model init to understand config expectations
sed -n '115,200p' tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3149


🏁 Script executed:

# Check the file to see if there's any initialization of config with dtype anywhere
head -n 60 tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2278


🏁 Script executed:

# Search for any from_pretrained or config creation patterns in the entire auto_deploy directory
rg -n "from_pretrained\|load_checkpoint\|config_dict" tensorrt_llm/_torch/auto_deploy/ --type py -A 2 | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look for any test or example files that might use MockEagle3
find tensorrt_llm -type f -name "*.py" -exec grep -l "MockEagle3\|mock_eagle3" {} \;

Repository: NVIDIA/TensorRT-LLM

Length of output: 126


🏁 Script executed:

# Check if config attributes are set anywhere else (looking for setattr or attribute assignment)
rg -n "config\.\w+\s*=" tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 213


🏁 Script executed:

# Look at HuggingFace PretrainedConfig documentation - check what attributes are standard
# Let's search for torch_dtype usage which is the standard HF attribute
rg -n "torch_dtype" tensorrt_llm/_torch/auto_deploy/ -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 2900


Fix config.dtype access which will cause AttributeError at runtime.

Line 418 accesses config.dtype, but Eagle3Config and MockEagle3Config do not define this attribute. HuggingFace's PretrainedConfig uses torch_dtype as the standard attribute, not dtype. This will raise AttributeError unless dtype is explicitly provided in model_kwargs.

Suggested fix
     def __init__(self, config):
         super().__init__(config)
         self._hidden_size = config.hidden_size
-        self._dtype = config.dtype
+        self._dtype = getattr(config, "torch_dtype", None) or torch.float16
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py` around lines
415 - 429, The constructor currently reads config.dtype which doesn't exist on
Eagle3Config; change reading of the dtype in __init__ to use config.torch_dtype
or a fallback (e.g. getattr(config, "torch_dtype", None) or torch.float32) and
store it to self._dtype, and ensure forward still uses self._dtype when creating
the mock target_hidden_states; also consider accepting an override from kwargs
(e.g. kwargs.get("torch_dtype")) so callers can supply dtype via model_kwargs if
needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In review

Development

Successfully merging this pull request may close these issues.

2 participants