-
Notifications
You must be signed in to change notification settings - Fork 86
Open
Description
Description
The discriminator loss in a GAN consists of -err_real+err_fake.
When differential privacy is enabled, we can separately backpropagate the fake error (with hooks disabled) and real error (with hooks enabled).
However, the current code backpropagates err_fake and err_real, but I believe this should be err_fake and -err_real, to match the non-differentially private loss?
i.e., we should backpropagate "minus the real error" instead of simply the "real error", when differential privacy is enabled for GANs.
How to Reproduce
- Go to 'synthcity/plugins/core/models/gan.py', line 452. Here we see that the positive real error is backpropagated. However, looking at line 438, I believe this should be -(err_real).
Expected Behavior
Change line 452 in 'synthcity/plugins/core/models/gan.py' from errD_real.backward(), to (-errD_real).backward().
Metadata
Metadata
Assignees
Labels
No labels