Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generated tests with anonymous unions constructed from bytes #624

Merged
merged 5 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions server/src/Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,18 @@ std::shared_ptr<StructValueView> KTestObjectParser::structView(const std::vector
size_t structEndOffset = offsetInBits + curStruct.size;
size_t fieldIndex = 0;
bool dirtyInitializedStruct = false;
bool isInitializedStruct = curStruct.subType == types::SubType::Struct;
for (const auto &field: curStruct.fields) {
bool dirtyInitializedField = false;
bool isInitializedField = true;
size_t fieldLen = typesHandler.typeSize(field.type);
size_t fieldStartOffset = offsetInBits + field.offset;
size_t fieldEndOffset = fieldStartOffset + fieldLen;
if (curStruct.subType == types::SubType::Union) {
prevFieldEndOffset = offsetInBits;
}

auto dirtyCheck = [&](int i) {
auto dirtyCheck = [&](size_t i) {
if (i >= byteArray.size()) {
LOG_S(ERROR) << "Bad type size info: " << field.name << " index: " << fieldIndex;
} else if (byteArray[i] == 0) {
Expand All @@ -302,15 +304,16 @@ std::shared_ptr<StructValueView> KTestObjectParser::structView(const std::vector

if (prevFieldEndOffset < fieldStartOffset) {
// check an alignment gap
for (int i = prevFieldEndOffset/8; i < fieldStartOffset/8; ++i) {
for (size_t i = prevFieldEndOffset / 8; i < fieldStartOffset / 8; ++i) {
if (dirtyCheck(i)) {
break;
}
}
}
if (!dirtyInitializedField && curStruct.subType == types::SubType::Union) {
// check the rest of the union
for (int i = fieldEndOffset/8; i < structEndOffset/8; ++i) {
if (!dirtyInitializedField && (curStruct.subType == types::SubType::Union ||
fieldIndex + 1 == curStruct.fields.size())) {
// check the rest of the union or the last field of the struct
for (size_t i = fieldEndOffset / 8; i < structEndOffset / 8; ++i) {
if (dirtyCheck(i)) {
break;
}
Expand All @@ -325,6 +328,7 @@ std::shared_ptr<StructValueView> KTestObjectParser::structView(const std::vector
PrinterUtils::getFieldAccess(name, field), objects,
initReferences);
dirtyInitializedField |= sv->isDirtyInit();
isInitializedField = sv->isInitialized();
subViews.push_back(sv);
}
break;
Expand Down Expand Up @@ -392,34 +396,38 @@ std::shared_ptr<StructValueView> KTestObjectParser::structView(const std::vector
throw NoSuchTypeException(message);
}

if (!dirtyInitializedField && sizeOfFieldToInitUnion < fieldLen) {
if (!dirtyInitializedField && sizeOfFieldToInitUnion < fieldLen &&
curStruct.subType == types::SubType::Union) {
fieldIndexToInitUnion = fieldIndex;
sizeOfFieldToInitUnion = fieldLen;
} else {
dirtyInitializedStruct = true;
isInitializedStruct = true;
dirtyInitializedStruct = false;
}
if (curStruct.subType == types::SubType::Struct) {
dirtyInitializedStruct |= dirtyInitializedField;
isInitializedStruct &= isInitializedField;
}
prevFieldEndOffset = fieldEndOffset;
++fieldIndex;
}

std::optional<std::string> entryValue;
if (curStruct.subType == types::SubType::Union) {
if (fieldIndexToInitUnion == SIZE_MAX && !curStruct.name.empty()) {
// init by memory copy
entryValue = PrinterUtils::convertBytesToUnion(
curStruct.name,
arrayView(byteArray, lazyPointersArray,
types::Type::createSimpleTypeFromName("utbot_byte"),
curStruct.size,
offsetInBits, usage)->getEntryValue(nullptr));
dirtyInitializedStruct = false;
}
if (fieldIndexToInitUnion != SIZE_MAX) {
dirtyInitializedStruct = false;
}
if (!isInitializedStruct && !curStruct.name.empty() && !anonymousField) {
// init by memory copy
entryValue = PrinterUtils::convertBytesToStruct(
curStruct.name,
arrayView(byteArray, lazyPointersArray,
types::Type::createSimpleTypeFromName("utbot_byte"),
curStruct.size,
offsetInBits, usage)->getEntryValue(nullptr));
isInitializedStruct = true;
dirtyInitializedStruct = false;
}
if (!isInitializedStruct) {
dirtyInitializedStruct = false;
}
return std::make_shared<StructValueView>(curStruct, subViews, entryValue,
anonymousField, dirtyInitializedStruct, fieldIndexToInitUnion);
anonymousField, isInitializedStruct, dirtyInitializedStruct, fieldIndexToInitUnion);
}

std::string KTestObjectParser::primitiveCharView(const types::Type &type, std::string value) {
Expand Down
7 changes: 7 additions & 0 deletions server/src/Tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,21 @@ namespace tests {
std::vector<std::shared_ptr<AbstractValueView>> _subViews,
std::optional<std::string> _entryValue,
bool _anonymous,
bool _isInit,
bool _dirtyInit,
size_t _fieldIndexToInitUnion)
: AbstractValueView(std::move(_subViews))
, entryValue(std::move(_entryValue))
, structInfo(_structInfo)
, anonymous(_anonymous)
, isInit(_isInit)
, dirtyInit(_dirtyInit)
, fieldIndexToInitUnion(_fieldIndexToInitUnion){}

bool isInitialized() const {
return isInit;
}

bool isDirtyInit() const {
return dirtyInit;
}
Expand Down Expand Up @@ -317,6 +323,7 @@ namespace tests {
std::optional<std::string> entryValue;

bool anonymous;
bool isInit;
bool dirtyInit;
size_t fieldIndexToInitUnion;
};
Expand Down
2 changes: 1 addition & 1 deletion server/src/utils/PrinterUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace PrinterUtils {
std::string convertToBytesFunctionName(const std::string &typeName) {
return StringUtils::stringFormat("from_bytes<%s>", typeName);
}
std::string convertBytesToUnion(const std::string &typeName, const std::string &bytes) {
std::string convertBytesToStruct(const std::string &typeName, const std::string &bytes) {
return StringUtils::stringFormat("%s(%s)", convertToBytesFunctionName(typeName), bytes);
}

Expand Down
2 changes: 1 addition & 1 deletion server/src/utils/PrinterUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ namespace PrinterUtils {

std::string convertToBytesFunctionName(std::string const &typeName);

std::string convertBytesToUnion(const std::string &typeName, const std::string &bytes);
std::string convertBytesToStruct(const std::string &typeName, const std::string &bytes);

std::string wrapperName(const std::string &declName,
utbot::ProjectContext const &projectContext,
Expand Down
3 changes: 2 additions & 1 deletion server/src/visitors/AbstractValueViewVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ namespace visitor {
auto subViews = view ? &view->getSubViews() : nullptr;

bool oldFlag = inUnion;
inUnion = structInfo.subType == types::SubType::Union;
inUnion |= structInfo.subType == types::SubType::Union;
for (int i = 0; i < structInfo.fields.size(); ++i) {
auto const &field = structInfo.fields[i];
auto newName = PrinterUtils::getFieldAccess(name, field);
auto const *newView = (subViews && i < subViews->size()) ? (*subViews)[i].get() : nullptr;
auto newAccess = PrinterUtils::getFieldAccess(access, field);
visitAny(field.type, newName, newView, newAccess, depth + 1);
}
inUnion = oldFlag;
}

void AbstractValueViewVisitor::visitEnum(const types::Type &type,
Expand Down
2 changes: 1 addition & 1 deletion server/src/visitors/VerboseAssertsReturnValueVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace visitor {
auto signature = processExpect(type, gtestMacro, {PrinterUtils::fillVarName(access, PrinterUtils::EXPECTED),
getDecorateActualVarName(access)});
signature = changeSignatureToNullCheck(signature, type, view, access);
printer->strFunctionCall(signature.name, signature.args);
printer->strFunctionCall(signature.name, signature.args, SCNL, std::nullopt, true, 0, std::nullopt, inUnion);
}

void VerboseAssertsReturnValueVisitor::visitPointer(const types::Type &type,
Expand Down
34 changes: 2 additions & 32 deletions server/test/framework/Server_Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,35 +456,6 @@ namespace {
}
}

void checkStructWithUnion_C(BaseTestGen &testGen) {
for (const auto &[methodName, methodDescription] :
testGen.tests.at(struct_with_union_c).methods) {
if (methodName == "struct_with_union_of_unnamed_type_as_return_type") {
checkTestCasePredicates(
methodDescription.testCases,
std::vector<TestCasePredicate>(
{[] (const tests::Tests::MethodTestCase& testCase) {
return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) <
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
testCase.returnValue.view->getEntryValue(nullptr) == "{{{'\\x99', -2.530171e-98}}}";
},
[] (const tests::Tests::MethodTestCase& testCase) {
return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) ==
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
StringUtils::startsWith(testCase.returnValue.view->getEntryValue(nullptr),
"{from_bytes<StructWithUnionOfUnnamedType_un>({");
},
[] (const tests::Tests::MethodTestCase& testCase) {
return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) >
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
testCase.returnValue.view->getEntryValue(nullptr) == "{{{'\\0', -2.530171e-98}}}";
}
}),
methodName);
}
}
}

void checkInnerBasicFunctions_C(BaseTestGen &testGen) {
for (const auto &[methodName, methodDescription] :
testGen.tests.at(inner_basic_functions_c).methods) {
Expand Down Expand Up @@ -2187,8 +2158,7 @@ namespace {
auto testGen = FileTestGen(*request, writer.get(), TESTMODE);
Status status = Server::TestsGenServiceImpl::ProcessBaseTestRequest(testGen, writer.get());
ASSERT_TRUE(status.ok()) << status.error_message();
EXPECT_GE(testUtils::getNumberOfTests(testGen.tests), 3);
checkStructWithUnion_C(testGen);
EXPECT_GE(testUtils::getNumberOfTests(testGen.tests), 6);

fs::path testsDirPath = getTestFilePath("tests");

Expand All @@ -2214,7 +2184,7 @@ namespace {
auto resultsMap = coverageGenerator.getTestResultMap();
auto tests = coverageGenerator.getTestsToLaunch();

StatusCountMap expectedStatusCountMap{ { testsgen::TEST_PASSED, 3 } };
StatusCountMap expectedStatusCountMap{ { testsgen::TEST_PASSED, 6 } };
testUtils::checkStatuses(resultsMap, tests);
}

Expand Down
8 changes: 4 additions & 4 deletions server/test/framework/Syntax_Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1882,7 +1882,7 @@ namespace {
return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) ==
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
StringUtils::startsWith(testCase.returnValue.view->getEntryValue(nullptr),
"{from_bytes<StructWithStructInUnion::DeepUnion>({");
"{from_bytes<StructWithStructInUnion::DeepUnion>({");;
},
[] (const tests::Tests::MethodTestCase& testCase) {
return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) >
Expand All @@ -1903,18 +1903,18 @@ namespace {
std::vector<TestCasePredicate>(
{[] (const tests::Tests::MethodTestCase& testCase) {
return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) <
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
testCase.returnValue.view->getEntryValue(nullptr) == "{{{'\\x99', -2.530171e-98}}}";
},
[] (const tests::Tests::MethodTestCase& testCase) {
return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) ==
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
StringUtils::startsWith(testCase.returnValue.view->getEntryValue(nullptr),
"{from_bytes<StructWithUnionOfUnnamedType_un>({");
},
[] (const tests::Tests::MethodTestCase& testCase) {
return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) >
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) &&
testCase.returnValue.view->getEntryValue(nullptr) == "{{{'\\0', -2.530171e-98}}}";
}
})
Expand Down
13 changes: 13 additions & 0 deletions server/test/suites/server/struct_with_union.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,17 @@ struct StructWithUnionOfUnnamedType struct_with_union_of_unnamed_type_as_return_
ans.un.ds.d = 1.0101;
}
return ans;
}

struct StructWithAnonymousUnion struct_with_anonymous_union_as_return_type(int a, int b) {
struct StructWithAnonymousUnion ans;
if (a > b) {
ans.ptr = 0;
} else if (a < b) {
ans.x = 153;
} else {
ans.c = 'k';
ans.d = 1.0101;
}
return ans;
}
13 changes: 13 additions & 0 deletions server/test/suites/server/struct_with_union.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ struct StructWithUnionOfUnnamedType {
} un;
};

struct StructWithAnonymousUnion {
union {
int x;
struct {
char c;
double d;
};
long long *ptr;
};
};

struct StructWithUnionOfUnnamedType struct_with_union_of_unnamed_type_as_return_type(int a, int b);

struct StructWithAnonymousUnion struct_with_anonymous_union_as_return_type(int a, int b);

#endif // SIMPLE_TEST_PROJECT_STRUCT_WITH_UNION_H
Loading