-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): export call_lower to SavedModel via jax2tf #4254
base: devel
Are you sure you want to change the base?
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 WalkthroughWalkthroughThis pull request introduces modifications across three files. In Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (3)
deepmd/backend/jax.py (1)
Line range hint
27-42
: Document supported formats in class docstring.Consider enhancing the class documentation to explicitly mention the supported file formats. This would help users understand which formats they can use with the JAX backend.
class JAXBackend(Backend): - """JAX backend.""" + """JAX backend. + + Supports the following model formats: + - .jax: Native JAX format + - .savedmodel: TensorFlow SavedModel format + """deepmd/jax/utils/serialization.py (2)
58-64
: Clarify the usage ofpolymorphic_shapes
injax2tf.convert
The
polymorphic_shapes
parameter uses variables likenf
,nloc
, andnghost
. If these are placeholders for dynamic dimensions, consider providing documentation or comments explaining their meanings. This will enhance code readability and maintainability for future developers.
Line range hint
88-89
: Updateserialize_from_file
to handle.savedmodel
filesCurrently,
serialize_from_file
only supports.jax
files and raises aValueError
for other formats. Sincedeserialize_to_file
now handles.savedmodel
files, consider adding support for serializing from.savedmodel
files for consistency. Alternatively, update the error message to reflect the supported formats.Example of updating the error message:
else: - raise ValueError("JAX backend only supports converting .jax directory") + raise ValueError("JAX backend only supports converting .jax and .savedmodel formats")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
- deepmd/backend/jax.py (1 hunks)
- deepmd/dpmodel/descriptor/se_e2_a.py (1 hunks)
- deepmd/jax/utils/serialization.py (1 hunks)
🔇 Additional comments (6)
deepmd/backend/jax.py (1)
41-41
: Verify SavedModel serialization implementation.The addition of ".savedmodel" suffix aligns with the PR objective. Let's verify that the serialization hooks properly support this format.
✅ Verification successful
SavedModel serialization implementation verified
The JAX backend properly implements SavedModel format support:
- The backend correctly registers
.savedmodel
suffix inJAXBackend
class- The serialization implementation in
deepmd/jax/utils/serialization.py
includes:
- Converting JAX model to TensorFlow using
jax2tf.convert
- Saving as TensorFlow SavedModel format using
tf.saved_model.save
- Proper shape specifications for the converted model
- Support for custom gradients
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that serialization hooks handle SavedModel format # Test: Check if deserialize_to_file has SavedModel handling rg -A 10 "deserialize_to_file.*savedmodel" deepmd/jax/utils/serialization.py # Test: Check if serialize_from_file has SavedModel handling rg -A 10 "serialize_from_file.*savedmodel" deepmd/jax/utils/serialization.pyLength of output: 157
Script:
#!/bin/bash # Let's check for any SavedModel related code in serialization files rg -l "SavedModel|savedmodel" --type py # Check implementation details in serialization-related files rg -A 10 "serialize|deserialize" --type py # Look for any JAX to SavedModel conversion functions ast-grep --pattern 'def $_($$) { $$$ savedmodel $$$ }'Length of output: 456711
Script:
#!/bin/bash # Let's check for SavedModel related code in JAX backend rg -A 10 "SavedModel|savedmodel" deepmd/jax/utils/serialization.py deepmd/backend/jax.py # Check if there are any JAX to TF conversion functions rg -A 10 "jax2tf" deepmd/jax/utils/serialization.py deepmd/backend/jax.py # Look for any model export/save functions rg -A 10 "save.*model|export.*model" deepmd/jax/utils/serialization.py deepmd/backend/jax.pyLength of output: 5000
deepmd/dpmodel/descriptor/se_e2_a.py (2)
558-558
: LGTM: Simplified array assignmentThe removal of unnecessary
xp.asarray
conversion is appropriate sinceself.sel_cumsum
is already suitable for indexing operations.
Line range hint
580-599
: Performance optimization: Replaced einsum with broadcastingThe replacement of
einsum
operations with equivalentsum
and broadcasting operations is a good optimization. The changes maintain mathematical equivalence while potentially offering better performance on some hardware accelerators.Original:
gr_tmp = xp.einsum("lni,lnj->lij", gg, tr) grrg = xp.einsum("flid,fljd->flij", gr, gr1)New:
gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)Let's verify the numerical equivalence of these operations:
✅ Verification successful
Optimization change is mathematically equivalent and safe to use
The replacement of einsum with broadcasting and sum operations is mathematically sound for the following reasons:
The original einsum operation
xp.einsum("lni,lnj->lij", gg, tr)
is equivalent to the new broadcasting operationxp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1)
as they both perform the same matrix multiplication along the specified axes.Similarly,
xp.einsum("flid,fljd->flij", gr, gr1)
is equivalent toxp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
, maintaining the same contraction pattern.The codebase shows consistent use of both einsum and broadcasting+sum patterns across different implementations (TensorFlow, PyTorch, and NumPy), indicating this is a well-understood transformation.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for any tests that verify the descriptor's output rg -l "test.*descriptor.*se.*a" --type pyLength of output: 85
Script:
#!/bin/bash # Let's examine the test file content to see if it verifies numerical equivalence cat source/tests/tf/test_descrpt_se_a_mask.pyLength of output: 8099
Script:
#!/bin/bash # Let's look for any other tests that might verify the descriptor output rg -l "test.*descriptor.*se.*[aA]" --type py # Also search for any numerical tests specifically related to einsum operations rg "einsum|sum.*broadcast" --type pyLength of output: 3078
deepmd/jax/utils/serialization.py (3)
50-51
: Verify thatcall_lower
is a valid attribute ofmodel
The code assigns
call_lower = model.call_lower
. Please ensure that all models being deserialized have thecall_lower
attribute. If some models might not have this attribute, consider adding a check or handling potentialAttributeError
exceptions to prevent runtime errors.
68-73
: Ensure correct tensor shapes intf.TensorSpec
The
tf.TensorSpec
definitions include dynamic dimensions (None
) and a call tomodel.get_nnei()
. Verify thatmodel.get_nnei()
returns an integer and that the tensor shapes align with the expected input dimensions. Misalignment can lead to runtime errors when the SavedModel is used.
76-79
: Review the necessity ofexperimental_custom_gradients=True
The option
experimental_custom_gradients=True
is used intf.saved_model.SaveOptions
. Confirm that custom gradients are required for your use case. If not, removing this option could simplify the code and avoid potential compatibility issues with future TensorFlow versions.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4254 +/- ##
==========================================
- Coverage 84.22% 84.21% -0.01%
==========================================
Files 548 548
Lines 51426 51435 +9
Branches 3051 3051
==========================================
+ Hits 43314 43317 +3
- Misses 7151 7160 +9
+ Partials 961 958 -3 ☔ View full report in Codecov by Sentry. |
Summary by CodeRabbit
New Features
Bug Fixes
sec
variable for improved clarity and performance in calculations.Refactor
DescrptSeAArrayAPI
class by replacing tensor contraction with broadcasting.