Skip to content

Commit 4fc9e95

Browse files
supriyarfacebook-github-bot
authored andcommitted
[quant] Add benchmakrs for embedding_bag coversion ops (pytorch#43291)
Summary: Pull Request resolved: pytorch#43291 Test Float2Fused and Fused2Float conversion operators for embedding_bag byte and 4-bit ops Test Plan: ``` python -m pt.qembedding_pack_tes ``` Imported from OSS Reviewed By: radkris-git Differential Revision: D23231641 fbshipit-source-id: a2afe51bba52980d2e96dfd7dbc183327e9349fd
1 parent c8bc298 commit 4fc9e95

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

benchmarks/operator_benchmark/benchmark_all_quantized_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
qtensor_method_test,
2323
quantization_test,
2424
qunary_test,
25+
qembedding_pack_test,
2526
)
2627

2728

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import operator_benchmark as op_bench
4+
import torch
5+
import numpy as np
6+
7+
embeddingbag_conversion_short_configs = op_bench.cross_product_configs(
8+
num_embeddings=(80,),
9+
embedding_dim=(128, 256, 512),
10+
tags=('short',)
11+
)
12+
13+
embeddingbag_conversion_long_configs = op_bench.cross_product_configs(
14+
num_embeddings=(100, 120, 1000),
15+
embedding_dim=(16, 64, 128, 256, 512, 1024, 2048),
16+
tags=('long',)
17+
)
18+
19+
conversion_ops = op_bench.op_list(
20+
attrs=(
21+
('qembeddingbag_byte_prepack', torch.ops.quantized.embedding_bag_byte_prepack),
22+
('qembeddingbag_4bit_prepack', torch.ops.quantized.embedding_bag_4bit_prepack),
23+
('qembeddingbag_2bit_prepack', torch.ops.quantized.embedding_bag_2bit_prepack),
24+
),
25+
attr_names=('op_name', 'op_func'),
26+
)
27+
28+
unpack_ops = op_bench.op_list(
29+
attrs=(
30+
('qembeddingbag_byte_unpack', torch.ops.quantized.embedding_bag_byte_unpack),
31+
('qembeddingbag_4bit_unpack', torch.ops.quantized.embedding_bag_4bit_unpack),
32+
('qembeddingbag_2bit_unpack', torch.ops.quantized.embedding_bag_2bit_unpack),
33+
),
34+
attr_names=('op_name', 'op_func'),
35+
)
36+
37+
class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase):
38+
def init(self, num_embeddings, embedding_dim, op_func):
39+
self.weight = torch.from_numpy((np.random.random_sample((
40+
num_embeddings, embedding_dim)) + 1).astype(np.float32))
41+
self.op_func = op_func
42+
43+
def forward(self):
44+
return self.op_func(self.weight)
45+
46+
class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
47+
def init(self, num_embeddings, embedding_dim, op_func):
48+
weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float)
49+
self.packed_weight = weight.to(torch.uint8)
50+
self.op_func = op_func
51+
52+
def forward(self):
53+
return self.op_func(self.packed_weight)
54+
55+
56+
op_bench.generate_pt_tests_from_op_list(conversion_ops,
57+
embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
58+
EmbeddingBagFloatToFusedBase)
59+
op_bench.generate_pt_tests_from_op_list(unpack_ops,
60+
embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
61+
EmbeddingBagFusedToFloatBase)
62+
63+
if __name__ == "__main__":
64+
op_bench.benchmark_runner.main()

0 commit comments

Comments
 (0)