Skip to content

Bug: TypeError when calling save_meridian (Meridian 1.4.0) #1364

@rickyhugo

Description

@rickyhugo

Problem

I'm seeing a serialization error when saving a Meridian model with custom priors using:

import numpy as np
import tensorflow_probability as tfp

from schema.serde.meridian_serde import save_meridian
from meridian.model import prior_distribution, spec, model

prior_roi_m = prior_distribution.lognormal_dist_from_mean_std(
    np.asarray([...], dtype=np.float32),
    np.asarray([...], dtype=np.float32),
)

model_spec = spec.ModelSpec(
    prior=prior_distribution.PriorDistribution(roi_m=prior_roi_m)
)

meridian = model.Meridian(input_data=..., model_spec=model_spec)
meridian.sample_prior(1000)

save_meridian(meridian, "./model.binpb")

Calling save_meridian() on a trained Meridian model raises a TypeError originating from DistributionSerde._to_parameter_value_proto().

The serializer encounters a numpy.ndarray inside a distribution's parameters and fails with TypeError: Unsupported type: <class 'numpy.ndarray'>, [...]

Environment

  • Meridian version: 1.4.0
  • Python: 3.13

Traceback

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[30], line 1
----> 1 save_meridian(meridian, "./model.binpb")

File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/meridian_serde.py:357, in save_meridian(mmm, file_path, distribution_function_registry, eda_function_registry)
    353   _make_dirs(os.path.dirname(file_path))
    355 with _file_open(file_path, 'wb') as f:
    356   # Creates an MmmKernel.
--> 357   serialized_kernel = MeridianSerde().serialize(
    358       mmm,
    359       distribution_function_registry=distribution_function_registry,
    360       eda_function_registry=eda_function_registry,
    361   )
    362   if file_path.endswith('.binpb'):
    363     f.write(serialized_kernel.SerializeToString())

File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/meridian_serde.py:114, in MeridianSerde.serialize(self, obj, model_id, meridian_version, include_convergence_info, distribution_function_registry, eda_function_registry)
    103 distribution_registry = (
    104     distribution_function_registry
    105     if distribution_function_registry is not None
    106     else function_registry_utils.FunctionRegistry()
    107 )
    109 eda_function_registry = (
    110     eda_function_registry
    111     if eda_function_registry is not None
    112     else function_registry_utils.FunctionRegistry()
    113 )
--> 114 meridian_model_proto = self._make_meridian_model_proto(
    115     mmm=obj,
    116     model_id=model_id,
    117     meridian_version=meridian_version,
    118     distribution_function_registry=distribution_registry,
    119     eda_function_registry=eda_function_registry,
    120     include_convergence_info=include_convergence_info,
    121 )
    122 any_model = any_pb2.Any()
    123 any_model.Pack(meridian_model_proto)

File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/meridian_serde.py:164, in MeridianSerde._make_meridian_model_proto(self, mmm, model_id, meridian_version, distribution_function_registry, eda_function_registry, include_convergence_info)
    131 def _make_meridian_model_proto(
    132     self,
    133     mmm: model.Meridian,
   (...)    138     include_convergence_info: bool = False,
    139 ) -> meridian_pb.MeridianModel:
    140   """Constructs a MeridianModel proto from the TrainedModel.
    141 
    142   Args:
   (...)    154     A MeridianModel proto.
    155   """
    156   model_proto = meridian_pb.MeridianModel(
    157       model_id=model_id,
    158       model_version=str(meridian_version),
    159       hyperparameters=hyperparameters.HyperparametersSerde().serialize(
    160           mmm.model_spec
    161       ),
    162       prior_tfp_distributions=distribution.DistributionSerde(
    163           distribution_function_registry
--> 164       ).serialize(mmm.model_spec.prior),
    165       inference_data=inference_data.InferenceDataSerde().serialize(
    166           mmm.inference_data
    167       ),
    168   )
    169   # For backwards compatibility, only serialize EDA spec if it exists.
    170   if hasattr(mmm, 'eda_spec'):

File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/distribution.py:74, in DistributionSerde.serialize(self, obj)
     71   if not hasattr(obj, param):
     72     continue
     73   getattr(proto, param).CopyFrom(
---> 74       self._to_distribution_proto(getattr(obj, param))
     75   )
     76 proto.function_registry.update(self.function_registry.hashed_registry)
     77 return proto

File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/distribution.py:163, in DistributionSerde._to_distribution_proto(self, dist)
    158 dist_name = type(dist).__name__
    159 dist_class = getattr(backend.tfd, dist_name)
    160 return meridian_pb.TfpDistribution(
    161     distribution_type=dist_name,
    162     parameters={
--> 163         name: self._to_parameter_value_proto(name, value, dist_class)
    164         for name, value in dist.parameters.items()
    165     },
    166 )

File /opt/conda/envs/mmm-se/lib/python3.13/site-packages/schema/serde/distribution.py:262, in DistributionSerde._to_parameter_value_proto(self, param_name, value, dist)
    255     raise ValueError(
    256         f"Custom function `{param_name}` detected for"
    257         f" {type(dist).__name__}, but not found in registry. Please"
    258         " add custom functions to registry when saving models."
    259     )
    261 # Handle unsupported types.
--> 262 raise TypeError(f"Unsupported type: {type(value)}, {value}")

TypeError: Unsupported type: <class 'numpy.ndarray'>, [ 0.22160271  0.22160271  0.22160271  0.22160271  0.22160271  0.22160271
 -0.00124847  0.22160271  0.22160271  0.22160271  0.22160271  0.22160271
  0.22160271  0.3527848   0.22160271  0.22160271  0.22160271  0.5815754
  0.22160271]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions