Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting a RuntimeError when trying to sample a pde as petrinet #487

Closed
sabinala opened this issue Feb 21, 2024 · 2 comments · Fixed by #491
Closed

Getting a RuntimeError when trying to sample a pde as petrinet #487

sabinala opened this issue Feb 21, 2024 · 2 comments · Fixed by #491
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@sabinala
Copy link
Contributor

Getting the following issue when trying result = pyciemss.sample(MODEL, end_time, logging_step_size, num_samples, start_time=start_time) where MODEL is an advection equation with backward derivative from here. See this notebook related to PR #460. See also this PR in DARPA-ASKEM/Model-Representations.

RuntimeError: The size of tensor a (9) must match the size of tensor b (3) at non-singleton dimension 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/altu809/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
    result = function(*args, **kwargs)
  File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 298, in sample
    samples = pyro.infer.Predictive(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward
    return _predictive(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 137, in _predictive
    trace = poutine.trace(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 282, in wrapped_model
    full_trajectory = model(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__
    result = super().__call__(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/altu809/Projects/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward
    simulate(self.deriv, self.initial_state(), start_time, end_time)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate
    state, start_time, next_interruption = simulate_to_interruption(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption
    msg["value"] = torchdiffeq_simulate_to_interruption(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption
    value = simulate_point(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point
    trajectory: State[T] = simulate_trajectory(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory
    msg["value"] = torchdiffeq_simulate_trajectory(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory
    return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner
    solns = _batched_odeint(  # torchdiffeq.odeint(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint
    yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
    solution = solver.integrate(t)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 28, in integrate
    self._before_integrate(t)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/rk_common.py", line 163, in _before_integrate
    first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol,
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 54, in _select_initial_step
    d1 = norm(f0 / scale)
RuntimeError: The size of tensor a (9) must match the size of tensor b (3) at non-singleton dimension 0
                                       Trace Shapes:    
                                        Param Sites:    
        numeric_initial_state_func$$$_nodes.0._value    
        numeric_initial_state_func$$$_nodes.1._value    
        numeric_initial_state_func$$$_nodes.2._value    
numeric_deriv_func$$$_nodes.0._args.0._args.0._value    
                                       Sample Sites:    
                                  persistent_dx dist 3 |
                                               value 3 |
                                   persistent_u dist 3 |
                                               value 3 |
@sabinala sabinala added bug Something isn't working help wanted Extra attention is needed labels Feb 21, 2024
@SamWitty
Copy link
Contributor

@sabinala , could you create a PR that adds this model to the tests? Thanks!

@sabinala
Copy link
Contributor Author

@SamWitty see #490.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants