From dbb9399e4c5d1aaa01cd59504ce23763255912cd Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Wed, 1 May 2024 14:14:18 -0700 Subject: [PATCH] Add torch.jit.trace option [ghstack-poisoned] --- torchbenchmark/util/backends/jit.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/torchbenchmark/util/backends/jit.py b/torchbenchmark/util/backends/jit.py index 01bb196f6a..bcc99d6e91 100644 --- a/torchbenchmark/util/backends/jit.py +++ b/torchbenchmark/util/backends/jit.py @@ -68,3 +68,19 @@ def _torchscript(): model.set_module(module) return _torchscript, extra_args + +@create_backend +def torchscript_trace( + model: "torchbenchmark.util.model.BenchmarkModel", backend_args: List[str] +): + model.jit = True + backend_args, extra_args = parse_torchscript_args(backend_args) + def _torchscript_trace(): + module, example_inputs = model.get_module() + module = torch.jit.trace( + module, + example_inputs=example_inputs, + ) + model.set_module(module) + + return _torchscript_trace, extra_args