|
| 1 | +from __future__ import absolute_import, division, print_function, unicode_literals |
| 2 | + |
| 3 | +import operator_benchmark as op_bench |
| 4 | +import torch |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +embeddingbag_conversion_short_configs = op_bench.cross_product_configs( |
| 8 | + num_embeddings=(80,), |
| 9 | + embedding_dim=(128, 256, 512), |
| 10 | + tags=('short',) |
| 11 | +) |
| 12 | + |
| 13 | +embeddingbag_conversion_long_configs = op_bench.cross_product_configs( |
| 14 | + num_embeddings=(100, 120, 1000), |
| 15 | + embedding_dim=(16, 64, 128, 256, 512, 1024, 2048), |
| 16 | + tags=('long',) |
| 17 | +) |
| 18 | + |
| 19 | +conversion_ops = op_bench.op_list( |
| 20 | + attrs=( |
| 21 | + ('qembeddingbag_byte_prepack', torch.ops.quantized.embedding_bag_byte_prepack), |
| 22 | + ('qembeddingbag_4bit_prepack', torch.ops.quantized.embedding_bag_4bit_prepack), |
| 23 | + ('qembeddingbag_2bit_prepack', torch.ops.quantized.embedding_bag_2bit_prepack), |
| 24 | + ), |
| 25 | + attr_names=('op_name', 'op_func'), |
| 26 | +) |
| 27 | + |
| 28 | +unpack_ops = op_bench.op_list( |
| 29 | + attrs=( |
| 30 | + ('qembeddingbag_byte_unpack', torch.ops.quantized.embedding_bag_byte_unpack), |
| 31 | + ('qembeddingbag_4bit_unpack', torch.ops.quantized.embedding_bag_4bit_unpack), |
| 32 | + ('qembeddingbag_2bit_unpack', torch.ops.quantized.embedding_bag_2bit_unpack), |
| 33 | + ), |
| 34 | + attr_names=('op_name', 'op_func'), |
| 35 | +) |
| 36 | + |
| 37 | +class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase): |
| 38 | + def init(self, num_embeddings, embedding_dim, op_func): |
| 39 | + self.weight = torch.from_numpy((np.random.random_sample(( |
| 40 | + num_embeddings, embedding_dim)) + 1).astype(np.float32)) |
| 41 | + self.op_func = op_func |
| 42 | + |
| 43 | + def forward(self): |
| 44 | + return self.op_func(self.weight) |
| 45 | + |
| 46 | +class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase): |
| 47 | + def init(self, num_embeddings, embedding_dim, op_func): |
| 48 | + weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float) |
| 49 | + self.packed_weight = weight.to(torch.uint8) |
| 50 | + self.op_func = op_func |
| 51 | + |
| 52 | + def forward(self): |
| 53 | + return self.op_func(self.packed_weight) |
| 54 | + |
| 55 | + |
| 56 | +op_bench.generate_pt_tests_from_op_list(conversion_ops, |
| 57 | + embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs, |
| 58 | + EmbeddingBagFloatToFusedBase) |
| 59 | +op_bench.generate_pt_tests_from_op_list(unpack_ops, |
| 60 | + embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs, |
| 61 | + EmbeddingBagFusedToFloatBase) |
| 62 | + |
| 63 | +if __name__ == "__main__": |
| 64 | + op_bench.benchmark_runner.main() |
0 commit comments