-
Notifications
You must be signed in to change notification settings - Fork 3k
Testing
Note: here we use the words kernel and operator interchangeably.
OpTester is an unit test utility that allows you to easily perform black box testing on a kernel. You simply specify your kernel name and version, with input values and expected output values. The tester will run the kernel with the supplied input, and verify it results with expected values. Search OpTester
in test code you get tons of example, below is one of them:
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {4, 3}, {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0});
test.AddInput<uint8_t>("T2", {3, 2}, {1, 4, 2, 5, 3, 6});
test.AddInput<uint8_t>("a_zero_point", {}, {12});
test.AddInput<uint8_t>("b_zero_point", {}, {0});
test.AddOutput<int32_t>("T3", {4, 2}, {-38, -83, -44, -98, -50, -113, -56, -128});
test.Run();
Here the operator under test is MatMulInteger since version 10. AddInput
function supply input by specifying input parameter name followed by values. AddOutput
is for supplying expected output, which will be compared against computing result at the end of the test.
Under the hood, OpTester construct a model with a single node, serialize the model to a protobuf string, and give it to OnnxRuntime for inference. It compare the inference results with the expected ones.
The code sample above is for black box testing of a kernel. Sometimes you may want to perform grey box testing, verifying certain internal behavior. One way you can do this, is to write a test class as a sub class of the kernel under test, map it to a special operator, and test that operator.
As mentioned above, the OpTester
construct a model and serialize it. As a result, you need to specify a custom schema for your operator. Luckily, it is supported, by OpTester::AddCustomOpRegistry()
method.
Next, you need to specify kernel definition and kernel creation function as well, before you can actually run your kernel. An example is listed below, hoping to help you construct your own:
Here the kernel under test is QLinearMatMul
.
struct PrePackTestOp {
// Name and domain of the custom operator
static constexpr const char* OpName = "QLinearMatMulPrePack";
static constexpr const char* OpDomain = "testing";
// Constructing schema for your custom kernel
static ONNX_NAMESPACE::OpSchema OpSchema() {
// Get QLinearMatMul schema from global registry, copy and hack
auto p_original = ONNX_NAMESPACE::OpSchemaRegistry::Schema("QLinearMatMul", 10, "");
ONNX_NAMESPACE::OpSchema modified;
// set your own name
modified.SetDoc("Return success, error, or throw based on the input.")
.SetName(OpName)
.SetDomain(OpDomain)
.SinceVersion(10);
// copy input definitions
const auto& inputs = p_original->inputs();
for (int i = 0; i < static_cast<int>(inputs.size()); i++) {
const auto& in = inputs[i];
modified.Input(i, in.GetName(), in.GetDescription(), in.GetTypeStr(),
in.GetOption(), in.GetIsHomogeneous(), in.GetMinArity(), in.GetDifferentiationCategory());
}
// copy output definitions
const auto& outputs = p_original->outputs();
for (int oi = 0; oi < static_cast<int>(outputs.size()); oi++) {
const auto& out = outputs[oi];
modified.Output(oi, out.GetName(), out.GetDescription(), out.GetTypeStr(),
out.GetOption(), out.GetIsHomogeneous(), out.GetMinArity(), out.GetDifferentiationCategory());
}
// copy types
for (const auto& ty : p_original->typeConstraintParams()) {
modified.TypeConstraint(ty.type_param_str, ty.allowed_type_strs, ty.description);
}
return modified;
}
// Extending kernel under test
class QLinearMatMulPrePackT : public QLinearMatMul {
public:
QLinearMatMulPrePackT(const OpKernelInfo& info) : QLinearMatMul(info) {
}
Status Compute(OpKernelContext* context) const override {
/*
* Custom verification logic here
*/
Status ret = QLinearMatMul::Compute(context);
/*
* Custom verification logic here
*/
return ret;
}
};
static KernelDefBuilder KernelDef() {
// TODO extract this out of existing OP's kernel def instead of copying code!
KernelDefBuilder def;
def.SetName(OpName)
.SetDomain(OpDomain)
.SinceVersion(10)
.TypeConstraint("T1", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(), DataTypeImpl::GetTensorType<int8_t>()})
.TypeConstraint("T3", DataTypeImpl::GetTensorType<uint8_t>())
.Provider(onnxruntime::kCpuExecutionProvider);
return def;
}
};
// Now we need to add the custom op and then run the test
TEST(QuantizeLinearMatmulOpTest, QLinearMatMulPrePack) {
// Register custom schema
auto registry = std::make_shared<CustomRegistry>();
std::vector<ONNX_NAMESPACE::OpSchema> schemas{PrePackTestOp::OpSchema()};
Status status;
ASSERT_TRUE((status = registry->RegisterOpSet(schemas, PrePackTestOp::OpDomain, 10, 11)).IsOK()) << status;
// Register custom kernel
KernelCreateFn kernel_create_fn = [](const OpKernelInfo& info) { return new typename PrePackTestOp::QLinearMatMulPrePackT(info); };
auto kernel_def = PrePackTestOp::KernelDef();
ASSERT_TRUE((status = registry->RegisterCustomKernel(kernel_def, kernel_create_fn)).IsOK()) << status;
// Specify which custom op you want to run
OpTester test_non_empty(PrePackTestOp::OpName, 10, PrePackTestOp::OpDomain);
// Add schema and kernel def
test_non_empty.AddCustomOpRegistry(registry);
test_non_empty.AddInput<uint8_t>("T1", {2, 4}, {208, 236, 0, 238, 3, 214, 255, 29});
test_non_empty.AddInput<float>("a_scale", {1}, {0.0066f}, true);
test_non_empty.AddInput<uint8_t>("a_zero_point", {1}, {113}, true);
test_non_empty.AddInput<uint8_t>("T2", {4, 3}, {152, 51, 244, 60, 26, 255, 0, 127, 246, 127, 254, 247}, true);
test_non_empty.AddInput<float>("b_scale", {1}, {0.00705f}, true);
test_non_empty.AddInput<uint8_t>("b_zero_point", {1}, {114}, true);
test_non_empty.AddInput<float>("y_scale", {1}, {0.0107f}, true);
test_non_empty.AddInput<uint8_t>("y_zero_point", {1}, {118}, true);
test_non_empty.AddOutput<uint8_t>("T3", {2, 3}, {168, 115, 255, 1, 66, 151});
test_non_empty.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
This test will run code defined in QLinearMatMulPrePackT
, allowing you to verify internals of QLinearMatMul
during runtime.
Please use the learning roadmap on the home wiki page for building general understanding of ORT.