Skip to content

Commit

Permalink
Refactor scheduler code
Browse files Browse the repository at this point in the history
Create helper function for getting UR details out of CG.
  • Loading branch information
EwanC committed Oct 23, 2024
1 parent be6d4a9 commit c838513
Showing 1 changed file with 47 additions and 54 deletions.
101 changes: 47 additions & 54 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2477,98 +2477,91 @@ static ur_result_t SetKernelParamsAndLaunch(
return Error;
}

ur_result_t enqueueImpCommandBufferKernel(
context Ctx, DeviceImplPtr DeviceImpl,
ur_exp_command_buffer_handle_t CommandBuffer,
const CGExecKernel &CommandGroup,
std::vector<ur_exp_command_buffer_sync_point_t> &SyncPoints,
ur_exp_command_buffer_sync_point_t *OutSyncPoint,
ur_exp_command_buffer_command_handle_t *OutCommand,
const std::function<void *(Requirement *Req)> &getMemAllocationFunc) {
auto ContextImpl = sycl::detail::getSyclObjImpl(Ctx);
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();

const std::vector<std::weak_ptr<sycl::detail::CGExecKernel>>
&AlternativeKernels = CommandGroup.MAlternativeKernels;
namespace {
std::tuple<ur_kernel_handle_t, std::shared_ptr<device_image_impl>,
const KernelArgMask *>
getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl,
DeviceImplPtr DeviceImpl,
std::vector<ur_kernel_handle_t> &UrKernelsToRelease,
std::vector<ur_program_handle_t> &UrProgramsToRelease) {

// UR kernel and program for 'CommandGroup'
ur_kernel_handle_t UrKernel = nullptr;
ur_program_handle_t UrProgram = nullptr;

// Impl objects created when 'CommandGroup' is from a kernel bundle
std::shared_ptr<kernel_impl> SyclKernelImpl = nullptr;
std::shared_ptr<device_image_impl> DeviceImageImpl = nullptr;

// List of ur objects to be released after UR call
std::vector<ur_kernel_handle_t> UrKernelsToRelease;
std::vector<ur_program_handle_t> UrProgramsToRelease;

auto Kernel = CommandGroup.MSyclKernel;
auto KernelBundleImplPtr = CommandGroup.MKernelBundle;
const KernelArgMask *EliminatedArgMask = nullptr;

// Use kernel_bundle if available unless it is interop.
// Interop bundles can't be used in the first branch, because the kernels
// in interop kernel bundles (if any) do not have kernel_id
// and can therefore not be looked up, but since they are self-contained
// they can simply be launched directly.
if (KernelBundleImplPtr && !KernelBundleImplPtr->isInterop()) {
if (auto KernelBundleImplPtr = CommandGroup.MKernelBundle;
KernelBundleImplPtr && !KernelBundleImplPtr->isInterop()) {
auto KernelName = CommandGroup.MKernelName;
kernel_id KernelID =
detail::ProgramManager::getInstance().getSYCLKernelID(KernelName);

kernel SyclKernel =
KernelBundleImplPtr->get_kernel(KernelID, KernelBundleImplPtr);
SyclKernelImpl = detail::getSyclObjImpl(SyclKernel);

auto SyclKernelImpl = detail::getSyclObjImpl(SyclKernel);
UrKernel = SyclKernelImpl->getHandleRef();
DeviceImageImpl = SyclKernelImpl->getDeviceImage();
UrProgram = DeviceImageImpl->get_ur_program_ref();
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
} else if (Kernel != nullptr) {
} else if (auto Kernel = CommandGroup.MSyclKernel; Kernel != nullptr) {
UrKernel = Kernel->getHandleRef();
UrProgram = Kernel->getProgramRef();
EliminatedArgMask = Kernel->getKernelArgMask();
} else {
ur_program_handle_t UrProgram = nullptr;
std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) =
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, CommandGroup.MKernelName);
UrKernelsToRelease.push_back(UrKernel);
UrProgramsToRelease.push_back(UrProgram);
}
return std::make_tuple(UrKernel, DeviceImageImpl, EliminatedArgMask);
}
} // anonymous namespace

ur_result_t enqueueImpCommandBufferKernel(
context Ctx, DeviceImplPtr DeviceImpl,
ur_exp_command_buffer_handle_t CommandBuffer,
const CGExecKernel &CommandGroup,
std::vector<ur_exp_command_buffer_sync_point_t> &SyncPoints,
ur_exp_command_buffer_sync_point_t *OutSyncPoint,
ur_exp_command_buffer_command_handle_t *OutCommand,
const std::function<void *(Requirement *Req)> &getMemAllocationFunc) {
// List of ur objects to be released after UR call. We don't do anything
// with the ur_program_handle_t objects, but need to update their reference
// count.
std::vector<ur_kernel_handle_t> UrKernelsToRelease;
std::vector<ur_program_handle_t> UrProgramsToRelease;

ur_kernel_handle_t UrKernel = nullptr;
std::shared_ptr<device_image_impl> DeviceImageImpl = nullptr;
const KernelArgMask *EliminatedArgMask = nullptr;

auto ContextImpl = sycl::detail::getSyclObjImpl(Ctx);
std::tie(UrKernel, DeviceImageImpl, EliminatedArgMask) =
getCGKernelInfo(CommandGroup, ContextImpl, DeviceImpl, UrKernelsToRelease,
UrProgramsToRelease);

// Build up the list of UR kernel handles that the UR command could be
// updated to use.
std::vector<ur_kernel_handle_t> AltUrKernels;
const std::vector<std::weak_ptr<sycl::detail::CGExecKernel>>
&AlternativeKernels = CommandGroup.MAlternativeKernels;
for (const auto &AltCGKernelWP : AlternativeKernels) {
auto AltCGKernel = AltCGKernelWP.lock();
assert(AltCGKernel != nullptr);

ur_kernel_handle_t AltUrKernel = nullptr;
if (auto KernelBundleImplPtr = AltCGKernel->MKernelBundle;
KernelBundleImplPtr && !KernelBundleImplPtr->isInterop()) {
auto KernelName = AltCGKernel->MKernelName;
kernel_id KernelID =
detail::ProgramManager::getInstance().getSYCLKernelID(KernelName);
kernel SyclKernel =
KernelBundleImplPtr->get_kernel(KernelID, KernelBundleImplPtr);
AltUrKernel = detail::getSyclObjImpl(SyclKernel)->getHandleRef();
} else if (AltCGKernel->MSyclKernel != nullptr) {
AltUrKernel = Kernel->getHandleRef();
} else {
ur_program_handle_t UrProgram = nullptr;
std::tie(AltUrKernel, std::ignore, std::ignore, UrProgram) =
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, AltCGKernel->MKernelName);
UrKernelsToRelease.push_back(AltUrKernel);
UrProgramsToRelease.push_back(UrProgram);
}

if (AltUrKernel != UrKernel) {
// Don't include command-group 'CommandGroup' in the list to pass to UR,
// as this will be used for the primary ur kernel parameter.
AltUrKernels.push_back(AltUrKernel);
}
std::tie(AltUrKernel, std::ignore, std::ignore) =
getCGKernelInfo(*AltCGKernel.get(), ContextImpl, DeviceImpl,
UrKernelsToRelease, UrProgramsToRelease);
AltUrKernels.push_back(AltUrKernel);
}

const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
auto SetFunc = [&Adapter, &UrKernel, &DeviceImageImpl, &Ctx,
&getMemAllocationFunc](sycl::detail::ArgDesc &Arg,
size_t NextTrueIndex) {
Expand Down

0 comments on commit c838513

Please sign in to comment.