Skip to content

Commit

Permalink
[CIR][ABI][Lowering] Fixes calls with union type (#1119)
Browse files Browse the repository at this point in the history
This PR handles calls with unions passed by value in the calling
convention pass.

#### Implementation
As one may know, data layout for unions in CIR and in LLVM differ one
from another. In CIR we track all the union members, while in LLVM IR
only the largest one.

And here we need to take this difference into account: we need to find a
type of the largest member and treat it as the first (and only) union
member in order to preserve all the logic from the original codegen.

There is a method `StructType::getLargestMember` - but looks like it
produces different results (with the one I implemented or better to say
copy-pasted). Maybe it's done intentionally, I don't know.

The LLVM IR produced has also some difference from the original one. In
the original IR `gep` is emitted - and we can not do the same. If we
create `getMemberOp` we may fail on type checking for unions - since the
first member type may differ from the largest type. This is why we
create `bitcast` instead. Relates to the issue #1061
  • Loading branch information
gitoleg authored Nov 14, 2024
1 parent 16a027a commit d1ad076
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 21 deletions.
1 change: 0 additions & 1 deletion clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ struct MissingFeatures {
static bool X86TypeClassification() { return false; }

static bool ABIClangTypeKind() { return false; }
static bool ABIEnterStructForCoercedAccess() { return false; }
static bool ABIFuncPtr() { return false; }
static bool ABIInRegAttribute() { return false; }
static bool ABINestedRecordLayout() { return false; }
Expand Down
13 changes: 6 additions & 7 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,12 @@ void StructType::computeSizeAndAlignment(

// Found a nested union: recurse into it to fetch its largest member.
auto structMember = mlir::dyn_cast<StructType>(ty);
if (structMember && structMember.isUnion()) {
auto candidate = structMember.getLargestMember(dataLayout);
if (dataLayout.getTypeSize(candidate) > largestMemberSize) {
largestMember = candidate;
largestMemberSize = dataLayout.getTypeSize(largestMember);
}
} else if (dataLayout.getTypeSize(ty) > largestMemberSize) {
if (!largestMember ||
dataLayout.getTypeABIAlignment(ty) >
dataLayout.getTypeABIAlignment(largestMember) ||
(dataLayout.getTypeABIAlignment(ty) ==
dataLayout.getTypeABIAlignment(largestMember) &&
dataLayout.getTypeSize(ty) > largestMemberSize)) {
largestMember = ty;
largestMemberSize = dataLayout.getTypeSize(largestMember);
}
Expand Down
41 changes: 30 additions & 11 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ mlir::Value createCoercedBitcast(mlir::Value Src, mlir::Type DestTy,
CastKind::bitcast, Src);
}

// FIXME(cir): Create a custom rewriter class to abstract this away.
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
Src);
}

/// Given a struct pointer that we are accessing some number of bytes out of it,
/// try to gep into the struct to get at its inner goodness. Dive as deep as
/// possible without entering an element with an in-memory size smaller than
Expand All @@ -67,6 +73,9 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr,

mlir::Type FirstElt = SrcSTy.getMembers()[0];

if (SrcSTy.isUnion())
FirstElt = SrcSTy.getLargestMember(CGF.LM.getDataLayout().layout);

// If the first elt is at least as large as what we're looking for, or if the
// first element is the same size as the whole struct, we can enter it. The
// comparison must be made on the store size and not the alloca size. Using
Expand All @@ -76,10 +85,26 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr,
FirstEltSize < CGF.LM.getDataLayout().getTypeStoreSize(SrcSTy))
return SrcPtr;

cir_cconv_assert_or_abort(
!cir::MissingFeatures::ABIEnterStructForCoercedAccess(), "NYI");
return SrcPtr; // FIXME: This is a temporary workaround for the assertion
// above.
auto &rw = CGF.getRewriter();
auto *ctxt = rw.getContext();
auto ptrTy = PointerType::get(ctxt, FirstElt);
if (mlir::isa<StructType>(SrcPtr.getType())) {
auto addr = SrcPtr;
if (auto load = mlir::dyn_cast<LoadOp>(SrcPtr.getDefiningOp()))
addr = load.getAddr();
cir_cconv_assert(mlir::isa<PointerType>(addr.getType()));
// we can not use getMemberOp here since we need a pointer to the first
// element. And in the case of unions we pick a type of the largest elt,
// that may or may not be the first one. Thus, getMemberOp verification
// may fail.
auto cast = createBitcast(addr, ptrTy, CGF);
SrcPtr = rw.create<LoadOp>(SrcPtr.getLoc(), cast);
}

if (auto sty = mlir::dyn_cast<StructType>(SrcPtr.getType()))
return enterStructPointerForCoercedAccess(SrcPtr, sty, DstSize, CGF);

return SrcPtr;
}

/// Convert a value Val to the specific Ty where both
Expand Down Expand Up @@ -141,12 +166,6 @@ static mlir::Value coerceIntOrPtrToIntOrPtr(mlir::Value val, mlir::Type typ,
return val;
}

// FIXME(cir): Create a custom rewriter class to abstract this away.
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
Src);
}

AllocaOp createTmpAlloca(LowerFunction &LF, mlir::Location loc, mlir::Type ty) {
auto &rw = LF.getRewriter();
auto *ctxt = rw.getContext();
Expand Down Expand Up @@ -302,7 +321,7 @@ mlir::Value createCoercedValue(mlir::Value Src, mlir::Type Ty,
// extension or truncation to the desired type.
if ((mlir::isa<IntType>(Ty) || mlir::isa<PointerType>(Ty)) &&
(mlir::isa<IntType>(SrcTy) || mlir::isa<PointerType>(SrcTy))) {
cir_cconv_unreachable("NYI");
return coerceIntOrPtrToIntOrPtr(Src, Ty, CGF);
}

// If load is legal, just bitcast the src pointer.
Expand Down
32 changes: 31 additions & 1 deletion clang/test/CIR/CallConvLowering/AArch64/union.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,34 @@ void foo(U u) {}
U init() {
U u;
return u;
}
}

typedef union {

struct {
short a;
char b;
char c;
};

int x;
} A;

void passA(A x) {}

// CIR: cir.func {{.*@callA}}()
// CIR: %[[#V0:]] = cir.alloca !ty_A, !cir.ptr<!ty_A>, ["x"] {alignment = 4 : i64}
// CIR: %[[#V1:]] = cir.cast(bitcast, %[[#V0:]] : !cir.ptr<!ty_A>), !cir.ptr<!s32i>
// CIR: %[[#V2:]] = cir.load %[[#V1]] : !cir.ptr<!s32i>, !s32i
// CIR: %[[#V3:]] = cir.cast(integral, %[[#V2]] : !s32i), !u64i
// CIR: cir.call @passA(%[[#V3]]) : (!u64i) -> ()

// LLVM: void @callA()
// LLVM: %[[#V0:]] = alloca %union.A, i64 1, align 4
// LLVM: %[[#V1:]] = load i32, ptr %[[#V0]], align 4
// LLVM: %[[#V2:]] = sext i32 %[[#V1]] to i64
// LLVM: call void @passA(i64 %[[#V2]])
void callA() {
A x;
passA(x);
}
2 changes: 1 addition & 1 deletion clang/test/CIR/Lowering/unions.cir
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module {
cir.global external @u2 = #cir.zero : !ty_U2_
cir.global external @u3 = #cir.zero : !ty_U3_
// CHECK: llvm.mlir.global external @u2() {addr_space = 0 : i32} : !llvm.struct<"union.U2", (f64)>
// CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (i32)>
// CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (struct<"union.U1", (i32)>)>

// CHECK: llvm.func @test
cir.func @test(%arg0: !cir.ptr<!ty_U1_>) {
Expand Down

0 comments on commit d1ad076

Please sign in to comment.