You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I can't quite figure out what is really responsible for the error but I suspect its _vmap failing to batch properly because my debugger indicates there is something wrong with the tensors that are yielded. If i look at the batched_inputs[0] variable in _vmap_internals._vmap and try print it, or view it, or do +1 to it then i get the error RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.
Computing the hessian in a loop works but is hideous and slow.
hvp_map = lambda V: torch.stack(
[autograd.grad(grad, x, v, retain_graph=True)[0] for v in V], dim=0)
hess = hvp_map(self._I)
Is this a real issue or am missing something?
The text was updated successfully, but these errors were encountered:
At this point, now that PyTorch 2 is out we should remove all uses of the _vmap utility and replace them with forward-mode automatic differentiation. Jacobian & Hessian computation will be much more efficient.
I have added your ticket to the PyTorch 2 milestone. Help appreciated!
Trying to use the minimize function with methods
But succeeds with the other methods. Presumably, the other methods aren't computing Hessians.
I tracked the error to here.
I can't quite figure out what is really responsible for the error but I suspect its
_vmap
failing to batch properly because my debugger indicates there is something wrong with the tensors that are yielded. If i look at thebatched_inputs[0]
variable in_vmap_internals._vmap
and try print it, or view it, or do+1
to it then i get the errorRuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.
Computing the hessian in a loop works but is hideous and slow.
Is this a real issue or am missing something?
The text was updated successfully, but these errors were encountered: