Skip to content

Commit 673ed46

Browse files
soulitzerfacebook-github-bot
authored andcommitted
Gradcheck small fixes (pytorch#53916)
Summary: Pull Request resolved: pytorch#53916 This PR fixes some bugs that are made more clear by the previous refactor. - make sure gradcheck returns false when its supposed to fail and when raise_exception=False. - make sure when test_batched_grad fails, it returns false when raise_exception=False Removing checkIfNumericalAnalyticAreClose made sense here to me because underneath its really doing `torch.allclose`, and using that directly instead of adding another opaque function to call seemed to make the code more clear. TODO: - ~add a test to see if when torch.allclose fails, we indeed return false.~ - ~uncomment test from previous PR.~ Test Plan: Imported from OSS Reviewed By: heitorschueroff Differential Revision: D27201692 Pulled By: soulitzer fbshipit-source-id: 8b8dc37c59edb7eebc2e8db6f8839ce98a81d78b
1 parent 796be04 commit 673ed46

File tree

3 files changed

+55
-21
lines changed

3 files changed

+55
-21
lines changed

test/test_autograd.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4094,6 +4094,8 @@ def test_gradcheck_check_batched_grad(self):
40944094
# runtime error while compute batched grad (print big error)
40954095
with self.assertRaisesRegex(RuntimeError, 'gradcheck or gradgradcheck failed while testing batched gradient'):
40964096
gradcheck(lambda x: x.to_dense(), (x,), check_sparse_nnz=True, check_batched_grad=True)
4097+
self.assertFalse(gradcheck(lambda x: x.to_dense(), (x,), check_sparse_nnz=True, check_batched_grad=True,
4098+
raise_exception=False))
40974099

40984100
def test_gradcheck_backward_mul_by_grad_output(self):
40994101
# when grad_input is sparse and has incorrect sparse_dim/dense_dim
@@ -4161,6 +4163,38 @@ def hook(x):
41614163
gradcheck(fn, (x,))
41624164
self.assertFalse(gradcheck(fn, (x,), raise_exception=False))
41634165

4166+
def test_gradcheck_jacobian_mismatch(self):
4167+
def fn(x): # R -> R, C -> C
4168+
y = x.clone()
4169+
y.register_hook(lambda x: x + 1e-2)
4170+
return y
4171+
x = torch.ones(2, 2, requires_grad=True)
4172+
with self.assertRaisesRegex(RuntimeError, 'Jacobian mismatch for output 0 with respect to input 0'):
4173+
gradcheck(fn, (x,))
4174+
self.assertFalse(gradcheck(fn, (x,), raise_exception=False))
4175+
4176+
x_c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128)
4177+
with self.assertRaisesRegex(RuntimeError, 'Gradients failed to compare equal for grad output = 1j'):
4178+
gradcheck(fn, (x_c,))
4179+
self.assertFalse(gradcheck(fn, (x_c,), raise_exception=False))
4180+
4181+
def fn2(x): # R -> C
4182+
y = torch.complex(x, x)
4183+
y.register_hook(lambda x: x + 1e-2)
4184+
return y
4185+
x = torch.ones(2, 2, requires_grad=True)
4186+
with self.assertRaisesRegex(RuntimeError, 'Gradients failed to compare equal for grad output = 1j'):
4187+
gradcheck(fn2, (x,))
4188+
self.assertFalse(gradcheck(fn2, (x,), raise_exception=False))
4189+
4190+
def fn3(x): # C -> R
4191+
y = torch.real(x)
4192+
y.register_hook(lambda x: x + 1e-2)
4193+
return y
4194+
with self.assertRaisesRegex(RuntimeError, 'Gradients failed to compare equal for grad output = 1'):
4195+
gradcheck(fn3, (x_c,))
4196+
self.assertFalse(gradcheck(fn3, (x_c,), raise_exception=False))
4197+
41644198
def test_version_counter(self):
41654199
x = torch.randn(1, 2)
41664200

test/test_overrides.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,6 @@ def test_gradcheck(self):
804804
# Tensor-likes.
805805
self.assertEqual(total_used_attrs, {
806806
'data',
807-
'device',
808807
'dtype',
809808
'is_complex',
810809
'is_floating_point',

torch/autograd/gradcheck.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp):
318318
""".strip()
319319

320320

321-
def test_batched_grad(fail_test, input, output, output_idx):
321+
def test_batched_grad(fail_test, input, output, output_idx) -> bool:
322322
# NB: test_batched_grad compares two autograd.grad invocations with a single
323323
# vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the
324324
# sense that we're not comparing an analytical jacobian with a numeric one,
@@ -358,6 +358,7 @@ def vjp(v):
358358
if torch.allclose(res, exp):
359359
continue
360360
return fail_test(get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp))
361+
return True
361362

362363

363364
def test_backward_mul_by_grad_output(fail_test, outputs, inputs, check_sparse_nnz) -> bool:
@@ -458,6 +459,11 @@ def _differentiable_outputs(x):
458459
return tuple(o for o in _as_tuple(x) if o.requires_grad)
459460

460461

462+
def get_notallclose_msg(analytical, numerical, output_idx, input_idx, error_str='') -> str:
463+
return error_str + 'Jacobian mismatch for output %d with respect to input %d,\n' \
464+
'numerical:%s\nanalytical:%s\n' % (output_idx, input_idx, numerical, analytical)
465+
466+
461467
# Note [VarArg of Tensors]
462468
# ~~~~~~~~~~~~~~~~~~~~~~~~
463469
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
@@ -562,31 +568,26 @@ def fn(input):
562568
return False
563569
numerical_from_imag_grad_out = get_numerical_jacobian(fn, tupled_inputs, eps=eps, grad_out=1j)
564570

565-
def checkIfNumericalAnalyticAreClose(a, n, j, error_str=''):
566-
if not torch.allclose(a, n, rtol, atol):
567-
return fail_test(error_str + 'Jacobian mismatch for output %d with respect to input %d,\n'
568-
'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
569-
570571
inp_tensors = iter_tensors(tupled_inputs, True)
571572

572573
for j, (a, n, inp) in enumerate(zip(analytical, numerical, inp_tensors)):
573574
if a.numel() != 0 or n.numel() != 0:
574-
if o.is_complex():
575-
# C -> C, R -> C
576-
a_imag_grad_out = analytical_from_imag_grad_out[j]
577-
n_imag_grad_out = numerical_from_imag_grad_out[j]
578-
checkIfNumericalAnalyticAreClose(a_imag_grad_out, n_imag_grad_out, j,
579-
"Gradients failed to compare equal for grad output = 1j. ")
580-
if inp.is_complex():
581-
# C -> R, C -> C
582-
checkIfNumericalAnalyticAreClose(a, n, j,
583-
"Gradients failed to compare equal for grad output = 1. ")
584-
else:
585-
# R -> R, R -> C
586-
checkIfNumericalAnalyticAreClose(a, n, j)
575+
if o.is_complex(): # C -> C, R -> C
576+
if not torch.allclose(analytical_from_imag_grad_out[j], numerical_from_imag_grad_out[j], rtol, atol):
577+
return fail_test(get_notallclose_msg(analytical_from_imag_grad_out[j],
578+
numerical_from_imag_grad_out[j], i, j,
579+
"Gradients failed to compare equal for grad output = 1j. "))
580+
if inp.is_complex(): # C -> R, C -> C
581+
if not torch.allclose(a, n, rtol, atol):
582+
return fail_test(get_notallclose_msg(a, n, i, j,
583+
"Gradients failed to compare equal for grad output = 1. "))
584+
else: # R -> R, R -> C
585+
if not torch.allclose(a, n, rtol, atol):
586+
return fail_test(get_notallclose_msg(a, n, i, j))
587587

588588
if check_batched_grad:
589-
test_batched_grad(fail_test, tupled_inputs, o, j)
589+
if not test_batched_grad(fail_test, tupled_inputs, o, i):
590+
return False
590591

591592
if not test_backward_mul_by_grad_output(fail_test, outputs, tupled_inputs, check_sparse_nnz):
592593
return False

0 commit comments

Comments
 (0)