diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index bb2fba5a7c..3b2b0fcc56 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -38,7 +38,7 @@ class JAXBackend(Backend): # | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".jax"] + suffixes: ClassVar[list[str]] = [".jax", ".savedmodel"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index d29ce8862e..ff8e401063 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -555,7 +555,7 @@ def call( coord_ext, atype_ext, nlist, self.davg, self.dstd ) nf, nloc, nnei, _ = rr.shape - sec = xp.asarray(self.sel_cumsum) + sec = self.sel_cumsum ng = self.neuron[-1] gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 43070f8a07..1a086326ce 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -39,6 +39,45 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_def_script=ocp.args.JsonSave(model_def_script), ), ) + elif model_file.endswith(".savedmodel"): + import tensorflow as tf + from jax.experimental import ( + jax2tf, + ) + + model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] + call_lower = model.call_lower + + my_model = tf.Module() + + # Save a function that can take scalar inputs. + my_model.call_lower = tf.function( + jax2tf.convert( + call_lower, + polymorphic_shapes=[ + "(nf, nloc + nghost, 3)", + "(nf, nloc + nghost)", + f"(nf, nloc, {model.get_nnei()})", + "(nf, np)", + "(nf, na)", + ], + ), + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None], tf.float64), + tf.TensorSpec([None, None], tf.float64), + ], + ) + my_model.model_def_script = model_def_script + tf.saved_model.save( + my_model, + model_file, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), + ) else: raise ValueError("JAX backend only supports converting .jax directory")