Skip to content

Commit

Permalink
[OMPIRBuilder][MLIR] Add support for target 'if' clause
Browse files Browse the repository at this point in the history
This patch implements support for handling the 'if' clause of OpenMP 'target'
constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the
`omp.target` MLIR operation to make use of this new feature.
  • Loading branch information
skatrak committed Jan 10, 2025
1 parent cef1269 commit 3aea11b
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 83 deletions.
26 changes: 15 additions & 11 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2965,21 +2965,25 @@ class OpenMPIRBuilder {
/// \param NumThreads Number of teams specified in the thread_limit clause.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param IfCond value of the `if` clause.
/// \param BodyGenCB Callback that will generate the region code.
/// \param ArgAccessorFuncCB Callback that will generate accessors
/// instructions for passed in target arguments where neccessary
/// \param Dependencies A vector of DependData objects that carry
// dependency information as passed in the depend clause
// \param HasNowait Whether the target construct has a `nowait` clause or not.
InsertPointOrErrorTy createTarget(
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
/// dependency information as passed in the depend clause
/// \param HasNowait Whether the target construct has a `nowait` clause or
/// not.
InsertPointOrErrorTy
createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
Value *IfCond, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies = {},
bool HasNowait = false);

/// Returns __kmpc_for_static_init_* runtime function for the specified
/// size \a IVSize and sign \a IVSigned. Will create a distribute call
Expand Down
130 changes: 77 additions & 53 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7310,6 +7310,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
Value *IfCond,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
bool HasNoWait = false) {
Expand Down Expand Up @@ -7354,9 +7355,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
return Error::success();
};

// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
auto &&EmitTargetCallElse =
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
// produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
Expand All @@ -7372,65 +7373,87 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}());

Builder.restoreIP(AfterIP);
return;
}
return Error::success();
};

auto &&EmitTargetCallThen =
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
RTArgs, MapInfo,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);

SmallVector<Value *, 3> NumTeamsC;
SmallVector<Value *, 3> NumThreadsC;
for (auto V : NumTeams)
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
for (auto V : NumThreads)
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));

unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
// TODO: Use correct NumIterations
Value *NumIterations = Builder.getInt64(0);
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);

KArgs = OpenMPIRBuilder::TargetKernelArgs(
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);

// Assume no error was returned because TaskBodyCB and
// EmitTargetCallFallbackCB don't produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask)
return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
Dependencies, HasNoWait);

return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP);
}());

OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);
Builder.restoreIP(AfterIP);
return Error::success();
};

OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
RTArgs, MapInfo,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);
// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
return;
}

SmallVector<Value *, 3> NumTeamsC;
SmallVector<Value *, 3> NumThreadsC;
for (auto V : NumTeams)
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
for (auto V : NumThreads)
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
// If there's no IF clause, only generate the kernel launch code path.
if (!IfCond) {
cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
return;
}

unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
// TODO: Use correct NumIterations
Value *NumIterations = Builder.getInt64(0);
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);

KArgs = OpenMPIRBuilder::TargetKernelArgs(
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);

// Assume no error was returned because TaskBodyCB and
// EmitTargetCallFallbackCB don't produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask)
return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
Dependencies, HasNoWait);

return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP);
}());

Builder.restoreIP(AfterIP);
cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
EmitTargetCallElse, AllocaIP));
}

OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
SmallVectorImpl<Value *> &Args, Value *IfCond,
GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
Expand All @@ -7455,7 +7478,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
NumThreads, Args, IfCond, GenMapInfoCB, Dependencies,
HasNowait);
return Builder.saveIP();
}

Expand Down
11 changes: 6 additions & 5 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6232,7 +6232,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
Builder.saveIP(), EntryInfo, -1, 0, Inputs,
Builder.saveIP(), EntryInfo, /*NumTeams=*/-1,
/*NumThreads=*/0, Inputs, /*IfCond=*/nullptr,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
OMPBuilder.finalize();
Expand Down Expand Up @@ -6343,8 +6344,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, /*NumTeams=*/-1,
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);

Expand Down Expand Up @@ -6500,8 +6501,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, /*NumTeams=*/-1,
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkBare(op, result);
checkDevice(op, result);
checkHasDeviceAddr(op, result);
checkIf(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
// Privatization clauses are supported, except on some situations, so we
Expand Down Expand Up @@ -4112,11 +4111,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);

llvm::Value *ifCond = nullptr;
if (Value targetIfCond = targetOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(targetIfCond);

llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
argAccessorCB, dds, targetOp.getNowait());
defaultValTeams, defaultValThreads, kernelInput, ifCond, genMapInfoCB,
bodyCB, argAccessorCB, dds, targetOp.getNowait());

if (failed(handleError(afterIP, opInst)))
return failure();
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Target/LLVMIR/omptarget-if.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
llvm.func @target_if_variable(%x : i1) {
omp.target if(%x) {
omp.terminator
}
llvm.return
}

// CHECK-LABEL: define void @target_if_variable(
// CHECK-SAME: i1 %[[IF_COND:.*]])
// CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]

// CHECK: [[THEN_LABEL]]:
// CHECK-NOT: {{^.*}}:
// CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
// CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
// CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]

// CHECK: [[OFFLOAD_FAIL_LABEL]]:
// CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
// CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]

// CHECK: [[OFFLOAD_CONT_LABEL]]:
// CHECK-NEXT: br label %[[END_LABEL:.*]]

// CHECK: [[ELSE_LABEL]]:
// CHECK-NEXT: call void @[[FALLBACK_FN]]()
// CHECK-NEXT: br label %[[END_LABEL]]

llvm.func @target_if_true() {
%0 = llvm.mlir.constant(true) : i1
omp.target if(%0) {
omp.terminator
}
llvm.return
}

// CHECK-LABEL: define void @target_if_true()
// CHECK-NOT: {{^.*}}:
// CHECK: br label %[[ENTRY:.*]]

// CHECK: [[ENTRY]]:
// CHECK-NOT: {{^.*}}:
// CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
// CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
// CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]

// CHECK: [[OFFLOAD_FAIL_LABEL]]:
// CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
// CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]

llvm.func @target_if_false() {
%0 = llvm.mlir.constant(false) : i1
omp.target if(%0) {
omp.terminator
}
llvm.return
}

// CHECK-LABEL: define void @target_if_false()
// CHECK-NEXT: br label %[[ENTRY:.*]]

// CHECK: [[ENTRY]]:
// CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
}

11 changes: 0 additions & 11 deletions mlir/test/Target/LLVMIR/openmp-todo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {

// -----

llvm.func @target_if(%x : i1) {
// expected-error@below {{not yet implemented: Unhandled clause if in omp.target operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
omp.target if(%x) {
omp.terminator
}
llvm.return
}

// -----

omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
Expand Down

0 comments on commit 3aea11b

Please sign in to comment.