Skip to content

Commit

Permalink
Add trace check for error assertions
Browse files Browse the repository at this point in the history
In the scope of the referenced Tarantool issue we are going to change
error trace of API to point to the caller place. The error should be
box.error and the trace will be changed for several modules at the
beginning (fix all API at once is difficult).

We are going to use existing tests to test the change. In particular in
case of Luatest let's check trace in `assert_error*` assertions besides
the main assertion.

Required for tarantool/tarantool#9914
  • Loading branch information
nshy committed Jun 3, 2024
1 parent d985997 commit b8b16d2
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Improve `luatest.log` function if a `nil` value is passed (gh-360).
- Added `assert_error_covers`.
- Add more logs (gh-326).
- Make `assert_error_*` additionally check error trace if required.

## 1.0.1

Expand Down
89 changes: 84 additions & 5 deletions luatest/assertions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ local mismatch_formatter = require('luatest.mismatch_formatter')
local pp = require('luatest.pp')
local log = require('luatest.log')
local utils = require('luatest.utils')
local tarantool = require('tarantool')
local ffi = require('ffi')

local prettystr = pp.tostring
local prettystr_pairs = pp.tostring_pair
Expand All @@ -18,6 +20,8 @@ local M = {}

local xfail = false

local box_error_type = ffi.typeof(box.error.new(box.error.UNKNOWN))

-- private exported functions (for testing)
M.private = {}

Expand Down Expand Up @@ -85,12 +89,87 @@ local function error_msg_equality(actual, expected, deep_analysis)
end
M.private.error_msg_equality = error_msg_equality

--
-- The wrapper is used when trace check is required. See pcall_check_trace.
--
-- Without wrapper the trace will point to the pcall implementation. So trace
-- check is not strict enough (the trace can point to any pcall in below in
-- call trace).
--
local trace_line = debug.getinfo(1, 'l').currentline + 2
local function wrapped_call(fn, ...)
local res = utils.table_pack(fn(...))
-- With `return fn(...)` wrapper does not work due to tail call
-- optimization.
return unpack(res, 1, res.n)
end

-- Expected trace for trace check. See pcall_check_trace.
local wrapped_trace = {
file = debug.getinfo(1, 'S').short_src,
line = trace_line,
}

-- Used in tests to force check for given module.
M.private.check_trace_module = nil

--
-- Return true if error trace check is required for function. Basically it is
-- just a wrapper around Tarantool's utils.proper_trace_required. Additionally
-- old Tarantool versions where this function is not present are handled.
--
local function trace_check_is_required(fn)
local src = debug.getinfo(fn, 'S').short_src
if M.private.check_trace_module == src then
return true
end
if tarantool._internal ~= nil and
tarantool._internal.trace_check_is_required ~= nil then
local path = debug.getinfo(fn, 'S').short_src
return tarantool._internal.trace_check_is_required(path)
end
return false
end

--
-- Substitute for pcall but additionally checks error trace if required.
--
-- The error should be box.error and trace should point to the place
-- where fn is called.
--
-- level is used to set proper level in error assertions that use this function.
--
local function pcall_check_trace(level, fn, ...)
local fn_explicit = fn
if type(fn) ~= 'function' then
fn_explicit = debug.getmetatable(fn).__call
end
if not trace_check_is_required(fn_explicit) then
return pcall(fn, ...)
end
local ok, err = pcall(wrapped_call, fn, ...)
if ok then
return ok, err
end
if type(err) ~= 'cdata' or ffi.typeof(err) ~= box_error_type then
fail_fmt(level + 1, nil, 'Error raised is not a box.error: %s',
prettystr(err))
end
local unpacked = err:unpack()
if not comparator.equals(unpacked.trace[1], wrapped_trace) then
fail_fmt(level + 1, nil,
'Unexpected error trace, expected: %s, actual: %s',
prettystr(wrapped_trace), prettystr(unpacked.trace[1]))
end
return ok, err
end

--- Check that calling fn raises an error.
--
-- @func fn
-- @param ... arguments for function
function M.assert_error(fn, ...)
local ok, err = pcall(fn, ...)
local ok, err = pcall_check_trace(2, fn, ...)
if ok then
failure("Expected an error when calling function but no error generated", nil, 2)
end
Expand Down Expand Up @@ -464,7 +543,7 @@ function M.assert_str_matches(value, pattern, start, final, message)
end

local function _assert_error_msg_equals(stripFileAndLine, expectedMsg, func, ...)
local no_error, error_msg = pcall(func, ...)
local no_error, error_msg = pcall_check_trace(3, func, ...)
if no_error then
local failure_message = string.format(
'Function successfully returned: %s\nExpected error: %s',
Expand Down Expand Up @@ -530,7 +609,7 @@ end
-- @func fn
-- @param ... arguments for function
function M.assert_error_msg_contains(expected_partial, fn, ...)
local no_error, error_msg = pcall(fn, ...)
local no_error, error_msg = pcall_check_trace(2, fn, ...)
log.info('Assert error message %s contains %s', error_msg, expected_partial)
if no_error then
local failure_message = string.format(
Expand All @@ -553,7 +632,7 @@ end
-- @func fn
-- @param ... arguments for function
function M.assert_error_msg_matches(pattern, fn, ...)
local no_error, error_msg = pcall(fn, ...)
local no_error, error_msg = pcall_check_trace(2, fn, ...)
if no_error then
local failure_message = string.format(
'Function successfully returned: %s\nExpected error matching: %s',
Expand All @@ -578,7 +657,7 @@ end
-- @func fn
-- @param ... arguments for function
function M.assert_error_covers(expected, fn, ...)
local ok, actual = pcall(fn, ...)
local ok, actual = pcall_check_trace(2, fn, ...)
if ok then
fail_fmt(2, nil,
'Function successfully returned: %s\nExpected error: %s',
Expand Down
5 changes: 5 additions & 0 deletions luatest/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,9 @@ function utils.is_tarantool_binary(path)
return path:find('^.*/tarantool[^/]*$') ~= nil
end

-- Return args as table with 'n' set to args number.
function utils.table_pack(...)
return {n = select('#', ...), ...}
end

return utils
75 changes: 75 additions & 0 deletions test/luaunit/assertions_error_test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ local g = t.group()
local helper = require('test.helpers.general')
local assert_failure = helper.assert_failure
local assert_failure_equals = helper.assert_failure_equals
local assert_failure_contains = helper.assert_failure_contains

local function f()
end
Expand All @@ -17,6 +18,28 @@ local function f_with_table_error()
error(setmetatable({this_table="has error"}, ts))
end

local f_check_trace = function(level)
box.error(box.error.UNKNOWN, level)
end

local line = debug.getinfo(1, 'l').currentline + 2
local f_check_trace_wrapper = function()
f_check_trace(2)
end

local _, err = pcall(f_check_trace_wrapper)
local box_error_has_level = err:unpack().trace[1].line == line

local f_check_success = function()
return {1, 'foo'}
end

local THIS_MODULE = debug.getinfo(1, 'S').short_src

g.after_each(function()
t.private.check_trace_module = nil
end)

function g.test_assert_error()
local x = 1

Expand Down Expand Up @@ -51,6 +74,10 @@ function g.test_assert_error()
-- error generated as table
t.assert_error(f_with_table_error, 1)

-- test assert failure due to unexpected error trace
t.private.check_trace_module = THIS_MODULE
assert_failure_contains('Unexpected error trace, expected:',
t.assert_error, f_check_trace, 1)
end

function g.test_assert_errorMsgContains()
Expand All @@ -64,6 +91,12 @@ function g.test_assert_errorMsgContains()

-- error message is a table which converts to a string
t.assert_error_msg_contains('This table has error', f_with_table_error, 1)

-- test assert failure due to unexpected error trace
t.private.check_trace_module = THIS_MODULE
assert_failure_contains('Unexpected error trace, expected:',
t.assert_error_msg_contains, 'bar', f_check_trace,
1)
end

function g.test_assert_error_msg_equals()
Expand Down Expand Up @@ -103,6 +136,11 @@ function g.test_assert_error_msg_equals()

-- expected table, error generated as string, no match
assert_failure(t.assert_error_msg_equals, {1}, function() error("{1}") end, 33)

-- test assert failure due to unexpected error trace
t.private.check_trace_module = THIS_MODULE
assert_failure_contains('Unexpected error trace, expected:',
t.assert_error_msg_equals, 'bar', f_check_trace, 1)
end

function g.test_assert_errorMsgMatches()
Expand All @@ -117,6 +155,11 @@ function g.test_assert_errorMsgMatches()
-- one space added to cause failure
assert_failure(t.assert_error_msg_matches, ' This is an error', f_with_error, x)
assert_failure(t.assert_error_msg_matches, "This", f_with_table_error, 33)

-- test assert failure due to unexpected error trace
t.private.check_trace_module = THIS_MODULE
assert_failure_contains('Unexpected error trace, expected:',
t.assert_error_msg_matches, 'bar', f_check_trace, 1)
end

function g.test_assert_errorCovers()
Expand All @@ -140,4 +183,36 @@ function g.test_assert_errorCovers()
-- bad error coverage
assert_failure(t.assert_error_covers, {b = 2},
function(a, b) error({a = a, b = b}) end, 1, 3)

-- test assert failure due to unexpected error trace
t.private.check_trace_module = THIS_MODULE
assert_failure_contains('Unexpected error trace, expected:',
t.assert_error_covers, 'bar', f_check_trace, 1)
end

function g.test_error_trace_check()
local foo = function(a) error(a) end
-- test when trace check is NOT required
t.assert_error_msg_content_equals('foo', foo, 'foo')

local ftor = setmetatable({}, {
__call = function(_, ...) return f_check_trace(...) end
})
t.private.check_trace_module = THIS_MODULE

-- test when trace check IS required
if box_error_has_level then
t.assert_error_covers({code = box.error.UNKNOWN}, f_check_trace, 2)
t.assert_error_covers({code = box.error.UNKNOWN}, ftor, 2)
end

-- check if there is no error then the returned value is reported correctly
assert_failure_contains('Function successfully returned: {1, "foo"}',
t.assert_error_msg_equals, 'bar', f_check_success)
-- test assert failure due to unexpected error type
assert_failure_contains('Error raised is not a box.error:',
t.assert_error, foo, 'foo')
-- test assert failure due to unexpected error trace
assert_failure_contains('Unexpected error trace, expected:',
t.assert_error, f_check_trace, 1)
end
8 changes: 8 additions & 0 deletions test/utils_test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,11 @@ g.test_is_tarantool_binary = function()
("Unexpected result for %q"):format(path))
end
end

g.test_table_pack = function()
t.assert_equals(utils.table_pack(), {n = 0})
t.assert_equals(utils.table_pack(1), {n = 1, 1})
t.assert_equals(utils.table_pack(1, 2), {n = 2, 1, 2})
t.assert_equals(utils.table_pack(1, 2, nil), {n = 3, 1, 2})
t.assert_equals(utils.table_pack(1, 2, nil, 3), {n = 4, 1, 2, nil, 3})
end

0 comments on commit b8b16d2

Please sign in to comment.