From 3aea11b2d7857784d442d97925561ea54a5a8095 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Fri, 10 Jan 2025 15:40:05 +0000 Subject: [PATCH] [OMPIRBuilder][MLIR] Add support for target 'if' clause This patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the `omp.target` MLIR operation to make use of this new feature. --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 26 ++-- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 130 +++++++++++------- .../Frontend/OpenMPIRBuilderTest.cpp | 11 +- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 9 +- mlir/test/Target/LLVMIR/omptarget-if.mlir | 68 +++++++++ mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 -- 6 files changed, 172 insertions(+), 83 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/omptarget-if.mlir diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 4ce47b1c05d9b08..b1a23996c7bdd20 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2965,21 +2965,25 @@ class OpenMPIRBuilder { /// \param NumThreads Number of teams specified in the thread_limit clause. /// \param Inputs The input values to the region that will be passed. /// as arguments to the outlined function. + /// \param IfCond value of the `if` clause. /// \param BodyGenCB Callback that will generate the region code. /// \param ArgAccessorFuncCB Callback that will generate accessors /// instructions for passed in target arguments where neccessary /// \param Dependencies A vector of DependData objects that carry - // dependency information as passed in the depend clause - // \param HasNowait Whether the target construct has a `nowait` clause or not. - InsertPointOrErrorTy createTarget( - const LocationDescription &Loc, bool IsOffloadEntry, - OpenMPIRBuilder::InsertPointTy AllocaIP, - OpenMPIRBuilder::InsertPointTy CodeGenIP, - TargetRegionEntryInfo &EntryInfo, ArrayRef NumTeams, - ArrayRef NumThreads, SmallVectorImpl &Inputs, - GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, - TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, - SmallVector Dependencies = {}, bool HasNowait = false); + /// dependency information as passed in the depend clause + /// \param HasNowait Whether the target construct has a `nowait` clause or + /// not. + InsertPointOrErrorTy + createTarget(const LocationDescription &Loc, bool IsOffloadEntry, + OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP, + TargetRegionEntryInfo &EntryInfo, ArrayRef NumTeams, + ArrayRef NumThreads, SmallVectorImpl &Inputs, + Value *IfCond, GenMapInfoCallbackTy GenMapInfoCB, + TargetBodyGenCallbackTy BodyGenCB, + TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, + SmallVector Dependencies = {}, + bool HasNowait = false); /// Returns __kmpc_for_static_init_* runtime function for the specified /// size \a IVSize and sign \a IVSigned. Will create a distribute call diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index db77c6a58697648..0e190f4c64a8b39 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -7310,6 +7310,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn, Constant *OutlinedFnID, ArrayRef NumTeams, ArrayRef NumThreads, SmallVectorImpl &Args, + Value *IfCond, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, SmallVector Dependencies = {}, bool HasNoWait = false) { @@ -7354,9 +7355,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, return Error::success(); }; - // If we don't have an ID for the target region, it means an offload entry - // wasn't created. In this case we just run the host fallback directly. - if (!OutlinedFnID) { + auto &&EmitTargetCallElse = + [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { // Assume no error was returned because EmitTargetCallFallbackCB doesn't // produce any. OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { @@ -7372,65 +7373,87 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, }()); Builder.restoreIP(AfterIP); - return; - } + return Error::success(); + }; + + auto &&EmitTargetCallThen = + [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { + OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); + OpenMPIRBuilder::TargetDataRTArgs RTArgs; + OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, + RTArgs, MapInfo, + /*IsNonContiguous=*/true, + /*ForEndCall=*/false); + + SmallVector NumTeamsC; + SmallVector NumThreadsC; + for (auto V : NumTeams) + NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + for (auto V : NumThreads) + NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + + unsigned NumTargetItems = Info.NumberOfPtrs; + // TODO: Use correct device ID + Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); + uint32_t SrcLocStrSize; + Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); + Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, + llvm::omp::IdentFlag(0), 0); + // TODO: Use correct NumIterations + Value *NumIterations = Builder.getInt64(0); + // TODO: Use correct DynCGGroupMem + Value *DynCGGroupMem = Builder.getInt32(0); + + KArgs = OpenMPIRBuilder::TargetKernelArgs( + NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC, + DynCGGroupMem, HasNoWait); + + // Assume no error was returned because TaskBodyCB and + // EmitTargetCallFallbackCB don't produce any. + OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { + // The presence of certain clauses on the target directive require the + // explicit generation of the target task. + if (RequiresOuterTargetTask) + return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP, + Dependencies, HasNoWait); + + return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID, + EmitTargetCallFallbackCB, KArgs, + DeviceID, RTLoc, AllocaIP); + }()); - OpenMPIRBuilder::TargetDataInfo Info( - /*RequiresDevicePointerInfo=*/false, - /*SeparateBeginEndCalls=*/true); + Builder.restoreIP(AfterIP); + return Error::success(); + }; - OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); - OpenMPIRBuilder::TargetDataRTArgs RTArgs; - OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, - RTArgs, MapInfo, - /*IsNonContiguous=*/true, - /*ForEndCall=*/false); + // If we don't have an ID for the target region, it means an offload entry + // wasn't created. In this case we just run the host fallback directly. + if (!OutlinedFnID) { + cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP())); + return; + } - SmallVector NumTeamsC; - SmallVector NumThreadsC; - for (auto V : NumTeams) - NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); - for (auto V : NumThreads) - NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + // If there's no IF clause, only generate the kernel launch code path. + if (!IfCond) { + cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP())); + return; + } - unsigned NumTargetItems = Info.NumberOfPtrs; - // TODO: Use correct device ID - Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); - uint32_t SrcLocStrSize; - Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); - Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, - llvm::omp::IdentFlag(0), 0); - // TODO: Use correct NumIterations - Value *NumIterations = Builder.getInt64(0); - // TODO: Use correct DynCGGroupMem - Value *DynCGGroupMem = Builder.getInt32(0); - - KArgs = OpenMPIRBuilder::TargetKernelArgs( - NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC, - DynCGGroupMem, HasNoWait); - - // Assume no error was returned because TaskBodyCB and - // EmitTargetCallFallbackCB don't produce any. - OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() { - // The presence of certain clauses on the target directive require the - // explicit generation of the target task. - if (RequiresOuterTargetTask) - return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP, - Dependencies, HasNoWait); - - return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID, - EmitTargetCallFallbackCB, KArgs, - DeviceID, RTLoc, AllocaIP); - }()); - - Builder.restoreIP(AfterIP); + cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen, + EmitTargetCallElse, AllocaIP)); } OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, ArrayRef NumTeams, ArrayRef NumThreads, - SmallVectorImpl &Args, GenMapInfoCallbackTy GenMapInfoCB, + SmallVectorImpl &Args, Value *IfCond, + GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, SmallVector Dependencies, bool HasNowait) { @@ -7455,7 +7478,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // that represents the target region. Do that now. if (!Config.isTargetDevice()) emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams, - NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait); + NumThreads, Args, IfCond, GenMapInfoCB, Dependencies, + HasNowait); return Builder.saveIP(); } diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index cdca725b147436c..94dce5243d70047 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6232,7 +6232,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, -1, 0, Inputs, + Builder.saveIP(), EntryInfo, /*NumTeams=*/-1, + /*NumThreads=*/0, Inputs, /*IfCond=*/nullptr, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); @@ -6343,8 +6344,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, /*NumTeams=*/-1, - /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, + EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0, + CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); @@ -6500,8 +6501,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, /*NumTeams=*/-1, - /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, + EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0, + CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.restoreIP(AfterIP); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index a364098e0bd8a62..0c637bd32ab3f32 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -285,7 +285,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkBare(op, result); checkDevice(op, result); checkHasDeviceAddr(op, result); - checkIf(op, result); checkInReduction(op, result); checkIsDevicePtr(op, result); // Privatization clauses are supported, except on some situations, so we @@ -4112,11 +4111,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::Value *ifCond = nullptr; + if (Value targetIfCond = targetOp.getIfExpr()) + ifCond = moduleTranslation.lookupValue(targetIfCond); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, - defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB, - argAccessorCB, dds, targetOp.getNowait()); + defaultValTeams, defaultValThreads, kernelInput, ifCond, genMapInfoCB, + bodyCB, argAccessorCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) return failure(); diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir new file mode 100644 index 000000000000000..706ad4411438bad --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @target_if_variable(%x : i1) { + omp.target if(%x) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_variable( + // CHECK-SAME: i1 %[[IF_COND:.*]]) + // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]] + + // CHECK: [[THEN_LABEL]]: + // CHECK-NOT: {{^.*}}: + // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel + // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0 + // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]] + + // CHECK: [[OFFLOAD_FAIL_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]() + // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]] + + // CHECK: [[OFFLOAD_CONT_LABEL]]: + // CHECK-NEXT: br label %[[END_LABEL:.*]] + + // CHECK: [[ELSE_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN]]() + // CHECK-NEXT: br label %[[END_LABEL]] + + llvm.func @target_if_true() { + %0 = llvm.mlir.constant(true) : i1 + omp.target if(%0) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_true() + // CHECK-NOT: {{^.*}}: + // CHECK: br label %[[ENTRY:.*]] + + // CHECK: [[ENTRY]]: + // CHECK-NOT: {{^.*}}: + // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel + // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0 + // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]] + + // CHECK: [[OFFLOAD_FAIL_LABEL]]: + // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]() + // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]] + + llvm.func @target_if_false() { + %0 = llvm.mlir.constant(false) : i1 + omp.target if(%0) { + omp.terminator + } + llvm.return + } + + // CHECK-LABEL: define void @target_if_false() + // CHECK-NEXT: br label %[[ENTRY:.*]] + + // CHECK: [[ENTRY]]: + // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}() +} + diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 83a0990d6316201..4e0925c833c3b7d 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -266,17 +266,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) { // ----- -llvm.func @target_if(%x : i1) { - // expected-error@below {{not yet implemented: Unhandled clause if in omp.target operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.target}} - omp.target if(%x) { - omp.terminator - } - llvm.return -} - -// ----- - omp.declare_reduction @add_f32 : f32 init { ^bb0(%arg: f32):