-
Notifications
You must be signed in to change notification settings - Fork 223
Open
Description
Problem
I'm seeing a serialization error when saving a Meridian model with custom priors using:
import numpy as np
import tensorflow_probability as tfp
from schema.serde.meridian_serde import save_meridian
from meridian.model import prior_distribution, spec, model
prior_roi_m = prior_distribution.lognormal_dist_from_mean_std(
np.asarray([...], dtype=np.float32),
np.asarray([...], dtype=np.float32),
)
model_spec = spec.ModelSpec(
prior=prior_distribution.PriorDistribution(roi_m=prior_roi_m)
)
meridian = model.Meridian(input_data=..., model_spec=model_spec)
meridian.sample_prior(1000)
save_meridian(meridian, "./model.binpb")Calling save_meridian() on a trained Meridian model raises a TypeError originating from DistributionSerde._to_parameter_value_proto().
The serializer encounters a numpy.ndarray inside a distribution's parameters and fails with TypeError: Unsupported type: <class 'numpy.ndarray'>, [...]
Environment
- Meridian version: 1.4.0
- Python: 3.13
Traceback
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[30], line 1
----> 1 save_meridian(meridian, "./model.binpb")
File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/meridian_serde.py:357, in save_meridian(mmm, file_path, distribution_function_registry, eda_function_registry)
353 _make_dirs(os.path.dirname(file_path))
355 with _file_open(file_path, 'wb') as f:
356 # Creates an MmmKernel.
--> 357 serialized_kernel = MeridianSerde().serialize(
358 mmm,
359 distribution_function_registry=distribution_function_registry,
360 eda_function_registry=eda_function_registry,
361 )
362 if file_path.endswith('.binpb'):
363 f.write(serialized_kernel.SerializeToString())
File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/meridian_serde.py:114, in MeridianSerde.serialize(self, obj, model_id, meridian_version, include_convergence_info, distribution_function_registry, eda_function_registry)
103 distribution_registry = (
104 distribution_function_registry
105 if distribution_function_registry is not None
106 else function_registry_utils.FunctionRegistry()
107 )
109 eda_function_registry = (
110 eda_function_registry
111 if eda_function_registry is not None
112 else function_registry_utils.FunctionRegistry()
113 )
--> 114 meridian_model_proto = self._make_meridian_model_proto(
115 mmm=obj,
116 model_id=model_id,
117 meridian_version=meridian_version,
118 distribution_function_registry=distribution_registry,
119 eda_function_registry=eda_function_registry,
120 include_convergence_info=include_convergence_info,
121 )
122 any_model = any_pb2.Any()
123 any_model.Pack(meridian_model_proto)
File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/meridian_serde.py:164, in MeridianSerde._make_meridian_model_proto(self, mmm, model_id, meridian_version, distribution_function_registry, eda_function_registry, include_convergence_info)
131 def _make_meridian_model_proto(
132 self,
133 mmm: model.Meridian,
(...) 138 include_convergence_info: bool = False,
139 ) -> meridian_pb.MeridianModel:
140 """Constructs a MeridianModel proto from the TrainedModel.
141
142 Args:
(...) 154 A MeridianModel proto.
155 """
156 model_proto = meridian_pb.MeridianModel(
157 model_id=model_id,
158 model_version=str(meridian_version),
159 hyperparameters=hyperparameters.HyperparametersSerde().serialize(
160 mmm.model_spec
161 ),
162 prior_tfp_distributions=distribution.DistributionSerde(
163 distribution_function_registry
--> 164 ).serialize(mmm.model_spec.prior),
165 inference_data=inference_data.InferenceDataSerde().serialize(
166 mmm.inference_data
167 ),
168 )
169 # For backwards compatibility, only serialize EDA spec if it exists.
170 if hasattr(mmm, 'eda_spec'):
File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/distribution.py:74, in DistributionSerde.serialize(self, obj)
71 if not hasattr(obj, param):
72 continue
73 getattr(proto, param).CopyFrom(
---> 74 self._to_distribution_proto(getattr(obj, param))
75 )
76 proto.function_registry.update(self.function_registry.hashed_registry)
77 return proto
File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/distribution.py:163, in DistributionSerde._to_distribution_proto(self, dist)
158 dist_name = type(dist).__name__
159 dist_class = getattr(backend.tfd, dist_name)
160 return meridian_pb.TfpDistribution(
161 distribution_type=dist_name,
162 parameters={
--> 163 name: self._to_parameter_value_proto(name, value, dist_class)
164 for name, value in dist.parameters.items()
165 },
166 )
File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/distribution.py:262, in DistributionSerde._to_parameter_value_proto(self, param_name, value, dist)
255 raise ValueError(
256 f"Custom function `{param_name}` detected for"
257 f" {type(dist).__name__}, but not found in registry. Please"
258 " add custom functions to registry when saving models."
259 )
261 # Handle unsupported types.
--> 262 raise TypeError(f"Unsupported type: {type(value)}, {value}")
TypeError: Unsupported type: <class 'numpy.ndarray'>, [ 0.22160271 0.22160271 0.22160271 0.22160271 0.22160271 0.22160271
-0.00124847 0.22160271 0.22160271 0.22160271 0.22160271 0.22160271
0.22160271 0.3527848 0.22160271 0.22160271 0.22160271 0.5815754
0.22160271]Metadata
Metadata
Assignees
Labels
No labels