-
Notifications
You must be signed in to change notification settings - Fork 204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[STF] Implement a reduce algorithm over CUB #3122
Draft
caugonnet
wants to merge
37
commits into
NVIDIA:main
Choose a base branch
from
caugonnet:stf_cub_reduce
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 22 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
ab1c73f
Copy the existing 08-cub-reduce example in another file indicating it…
caugonnet cac9438
Start to implement a reduce method over CUB
caugonnet a012450
Save WIP : start to implement transform_reduce on top of CUB
caugonnet b023269
Do the reduction part as well
caugonnet c920e13
clang-format
caugonnet c3468c2
Better types in transform_reduce
caugonnet a9db17c
Minor code improvements
caugonnet 023954a
compute shape size once
caugonnet 90db5a5
Merge branch 'main' into stf_cub_reduce
caugonnet 2794242
Explain the algorithm and rename scalar to scalar_view
caugonnet 38c820b
Remove some piece of code intended to use ->* in transform_reduce
caugonnet e9d4d3e
Get build to work
andralex 616ff99
Use chaining of operator->* in transform_reduce
andralex 493371f
Merge branch 'main' into stf_cub_reduce
caugonnet 30dae91
Implement an example of exclusive scan over slices
caugonnet afa8a21
clang-format
caugonnet 88baa83
improve reduce example to take a logical data of a slice only
caugonnet f56775f
fix some constness issue
caugonnet 2469994
fix some constness issue
caugonnet 318eb7b
Merge branch 'main' into stf_cub_reduce
caugonnet 0963c9f
Implement transform_exclusive_scan
caugonnet 26b41e1
ReduceOpWrapper -> LambdaOpWrapper
caugonnet 3b6ccfe
Use the ->* operator of the task to support graphs
caugonnet fc33c3a
clang-format
caugonnet 649fde6
WIP : Try to move CUB algorithms to utilities
caugonnet 872cd4d
Temporary experiment to use scopes again
caugonnet b9156bb
Save WIP
caugonnet fdc91a2
Make things work, first pass
andralex 20ab981
Merge branch 'NVIDIA:main' into stf_cub_reduce
caugonnet dac5c4b
clang-format
caugonnet 77314e6
Merge branch 'main' into stf_cub_reduce
caugonnet ea40836
Start to add facilities to create iterator for data instances (or get…
caugonnet 2cf2329
Merge branch 'main' into stf_cub_reduce
caugonnet 28ca24f
Save WIP : start to implement a view_of trait class which defines ite…
caugonnet 25c7e7e
restore private property
caugonnet 09812e9
Remove an inappropriate static assertion
caugonnet 7459acb
Merge branch 'main' into stf_cub_reduce
caugonnet File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of CUDASTF in CUDA C++ Core Libraries, | ||
// under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
/** | ||
* @file | ||
* @brief Example of reduction implementing using CUB | ||
*/ | ||
|
||
#include <cub/cub.cuh> | ||
|
||
#include <cuda/experimental/stf.cuh> | ||
|
||
using namespace cuda::experimental::stf; | ||
|
||
template <typename BinaryOp> | ||
struct OpWrapper | ||
{ | ||
OpWrapper(BinaryOp _op) | ||
: op(mv(_op)) {}; | ||
|
||
template <typename T> | ||
__device__ __forceinline__ T operator()(const T& a, const T& b) const | ||
{ | ||
return op(a, b); | ||
} | ||
|
||
BinaryOp op; | ||
}; | ||
|
||
template <typename Ctx, typename InT, typename OutT, typename BinaryOp> | ||
void exclusive_scan( | ||
Ctx& ctx, logical_data<slice<InT>> in_data, logical_data<slice<OutT>> out_data, BinaryOp&& op, OutT init_val) | ||
{ | ||
size_t nitems = in_data.shape().size(); | ||
|
||
// Determine temporary device storage requirements | ||
void* d_temp_storage = nullptr; | ||
size_t temp_storage_bytes = 0; | ||
cub::DeviceScan::ExclusiveScan( | ||
d_temp_storage, | ||
temp_storage_bytes, | ||
(InT*) nullptr, | ||
(OutT*) nullptr, | ||
OpWrapper<BinaryOp>(op), | ||
init_val, | ||
in_data.shape().size(), | ||
0); | ||
|
||
auto ltemp = ctx.logical_data(shape_of<slice<char>>(temp_storage_bytes)); | ||
|
||
ctx.task(in_data.read(), out_data.write(), ltemp.write()) | ||
->*[&op, init_val, nitems, temp_storage_bytes](cudaStream_t stream, auto d_in, auto d_out, auto d_temp) { | ||
size_t d_temp_size = shape(d_temp).size(); | ||
cub::DeviceScan::ExclusiveScan( | ||
(void*) d_temp.data_handle(), | ||
d_temp_size, | ||
(InT*) d_in.data_handle(), | ||
(OutT*) d_out.data_handle(), | ||
OpWrapper<BinaryOp>(op), | ||
init_val, | ||
nitems, | ||
stream); | ||
}; | ||
} | ||
|
||
template <typename Ctx> | ||
void run() | ||
{ | ||
Ctx ctx; | ||
|
||
const size_t N = 1024 * 16; | ||
|
||
::std::vector<int> X(N); | ||
::std::vector<int> out(N); | ||
|
||
::std::vector<int> ref_out(N); | ||
|
||
for (size_t ind = 0; ind < N; ind++) | ||
{ | ||
X[ind] = rand() % N; | ||
|
||
// compute the exclusive sum of X | ||
ref_out[ind] = (ind == 0) ? 0 : (X[ind - 1] + ref_out[ind - 1]); | ||
} | ||
|
||
auto lX = ctx.logical_data(X.data(), {N}); | ||
auto lout = ctx.logical_data(out.data(), {N}); | ||
|
||
exclusive_scan( | ||
ctx, | ||
lX, | ||
lout, | ||
[] __device__(const int& a, const int& b) { | ||
return a + b; | ||
}, | ||
0); | ||
|
||
ctx.finalize(); | ||
|
||
for (size_t i = 0; i < N; i++) | ||
{ | ||
_CCCL_ASSERT(ref_out[i] == out[i], "Incorrect result"); | ||
} | ||
} | ||
|
||
int main() | ||
{ | ||
run<stream_ctx>(); | ||
// run<graph_ctx>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of CUDASTF in CUDA C++ Core Libraries, | ||
// under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
/** | ||
* @file | ||
* @brief Example of reduction implementing using CUB kernels | ||
*/ | ||
|
||
#include <thrust/device_vector.h> | ||
|
||
#include <cuda/experimental/stf.cuh> | ||
|
||
using namespace cuda::experimental::stf; | ||
|
||
template <int BLOCK_THREADS, typename T> | ||
__global__ void reduce(slice<const T> values, slice<T> partials, size_t nelems) | ||
{ | ||
using namespace cub; | ||
typedef BlockReduce<T, BLOCK_THREADS> BlockReduceT; | ||
|
||
auto thread_id = BLOCK_THREADS * blockIdx.x + threadIdx.x; | ||
|
||
// Local reduction | ||
T local_sum = 0; | ||
for (size_t ind = thread_id; ind < nelems; ind += blockDim.x * gridDim.x) | ||
{ | ||
local_sum += values(ind); | ||
} | ||
|
||
__shared__ typename BlockReduceT::TempStorage temp_storage; | ||
|
||
// Per-thread tile data | ||
T result = BlockReduceT(temp_storage).Sum(local_sum); | ||
|
||
if (threadIdx.x == 0) | ||
{ | ||
partials(blockIdx.x) = result; | ||
} | ||
} | ||
|
||
template <typename Ctx> | ||
void run() | ||
{ | ||
Ctx ctx; | ||
|
||
const size_t N = 1024 * 16; | ||
const size_t BLOCK_SIZE = 128; | ||
const size_t num_blocks = 32; | ||
|
||
int *X, ref_tot; | ||
|
||
X = new int[N]; | ||
ref_tot = 0; | ||
|
||
for (size_t ind = 0; ind < N; ind++) | ||
{ | ||
X[ind] = rand() % N; | ||
ref_tot += X[ind]; | ||
} | ||
|
||
auto values = ctx.logical_data(X, {N}); | ||
auto partials = ctx.logical_data(shape_of<slice<int>>(num_blocks)); | ||
auto result = ctx.logical_data(shape_of<slice<int>>(1)); | ||
|
||
ctx.task(values.read(), partials.write(), result.write())->*[&](auto stream, auto values, auto partials, auto result) { | ||
// reduce values into partials | ||
reduce<BLOCK_SIZE, int><<<num_blocks, BLOCK_SIZE, 0, stream>>>(values, partials, N); | ||
|
||
// reduce partials on a single block into result | ||
reduce<BLOCK_SIZE, int><<<1, BLOCK_SIZE, 0, stream>>>(partials, result, num_blocks); | ||
}; | ||
|
||
ctx.host_launch(result.read())->*[&](auto p) { | ||
if (p(0) != ref_tot) | ||
{ | ||
fprintf(stderr, "INCORRECT RESULT: p sum = %d, ref tot = %d\n", p(0), ref_tot); | ||
abort(); | ||
} | ||
}; | ||
|
||
ctx.finalize(); | ||
} | ||
|
||
int main() | ||
{ | ||
run<stream_ctx>(); | ||
run<graph_ctx>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am unsure it is worth keeping this file, unless we want to show the interop with CUB