Skip to content

Commit

Permalink
Precompile fn in ZarrChain
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Jan 9, 2025
1 parent 4eebae1 commit d2c6c42
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
17 changes: 12 additions & 5 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping, MutableMapping, Sequence
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any

import arviz as az
Expand Down Expand Up @@ -91,10 +91,11 @@ def __init__(
vars: Sequence[TensorVariable] | None = None,
test_point: dict[str, np.ndarray] | None = None,
draws_per_chunk: int = 1,
fn: Callable | None = None,
):
if not _zarr_available:
raise RuntimeError("You must install zarr to be able to create ZarrChain instances")
super().__init__(name="zarr", model=model, vars=vars, test_point=test_point)
super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn)
self._step_method: BlockedStep | CompoundStep | None = None
self.unconstrained_variables = {
var.name for var in self.vars if is_transformed_name(var.name)
Expand Down Expand Up @@ -168,7 +169,7 @@ def record(
:meth:`~ZarrChain.flush`
"""
unconstrained_variables = self.unconstrained_variables
for var_name, var_value in zip(self.varnames, self.fn(draw)):
for var_name, var_value in zip(self.varnames, self.fn(**draw)):
if var_name in unconstrained_variables:
self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value)
else:
Expand Down Expand Up @@ -452,13 +453,18 @@ def init_trace(
)
self.vars = [var for var in vars if var.name in self.varnames]

self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore")
self.fn = model.compile_fn(
self.vars,
inputs=model.value_vars,
on_unused_input="ignore",
point_fn=False,
)

# Get variable shapes. Most backends will need this
# information.
if test_point is None:
test_point = model.initial_point()
var_values = list(zip(self.varnames, self.fn(test_point)))
var_values = list(zip(self.varnames, self.fn(**test_point)))
self.var_dtype_shapes = {
var: (value.dtype, value.shape)
for var, value in var_values
Expand Down Expand Up @@ -528,6 +534,7 @@ def init_trace(
test_point=test_point,
stats_bijection=StatsBijection(step.stats_dtypes),
draws_per_chunk=self.draws_per_chunk,
fn=self.fn,
)
for _ in range(chains)
]
Expand Down
6 changes: 4 additions & 2 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def __init__(
zarr_chains: list[ZarrChain] | bytes | None = None,
zarr_chains_is_pickled: bool = False,
):
# For some strange reason, spawn multiprocessing doesn't copy the rng
# seed sequence, so we have to rebuild it from scratch
# Because of https://github.com/numpy/numpy/issues/27727, we can't send
# the rng instance to the child process because pickling (copying) looses
# the seed sequence state information. For this reason, we send a
# RandomGeneratorState instead.
rng = random_generator_from_state(rng_state)
self._msg_pipe = msg_pipe
self._step_method = step_method
Expand Down

0 comments on commit d2c6c42

Please sign in to comment.