diff --git a/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py b/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py index da73294b518..ea91a53448d 100644 --- a/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py +++ b/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py @@ -2,24 +2,26 @@ # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from cuda.cooperative.experimental.block._block_merge_sort import merge_sort_keys +from cuda.cooperative.experimental.block._block_load_store import load +from cuda.cooperative.experimental.block._block_load_store import store +from cuda.cooperative.experimental.block._block_exchange import striped_to_blocked from cuda.cooperative.experimental.block._block_reduce import reduce from cuda.cooperative.experimental.block._block_reduce import sum from cuda.cooperative.experimental.block._block_scan import exclusive_sum +from cuda.cooperative.experimental.block._block_merge_sort import merge_sort_keys from cuda.cooperative.experimental.block._block_radix_sort import radix_sort_keys from cuda.cooperative.experimental.block._block_radix_sort import ( radix_sort_keys_descending, ) -from cuda.cooperative.experimental.block._block_load_store import load -from cuda.cooperative.experimental.block._block_load_store import store __all__ = [ - "merge_sort_keys", + "load", + "store", + "striped_to_blocked", "reduce", "sum", "exclusive_sum", + "merge_sort_keys", "radix_sort_keys", "radix_sort_keys_descending", - "load", - "store", -] +] \ No newline at end of file diff --git a/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_exchange.py b/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_exchange.py new file mode 100644 index 00000000000..08571c60e04 --- /dev/null +++ b/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_exchange.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from cuda.cooperative.experimental._types import * +from cuda.cooperative.experimental._common import make_binary_tempfile, normalize_dim_param + + +def striped_to_blocked(dtype, threads_in_block, items_per_thread, warp_time_slicing=False): + template = Algorithm('BlockExchange', + 'StripedToBlocked', + 'block_exchange', + ['cub/block/block_exchange.cuh'], + [TemplateParameter('T'), + TemplateParameter('BLOCK_DIM_X'), + TemplateParameter('ITEMS_PER_THREAD'), + TemplateParameter('WARP_TIME_SLICING'), + TemplateParameter('BLOCK_DIM_Y'), + TemplateParameter('BLOCK_DIM_Z')], + [[Pointer(numba.uint8), + DependentArray(Dependency( + 'T'), Dependency('ITEMS_PER_THREAD'))], + [Pointer(numba.uint8), + DependentArray(Dependency( + 'T'), Dependency('ITEMS_PER_THREAD')), + DependentArray(Dependency( + 'T'), Dependency('ITEMS_PER_THREAD'))]]) + dim = normalize_dim_param(threads_in_block) + specialization = template.specialize({'T': dtype, + 'BLOCK_DIM_X': dim[0], + 'ITEMS_PER_THREAD': items_per_thread, + 'WARP_TIME_SLICING': int(warp_time_slicing), + 'BLOCK_DIM_Y': dim[1], + 'BLOCK_DIM_Z': dim[2]}) + return Invocable(temp_files=[make_binary_tempfile(ltoir, '.ltoir') for ltoir in specialization.get_lto_ir()], + temp_storage_bytes=specialization.get_temp_storage_bytes(), + algorithm=specialization)