diff --git a/luatest/assertions.lua b/luatest/assertions.lua index edf768c..bbf26fa 100644 --- a/luatest/assertions.lua +++ b/luatest/assertions.lua @@ -10,6 +10,8 @@ local mismatch_formatter = require('luatest.mismatch_formatter') local pp = require('luatest.pp') local log = require('luatest.log') local utils = require('luatest.utils') +local tarantool = require('tarantool') +local ffi = require('ffi') local prettystr = pp.tostring local prettystr_pairs = pp.tostring_pair @@ -18,6 +20,8 @@ local M = {} local xfail = false +local box_error_type = ffi.typeof(box.error.new(box.error.UNKNOWN)) + -- private exported functions (for testing) M.private = {} @@ -85,12 +89,87 @@ local function error_msg_equality(actual, expected, deep_analysis) end M.private.error_msg_equality = error_msg_equality +-- +-- The wrapper is used when trace check is required. See pcall_check_trace. +-- +-- Without wrapper the trace will point to the pcall implementation. So trace +-- check is not strict enough (the trace can point to any pcall in below in +-- call trace). +-- +local trace_line = debug.getinfo(1, 'l').currentline + 2 +local function wrapped_call(fn, ...) + local res = utils.table_pack(fn(...)) + -- With `return fn(...)` wrapper does not work due to tail call + -- optimization. + return unpack(res, 1, res.n) +end + +-- Expected trace for trace check. See pcall_check_trace. +local wrapped_trace = { + file = debug.getinfo(1, 'S').short_src, + line = trace_line, +} + +-- Used in tests to force check for given module. +M.private.check_trace_module = nil + +-- +-- Return true if error trace check is required for function. Basically it is +-- just a wrapper around Tarantool's utils.proper_trace_required. Additionally +-- old Tarantool versions where this function is not present are handled. +-- +local function trace_check_is_required(fn) + local src = debug.getinfo(fn, 'S').short_src + if M.private.check_trace_module == src then + return true + end + if tarantool._internal ~= nil and + tarantool._internal.trace_check_is_required ~= nil then + local path = debug.getinfo(fn, 'S').short_src + return tarantool._internal.trace_check_is_required(path) + end + return false +end + +-- +-- Substitute for pcall but additionally checks error trace if required. +-- +-- The error should be box.error and trace should point to the place +-- where fn is called. +-- +-- level is used to set proper level in error assertions that use this function. +-- +local function pcall_check_trace(level, fn, ...) + local fn_explicit = fn + if type(fn) ~= 'function' then + fn_explicit = debug.getmetatable(fn).__call + end + if not trace_check_is_required(fn_explicit) then + return pcall(fn, ...) + end + local ok, err = pcall(wrapped_call, fn, ...) + if ok then + return ok, err + end + if type(err) ~= 'cdata' or ffi.typeof(err) ~= box_error_type then + fail_fmt(level + 1, nil, 'Error raised is not a box.error: %s', + prettystr(err)) + end + local unpacked = err:unpack() + if not comparator.equals(unpacked.trace[1], wrapped_trace) then + fail_fmt(level + 1, nil, + 'Unexpected error trace, expected: %s, actual: %s', + prettystr(wrapped_trace), prettystr(unpacked.trace[1])) + end + return ok, err +end + --- Check that calling fn raises an error. -- -- @func fn -- @param ... arguments for function function M.assert_error(fn, ...) - local ok, err = pcall(fn, ...) + local ok, err = pcall_check_trace(2, fn, ...) if ok then failure("Expected an error when calling function but no error generated", nil, 2) end @@ -464,7 +543,7 @@ function M.assert_str_matches(value, pattern, start, final, message) end local function _assert_error_msg_equals(stripFileAndLine, expectedMsg, func, ...) - local no_error, error_msg = pcall(func, ...) + local no_error, error_msg = pcall_check_trace(3, func, ...) if no_error then local failure_message = string.format( 'Function successfully returned: %s\nExpected error: %s', @@ -530,7 +609,7 @@ end -- @func fn -- @param ... arguments for function function M.assert_error_msg_contains(expected_partial, fn, ...) - local no_error, error_msg = pcall(fn, ...) + local no_error, error_msg = pcall_check_trace(2, fn, ...) log.info('Assert error message %s contains %s', error_msg, expected_partial) if no_error then local failure_message = string.format( @@ -553,7 +632,7 @@ end -- @func fn -- @param ... arguments for function function M.assert_error_msg_matches(pattern, fn, ...) - local no_error, error_msg = pcall(fn, ...) + local no_error, error_msg = pcall_check_trace(2, fn, ...) if no_error then local failure_message = string.format( 'Function successfully returned: %s\nExpected error matching: %s', @@ -578,7 +657,7 @@ end -- @func fn -- @param ... arguments for function function M.assert_error_covers(expected, fn, ...) - local ok, actual = pcall(fn, ...) + local ok, actual = pcall_check_trace(2, fn, ...) if ok then fail_fmt(2, nil, 'Function successfully returned: %s\nExpected error: %s', diff --git a/luatest/utils.lua b/luatest/utils.lua index 38375da..746b40e 100644 --- a/luatest/utils.lua +++ b/luatest/utils.lua @@ -191,4 +191,9 @@ function utils.is_tarantool_binary(path) return path:find('^.*/tarantool[^/]*$') ~= nil end +-- Return args as table with 'n' set to args number. +function utils.table_pack(...) + return {n = select('#', ...), ...} +end + return utils diff --git a/test/luaunit/assertions_error_test.lua b/test/luaunit/assertions_error_test.lua index a385f30..e386596 100644 --- a/test/luaunit/assertions_error_test.lua +++ b/test/luaunit/assertions_error_test.lua @@ -4,6 +4,7 @@ local g = t.group() local helper = require('test.helpers.general') local assert_failure = helper.assert_failure local assert_failure_equals = helper.assert_failure_equals +local assert_failure_contains = helper.assert_failure_contains local function f() end @@ -17,6 +18,28 @@ local function f_with_table_error() error(setmetatable({this_table="has error"}, ts)) end +local f_check_trace = function(level) + box.error(box.error.UNKNOWN, level) +end + +local line = debug.getinfo(1, 'l').currentline + 2 +local f_check_trace_wrapper = function() + f_check_trace(2) +end + +local _, err = pcall(f_check_trace_wrapper) +local box_error_has_level = err:unpack().trace[1].line == line + +local f_check_success = function() + return {1, 'foo'} +end + +local THIS_MODULE = debug.getinfo(1, 'S').short_src + +g.after_each(function() + t.private.check_trace_module = nil +end) + function g.test_assert_error() local x = 1 @@ -51,6 +74,10 @@ function g.test_assert_error() -- error generated as table t.assert_error(f_with_table_error, 1) + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error, f_check_trace, 1) end function g.test_assert_errorMsgContains() @@ -64,6 +91,12 @@ function g.test_assert_errorMsgContains() -- error message is a table which converts to a string t.assert_error_msg_contains('This table has error', f_with_table_error, 1) + + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error_msg_contains, 'bar', f_check_trace, + 1) end function g.test_assert_error_msg_equals() @@ -103,6 +136,11 @@ function g.test_assert_error_msg_equals() -- expected table, error generated as string, no match assert_failure(t.assert_error_msg_equals, {1}, function() error("{1}") end, 33) + + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error_msg_equals, 'bar', f_check_trace, 1) end function g.test_assert_errorMsgMatches() @@ -117,6 +155,11 @@ function g.test_assert_errorMsgMatches() -- one space added to cause failure assert_failure(t.assert_error_msg_matches, ' This is an error', f_with_error, x) assert_failure(t.assert_error_msg_matches, "This", f_with_table_error, 33) + + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error_msg_matches, 'bar', f_check_trace, 1) end function g.test_assert_errorCovers() @@ -140,4 +183,36 @@ function g.test_assert_errorCovers() -- bad error coverage assert_failure(t.assert_error_covers, {b = 2}, function(a, b) error({a = a, b = b}) end, 1, 3) + + -- test assert failure due to unexpected error trace + t.private.check_trace_module = THIS_MODULE + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error_covers, 'bar', f_check_trace, 1) +end + +function g.test_error_trace_check() + local foo = function(a) error(a) end + -- test when trace check is NOT required + t.assert_error_msg_content_equals('foo', foo, 'foo') + + local ftor = setmetatable({}, { + __call = function(_, ...) return f_check_trace(...) end + }) + t.private.check_trace_module = THIS_MODULE + + -- test when trace check IS required + if box_error_has_level then + t.assert_error_covers({code = box.error.UNKNOWN}, f_check_trace, 2) + t.assert_error_covers({code = box.error.UNKNOWN}, ftor, 2) + end + + -- check if there is no error then the returned value is reported correctly + assert_failure_contains('Function successfully returned: {1, "foo"}', + t.assert_error_msg_equals, 'bar', f_check_success) + -- test assert failure due to unexpected error type + assert_failure_contains('Error raised is not a box.error:', + t.assert_error, foo, 'foo') + -- test assert failure due to unexpected error trace + assert_failure_contains('Unexpected error trace, expected:', + t.assert_error, f_check_trace, 1) end diff --git a/test/utils_test.lua b/test/utils_test.lua index d9312b2..0c45122 100644 --- a/test/utils_test.lua +++ b/test/utils_test.lua @@ -22,3 +22,11 @@ g.test_is_tarantool_binary = function() ("Unexpected result for %q"):format(path)) end end + +g.test_table_pack = function() + t.assert_equals(utils.table_pack(), {n = 0}) + t.assert_equals(utils.table_pack(1), {n = 1, 1}) + t.assert_equals(utils.table_pack(1, 2), {n = 2, 1, 2}) + t.assert_equals(utils.table_pack(1, 2, nil), {n = 3, 1, 2}) + t.assert_equals(utils.table_pack(1, 2, nil, 3), {n = 4, 1, 2, nil, 3}) +end