88from caffe2 .python import core
99from caffe2 .python .scope import CurrentNameScope
1010
11+
1112def recurrent_net (
1213 net , cell_net , inputs , initial_cell_inputs ,
1314 links , timestep = None , scope = None , outputs_with_grads = (0 ,),
14- recompute_blobs_on_backward = None ,
15+ recompute_blobs_on_backward = None , forward_only = False ,
1516):
1617 '''
1718 net: the main net operator should be added to
@@ -43,6 +44,8 @@ def recurrent_net(
4344 recompute_blobs_on_backward: specify a list of blobs that will be
4445 recomputed for backward pass, and thus need not to be
4546 stored for each forward timestep.
47+
48+ forward_only: if True, only forward steps are executed
4649 '''
4750 assert len (inputs ) == 1 , "Only one input blob is supported so far"
4851
@@ -77,54 +80,55 @@ def s(name):
7780 inner_outputs = list (cell_net .Proto ().external_output )
7881 # These gradients are expected to be available during the backward pass
7982 inner_outputs_map = {o : o + '_grad' for o in inner_outputs }
83+ recompute_blobs_on_backward = set ()
8084
8185 # compute the backward pass of the cell net
82- backward_ops , backward_mapping = core .GradientRegistry .GetBackwardPass (
83- cell_net .Proto ().op , inner_outputs_map )
84- backward_mapping = {str (k ): v for k , v in backward_mapping .items ()}
85- backward_cell_net = core .Net ("RecurrentBackwardStep" )
86- del backward_cell_net .Proto ().op [:]
87-
88- if recompute_blobs_on_backward is not None :
89- # Insert operators to re-compute the specified blobs.
90- # They are added in the same order as for the forward pass, thus
91- # the order is correct.
92- recompute_blobs_on_backward = set (
93- [str (b ) for b in recompute_blobs_on_backward ]
94- )
95- for op in cell_net .Proto ().op :
96- if not recompute_blobs_on_backward .isdisjoint (set (op .output )):
97- backward_cell_net .Proto ().op .extend ([op ])
98- assert set (op .output ).issubset (recompute_blobs_on_backward ), \
99- 'Outputs {} are output by op but not recomputed: {}' .format (
100- set (op .output ) - recompute_blobs_on_backward ,
101- op
102- )
86+ if not forward_only :
87+ backward_ops , backward_mapping = core .GradientRegistry .GetBackwardPass (
88+ cell_net .Proto ().op , inner_outputs_map )
89+ backward_mapping = {str (k ): v for k , v in backward_mapping .items ()}
90+
91+ backward_cell_net = core .Net ("RecurrentBackwardStep" )
92+ del backward_cell_net .Proto ().op [:]
93+
94+ if recompute_blobs_on_backward is not None :
95+ # Insert operators to re-compute the specified blobs.
96+ # They are added in the same order as for the forward pass, thus
97+ # the order is correct.
98+ recompute_blobs_on_backward = {str (b ) for b in
99+ recompute_blobs_on_backward }
100+
101+ for op in cell_net .Proto ().op :
102+ if not recompute_blobs_on_backward .isdisjoint (set (op .output )):
103+ backward_cell_net .Proto ().op .extend ([op ])
104+ # This fires if other outputs than the declared
105+ # are computed by the ops that are recomputed
106+ assert set (op .output ).issubset (recompute_blobs_on_backward )
107+
108+ backward_cell_net .Proto ().op .extend (backward_ops )
109+ # compute blobs used but not defined in the backward pass
110+ backward_ssa , backward_blob_versions = core .get_ssa (
111+ backward_cell_net .Proto ())
112+ undefined = core .get_undefined_blobs (backward_ssa )
113+
114+ # also add to the output list the intermediate outputs of fwd_step that
115+ # are used by backward.
116+ ssa , blob_versions = core .get_ssa (cell_net .Proto ())
117+ scratches = [
118+ blob for (blob , ver ) in blob_versions .items ()
119+ if ver > 0 and
120+ blob in undefined and
121+ blob not in cell_net .Proto ().external_output ]
122+ backward_cell_net .Proto ().external_input .extend (scratches )
123+ backward_cell_net .Proto ().type = 'simple'
103124 else :
104- recompute_blobs_on_backward = set ()
105-
106- backward_cell_net .Proto ().op .extend (backward_ops )
107- # compute blobs used but not defined in the backward pass
108- backward_ssa , backward_blob_versions = core .get_ssa (
109- backward_cell_net .Proto ())
110- undefined = core .get_undefined_blobs (backward_ssa )
111-
112- # also add to the output list the intermediate outputs of fwd_step that
113- # are used by backward.
114- ssa , blob_versions = core .get_ssa (cell_net .Proto ())
115- scratches = [
116- blob for (blob , ver ) in blob_versions .items ()
117- if ver > 0 and
118- blob in undefined and
119- blob not in cell_net .Proto ().external_output ]
120- backward_cell_net .Proto ().external_input .extend (scratches )
125+ backward_cell_net = None
121126
122127 all_inputs = [i [1 ] for i in inputs ] + [
123128 x [1 ] for x in initial_cell_inputs ] + references
124129 all_outputs = []
125130
126131 cell_net .Proto ().type = 'simple'
127- backward_cell_net .Proto ().type = 'simple'
128132
129133 # Internal arguments used by RecurrentNetwork operator
130134
@@ -153,53 +157,58 @@ def s(name):
153157 cell_output = links [str (cell_input )]
154158 forward_links .append ((cell_input , state , 0 ))
155159 forward_links .append ((cell_output , state , 1 ))
156- backward_links .append ((cell_output + "_grad" , states_grad , 1 ))
157160
158- backward_cell_net .Proto ().external_input .append (
159- str (cell_output ) + "_grad" )
160161 aliases .append ((state , cell_output + "_all" , 1 ))
161162 aliases .append ((state , cell_output + "_last" , - 1 ))
162163 all_outputs .extend ([cell_output + "_all" , cell_output + "_last" ])
163164
164165 recurrent_states .append (state )
165166
166- recurrent_input_grad = cell_input + "_grad"
167- if not backward_blob_versions .get (recurrent_input_grad , 0 ):
168- # If nobody writes to this recurrent input gradient, we need
169- # to make sure it gets to the states grad blob after all.
170- # We do this by using backward_links which triggers an alias
171- # This logic is being used for example in a SumOp case
172- backward_links .append (
173- (backward_mapping [cell_input ], states_grad , 0 ))
174- else :
175- backward_links .append ((cell_input + "_grad" , states_grad , 0 ))
176-
177- for reference in references :
178- # Similar to above, in a case of a SumOp we need to write our parameter
179- # gradient to an external blob. In this case we can be sure that
180- # reference + "_grad" is a correct parameter name as we know how
181- # RecurrentNetworkOp gradient schema looks like.
182- reference_grad = reference + "_grad"
183- if (reference in backward_mapping and
184- reference_grad != str (backward_mapping [reference ])):
185- # We can use an Alias because after each timestep
186- # RNN op adds value from reference_grad into and _acc blob
187- # which accumulates gradients for corresponding parameter accross
188- # timesteps. Then in the end of RNN op these two are being
189- # swaped and reference_grad blob becomes a real blob instead of
190- # being an alias
191- backward_cell_net .Alias (
192- backward_mapping [reference ], reference_grad )
167+ if backward_cell_net is not None :
168+ backward_links .append ((cell_output + "_grad" , states_grad , 1 ))
169+ backward_cell_net .Proto ().external_input .append (
170+ str (cell_output ) + "_grad" )
171+
172+ recurrent_input_grad = cell_input + "_grad"
173+ if not backward_blob_versions .get (recurrent_input_grad , 0 ):
174+ # If nobody writes to this recurrent input gradient, we need
175+ # to make sure it gets to the states grad blob after all.
176+ # We do this by using backward_links which triggers an alias
177+ # This logic is being used for example in a SumOp case
178+ backward_links .append (
179+ (backward_mapping [cell_input ], states_grad , 0 ))
180+ else :
181+ backward_links .append ((cell_input + "_grad" , states_grad , 0 ))
193182
194183 for input_t , input_blob in inputs :
195184 forward_links .append ((str (input_t ), str (input_blob ), 0 ))
196- backward_links .append ((
197- backward_mapping [str (input_t )], str (input_blob ) + "_grad" , 0
198- ))
199- backward_cell_net .Proto ().external_input .extend (
200- cell_net .Proto ().external_input )
201- backward_cell_net .Proto ().external_input .extend (
202- cell_net .Proto ().external_output )
185+
186+ if backward_cell_net is not None :
187+ for reference in references :
188+ # Similar to above, in a case of a SumOp we need to write our parameter
189+ # gradient to an external blob. In this case we can be sure that
190+ # reference + "_grad" is a correct parameter name as we know how
191+ # RecurrentNetworkOp gradient schema looks like.
192+ reference_grad = reference + "_grad"
193+ if (reference in backward_mapping and
194+ reference_grad != str (backward_mapping [reference ])):
195+ # We can use an Alias because after each timestep
196+ # RNN op adds value from reference_grad into and _acc blob
197+ # which accumulates gradients for corresponding parameter accross
198+ # timesteps. Then in the end of RNN op these two are being
199+ # swaped and reference_grad blob becomes a real blob instead of
200+ # being an alias
201+ backward_cell_net .Alias (
202+ backward_mapping [reference ], reference_grad )
203+
204+ for input_t , input_blob in inputs :
205+ backward_links .append ((
206+ backward_mapping [str (input_t )], str (input_blob ) + "_grad" , 0
207+ ))
208+ backward_cell_net .Proto ().external_input .extend (
209+ cell_net .Proto ().external_input )
210+ backward_cell_net .Proto ().external_input .extend (
211+ cell_net .Proto ().external_output )
203212
204213 def unpack_triple (x ):
205214 if x :
@@ -210,18 +219,28 @@ def unpack_triple(x):
210219 # Splitting to separate lists so we can pass them to c++
211220 # where we ensemle them back
212221 link_internal , link_external , link_offset = unpack_triple (forward_links )
213- backward_link_internal , backward_link_external , backward_link_offset = \
214- unpack_triple (backward_links )
215222 alias_src , alias_dst , alias_offset = unpack_triple (aliases )
216223
217- params = [x for x in references if x in backward_mapping .keys ()]
218224 recurrent_inputs = [str (x [1 ]) for x in initial_cell_inputs ]
219225
220- global _workspace_seq
226+ backward_args = {}
227+ if backward_cell_net is not None :
228+ backward_link_internal , backward_link_external , backward_link_offset = \
229+ unpack_triple (backward_links )
230+ params = [x for x in references if x in backward_mapping .keys ()]
231+ backward_args = {
232+ 'param' : map (all_inputs .index , params ),
233+ 'backward_link_internal' : map (str , backward_link_internal ),
234+ 'backward_link_external' : map (str , backward_link_external ),
235+ 'backward_link_offset' : backward_link_offset ,
236+ 'backward_step_net' : str (backward_cell_net .Proto ()),
237+ 'outputs_with_grads' : outputs_with_grads ,
238+ 'recompute_blobs_on_backward' : map (str , recompute_blobs_on_backward )
239+ }
240+
221241 results = net .RecurrentNetwork (
222242 all_inputs ,
223243 all_outputs + [s ("step_workspaces" )],
224- param = map (all_inputs .index , params ),
225244 alias_src = alias_src ,
226245 alias_dst = map (str , alias_dst ),
227246 alias_offset = alias_offset ,
@@ -230,14 +249,9 @@ def unpack_triple(x):
230249 link_internal = map (str , link_internal ),
231250 link_external = map (str , link_external ),
232251 link_offset = link_offset ,
233- backward_link_internal = map (str , backward_link_internal ),
234- backward_link_external = map (str , backward_link_external ),
235- backward_link_offset = backward_link_offset ,
236252 step_net = str (cell_net .Proto ()),
237- backward_step_net = str (backward_cell_net .Proto ()),
238253 timestep = "timestep" if timestep is None else str (timestep ),
239- outputs_with_grads = outputs_with_grads ,
240- recompute_blobs_on_backward = map (str , recompute_blobs_on_backward )
254+ ** backward_args
241255 )
242256 # The last output is a list of step workspaces,
243257 # which is only needed internally for gradient propogation
0 commit comments