Skip to content

Commit 81b5921

Browse files
ajyufacebook-github-bot
authored andcommitted
[static runtime] binding for aten::div_out (pytorch#56653)
Summary: Pull Request resolved: pytorch#56653 Test Plan: ``` ./buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench --scripted_model=/data/users/ansha/tmp/adfinder/aug_1x/210616848_0.predictor.disagg.local.local.pt --pt_inputs=/data/users/ansha/tmp/adfinder/aug_1x/210616848_0.predictor.disagg.input_data.container.pt --iters=500 --warmup_iters=500 --num_threads=1 --pt_enable_static_runtime=1 --pt_cleanup_activations=true --pt_enable_out_variant=1 --pt_optimize_memory=1 --compare_results=1 --do_profile=1 --adsfinder_compatibility=1 ``` ``` Time per node type: 1.48563 ms. 35.9861%. fb::sigrid_transforms_torch_bind (1 nodes) 0.92385 ms. 22.3783%. aten::linear (6 nodes) 0.681066 ms. 16.4974%. aten::argmin (1 nodes) 0.239311 ms. 5.79679%. aten::matmul (1 nodes) 0.140157 ms. 3.39501%. fb::clip_ranges_gather_sigrid_hash_v3 (77 nodes) 0.0951568 ms. 2.30497%. fb::clip_ranges_gather (263 nodes) 0.0835801 ms. 2.02455%. aten::sub (1 nodes) 0.054081 ms. 1.31%. aten::repeat (1 nodes) 0.0424465 ms. 1.02818%. aten::norm (1 nodes) 0.0389049 ms. 0.942389%. fb::batch_box_cox (1 nodes) 0.0346992 ms. 0.840514%. aten::__getitem__ (506 nodes) 0.0341335 ms. 0.82681%. prim::TupleUnpack (254 nodes) 0.0306839 ms. 0.743252%. aten::sigmoid (2 nodes) 0.0280489 ms. 0.679426%. aten::mul (3 nodes) 0.0265321 ms. 0.642684%. fb::offsets_to_ranges (253 nodes) 0.0207622 ms. 0.50292%. aten::pow (1 nodes) 0.0202067 ms. 0.489465%. fb::simple_embedding_bag_sum (3 nodes) 0.0195497 ms. 0.47355%. fb::casted_batch_one_hot_lengths (1 nodes) 0.0184351 ms. 0.446551%. fb::concat_add_mul_replacenan_clip (1 nodes) 0.016382 ms. 0.39682%. aten::sum (3 nodes) 0.0158651 ms. 0.384299%. prim::TupleConstruct (1 nodes) 0.0150918 ms. 0.365567%. prim::DictConstruct (2 nodes) 0.00858005 ms. 0.207833%. aten::div (1 nodes) 0.00810684 ms. 0.196371%. fb::sigrid_hash_precompute (1 nodes) 0.00796325 ms. 0.192893%. static_runtime::to_copy (8 nodes) 0.00782038 ms. 0.189432%. prim::ListConstruct (4 nodes) 0.0057504 ms. 0.139291%. aten::contiguous (1 nodes) 0.0044688 ms. 0.108247%. aten::narrow (4 nodes) 0.00284054 ms. 0.068806%. aten::logit (1 nodes) 0.00265049 ms. 0.0642024%. aten::add (1 nodes) 0.00216242 ms. 0.05238%. aten::full (1 nodes) 0.00207732 ms. 0.0503187%. aten::relu (1 nodes) 0.00198412 ms. 0.048061%. fb::gather_ranges (4 nodes) 0.00176954 ms. 0.0428632%. aten::stack (1 nodes) 0.00175913 ms. 0.0426112%. static_runtime::reshape_copy (2 nodes) 0.0016996 ms. 0.0411692%. aten::clamp_min (1 nodes) 0.00128528 ms. 0.0311331%. aten::size (3 nodes) 0.000849156 ms. 0.020569%. aten::expand_as (1 nodes) 0.000757672 ms. 0.018353%. fb::clip_ranges (2 nodes) 0.000596224 ms. 0.0144423%. fb::lengths_to_offsets (3 nodes) 0.000442632 ms. 0.0107218%. static_runtime::flatten_copy (1 nodes) 0.000196158 ms. 0.00475151%. prim::device (1 nodes) 4.12833 ms. in Total StaticRuntime setup time: 0.000451 ms Memory allocation time: 0.0089336 ms Memory deallocation time: 0.0578358 ms Outputs deallocation time: 0.0431742 ms Total memory managed: 947328 bytes Total number of reused tensors: 31 W0421 16:56:34.220682 1522800 PyTorchPredictorContainer.cpp:200] Failed to load metadata file W0421 16:56:34.220772 1522800 PyTorchPredictorContainer.cpp:457] Couldn't find model param config file xl_model_weights/model_param_config I0421 16:56:34.220791 1522800 PyTorchPredictorBenchLib.cpp:137] PyTorch predictor: number of prediction threads 1 I0421 16:56:34.366667 1522800 PyTorchPredictorBenchLib.cpp:230] PyTorch run finished. Milliseconds per iter: 145.863. Iters per second: 6.85573 I0421 16:56:34.514202 1522800 PtVsBlackBoxPredictorBenchLib.cpp:132] Finished comparing PT static runtime and jit interpreter results ``` Reviewed By: hlu1 Differential Revision: D27927731 fbshipit-source-id: 595883a31ba0cadf6449799d47bf2294a1d05b41
1 parent 57cba8e commit 81b5921

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

benchmarks/static_runtime/test_scripts.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,23 @@ const std::string embedding_bag_max_last_offset = R"JIT(
224224
def forward(self, a: Tensor, b: Tensor, c: Tensor):
225225
return torch.embedding_bag(a, b, c, False, 2, False, None, True)
226226
)JIT";
227+
228+
const auto div_tensor = R"JIT(
229+
def forward(self, a: Tensor, b: Tensor):
230+
return torch.div(a, b)
231+
)JIT";
232+
233+
const auto div_scalar = R"JIT(
234+
def forward(self, a: Tensor, b: int):
235+
return torch.div(a, b)
236+
)JIT";
237+
238+
const auto div_tensor_mode = R"JIT(
239+
def forward(self, a: Tensor, b: Tensor, c: str):
240+
return torch.div(a, b, rounding_mode=c)
241+
)JIT";
242+
243+
const auto div_scalar_mode = R"JIT(
244+
def forward(self, a: Tensor, b: float, c: str):
245+
return torch.div(a, b, rounding_mode=c)
246+
)JIT";

benchmarks/static_runtime/test_static_runtime.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,23 @@ TEST(StaticRuntime, IndividualOps_Binary) {
157157
testStaticRuntime(tuple_construct_script_2, args);
158158
}
159159

160+
TEST(StaticRuntime, IndividualOps_Div) {
161+
auto a = at::randn({2, 3});
162+
auto b = at::randn({2, 3});
163+
164+
std::vector<IValue> args0{a, b};
165+
testStaticRuntime(div_tensor, args0);
166+
167+
std::vector<IValue> args1{a, 3};
168+
testStaticRuntime(div_scalar, args1);
169+
170+
std::vector<IValue> args2{a, b, "floor"};
171+
testStaticRuntime(div_tensor_mode, args2);
172+
173+
std::vector<IValue> args3{a, 2.3, "trunc"};
174+
testStaticRuntime(div_scalar_mode, args3);
175+
}
176+
160177
TEST(StaticRuntime, IndividualOps_Reshape) {
161178
auto a = at::randn({2, 3});
162179
auto b = std::vector<int64_t>({3, 2});

torch/csrc/jit/runtime/static/ops.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/CPUFunctions.h>
44
#include <ATen/InferSize.h>
55
#include <ATen/NativeFunctions.h>
6+
#include <ATen/ScalarOps.h>
67
#include <ATen/TensorUtils.h>
78
#include <ATen/native/EmbeddingBag.h>
89
#include <ATen/native/IndexingUtils.h>
@@ -1129,5 +1130,25 @@ REGISTER_OPERATOR_FUNCTOR(
11291130
};
11301131
});
11311132

1133+
REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
1134+
return [](ProcessedNode* p_node) {
1135+
const auto& in0_t = p_node->Input(0).toTensor();
1136+
c10::optional<std::string> rounding_mode = c10::nullopt;
1137+
if (p_node->inputs().size() > 2) {
1138+
rounding_mode = p_node->Input(2).toOptional<std::string>();
1139+
}
1140+
1141+
if (p_node->Output(0).isNone()) {
1142+
p_node->Output(0) = create_empty_from(in0_t);
1143+
}
1144+
auto& out_t = p_node->Output(0).toTensor();
1145+
fastResizeToZero(out_t);
1146+
1147+
const auto& in1_t = p_node->Input(1).isTensor()
1148+
? p_node->Input(1).toTensor()
1149+
: at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar());
1150+
at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode);
1151+
};
1152+
});
11321153
} // namespace jit
11331154
} // namespace torch

0 commit comments

Comments
 (0)