Skip to content

Commit

Permalink
[SYCL] Add support for work group memory free function kernel paramet…
Browse files Browse the repository at this point in the history
…er (#15861)

This PR concludes the implementation of the work group memory
[extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_work_group_memory.asciidoc).
It adds support for work group memory parameters when using free
function kernels.

---------

Co-authored-by: lorenc.bushi <[email protected]>
  • Loading branch information
lbushi25 and lorenc.bushi authored Nov 19, 2024
1 parent 8922ee7 commit 39483ab
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 33 deletions.
141 changes: 123 additions & 18 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1522,7 +1522,7 @@ class KernelObjVisitor {
void visitParam(ParmVarDecl *Param, QualType ParamTy,
HandlerTys &...Handlers) {
if (isSyclSpecialType(ParamTy, SemaSYCLRef))
KP_FOR_EACH(handleOtherType, Param, ParamTy);
KP_FOR_EACH(handleSyclSpecialType, Param, ParamTy);
else if (ParamTy->isStructureOrClassType()) {
if (KP_FOR_EACH(handleStructType, Param, ParamTy)) {
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl();
Expand Down Expand Up @@ -2075,8 +2075,11 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
}

bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type) << ParamTy;
IsInvalid = true;
if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type)
<< ParamTy;
IsInvalid = true;
}
return isValid();
}

Expand Down Expand Up @@ -2228,8 +2231,8 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
}

bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory))
unsupportedFreeFunctionParamType(); // TODO
return true;
}

Expand Down Expand Up @@ -3013,9 +3016,26 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
return handleSpecialType(FD, FieldTy);
}

bool handleSyclSpecialType(ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
const auto *RecordDecl = ParamTy->getAsCXXRecordDecl();
assert(RecordDecl && "The type must be a RecordDecl");
CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName);
assert(InitMethod && "The type must have the __init method");
// Don't do -1 here because we count on this to be the first parameter
// added (if any).
size_t ParamIndex = Params.size();
for (const ParmVarDecl *Param : InitMethod->parameters()) {
QualType ParamTy = Param->getType();
addParam(Param, ParamTy.getCanonicalType());
// Propagate add_ir_attributes_kernel_parameter attribute.
if (const auto *AddIRAttr =
Param->getAttr<SYCLAddIRAttributesKernelParameterAttr>())
Params.back()->addAttr(AddIRAttr->clone(SemaSYCLRef.getASTContext()));
}
LastParamIndex = ParamIndex;
} else // TODO
unsupportedFreeFunctionParamType();
return true;
}

Expand Down Expand Up @@ -3291,9 +3311,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
}

bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
// TODO
unsupportedFreeFunctionParamType();
return true;
return handleSpecialType(ParamTy);
}

bool handleSyclSpecialType(const CXXRecordDecl *, const CXXBaseSpecifier &BS,
Expand Down Expand Up @@ -4442,6 +4460,45 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
{});
}

MemberExpr *buildMemberExpr(Expr *Base, ValueDecl *Member) {
DeclAccessPair MemberDAP = DeclAccessPair::make(Member, AS_none);
MemberExpr *Result = SemaSYCLRef.SemaRef.BuildMemberExpr(
Base, /*IsArrow */ false, FreeFunctionSrcLoc, NestedNameSpecifierLoc(),
FreeFunctionSrcLoc, Member, MemberDAP,
/*HadMultipleCandidates*/ false,
DeclarationNameInfo(Member->getDeclName(), FreeFunctionSrcLoc),
Member->getType(), VK_LValue, OK_Ordinary);
return Result;
}

void createSpecialMethodCall(const CXXRecordDecl *RD, StringRef MethodName,
Expr *MemberBaseExpr,
SmallVectorImpl<Stmt *> &AddTo) {
CXXMethodDecl *Method = getMethodByName(RD, MethodName);
if (!Method)
return;
unsigned NumParams = Method->getNumParams();
llvm::SmallVector<Expr *, 4> ParamDREs(NumParams);
llvm::ArrayRef<ParmVarDecl *> KernelParameters =
DeclCreator.getParamVarDeclsForCurrentField();
for (size_t I = 0; I < NumParams; ++I) {
QualType ParamType = KernelParameters[I]->getOriginalType();
ParamDREs[I] = SemaSYCLRef.SemaRef.BuildDeclRefExpr(
KernelParameters[I], ParamType, VK_LValue, FreeFunctionSrcLoc);
}
MemberExpr *MethodME = buildMemberExpr(MemberBaseExpr, Method);
QualType ResultTy = Method->getReturnType();
ExprValueKind VK = Expr::getValueKindForType(ResultTy);
ResultTy = ResultTy.getNonLValueExprType(SemaSYCLRef.getASTContext());
llvm::SmallVector<Expr *, 4> ParamStmts;
const auto *Proto = cast<FunctionProtoType>(Method->getType());
SemaSYCLRef.SemaRef.GatherArgumentsForCall(FreeFunctionSrcLoc, Method,
Proto, 0, ParamDREs, ParamStmts);
AddTo.push_back(CXXMemberCallExpr::Create(
SemaSYCLRef.getASTContext(), MethodME, ParamStmts, ResultTy, VK,
FreeFunctionSrcLoc, FPOptionsOverride()));
}

public:
static constexpr const bool VisitInsideSimpleContainers = false;

Expand All @@ -4461,9 +4518,53 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
return true;
}

bool handleSyclSpecialType(ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
// Default inits the type, then calls the init-method in the body.
// A type may not have a public default constructor as per its spec so
// typically if this is the case the default constructor will be private and
// in such cases we must manually override the access specifier from private
// to public just for the duration of this default initialization.
// TODO: Revisit this approach once https://github.com/intel/llvm/issues/16061
// is closed.
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
const auto *RecordDecl = ParamTy->getAsCXXRecordDecl();
AccessSpecifier DefaultConstructorAccess;
auto DefaultConstructor =
std::find_if(RecordDecl->ctor_begin(), RecordDecl->ctor_end(),
[](auto it) { return it->isDefaultConstructor(); });
DefaultConstructorAccess = DefaultConstructor->getAccess();
DefaultConstructor->setAccess(AS_public);

QualType Ty = PD->getOriginalType();
ASTContext &Ctx = SemaSYCLRef.SemaRef.getASTContext();
VarDecl *WorkGroupMemoryClone = VarDecl::Create(
Ctx, DeclCreator.getKernelDecl(), FreeFunctionSrcLoc,
FreeFunctionSrcLoc, PD->getIdentifier(), PD->getType(),
Ctx.getTrivialTypeSourceInfo(Ty), SC_None);
InitializedEntity VarEntity =
InitializedEntity::InitializeVariable(WorkGroupMemoryClone);
InitializationKind InitKind =
InitializationKind::CreateDefault(FreeFunctionSrcLoc);
InitializationSequence InitSeq(SemaSYCLRef.SemaRef, VarEntity, InitKind,
std::nullopt);
ExprResult Init = InitSeq.Perform(SemaSYCLRef.SemaRef, VarEntity,
InitKind, std::nullopt);
WorkGroupMemoryClone->setInit(
SemaSYCLRef.SemaRef.MaybeCreateExprWithCleanups(Init.get()));
WorkGroupMemoryClone->setInitStyle(VarDecl::CallInit);
DefaultConstructor->setAccess(DefaultConstructorAccess);

Stmt *DS = new (SemaSYCLRef.getASTContext())
DeclStmt(DeclGroupRef(WorkGroupMemoryClone), FreeFunctionSrcLoc,
FreeFunctionSrcLoc);
BodyStmts.push_back(DS);
Expr *MemberBaseExpr = SemaSYCLRef.SemaRef.BuildDeclRefExpr(
WorkGroupMemoryClone, Ty, VK_PRValue, FreeFunctionSrcLoc);
createSpecialMethodCall(RecordDecl, InitMethodName, MemberBaseExpr,
BodyStmts);
ArgExprs.push_back(MemberBaseExpr);
} else // TODO
unsupportedFreeFunctionParamType();
return true;
}

Expand Down Expand Up @@ -4748,9 +4849,11 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
return true;
}

bool handleSyclSpecialType(ParmVarDecl *, QualType) final {
// TODO
unsupportedFreeFunctionParamType();
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory))
addParam(PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
else
unsupportedFreeFunctionParamType(); // TODO
return true;
}

Expand Down Expand Up @@ -6227,7 +6330,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
O << "#include <sycl/detail/defines_elementary.hpp>\n";
O << "#include <sycl/detail/kernel_desc.hpp>\n";
O << "#include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n";

O << "\n";

LangOptions LO;
Expand Down Expand Up @@ -6502,6 +6604,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {

O << "\n";
O << "// Forward declarations of kernel and its argument types:\n";
Policy.SuppressDefaultTemplateArgs = false;
FwdDeclEmitter.Visit(K.SyclKernel->getType());
O << "\n";

Expand Down Expand Up @@ -6579,6 +6682,8 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
}
O << ";\n";
O << "}\n";
Policy.SuppressDefaultTemplateArgs = true;
Policy.EnforceDefaultTemplateArgs = false;

// Generate is_kernel, is_single_task_kernel and nd_range_kernel functions.
O << "namespace sycl {\n";
Expand Down
40 changes: 39 additions & 1 deletion clang/test/CodeGenSYCL/free_function_int_header.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// RUN: FileCheck -input-file=%t.h %s
//
// This test checks integration header contents for free functions with scalar,
// pointer and non-decomposed struct parameters.
// pointer, non-decomposed struct parameters and work group memory parameters.

#include "mock_properties.hpp"
#include "sycl.hpp"
Expand Down Expand Up @@ -96,6 +96,12 @@ void ff_7(KArgWithPtrArray<ArrSize> KArg) {

template void ff_7(KArgWithPtrArray<TestArrSize> KArg);

__attribute__((sycl_device))
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 0)]]
void ff_8(sycl::work_group_memory<int>) {
}


// CHECK: const char* const kernel_names[] = {
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piii
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_2Piiii
Expand All @@ -105,6 +111,7 @@ template void ff_7(KArgWithPtrArray<TestArrSize> KArg);
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_410NoPointers8Pointers3Agg
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_6I3Agg7DerivedEvT_T0_i
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_7ILi3EEv16KArgWithPtrArrayIXT_EE
// CHECK-NEXT: {{.*}}__sycl_kernel_ff_8N4sycl3_V117work_group_memoryIiEE
// CHECK-NEXT: ""
// CHECK-NEXT: };

Expand Down Expand Up @@ -148,6 +155,9 @@ template void ff_7(KArgWithPtrArray<TestArrSize> KArg);
// CHECK: //--- _Z18__sycl_kernel_ff_7ILi3EEv16KArgWithPtrArrayIXT_EE
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 48, 0 },

// CHECK: //--- _Z18__sycl_kernel_ff_8N4sycl3_V117work_group_memoryIiEE
// CHECK-NEXT: { kernel_param_kind_t::kind_work_group_memory, 8, 0 },

// CHECK: { kernel_param_kind_t::kind_invalid, -987654321, -987654321 },
// CHECK-NEXT: };

Expand Down Expand Up @@ -294,6 +304,26 @@ template void ff_7(KArgWithPtrArray<TestArrSize> KArg);
// CHECK-NEXT: };
// CHECK-NEXT: }

// CHECK: Definition of _Z18__sycl_kernel_ff_8N4sycl3_V117work_group_memoryIiEE as a free function kernel

// CHECK: Forward declarations of kernel and its argument types:
// CHECK: template <typename DataT> class work_group_memory;

// CHECK: void ff_8(sycl::work_group_memory<int>);
// CHECK-NEXT: static constexpr auto __sycl_shim9() {
// CHECK-NEXT: return (void (*)(class sycl::work_group_memory<int>))ff_8;
// CHECK-NEXT: }
// CHECK-NEXT: namespace sycl {
// CHECK-NEXT: template <>
// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim9()> {
// CHECK-NEXT: static constexpr bool value = true;
// CHECK-NEXT: };
// CHECK-NEXT: template <>
// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim9()> {
// CHECK-NEXT: static constexpr bool value = true;
// CHECK-NEXT: };
// CHECK-NEXT: }

// CHECK: #include <sycl/kernel_bundle.hpp>

// CHECK: Definition of kernel_id of _Z18__sycl_kernel_ff_2Piii
Expand Down Expand Up @@ -359,3 +389,11 @@ template void ff_7(KArgWithPtrArray<TestArrSize> KArg);
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z18__sycl_kernel_ff_7ILi3EEv16KArgWithPtrArrayIXT_EE"});
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: Definition of kernel_id of _Z18__sycl_kernel_ff_8N4sycl3_V117work_group_memoryIiEE
// CHECK-NEXT: namespace sycl {
// CHECK-NEXT: template <>
// CHECK-NEXT: kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim9()>() {
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z18__sycl_kernel_ff_8N4sycl3_V117work_group_memoryIiEE"});
// CHECK-NEXT: }
// CHECK-NEXT: }
17 changes: 16 additions & 1 deletion clang/test/CodeGenSYCL/free_function_kernel_params.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: %clang_cc1 -internal-isystem %S/Inputs -fsycl-is-device -triple spir64 \
// RUN: -emit-llvm %s -o - | FileCheck %s
// This test checks parameter IR generation for free functions with parameters
// of non-decomposed struct type.
// of non-decomposed struct type and work group memory type.

#include "sycl.hpp"

Expand Down Expand Up @@ -56,3 +56,18 @@ template void ff_6(KArgWithPtrArray<TestArrSize> KArg);
// CHECK: %struct.KArgWithPtrArray = type { [3 x ptr addrspace(4)], [3 x i32], [3 x i32] }
// CHECK: define dso_local spir_kernel void @{{.*}}__sycl_kernel{{.*}}(ptr noundef byval(%struct.NoPointers) align 4 %__arg_S1, ptr noundef byval(%struct.__generated_Pointers) align 8 %__arg_S2, ptr noundef byval(%struct.__generated_Agg) align 8 %__arg_S3)
// CHECK: define dso_local spir_kernel void @{{.*}}__sycl_kernel_ff_6{{.*}}(ptr noundef byval(%struct.__generated_KArgWithPtrArray) align 8 %__arg_KArg)

__attribute__((sycl_device))
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 0)]]
void ff_7(sycl::work_group_memory<int> mem) {
}

// CHECK: define dso_local spir_kernel void @{{.*}}__sycl_kernel_ff_7{{.*}}(ptr addrspace(3) noundef align 4 %__arg_Ptr)
// CHECK: %__arg_Ptr.addr = alloca ptr addrspace(3), align 8
// CHECK-NEXT: %mem = alloca %"class.sycl::_V1::work_group_memory", align 8
// CHECK: %__arg_Ptr.addr.ascast = addrspacecast ptr %__arg_Ptr.addr to ptr addrspace(4)
// CHECK-NEXT: %mem.ascast = addrspacecast ptr %mem to ptr addrspace(4)
// CHECK: store ptr addrspace(3) %__arg_Ptr, ptr addrspace(4) %__arg_Ptr.addr.ascast, align 8
// CHECK-NEXT: [[REGISTER:%[a-zA-Z0-9_]+]] = load ptr addrspace(3), ptr addrspace(4) %__arg_Ptr.addr.ascast, align 8
// CHECK-NEXT: call spir_func void @{{.*}}work_group_memory{{.*}}__init{{.*}}(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %mem.ascast, ptr addrspace(3) noundef [[REGISTER]])

22 changes: 21 additions & 1 deletion clang/test/SemaSYCL/free_function_kernel_params.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: %clang_cc1 -internal-isystem %S/Inputs -fsycl-is-device -ast-dump \
// RUN: %s -o - | FileCheck %s
// This test checks parameter rewriting for free functions with parameters
// of type scalar, pointer and non-decomposed struct.
// of type scalar, pointer, non-decomposed struct and work group memory.

#include "sycl.hpp"

Expand Down Expand Up @@ -171,3 +171,23 @@ template void ff_6(Agg S1, Derived1 S2, int);
// CHECK-NEXT: DeclRefExpr {{.*}} '__generated_Derived1' lvalue ParmVar {{.*}} '__arg_S2' '__generated_Derived1'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' <LValueToRValue>
// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} '__arg_end' 'int'

__attribute__((sycl_device))
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 0)]]
void ff_7(sycl::work_group_memory<int> mem) {
}
// CHECK: FunctionDecl {{.*}}__sycl_kernel{{.*}}'void (__local int *)'
// CHECK-NEXT: ParmVarDecl {{.*}} used __arg_Ptr '__local int *'
// CHECK-NEXT: CompoundStmt
// CHECK-NEXT: DeclStmt
// CHECK-NEXT: VarDecl {{.*}} used mem 'sycl::work_group_memory<int>' callinit
// CHECK-NEXT: CXXConstructExpr {{.*}} 'sycl::work_group_memory<int>' 'void () noexcept'
// CHECK-NEXT: CXXMemberCallExpr {{.*}} 'void'
// CHECK-NEXT: MemberExpr {{.*}} 'void (__local int *)' lvalue .__init
// CHECK-NEXT: DeclRefExpr {{.*}} 'sycl::work_group_memory<int>' Var {{.*}} 'mem' 'sycl::work_group_memory<int>'
// CHECK-NEXT: ImplicitCastExpr {{.*}} '__local int *' <LValueToRValue>
// CHECK-NEXT: DeclRefExpr {{.*}} '__local int *' lvalue ParmVar {{.*}} '__arg_Ptr' '__local int *'
// CHECK-NEXT: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(sycl::work_group_memory<int>)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (sycl::work_group_memory<int>)' lvalue Function {{.*}} 'ff_7' 'void (sycl::work_group_memory<int>)'
// CHECK-NEXT: DeclRefExpr {{.*}} 'sycl::work_group_memory<int>' Var {{.*}} 'mem' 'sycl::work_group_memory<int>'
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,10 @@ This extension also depends on the following other SYCL extensions:

== Status

This is a proposed extension specification, intended to gather community
feedback.
Interfaces defined in this specification may not be implemented yet or may be
in a preliminary state.
The specification itself may also change in incompatible ways before it is
finalized.
This is an experimental extension specification, intended to provide early
access to features and gather community feedback. Interfaces defined in this
specification are implemented in {dpcpp}, but they are not finalized and may
change incompatibly in future versions of {dpcpp} without prior notice.
*Shipping software products should not rely on APIs defined in this
specification.*

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//===-------------------- work_group_memory.hpp ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand Down Expand Up @@ -103,6 +102,9 @@ class __SYCL_SPECIAL_CLASS __SYCL_TYPE(work_group_memory) work_group_memory
}

private:
friend class sycl::handler; // needed in order for handler class to be aware
// of the private inheritance with
// work_group_memory_impl as base class
decoratedPtr ptr = nullptr;
};
} // namespace ext::oneapi::experimental
Expand Down
Loading

0 comments on commit 39483ab

Please sign in to comment.