-
Notifications
You must be signed in to change notification settings - Fork 68
[jax-inference-offloading] consolidate definitions for default tensor dtype #1816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the default dtype handling by moving the default value from Python code to the protobuf definition. Instead of using fallback logic (param.vllm_param.dtype or 'bfloat16') in the Python code, the default is now specified directly in the proto file, simplifying the code and making the default more explicit.
Key changes:
- Added default value
'bfloat16'to thedtypefield in theVllmParammessage definition - Removed all fallback
or 'bfloat16'logic from the Python code in four locations
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto | Added default value for the dtype field in VllmParam message |
| jax-inference-offloading/jax_inference_offloading/vllm/extension.py | Removed fallback logic for dtype in update_weights and update_weights_grouped methods |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto
Outdated
Show resolved
Hide resolved
Co-authored-by: Copilot <[email protected]>
|
Why? |
Would it be better to have a single source of truth for model's default dtype? |
|
Default values in protos are a bit of an anti pattern (they even removed the feature completely in proto3). Once you put them in, you can never remove or change them again. I think they're acceptable when there's an obviously meaningful default, but I don't think that's the case here. I'd keep it in the application logic. |
That makes sense. I have removed the default value for |
|
|
||
| def make_mapping( | ||
| jax_name, vllm_name, vllm_shape, *, transform=None, jax_prefix="model", vllm_prefix="model" | ||
| jax_name, vllm_name, vllm_shape, *, transform=None, jax_prefix="model", vllm_prefix="model", dtype="bfloat16" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/dtype/vllm_dtype/?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment we don’t support any dtype conversion between the JAX and vLLM sides, so only vllm_param carries a dtype field, and the dtypes are expected to match between JAX and vLLM. Once we add conversion support, it may even make sense to stop specifying dtype in make_mapping altogether and instead rely on the handshake to discover the dtype at runtime.
Consolidates the definition for the default tensor dtype in refitting specs, via the
dtype="bfloat16"keyword argument ofmake_mapping(...)inmodels/__init__.py. Since all current refitting specs are defined viamake_mapping, this gives us a single source of truth for the default tensor dtype. This change should not introduce any visiblel functional changes for existing models.