Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix stop-gradient not put in SIR problem
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 committed Oct 11, 2023
1 parent 44d8049 commit 8fb2f17
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,30 +475,35 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
FunctionGraph.get_opcode_executor_stack()
),
)
if outputs is not None:
if is_inplace_api(func):
# if we want to use a non-inplace api (static api) to replace an inplace behavior (in simulation)
# just set it back in SIR, and return outputs to replace tensor meta (it might changes?)
# in this case, the output will not exactly be used
compute_fn(
func,
inputs_symbols,
convert_to_symbol(args[0]),
stmt_stacks,
)
else:
compute_fn(
func,
inputs_symbols,
convert_to_symbol(outputs),
stmt_stacks,
) # symbolic only contain symbols.
self._put_inner(outputs)
if is_inplace_api(func):
# if we want to use a non-inplace api (static api) to replace an inplace behavior (in simulation)
# just set it back in SIR, and return outputs to replace tensor meta (it might changes?)
# in this case, the output will not exactly be used
compute_fn(
func,
inputs_symbols,
convert_to_symbol(args[0]),
stmt_stacks,
)
elif outputs is not None:
compute_fn(
func,
inputs_symbols,
convert_to_symbol(outputs),
stmt_stacks,
) # symbolic only contain symbols.
self._put_inner(outputs)
return VariableFactory.from_value(
outputs, self, DummyTracker(list(args) + list(kwargs.values()))
)
else:
return None
elif outputs is None:
# tensor.stop_gradient=True
compute_fn(
func,
inputs_symbols,
None,
stmt_stacks,
)

def _put_inner(self, vars: VariableBase):
"""
Expand Down

0 comments on commit 8fb2f17

Please sign in to comment.