Skip to content

Commit

Permalink
Improve usability
Browse files Browse the repository at this point in the history
  • Loading branch information
mtanghu committed Feb 4, 2022
1 parent f6afb8a commit 5c815ff
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 27 deletions.
65 changes: 44 additions & 21 deletions dni.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,33 @@
import torch.nn as nn
from torch.autograd import Variable

import traceback

# TODO: add help section for this!! -- maybe make a doc for this somewhere? like in a src file?
# TODO: assert somewhere the hidden state needs to have requires_grad=True, and explain why!
# TODO: figure out how to make an assert that will stop people from forgetting to set retain_grad to true
# TODO: add assert message to the prev.grad is not None that is helpful

# TODO: ADD A COMMENT ABOUT THE ACCUMULATED GRADIENTS!
# TODO: add help section for this!! -- maybe make a doc for this somewhere? like in a src file?
# just somewhere the parameters need to be explained
# TODO: think about and find out what happens with stacked GRUs and LSTMs

# TODO: add one last



class Synthesizer(nn.Module):
def __init__(self, size, is_lstm = True, hidden_layers = 1, factor = 1, allow_backwarding = False, activation = nn.GELU,
optimizer = torch.optim.Adam, lr = .0001, aux = True, device = 'cuda', use_improvement = True):
optimizer = torch.optim.Adam, lr = .0001, aux = False, use_improvement = True):
super().__init__()
self.factor = factor
self.device = device
self.is_lstm = is_lstm
self.use_improvement = use_improvement
self.allow_backwarding = allow_backwarding
self.device = 'cpu'

if self.is_lstm:
# the hidden state and cell state will be concatenated (as per paper) thus doubling gradient size
size = size * 2

layers = [nn.Linear(size, size)] + [nn.Sequential(activation(), nn.Linear(size, size)) for i in range(hidden_layers)]
self.synthesizer = nn.Sequential(*layers).to(self.device)
self.synthesizer = nn.Sequential(*layers)

# initialize last layer of synthesizer to only output 0 (as per paper)
# this will stop spurious gradient from being backwarded at the start
Expand All @@ -40,7 +40,7 @@ def __init__(self, size, is_lstm = True, hidden_layers = 1, factor = 1, allow_ba
self.aux = True
if self.aux:
aux_layers = [nn.Linear(size, size)] + [nn.Sequential(activation(), nn.Linear(size, size)) for i in range(hidden_layers)]
self.aux_layers = nn.Sequential(*aux_layers).to(self.device)
self.aux_layers = nn.Sequential(*aux_layers)

# somehow this parameters call works
self.optimizer = optimizer(self.parameters(), lr = lr)
Expand All @@ -51,40 +51,64 @@ def __init__(self, size, is_lstm = True, hidden_layers = 1, factor = 1, allow_ba
self.prev_hidden = None
self.prev_synth = None
self.synthetic_loss = 0


def backward_synthetic(self, last_hidden):
# assert self.prev_hidden is None or self.prev_hidden.grad is not None, "make sure to set retain_graph=True for backward call"

def to(self, device):
self.device = device
return super().to(device)

def cuda(self):
self.device = torch.cuda.current_device()
return super().cuda()

def cpu(sef):
self.device = 'cpu'
return super.cpu()

def backward_synthetic(self, last_hidden):
last_hidden = torch.cat(last_hidden, dim = 2) if self.is_lstm else last_hidden

# predict future losses, not detach will allow losses from synthetic gradient predict to flow into the model
if self.allow_backwarding:
synthetic_gradient = self.synthesizer(last_hidden)
else:
synthetic_gradient = self.synthesizer(last_hidden.detach())

if self.prev_hidden is not None and self.prev_hidden.grad is None:
raise ValueError(
"Loss gradient not found, make sure to run .backward_synthetic() AFTER loss.backward(retain_graph=True). "
"The graph needs to be retained since .backward_synthetic() uses it."
)

# backward this future loss scaled by a factor (.1 in the paper) for stable training
last_hidden.backward(gradient = synthetic_gradient.detach() * self.factor, retain_graph = True)
try:
# backward this future loss scaled by a factor (.1 in the paper) for stable training
last_hidden.backward(gradient = synthetic_gradient.detach() * self.factor, retain_graph = True)
except RuntimeError:
traceback.print_exc()
raise RuntimeError(
'Unable to backward synthetic gradient. See error above, if it says '
'\"Trying to backward through the graph a second time\" '
'then you need to set retain_graph=True in your backward call. '
'Ex: loss.backward(retain_graph=True) the synthesizer uses the graph and will free it on its own. '
)

# auxilliary task described in paper
if self.aux and self.predicted_grad is not None:
aux_loss = self.loss_func(self.predicted_grad, synthetic_gradient.detach())
aux_loss.backward(retain_graph = True)
self.aux_loss = aux_loss.item()
self.predicted_grad = self.aux_layers(last_hidden.detach())

if self.prev_hidden is not None:
assert self.prev_hidden.grad is not None

# update synthesizer
# TODO: ADD A COMMENT ABOUT THE GRADIENTS ACCUMULATED HERE!
# update synthesizer with the accumulated gradients in the prev_hidden
# right now prev_hidden.grad = d_loss/d_prev_hidden + d_future_loss/d_prev_hidden (from the synthetic gradient)
synth_loss = self.loss_func(self.prev_synth, self.prev_hidden.grad)
synth_loss.backward()

# store the synthetic gradient loss for monitoring purposes
self.synthetic_loss = synth_loss.item()

# save the last hidden state and unroll an extra RNN core by requiring grad (in paper)
# graph should be discarded here after the detach unless prev_synth keeps a copy of the graph
self.prev_hidden = last_hidden.detach()
self.prev_hidden = Variable(self.prev_hidden, requires_grad = True)

Expand All @@ -100,8 +124,7 @@ def backward_synthetic(self, last_hidden):
def step(self):
if self.use_improvement:
# NEW IDEA!!
# set the last future loss to 0 since this is the end of the epoch
# this is very important for stopping synthetic gradients from exploding
# set the last future loss to 0 since this is the end of the epoch to stop synthetic gradients from exploding
synth_loss = self.loss_func(self.prev_synth, torch.zeros(self.prev_synth.shape).to(self.device))
synth_loss.backward()

Expand Down
10 changes: 4 additions & 6 deletions examples/copy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def forward(self, input, h0):


if __name__ == "__main__":
device = 'cpu'
device = 'cuda'

### DATA ###
# create prompts with a stop character at the end
Expand All @@ -46,7 +46,7 @@ def forward(self, input, h0):
loss = nn.CrossEntropyLoss()

# instantiate DNI model that will backward synthetic gradients
synth = dni.Synthesizer(size = D_MODEL, is_lstm = True, device = device)
synth = dni.Synthesizer(size = D_MODEL, is_lstm = True).cuda()

losses = []
synth_losses = []
Expand All @@ -63,11 +63,10 @@ def forward(self, input, h0):
# standard forward pass
out, h_n = rnn(split, h_n)
cross_loss = loss(out.view(-1, len(ALPHABET)+1), torch.zeros(BATCH_SIZE, TBPTT, dtype = torch.long).view(-1).to(device))
cross_loss.backward(retain_graph = True)

# just add ONE line for synthetic gradients
h_n = synth.backward_synthetic(h_n)

cross_loss.backward()

torch.nn.utils.clip_grad_norm_(rnn.parameters(), 25)
optim.step()
Expand All @@ -78,11 +77,10 @@ def forward(self, input, h0):
# standard forward pass
out, h_n = rnn(torch.zeros(TBPTT, BATCH_SIZE, dtype = torch.long).to(device), h_n)
cross_loss = loss(out.reshape(-1, len(ALPHABET)+1), split.reshape(-1))
cross_loss.backward(retain_graph = True)

# just add ONE line for synthetic gradients
h_n = synth.backward_synthetic(h_n)

cross_loss.backward()

losses.append(cross_loss.item())

Expand Down

0 comments on commit 5c815ff

Please sign in to comment.