diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 4ce47b1c05d9b0..b1a23996c7bdd2 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2965,21 +2965,25 @@ class OpenMPIRBuilder { /// \param NumThreads Number of teams specified in the thread_limit clause. /// \param Inputs The input values to the region that will be passed. /// as arguments to the outlined function. + /// \param IfCond value of the `if` clause. /// \param BodyGenCB Callback that will generate the region code. /// \param ArgAccessorFuncCB Callback that will generate accessors /// instructions for passed in target arguments where neccessary /// \param Dependencies A vector of DependData objects that carry - // dependency information as passed in the depend clause - // \param HasNowait Whether the target construct has a `nowait` clause or not. - InsertPointOrErrorTy createTarget( - const LocationDescription &Loc, bool IsOffloadEntry, - OpenMPIRBuilder::InsertPointTy AllocaIP, - OpenMPIRBuilder::InsertPointTy CodeGenIP, - TargetRegionEntryInfo &EntryInfo, ArrayRef NumTeams, - ArrayRef NumThreads, SmallVectorImpl &Inputs, - GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, - TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, - SmallVector Dependencies = {}, bool HasNowait = false); + /// dependency information as passed in the depend clause + /// \param HasNowait Whether the target construct has a `nowait` clause or + /// not. + InsertPointOrErrorTy + createTarget(const LocationDescription &Loc, bool IsOffloadEntry, + OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP, + TargetRegionEntryInfo &EntryInfo, ArrayRef NumTeams, + ArrayRef NumThreads, SmallVectorImpl &Inputs, + Value *IfCond, GenMapInfoCallbackTy GenMapInfoCB, + TargetBodyGenCallbackTy BodyGenCB, + TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, + SmallVector Dependencies = {}, + bool HasNowait = false); /// Returns __kmpc_for_static_init_* runtime function for the specified /// size \a IVSize and sign \a IVSigned. Will create a distribute call diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index db77c6a5869764..0e190f4c64a8b3 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -7310,6 +7310,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn, Constant *OutlinedFnID, ArrayRef NumTeams, ArrayRef NumThreads, SmallVectorImpl &Args, + Value *IfCond, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, SmallVector Dependencies = {}, bool HasNoWait = false) { @@ -7354,9 +7355,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, return Error::success(); }; - // If we don't have an ID for the target region, it means an offload entry - // wasn't created. In this case we just run the host fallback directly. - if (!OutlinedFnID) { + auto &&EmitTargetCallElse = + [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { // Assume no error was returned because EmitTargetCallFallbackCB doesn't // produce any. OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { @@ -7372,65 +7373,87 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, }()); Builder.restoreIP(AfterIP); - return; - } + return Error::success(); + }; + + auto &&EmitTargetCallThen = + [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { + OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); + OpenMPIRBuilder::TargetDataRTArgs RTArgs; + OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, + RTArgs, MapInfo, + /*IsNonContiguous=*/true, + /*ForEndCall=*/false); + + SmallVector NumTeamsC; + SmallVector NumThreadsC; + for (auto V : NumTeams) + NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + for (auto V : NumThreads) + NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + + unsigned NumTargetItems = Info.NumberOfPtrs; + // TODO: Use correct device ID + Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); + uint32_t SrcLocStrSize; + Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); + Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, + llvm::omp::IdentFlag(0), 0); + // TODO: Use correct NumIterations + Value *NumIterations = Builder.getInt64(0); + // TODO: Use correct DynCGGroupMem + Value *DynCGGroupMem = Builder.getInt32(0); + + KArgs = OpenMPIRBuilder::TargetKernelArgs( + NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC, + DynCGGroupMem, HasNoWait); + + // Assume no error was returned because TaskBodyCB and + // EmitTargetCallFallbackCB don't produce any. + OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { + // The presence of certain clauses on the target directive require the + // explicit generation of the target task. + if (RequiresOuterTargetTask) + return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP, + Dependencies, HasNoWait); + + return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID, + EmitTargetCallFallbackCB, KArgs, + DeviceID, RTLoc, AllocaIP); + }()); - OpenMPIRBuilder::TargetDataInfo Info( - /*RequiresDevicePointerInfo=*/false, - /*SeparateBeginEndCalls=*/true); + Builder.restoreIP(AfterIP); + return Error::success(); + }; - OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); - OpenMPIRBuilder::TargetDataRTArgs RTArgs; - OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, - RTArgs, MapInfo, - /*IsNonContiguous=*/true, - /*ForEndCall=*/false); + // If we don't have an ID for the target region, it means an offload entry + // wasn't created. In this case we just run the host fallback directly. + if (!OutlinedFnID) { + cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP())); + return; + } - SmallVector NumTeamsC; - SmallVector NumThreadsC; - for (auto V : NumTeams) - NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); - for (auto V : NumThreads) - NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + // If there's no IF clause, only generate the kernel launch code path. + if (!IfCond) { + cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP())); + return; + } - unsigned NumTargetItems = Info.NumberOfPtrs; - // TODO: Use correct device ID - Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); - uint32_t SrcLocStrSize; - Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); - Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, - llvm::omp::IdentFlag(0), 0); - // TODO: Use correct NumIterations - Value *NumIterations = Builder.getInt64(0); - // TODO: Use correct DynCGGroupMem - Value *DynCGGroupMem = Builder.getInt32(0); - - KArgs = OpenMPIRBuilder::TargetKernelArgs( - NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC, - DynCGGroupMem, HasNoWait); - - // Assume no error was returned because TaskBodyCB and - // EmitTargetCallFallbackCB don't produce any. - OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { - // The presence of certain clauses on the target directive require the - // explicit generation of the target task. - if (RequiresOuterTargetTask) - return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP, - Dependencies, HasNoWait); - - return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID, - EmitTargetCallFallbackCB, KArgs, - DeviceID, RTLoc, AllocaIP); - }()); - - Builder.restoreIP(AfterIP); + cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen, + EmitTargetCallElse, AllocaIP)); } OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, ArrayRef NumTeams, ArrayRef NumThreads, - SmallVectorImpl &Args, GenMapInfoCallbackTy GenMapInfoCB, + SmallVectorImpl &Args, Value *IfCond, + GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, SmallVector Dependencies, bool HasNowait) { @@ -7455,7 +7478,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // that represents the target region. Do that now. if (!Config.isTargetDevice()) emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams, - NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait); + NumThreads, Args, IfCond, GenMapInfoCB, Dependencies, + HasNowait); return Builder.saveIP(); } diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index cdca725b147436..94dce5243d7004 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6232,7 +6232,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, -1, 0, Inputs, + Builder.saveIP(), EntryInfo, /*NumTeams=*/-1, + /*NumThreads=*/0, Inputs, /*IfCond=*/nullptr, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); @@ -6343,8 +6344,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, /*NumTeams=*/-1, - /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, + EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0, + CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); @@ -6500,8 +6501,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, /*NumTeams=*/-1, - /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, + EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0, + CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index a364098e0bd8a6..0c637bd32ab3f3 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -285,7 +285,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkBare(op, result); checkDevice(op, result); checkHasDeviceAddr(op, result); - checkIf(op, result); checkInReduction(op, result); checkIsDevicePtr(op, result); // Privatization clauses are supported, except on some situations, so we @@ -4112,11 +4111,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::Value *ifCond = nullptr; + if (Value targetIfCond = targetOp.getIfExpr()) + ifCond = moduleTranslation.lookupValue(targetIfCond); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, - defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB, - argAccessorCB, dds, targetOp.getNowait()); + defaultValTeams, defaultValThreads, kernelInput, ifCond, genMapInfoCB, + bodyCB, argAccessorCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) return failure(); diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir new file mode 100644 index 00000000000000..706ad4411438ba --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @target_if_variable(%x : i1) { + omp.target if(%x) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_variable( + // CHECK-SAME: i1 %[[IF_COND:.*]]) + // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]] + + // CHECK: [[THEN_LABEL]]: + // CHECK-NOT: {{^.*}}: + // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel + // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0 + // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]] + + // CHECK: [[OFFLOAD_FAIL_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]() + // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]] + + // CHECK: [[OFFLOAD_CONT_LABEL]]: + // CHECK-NEXT: br label %[[END_LABEL:.*]] + + // CHECK: [[ELSE_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN]]() + // CHECK-NEXT: br label %[[END_LABEL]] + + llvm.func @target_if_true() { + %0 = llvm.mlir.constant(true) : i1 + omp.target if(%0) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_true() + // CHECK-NOT: {{^.*}}: + // CHECK: br label %[[ENTRY:.*]] + + // CHECK: [[ENTRY]]: + // CHECK-NOT: {{^.*}}: + // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel + // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0 + // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]] + + // CHECK: [[OFFLOAD_FAIL_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]() + // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]] + + llvm.func @target_if_false() { + %0 = llvm.mlir.constant(false) : i1 + omp.target if(%0) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_false() + // CHECK-NEXT: br label %[[ENTRY:.*]] + + // CHECK: [[ENTRY]]: + // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}() +} + diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 83a0990d631620..4e0925c833c3b7 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -266,17 +266,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) { // ----- -llvm.func @target_if(%x : i1) { - // expected-error@below {{not yet implemented: Unhandled clause if in omp.target operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.target}} - omp.target if(%x) { - omp.terminator - } - llvm.return -} - -// ----- - omp.declare_reduction @add_f32 : f32 init { ^bb0(%arg: f32):