Skip to content

Conversation

@yhtang
Copy link
Contributor

@yhtang yhtang commented Dec 5, 2025

Consolidates the definition for the default tensor dtype in refitting specs, via the dtype="bfloat16" keyword argument of make_mapping(...) in models/__init__.py. Since all current refitting specs are defined via make_mapping, this gives us a single source of truth for the default tensor dtype. This change should not introduce any visiblel functional changes for existing models.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the default dtype handling by moving the default value from Python code to the protobuf definition. Instead of using fallback logic (param.vllm_param.dtype or 'bfloat16') in the Python code, the default is now specified directly in the proto file, simplifying the code and making the default more explicit.

Key changes:

  • Added default value 'bfloat16' to the dtype field in the VllmParam message definition
  • Removed all fallback or 'bfloat16' logic from the Python code in four locations

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
jax-inference-offloading/jax_inference_offloading/api/param_mapping.proto Added default value for the dtype field in VllmParam message
jax-inference-offloading/jax_inference_offloading/vllm/extension.py Removed fallback logic for dtype in update_weights and update_weights_grouped methods

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

yhtang and others added 2 commits December 5, 2025 00:44
@jreiffers
Copy link
Member

Why?

@yhtang
Copy link
Contributor Author

yhtang commented Dec 5, 2025

Why?

Would it be better to have a single source of truth for model's default dtype?

@jreiffers
Copy link
Member

Default values in protos are a bit of an anti pattern (they even removed the feature completely in proto3). Once you put them in, you can never remove or change them again. I think they're acceptable when there's an obviously meaningful default, but I don't think that's the case here. I'd keep it in the application logic.

@yhtang yhtang changed the title [jax-inference-offloading] Move default dtype to proto [jax-inference-offloading] consolidate definitions for default tensor dtype Dec 10, 2025
@yhtang
Copy link
Contributor Author

yhtang commented Dec 10, 2025

Default values in protos are a bit of an anti pattern (they even removed the feature completely in proto3). Once you put them in, you can never remove or change them again. I think they're acceptable when there's an obviously meaningful default, but I don't think that's the case here. I'd keep it in the application logic.

That makes sense. I have removed the default value for dtype from the proto file. Since all refitting specs are currently defined using make_mapping, would it be sufficient to rely on the default value there instead? Assuming it is OK to do so, I've updated the PR title and descriptiona accordingly.


def make_mapping(
jax_name, vllm_name, vllm_shape, *, transform=None, jax_prefix="model", vllm_prefix="model"
jax_name, vllm_name, vllm_shape, *, transform=None, jax_prefix="model", vllm_prefix="model", dtype="bfloat16"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/dtype/vllm_dtype/?

Copy link
Contributor Author

@yhtang yhtang Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment we don’t support any dtype conversion between the JAX and vLLM sides, so only vllm_param carries a dtype field, and the dtypes are expected to match between JAX and vLLM. Once we add conversion support, it may even make sense to stop specifying dtype in make_mapping altogether and instead rely on the handshake to discover the dtype at runtime.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants