From fbe938d0cd40521dac4ec0b9912ee8fd9259a6d7 Mon Sep 17 00:00:00 2001 From: Nikolay Shirokovskiy Date: Mon, 6 May 2024 11:00:56 +0300 Subject: [PATCH] Add trace check for error assertions In the scope of the referenced Tarantool issue we are going to change error trace of API to point to the caller place. The error should be box.error and the trace will be changed for several modules at the beginning (fix all API at once is difficult). We are going to use existing tests to test the change. In particular in case of Luatest let's check trace in `assert_error*` assertions besides the main assertion. Required for https://github.com/tarantool/tarantool/issues/9914 --- luatest/assertions.lua | 89 ++++++++++++++++++++++++-- luatest/utils.lua | 5 ++ test/luaunit/assertions_error_test.lua | 75 ++++++++++++++++++++++ test/utils_test.lua | 8 +++ 4 files changed, 172 insertions(+), 5 deletions(-) diff --git a/luatest/assertions.lua b/luatest/assertions.lua index edf768c9..bbf26faf 100644 --- a/luatest/assertions.lua +++ b/luatest/assertions.lua @@ -10,6 +10,8 @@ local mismatch_formatter = require('luatest.mismatch_formatter') local pp = require('luatest.pp') local log = require('luatest.log') local utils = require('luatest.utils') +local tarantool = require('tarantool') +local ffi = require('ffi') local prettystr = pp.tostring local prettystr_pairs = pp.tostring_pair @@ -18,6 +20,8 @@ local M = {} local xfail = false +local box_error_type = ffi.typeof(box.error.new(box.error.UNKNOWN)) + -- private exported functions (for testing) M.private = {} @@ -85,12 +89,87 @@ local function error_msg_equality(actual, expected, deep_analysis) end M.private.error_msg_equality = error_msg_equality +-- +-- The wrapper is used when trace check is required. See pcall_check_trace. +-- +-- Without wrapper the trace will point to the pcall implementation. So trace +-- check is not strict enough (the trace can point to any pcall in below in +-- call trace). +-- +local trace_line = debug.getinfo(1, 'l').currentline + 2 +local function wrapped_call(fn, ...) + local res = utils.table_pack(fn(...)) + -- With `return fn(...)` wrapper does not work due to tail call + -- optimization. + return unpack(res, 1, res.n) +end + +-- Expected trace for trace check. See pcall_check_trace. +local wrapped_trace = { + file = debug.getinfo(1, 'S').short_src, + line = trace_line, +} + +-- Used in tests to force check for given module. +M.private.check_trace_module = nil + +-- +-- Return true if error trace check is required for function. Basically it is +-- just a wrapper around Tarantool's utils.proper_trace_required. Additionally +-- old Tarantool versions where this function is not present are handled. +-- +local function trace_check_is_required(fn) + local src = debug.getinfo(fn, 'S').short_src + if M.private.check_trace_module == src then + return true + end + if tarantool._internal ~= nil and + tarantool._internal.trace_check_is_required ~= nil then + local path = debug.getinfo(fn, 'S').short_src + return tarantool._internal.trace_check_is_required(path) + end + return false +end + +-- +-- Substitute for pcall but additionally checks error trace if required. +-- +-- The error should be box.error and trace should point to the place +-- where fn is called. +-- +-- level is used to set proper level in error assertions that use this function. +-- +local function pcall_check_trace(level, fn, ...) + local fn_explicit = fn + if type(fn) ~= 'function' then + fn_explicit = debug.getmetatable(fn).__call + end + if not trace_check_is_required(fn_explicit) then + return pcall(fn, ...) + end + local ok, err = pcall(wrapped_call, fn, ...) + if ok then + return ok, err + end + if type(err) ~= 'cdata' or ffi.typeof(err) ~= box_error_type then + fail_fmt(level + 1, nil, 'Error raised is not a box.error: %s', + prettystr(err)) + end + local unpacked = err:unpack() + if not comparator.equals(unpacked.trace[1], wrapped_trace) then + fail_fmt(level + 1, nil, + 'Unexpected error trace, expected: %s, actual: %s', + prettystr(wrapped_trace), prettystr(unpacked.trace[1])) + end + return ok, err +end + --- Check that calling fn raises an error. -- -- @func fn -- @param ... arguments for function function M.assert_error(fn, ...) - local ok, err = pcall(fn, ...) + local ok, err = pcall_check_trace(2, fn, ...) if ok then failure("Expected an error when calling function but no error generated", nil, 2) end @@ -464,7 +543,7 @@ function M.assert_str_matches(value, pattern, start, final, message) end local function _assert_error_msg_equals(stripFileAndLine, expectedMsg, func, ...) - local no_error, error_msg = pcall(func, ...) + local no_error, error_msg = pcall_check_trace(3, func, ...) if no_error then local failure_message = string.format( 'Function successfully returned: %s\nExpected error: %s', @@ -530,7 +609,7 @@ end -- @func fn -- @param ... arguments for function function M.assert_error_msg_contains(expected_partial, fn, ...) - local no_error, error_msg = pcall(fn, ...) + local no_error, error_msg = pcall_check_trace(2, fn, ...) log.info('Assert error message %s contains %s', error_msg, expected_partial) if no_error then local failure_message = string.format( @@ -553,7 +632,7 @@ end -- @func fn -- @param ... arguments for function function M.assert_error_msg_matches(pattern, fn, ...) - local no_error, error_msg = pcall(fn, ...) + local no_error, error_msg = pcall_check_trace(2, fn, ...) if no_error then local failure_message = string.format( 'Function successfully returned: %s\nExpected error matching: %s', @@ -578,7 +657,7 @@ end -- @func fn -- @param ... arguments for function function M.assert_error_covers(expected, fn, ...) - local ok, actual = pcall(fn, ...) + local ok, actual = pcall_check_trace(2, fn, ...) if ok then fail_fmt(2, nil, 'Function successfully returned: %s\nExpected error: %s', diff --git a/luatest/utils.lua b/luatest/utils.lua index 38375daf..746b40e2 100644 --- a/luatest/utils.lua +++ b/luatest/utils.lua @@ -191,4 +191,9 @@ function utils.is_tarantool_binary(path) return path:find('^.*/tarantool[^/]*$') ~= nil end +-- Return args as table with 'n' set to args number. +function utils.table_pack(...) + return {n = select('#', ...), ...} +end + return utils diff --git a/test/luaunit/assertions_error_test.lua b/test/luaunit/assertions_error_test.lua index a385f301..e3865966 100644 --- a/test/luaunit/assertions_error_test.lua +++ b/test/luaunit/assertions_error_test.lua @@ -4,6 +4,7 @@ local g = t.group() local helper = require('test.helpers.general') local assert_failure = helper.assert_failure local assert_failure_equals = helper.assert_failure_equals +local assert_failure_contains = helper.assert_failure_contains local function f() end @@ -17,6 +18,28 @@ local function f_with_table_error() error(setmetatable({this_table="has error"}, ts)) end +local f_check_trace = function(level) + box.error(box.error.UNKNOWN, level) +end + +local line = debug.getinfo(1, 'l').currentline + 2 +local f_check_trace_wrapper = function() + f_check_trace(2) +end + +local _, err = pcall(f_check_trace_wrapper) +local box_error_has_level = err:unpack().trace[1].line == line + +local f_check_success = function() + return {1, 'foo'} +end + +local THIS_MODULE = debug.getinfo(1, 'S').short_src + +g.after_each(function() + t.private.check_trace_module = nil +end) + function g.test_assert_error() local x = 1 @@ -51,6 +74,10 @@ function g.test_assert_error() -- error generated as table t.assert_error(f_with_table_error, 1) + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error, f_check_trace, 1) end function g.test_assert_errorMsgContains() @@ -64,6 +91,12 @@ function g.test_assert_errorMsgContains() -- error message is a table which converts to a string t.assert_error_msg_contains('This table has error', f_with_table_error, 1) + + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error_msg_contains, 'bar', f_check_trace, + 1) end function g.test_assert_error_msg_equals() @@ -103,6 +136,11 @@ function g.test_assert_error_msg_equals() -- expected table, error generated as string, no match assert_failure(t.assert_error_msg_equals, {1}, function() error("{1}") end, 33) + + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error_msg_equals, 'bar', f_check_trace, 1) end function g.test_assert_errorMsgMatches() @@ -117,6 +155,11 @@ function g.test_assert_errorMsgMatches() -- one space added to cause failure assert_failure(t.assert_error_msg_matches, ' This is an error', f_with_error, x) assert_failure(t.assert_error_msg_matches, "This", f_with_table_error, 33) + + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error_msg_matches, 'bar', f_check_trace, 1) end function g.test_assert_errorCovers() @@ -140,4 +183,36 @@ function g.test_assert_errorCovers() -- bad error coverage assert_failure(t.assert_error_covers, {b = 2}, function(a, b) error({a = a, b = b}) end, 1, 3) + + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error_covers, 'bar', f_check_trace, 1) +end + +function g.test_error_trace_check() + local foo = function(a) error(a) end + -- test when trace check is NOT required + t.assert_error_msg_content_equals('foo', foo, 'foo') + + local ftor = setmetatable({}, { + __call = function(_, ...) return f_check_trace(...) end + }) + t.private.check_trace_module = THIS_MODULE + + -- test when trace check IS required + if box_error_has_level then + t.assert_error_covers({code = box.error.UNKNOWN}, f_check_trace, 2) + t.assert_error_covers({code = box.error.UNKNOWN}, ftor, 2) + end + + -- check if there is no error then the returned value is reported correctly + assert_failure_contains('Function successfully returned: {1, "foo"}', + t.assert_error_msg_equals, 'bar', f_check_success) + -- test assert failure due to unexpected error type + assert_failure_contains('Error raised is not a box.error:', + t.assert_error, foo, 'foo') + -- test assert failure due to unexpected error trace + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error, f_check_trace, 1) end diff --git a/test/utils_test.lua b/test/utils_test.lua index d9312b2e..0c451223 100644 --- a/test/utils_test.lua +++ b/test/utils_test.lua @@ -22,3 +22,11 @@ g.test_is_tarantool_binary = function() ("Unexpected result for %q"):format(path)) end end + +g.test_table_pack = function() + t.assert_equals(utils.table_pack(), {n = 0}) + t.assert_equals(utils.table_pack(1), {n = 1, 1}) + t.assert_equals(utils.table_pack(1, 2), {n = 2, 1, 2}) + t.assert_equals(utils.table_pack(1, 2, nil), {n = 3, 1, 2}) + t.assert_equals(utils.table_pack(1, 2, nil, 3), {n = 4, 1, 2, nil, 3}) +end