Skip to content

Commit 9cb901c

Browse files
Aapo KyrolaYangqing Jia
authored andcommitted
Forward-only rnns
Summary: Added option to recurrent_net and RNNCell's for forward_only. If this is set, the backward_step_net is not passed to the operator. When backward_step_net is not available, operator knows it is in forward_only mode and does not create workspaces for each step but cycles through only one private workspace. Note: we could avoid doing a lot of work in recurrent.py:recurrent_network call when backward step is not needed, but doing that nicely requires more refactoring that I did not want to do now. Thus, we create the backward step nets etc, but just don't pass it to the op. This can be used to create more efficient inference models. You can also sanitize existing inference nets and remove the backward_step_net argument to get the benefits. Reviewed By: salexspb Differential Revision: D4916482 fbshipit-source-id: c99b93c9cb897c32b0f449253f7f6d6a942618ad
1 parent 7440cd5 commit 9cb901c

File tree

5 files changed

+157
-109
lines changed

5 files changed

+157
-109
lines changed

caffe2/operators/recurrent_network_op.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,13 @@ class RecurrentNetworkOp final : public Operator<Context> {
257257
ri, seqLen, batchSize, sharedWs_, &context_);
258258
}
259259

260+
// If we don't have a backward step net, this operator is forward_only
261+
// and we can avoid creating multiple workspaces.
262+
263+
bool has_backward_pass =
264+
OperatorBase::GetSingleArgument<string>("backward_step_net", "") != "";
265+
266+
// With backward pass: we need to create workspace for each timestep
260267
detail::ScratchWorkspaces* scratch =
261268
OperatorBase::Output<detail::ScratchWorkspaces>(OutputSize() - 1);
262269
std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
@@ -271,13 +278,15 @@ class RecurrentNetworkOp final : public Operator<Context> {
271278
// have to be stored in step workspaces but can be shared.
272279
initializeBlobsToRecomputeOnBackward(forwardSharedWs.get());
273280

274-
if (seqLen > stepWorkspaces.size()) {
281+
if (has_backward_pass && seqLen > stepWorkspaces.size()) {
275282
stepWorkspaces.resize(seqLen);
276283
}
277284

278285
for (auto t = 0; t < seqLen; ++t) {
279-
auto& currentStepWorkspace = stepWorkspaces[t];
286+
auto& currentStepWorkspace =
287+
(has_backward_pass ? stepWorkspaces[t] : forwardSharedWs);
280288
if (!currentStepWorkspace) {
289+
CHECK(has_backward_pass);
281290
currentStepWorkspace =
282291
std::make_shared<Workspace>(forwardSharedWs.get());
283292
}

caffe2/python/lstm_benchmark.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def create_model(args, queue, label_queue, input_shape):
8080
dim_out=args.hidden_dim,
8181
scope="lstm1",
8282
memory_optimization=args.memory_optimization,
83+
forward_only=args.forward_only,
8384
)
8485
elif args.implementation == "cudnn":
8586
# We need to feed a placeholder input so that RecurrentInitOp
@@ -104,7 +105,8 @@ def create_model(args, queue, label_queue, input_shape):
104105
['softmax', 'loss'],
105106
)
106107

107-
model.AddGradientOperators([loss])
108+
if not args.forward_only:
109+
model.AddGradientOperators([loss])
108110

109111
# carry states over
110112
model.net.Copy(last_hidden, hidden_init)
@@ -232,6 +234,11 @@ def GetArgumentParser():
232234
action="store_true",
233235
help="Whether to use memory optimized LSTM or not",
234236
)
237+
parser.add_argument(
238+
"--forward_only",
239+
action="store_true",
240+
help="Whether to run only forward pass"
241+
)
235242

236243
return parser
237244

caffe2/python/operator_test/rnn_cell_test.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,10 @@ def test_lstm_unit_recurrent_network(self, n, d, t, dc, gc):
378378
gc, op, inputs, i, [0, 1],
379379
input_device_options=input_device_options)
380380

381-
382381
@given(
383382
input_tensor=lstm_input(),
384383
forget_bias=st.floats(-10.0, 10.0),
384+
fwd_only=st.booleans(),
385385
)
386386
@ht_settings(max_examples=25)
387387
def test_lstm_main(self, **kwargs):
@@ -393,7 +393,7 @@ def test_lstm_main(self, **kwargs):
393393
**kwargs)
394394

395395
def lstm_base(self, lstm_type, outputs_with_grads, memory_optim,
396-
input_tensor, forget_bias):
396+
input_tensor, forget_bias, fwd_only):
397397
print("LSTM test parameters: ", locals())
398398
create_lstm, ref = lstm_type
399399
t, n, d = input_tensor.shape
@@ -412,7 +412,8 @@ def lstm_base(self, lstm_type, outputs_with_grads, memory_optim,
412412
d, d, scope="external/recurrent",
413413
outputs_with_grads=outputs_with_grads,
414414
memory_optimization=memory_optim,
415-
forget_bias=forget_bias)
415+
forget_bias=forget_bias,
416+
forward_only=fwd_only)
416417

417418
op = model.net._net.op[-1]
418419

@@ -447,16 +448,17 @@ def generate_random_state(n, d):
447448
)
448449

449450
# Checking for input, gates_t_w and gates_t_b gradients
450-
for param in range(5):
451-
self.assertGradientChecks(
452-
device_option=hu.cpu_do,
453-
op=op,
454-
inputs=inputs,
455-
outputs_to_check=param,
456-
outputs_with_grads=outputs_with_grads,
457-
threshold=0.01,
458-
stepsize=0.005,
459-
)
451+
if not fwd_only:
452+
for param in range(5):
453+
self.assertGradientChecks(
454+
device_option=hu.cpu_do,
455+
op=op,
456+
inputs=inputs,
457+
outputs_to_check=param,
458+
outputs_with_grads=outputs_with_grads,
459+
threshold=0.01,
460+
stepsize=0.005,
461+
)
460462

461463
@given(encoder_output_length=st.integers(1, 3),
462464
encoder_output_dim=st.integers(1, 3),

caffe2/python/recurrent.py

Lines changed: 102 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from caffe2.python import core
99
from caffe2.python.scope import CurrentNameScope
1010

11+
1112
def recurrent_net(
1213
net, cell_net, inputs, initial_cell_inputs,
1314
links, timestep=None, scope=None, outputs_with_grads=(0,),
14-
recompute_blobs_on_backward=None,
15+
recompute_blobs_on_backward=None, forward_only=False,
1516
):
1617
'''
1718
net: the main net operator should be added to
@@ -43,6 +44,8 @@ def recurrent_net(
4344
recompute_blobs_on_backward: specify a list of blobs that will be
4445
recomputed for backward pass, and thus need not to be
4546
stored for each forward timestep.
47+
48+
forward_only: if True, only forward steps are executed
4649
'''
4750
assert len(inputs) == 1, "Only one input blob is supported so far"
4851

@@ -77,54 +80,55 @@ def s(name):
7780
inner_outputs = list(cell_net.Proto().external_output)
7881
# These gradients are expected to be available during the backward pass
7982
inner_outputs_map = {o: o + '_grad' for o in inner_outputs}
83+
recompute_blobs_on_backward = set()
8084

8185
# compute the backward pass of the cell net
82-
backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
83-
cell_net.Proto().op, inner_outputs_map)
84-
backward_mapping = {str(k): v for k, v in backward_mapping.items()}
85-
backward_cell_net = core.Net("RecurrentBackwardStep")
86-
del backward_cell_net.Proto().op[:]
87-
88-
if recompute_blobs_on_backward is not None:
89-
# Insert operators to re-compute the specified blobs.
90-
# They are added in the same order as for the forward pass, thus
91-
# the order is correct.
92-
recompute_blobs_on_backward = set(
93-
[str(b) for b in recompute_blobs_on_backward]
94-
)
95-
for op in cell_net.Proto().op:
96-
if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
97-
backward_cell_net.Proto().op.extend([op])
98-
assert set(op.output).issubset(recompute_blobs_on_backward), \
99-
'Outputs {} are output by op but not recomputed: {}'.format(
100-
set(op.output) - recompute_blobs_on_backward,
101-
op
102-
)
86+
if not forward_only:
87+
backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
88+
cell_net.Proto().op, inner_outputs_map)
89+
backward_mapping = {str(k): v for k, v in backward_mapping.items()}
90+
91+
backward_cell_net = core.Net("RecurrentBackwardStep")
92+
del backward_cell_net.Proto().op[:]
93+
94+
if recompute_blobs_on_backward is not None:
95+
# Insert operators to re-compute the specified blobs.
96+
# They are added in the same order as for the forward pass, thus
97+
# the order is correct.
98+
recompute_blobs_on_backward = {str(b) for b in
99+
recompute_blobs_on_backward}
100+
101+
for op in cell_net.Proto().op:
102+
if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
103+
backward_cell_net.Proto().op.extend([op])
104+
# This fires if other outputs than the declared
105+
# are computed by the ops that are recomputed
106+
assert set(op.output).issubset(recompute_blobs_on_backward)
107+
108+
backward_cell_net.Proto().op.extend(backward_ops)
109+
# compute blobs used but not defined in the backward pass
110+
backward_ssa, backward_blob_versions = core.get_ssa(
111+
backward_cell_net.Proto())
112+
undefined = core.get_undefined_blobs(backward_ssa)
113+
114+
# also add to the output list the intermediate outputs of fwd_step that
115+
# are used by backward.
116+
ssa, blob_versions = core.get_ssa(cell_net.Proto())
117+
scratches = [
118+
blob for (blob, ver) in blob_versions.items()
119+
if ver > 0 and
120+
blob in undefined and
121+
blob not in cell_net.Proto().external_output]
122+
backward_cell_net.Proto().external_input.extend(scratches)
123+
backward_cell_net.Proto().type = 'simple'
103124
else:
104-
recompute_blobs_on_backward = set()
105-
106-
backward_cell_net.Proto().op.extend(backward_ops)
107-
# compute blobs used but not defined in the backward pass
108-
backward_ssa, backward_blob_versions = core.get_ssa(
109-
backward_cell_net.Proto())
110-
undefined = core.get_undefined_blobs(backward_ssa)
111-
112-
# also add to the output list the intermediate outputs of fwd_step that
113-
# are used by backward.
114-
ssa, blob_versions = core.get_ssa(cell_net.Proto())
115-
scratches = [
116-
blob for (blob, ver) in blob_versions.items()
117-
if ver > 0 and
118-
blob in undefined and
119-
blob not in cell_net.Proto().external_output]
120-
backward_cell_net.Proto().external_input.extend(scratches)
125+
backward_cell_net = None
121126

122127
all_inputs = [i[1] for i in inputs] + [
123128
x[1] for x in initial_cell_inputs] + references
124129
all_outputs = []
125130

126131
cell_net.Proto().type = 'simple'
127-
backward_cell_net.Proto().type = 'simple'
128132

129133
# Internal arguments used by RecurrentNetwork operator
130134

@@ -153,53 +157,58 @@ def s(name):
153157
cell_output = links[str(cell_input)]
154158
forward_links.append((cell_input, state, 0))
155159
forward_links.append((cell_output, state, 1))
156-
backward_links.append((cell_output + "_grad", states_grad, 1))
157160

158-
backward_cell_net.Proto().external_input.append(
159-
str(cell_output) + "_grad")
160161
aliases.append((state, cell_output + "_all", 1))
161162
aliases.append((state, cell_output + "_last", -1))
162163
all_outputs.extend([cell_output + "_all", cell_output + "_last"])
163164

164165
recurrent_states.append(state)
165166

166-
recurrent_input_grad = cell_input + "_grad"
167-
if not backward_blob_versions.get(recurrent_input_grad, 0):
168-
# If nobody writes to this recurrent input gradient, we need
169-
# to make sure it gets to the states grad blob after all.
170-
# We do this by using backward_links which triggers an alias
171-
# This logic is being used for example in a SumOp case
172-
backward_links.append(
173-
(backward_mapping[cell_input], states_grad, 0))
174-
else:
175-
backward_links.append((cell_input + "_grad", states_grad, 0))
176-
177-
for reference in references:
178-
# Similar to above, in a case of a SumOp we need to write our parameter
179-
# gradient to an external blob. In this case we can be sure that
180-
# reference + "_grad" is a correct parameter name as we know how
181-
# RecurrentNetworkOp gradient schema looks like.
182-
reference_grad = reference + "_grad"
183-
if (reference in backward_mapping and
184-
reference_grad != str(backward_mapping[reference])):
185-
# We can use an Alias because after each timestep
186-
# RNN op adds value from reference_grad into and _acc blob
187-
# which accumulates gradients for corresponding parameter accross
188-
# timesteps. Then in the end of RNN op these two are being
189-
# swaped and reference_grad blob becomes a real blob instead of
190-
# being an alias
191-
backward_cell_net.Alias(
192-
backward_mapping[reference], reference_grad)
167+
if backward_cell_net is not None:
168+
backward_links.append((cell_output + "_grad", states_grad, 1))
169+
backward_cell_net.Proto().external_input.append(
170+
str(cell_output) + "_grad")
171+
172+
recurrent_input_grad = cell_input + "_grad"
173+
if not backward_blob_versions.get(recurrent_input_grad, 0):
174+
# If nobody writes to this recurrent input gradient, we need
175+
# to make sure it gets to the states grad blob after all.
176+
# We do this by using backward_links which triggers an alias
177+
# This logic is being used for example in a SumOp case
178+
backward_links.append(
179+
(backward_mapping[cell_input], states_grad, 0))
180+
else:
181+
backward_links.append((cell_input + "_grad", states_grad, 0))
193182

194183
for input_t, input_blob in inputs:
195184
forward_links.append((str(input_t), str(input_blob), 0))
196-
backward_links.append((
197-
backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
198-
))
199-
backward_cell_net.Proto().external_input.extend(
200-
cell_net.Proto().external_input)
201-
backward_cell_net.Proto().external_input.extend(
202-
cell_net.Proto().external_output)
185+
186+
if backward_cell_net is not None:
187+
for reference in references:
188+
# Similar to above, in a case of a SumOp we need to write our parameter
189+
# gradient to an external blob. In this case we can be sure that
190+
# reference + "_grad" is a correct parameter name as we know how
191+
# RecurrentNetworkOp gradient schema looks like.
192+
reference_grad = reference + "_grad"
193+
if (reference in backward_mapping and
194+
reference_grad != str(backward_mapping[reference])):
195+
# We can use an Alias because after each timestep
196+
# RNN op adds value from reference_grad into and _acc blob
197+
# which accumulates gradients for corresponding parameter accross
198+
# timesteps. Then in the end of RNN op these two are being
199+
# swaped and reference_grad blob becomes a real blob instead of
200+
# being an alias
201+
backward_cell_net.Alias(
202+
backward_mapping[reference], reference_grad)
203+
204+
for input_t, input_blob in inputs:
205+
backward_links.append((
206+
backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
207+
))
208+
backward_cell_net.Proto().external_input.extend(
209+
cell_net.Proto().external_input)
210+
backward_cell_net.Proto().external_input.extend(
211+
cell_net.Proto().external_output)
203212

204213
def unpack_triple(x):
205214
if x:
@@ -210,18 +219,28 @@ def unpack_triple(x):
210219
# Splitting to separate lists so we can pass them to c++
211220
# where we ensemle them back
212221
link_internal, link_external, link_offset = unpack_triple(forward_links)
213-
backward_link_internal, backward_link_external, backward_link_offset = \
214-
unpack_triple(backward_links)
215222
alias_src, alias_dst, alias_offset = unpack_triple(aliases)
216223

217-
params = [x for x in references if x in backward_mapping.keys()]
218224
recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]
219225

220-
global _workspace_seq
226+
backward_args = {}
227+
if backward_cell_net is not None:
228+
backward_link_internal, backward_link_external, backward_link_offset = \
229+
unpack_triple(backward_links)
230+
params = [x for x in references if x in backward_mapping.keys()]
231+
backward_args = {
232+
'param': map(all_inputs.index, params),
233+
'backward_link_internal': map(str, backward_link_internal),
234+
'backward_link_external': map(str, backward_link_external),
235+
'backward_link_offset': backward_link_offset,
236+
'backward_step_net': str(backward_cell_net.Proto()),
237+
'outputs_with_grads': outputs_with_grads,
238+
'recompute_blobs_on_backward': map(str, recompute_blobs_on_backward)
239+
}
240+
221241
results = net.RecurrentNetwork(
222242
all_inputs,
223243
all_outputs + [s("step_workspaces")],
224-
param=map(all_inputs.index, params),
225244
alias_src=alias_src,
226245
alias_dst=map(str, alias_dst),
227246
alias_offset=alias_offset,
@@ -230,14 +249,9 @@ def unpack_triple(x):
230249
link_internal=map(str, link_internal),
231250
link_external=map(str, link_external),
232251
link_offset=link_offset,
233-
backward_link_internal=map(str, backward_link_internal),
234-
backward_link_external=map(str, backward_link_external),
235-
backward_link_offset=backward_link_offset,
236252
step_net=str(cell_net.Proto()),
237-
backward_step_net=str(backward_cell_net.Proto()),
238253
timestep="timestep" if timestep is None else str(timestep),
239-
outputs_with_grads=outputs_with_grads,
240-
recompute_blobs_on_backward=map(str, recompute_blobs_on_backward)
254+
**backward_args
241255
)
242256
# The last output is a list of step workspaces,
243257
# which is only needed internally for gradient propogation

0 commit comments

Comments
 (0)