From fce2f99969d943b3639338f7bfaf85b37e2633d5 Mon Sep 17 00:00:00 2001 From: Bryce Adelstein Lelbach aka wash Date: Wed, 11 Dec 2024 11:46:50 -0800 Subject: [PATCH] [cuda.cooperative] Add striped_to_blocked exchanges to cuda.cooperative. --- .../experimental/block/__init__.py | 1 + .../experimental/block/_block_exchange.py | 38 +++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 python/cuda_cooperative/cuda/cooperative/experimental/block/_block_exchange.py diff --git a/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py b/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py index f51c3dccfb6..38a1e3414ba 100644 --- a/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py +++ b/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py @@ -7,3 +7,4 @@ from cuda.cooperative.experimental.block._block_scan import exclusive_sum from cuda.cooperative.experimental.block._block_radix_sort import radix_sort_keys, radix_sort_keys_descending from cuda.cooperative.experimental.block._block_load_store import load, store +from cuda.cooperative.experimental.block._block_exchange import striped_to_blocked diff --git a/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_exchange.py b/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_exchange.py new file mode 100644 index 00000000000..08571c60e04 --- /dev/null +++ b/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_exchange.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from cuda.cooperative.experimental._types import * +from cuda.cooperative.experimental._common import make_binary_tempfile, normalize_dim_param + + +def striped_to_blocked(dtype, threads_in_block, items_per_thread, warp_time_slicing=False): + template = Algorithm('BlockExchange', + 'StripedToBlocked', + 'block_exchange', + ['cub/block/block_exchange.cuh'], + [TemplateParameter('T'), + TemplateParameter('BLOCK_DIM_X'), + TemplateParameter('ITEMS_PER_THREAD'), + TemplateParameter('WARP_TIME_SLICING'), + TemplateParameter('BLOCK_DIM_Y'), + TemplateParameter('BLOCK_DIM_Z')], + [[Pointer(numba.uint8), + DependentArray(Dependency( + 'T'), Dependency('ITEMS_PER_THREAD'))], + [Pointer(numba.uint8), + DependentArray(Dependency( + 'T'), Dependency('ITEMS_PER_THREAD')), + DependentArray(Dependency( + 'T'), Dependency('ITEMS_PER_THREAD'))]]) + dim = normalize_dim_param(threads_in_block) + specialization = template.specialize({'T': dtype, + 'BLOCK_DIM_X': dim[0], + 'ITEMS_PER_THREAD': items_per_thread, + 'WARP_TIME_SLICING': int(warp_time_slicing), + 'BLOCK_DIM_Y': dim[1], + 'BLOCK_DIM_Z': dim[2]}) + return Invocable(temp_files=[make_binary_tempfile(ltoir, '.ltoir') for ltoir in specialization.get_lto_ir()], + temp_storage_bytes=specialization.get_temp_storage_bytes(), + algorithm=specialization)