Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: Event handling for fixed step solvers currently requires step_size to be provided in options. #502

Closed
sabinala opened this issue Feb 26, 2024 · 3 comments · May be fixed by #503
Assignees

Comments

@sabinala
Copy link
Contributor

I'm getting the error message below when I try:

def make_var_threshold(var: str, threshold: torch.Tensor):
    return lambda time, state: state[var] - threshold  
    
infection_threshold = make_var_threshold("I", torch.tensor(400.0))
dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}

result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         dynamic_parameter_interventions=dynamic_parameter_interventions1, 
                         solver_method="euler")

Where model3 = os.path.join(MODEL_PATH, "SIR_stockflow.json") and MODEL_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/"

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[6], line 10
      7 infection_threshold = make_var_threshold("I", torch.tensor(400.0))
      8 dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}
---> 10 result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
     11                          dynamic_parameter_interventions=dynamic_parameter_interventions1, 
     12                          solver_method="euler")
     13 display(result["data"].head())
     15 # Plot the result

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
     17 log_message = """
     18     ###############################
     19 
   (...)
     26     ################################
     27 """
     28 logging.exception(log_message, function.__name__, function.__doc__)
---> 29 raise e

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
      8 try:
      9     start_time = time.perf_counter()
---> 10     result = function(*args, **kwargs)
     11     end_time = time.perf_counter()
     12     logging.info(
     13         "Elapsed time for %s: %f", function.__name__, end_time - start_time
     14     )

File ~/Projects/pyciemss/pyciemss/interfaces.py:298, in sample(model_path_or_json, end_time, logging_step_size, num_samples, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, inferred_parameters, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions)
    294         compiled_noise_model(full_trajectory)
    296 parallel = False if len(intervention_handlers) > 0 else True
--> 298 samples = pyro.infer.Predictive(
    299     wrapped_model,
    300     guide=inferred_parameters,
    301     num_samples=num_samples,
    302     parallel=parallel,
    303 )()
    305 return prepare_interchange_dictionary(samples)

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs)
    263     return_sites = None if not return_sites else return_sites
    264     posterior_samples = _predictive(
    265         self.guide,
    266         posterior_samples,
   (...)
    271         model_kwargs=kwargs,
    272     )
--> 273 return _predictive(
    274     self.model,
    275     posterior_samples,
    276     self.num_samples,
    277     return_sites=return_sites,
    278     parallel=self.parallel,
    279     model_args=args,
    280     model_kwargs=kwargs,
    281 )

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:78, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
     67 def _predictive(
     68     model,
     69     posterior_samples,
   (...)
     75     model_kwargs={},
     76 ):
     77     model = torch.no_grad()(poutine.mask(model, mask=False))
---> 78     max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
     79     vectorize = pyro.plate(
     80         "_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1
     81     )
     82     model_trace = prune_subsample_sites(
     83         poutine.trace(model).get_trace(*model_args, **model_kwargs)
     84     )

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:21, in _guess_max_plate_nesting(model, args, kwargs)
     15 """
     16 Guesses max_plate_nesting by running the model once
     17 without enumeration. This optimistically assumes static model
     18 structure.
     19 """
     20 with poutine.block():
---> 21     model_trace = poutine.trace(model).get_trace(*args, **kwargs)
     22 sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"]
     24 dims = [
     25     frame.dim
     26     for site in sites
     27     for frame in site["cond_indep_stack"]
     28     if frame.vectorized
     29 ]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
    190 def get_trace(self, *args, **kwargs):
    191     """
    192     :returns: data structure
    193     :rtype: pyro.poutine.Trace
   (...)
    196     Calls this poutine and returns its trace instead of the function's return value.
    197     """
--> 198     self(*args, **kwargs)
    199     return self.msngr.get_trace()

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
    170 self.msngr.trace.add_node(
    171     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
    172 )
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:
    176     exc_type, exc_value, traceback = sys.exc_info()

File ~/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)

File ~/Projects/pyciemss/pyciemss/interfaces.py:282, in sample.<locals>.wrapped_model()
    280         for handler in intervention_handlers:
    281             stack.enter_context(handler)
--> 282         full_trajectory = model(
    283             torch.as_tensor(start_time),
    284             torch.as_tensor(end_time),
    285             logging_times=logging_times,
    286             is_traced=True,
    287         )
    289 if noise_model is not None:
    290     compiled_noise_model = compile_noise_model(
    291         noise_model, vars=set(full_trajectory.keys()), **noise_model_kwargs
    292     )

File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
    447 def __call__(self, *args, **kwargs):
    448     with self._pyro_context:
--> 449         result = super().__call__(*args, **kwargs)
    450     if (
    451         pyro.settings.get("validate_poutine")
    452         and not self._pyro_context.active
    453         and _is_module_local_param_enabled()
    454     ):
    455         self._check_module_local_param_usage()

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Projects/pyciemss/pyciemss/compiled_dynamics.py:77, in CompiledDynamics.forward(self, start_time, end_time, logging_times, is_traced)
     75 if logging_times is not None:
     76     with LogTrajectory(logging_times) as lt:
---> 77         simulate(self.deriv, self.initial_state(), start_time, end_time)
     78         state = lt.trajectory
     79 else:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.<locals>._fn(*args, **kwargs)
    264 msg = {
    265     "type": type,
    266     "name": name,
   (...)
    278     "infer": infer,
    279 }
    280 # apply the stack and return its return value
--> 281 apply_stack(msg)
    282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg)
    209 for frame in reversed(stack):
    210     pointer = pointer + 1
--> 212     frame._process_message(msg)
    214     if msg["stop"]:
    215         break

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg)
    160 method = getattr(self, "_pyro_{}".format(msg["type"]), None)
    161 if method is not None:
--> 162     return method(msg)
    163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py:109, in Solver._pyro_simulate(self, msg)
    106     if ph.priority > start_time:
    107         break
--> 109 state, start_time, next_interruption = simulate_to_interruption(
    110     possible_interruptions,
    111     dynamics,
    112     state,
    113     start_time,
    114     end_time,
    115     **msg["kwargs"],
    116 )
    118 if next_interruption is not None:
    119     dynamics, state = next_interruption.callback(dynamics, state)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.<locals>._fn(*args, **kwargs)
    264 msg = {
    265     "type": type,
    266     "name": name,
   (...)
    278     "infer": infer,
    279 }
    280 # apply the stack and return its return value
--> 281 apply_stack(msg)
    282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg)
    209 for frame in reversed(stack):
    210     pointer = pointer + 1
--> 212     frame._process_message(msg)
    214     if msg["stop"]:
    215         break

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg)
    160 method = getattr(self, "_pyro_{}".format(msg["type"]), None)
    161 if method is not None:
--> 162     return method(msg)
    163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py:89, in TorchDiffEq._pyro_simulate_to_interruption(self, msg)
     87 interruptions, dynamics, initial_state, start_time, end_time = msg["args"]
     88 msg["kwargs"].update(self.odeint_kwargs)
---> 89 msg["value"] = torchdiffeq_simulate_to_interruption(
     90     interruptions,
     91     dynamics,
     92     initial_state,
     93     start_time,
     94     end_time,
     95     **msg["kwargs"],
     96 )
     97 msg["done"] = True

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:244, in torchdiffeq_simulate_to_interruption(interruptions, dynamics, initial_state, start_time, end_time, **kwargs)
    234 def torchdiffeq_simulate_to_interruption(
    235     interruptions: List[Interruption[torch.Tensor]],
    236     dynamics: Dynamics[torch.Tensor],
   (...)
    240     **kwargs,
    241 ) -> Tuple[State[torch.Tensor], torch.Tensor, Optional[Interruption[torch.Tensor]]]:
    242     assert len(interruptions) > 0, "should have at least one interruption here"
--> 244     (next_interruption,), interruption_time = _torchdiffeq_get_next_interruptions(
    245         dynamics, initial_state, start_time, interruptions, **kwargs
    246     )
    248     value = simulate_point(
    249         dynamics, initial_state, start_time, interruption_time, **kwargs
    250     )
    251     return value, interruption_time, next_interruption

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:192, in _torchdiffeq_get_next_interruptions(dynamics, start_state, start_time, interruptions, **kwargs)
    189 combined_event_f = torchdiffeq_combined_event_f(interruptions, var_order)
    191 # Simulate to the event execution.
--> 192 event_time, event_solutions = _batched_odeint(  # torchdiffeq.odeint_event(
    193     functools.partial(_deriv, dynamics, var_order),
    194     tuple(start_state[v] for v in var_order),
    195     start_time,
    196     event_fn=combined_event_f,
    197     **kwargs,
    198 )
    200 # event_state has both the first and final state of the interrupted simulation. We just want the last.
    201 event_solution: Tuple[torch.Tensor, ...] = tuple(
    202     s[..., -1] for s in event_solutions
    203 )  # TODO support event_dim > 0

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:119, in _batched_odeint(func, y0, t, event_fn, **odeint_kwargs)
    112 y0_expanded = tuple(
    113     # y0_[(None,) * (len(y0_batch_shape) - (len(y0_.shape) - event_dim)) + (...,)]
    114     y0_.expand(y0_batch_shape + y0_.shape[len(y0_.shape) - event_dim :])
    115     for y0_ in y0
    116 )
    118 if event_fn is not None:
--> 119     event_t, yt_raw = torchdiffeq.odeint_event(
    120         func, y0_expanded, t, event_fn=event_fn, **odeint_kwargs
    121     )
    122 else:
    123     yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py:101, in odeint_event(func, y0, t0, event_fn, reverse_time, odeint_interface, **kwargs)
     98 else:
     99     t = torch.cat([t0.reshape(-1), t0.reshape(-1).detach() + 1.0])
--> 101 event_t, solution = odeint_interface(func, y0, t, event_fn=event_fn, **kwargs)
    103 # Dummy values for rtol, atol, method, and options.
    104 shapes, _func, _, t, _, _, _, _, event_fn, _ = _check_inputs(func, y0, t, 0.0, 0.0, None, None, event_fn, SOLVERS)

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py:79, in odeint(func, y0, t, rtol, atol, method, options, event_fn)
     77     solution = solver.integrate(t)
     78 else:
---> 79     event_t, solution = solver.integrate_until_event(t[0], event_fn)
     80     event_t = event_t.to(t)
     81     if t_is_reversed:

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py:122, in FixedGridODESolver.integrate_until_event(self, t0, event_fn)
    121 def integrate_until_event(self, t0, event_fn):
--> 122     assert self.step_size is not None, "Event handling for fixed step solvers currently requires `step_size` to be provided in options."
    124     t0 = t0.type_as(self.y0)
    125     y0 = self.y0

AssertionError: Event handling for fixed step solvers currently requires `step_size` to be provided in options.
@sabinala
Copy link
Contributor Author

sabinala commented Feb 26, 2024

@djinnome
Copy link
Contributor

This fixes the problem:

def make_var_threshold(var: str, threshold: torch.Tensor):
    return lambda time, state: state[var] - threshold  
    
infection_threshold = make_var_threshold("I", torch.tensor(40.0))
dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}
# Specify solver options including the step_size
solver_options = {"step_size": 1e-2}  # Example step size, adjust as needed

result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         dynamic_parameter_interventions=dynamic_parameter_interventions1, 
                         solver_method="euler",
                         solver_options=solver_options)

However, you do not need to specify step_size with dopri5 and it is much faster and more stable.

def make_var_threshold(var: str, threshold: torch.Tensor):
    return lambda time, state: state[var] - threshold  
    
infection_threshold = make_var_threshold("I", torch.tensor(400.0))
dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}
# Specify solver options including the step_size
solver_options = {"step_size": 1e-2}  # Example step size, adjust as needed

result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         dynamic_parameter_interventions=dynamic_parameter_interventions1, 
                         solver_method="dopri5")
                         #solver_options=solver_options)

@sabinala
Copy link
Contributor Author

I believe this issue is now resolved and can be closed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants