Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit 178bbaa

Browse files
canonizeralliepiper
authored andcommitted
Templated type of num_items in DeviceRadixSort.
List of individual changes: - Fixed test errors - OffsetT == unsigned long long for the 64-bit case - using std::{is_same,conditional} - using "portion" consistently for 2^28-2^30-sized chunks of the input array - HasEnoughMemory() takes overwrite into account. - moved checking for enough memory earlier. - added a CTA_SYNC() to the histogram kernel - disabled tests with NumItemsT != int for segmented sort - testing with 4.5 bln. items - tests for different NumItemsT - NumItemsT for all device sorting functions - wrapped ChooseOffsetT into namespace detail - fixed typos - templatized the type of num_items in 2 methods of DeviceRadixSort - tuned radix sort with 64-bit OffsetT for V100 - tuned for 64-bit OffsetT for A100 - separate tuning parameters for 64-bit OffsetT - improved downsweep policy for GP100 - option for 64-bit num_items with 32-bit shared memory histogram counters. - introduced PartOffsetT into Onesweep kernel. - OffsetT is now only used for offsets into the whole array (e.g. bin counts or global read/write offsets) - PartOffsetT is used for offsets that do not exceed a single part (e.g. decoupled look-back, block index, number of items inside a part) - this fixes problems when OffsetT is unsigned, and also contributes towards supporting 64-bit num_items
1 parent 8500ac0 commit 178bbaa

5 files changed

+343
-227
lines changed

cub/agent/agent_radix_sort_histogram.cuh

+33-12
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "../block/radix_rank_sort_operations.cuh"
3939
#include "../config.cuh"
4040
#include "../thread/thread_reduce.cuh"
41+
#include "../util_math.cuh"
4142
#include "../util_type.cuh"
4243

4344

@@ -97,12 +98,13 @@ struct AgentRadixSortHistogram
9798
};
9899

99100
typedef RadixSortTwiddle<IS_DESCENDING, KeyT> Twiddle;
100-
typedef OffsetT ShmemAtomicOffsetT;
101+
typedef std::uint32_t ShmemCounterT;
102+
typedef ShmemCounterT ShmemAtomicCounterT;
101103
typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;
102104

103105
struct _TempStorage
104106
{
105-
ShmemAtomicOffsetT bins[MAX_NUM_PASSES][RADIX_DIGITS][NUM_PARTS];
107+
ShmemAtomicCounterT bins[MAX_NUM_PASSES][RADIX_DIGITS][NUM_PARTS];
106108
};
107109

108110
struct TempStorage : Uninitialized<_TempStorage> {};
@@ -133,8 +135,11 @@ struct AgentRadixSortHistogram
133135
d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)),
134136
num_items(num_items), begin_bit(begin_bit), end_bit(end_bit),
135137
num_passes((end_bit - begin_bit + RADIX_BITS - 1) / RADIX_BITS)
138+
{}
139+
140+
__device__ __forceinline__ void Init()
136141
{
137-
// init bins
142+
// Initialize bins to 0.
138143
#pragma unroll
139144
for (int bin = threadIdx.x; bin < RADIX_DIGITS; bin += BLOCK_THREADS)
140145
{
@@ -219,17 +224,33 @@ struct AgentRadixSortHistogram
219224

220225
__device__ __forceinline__ void Process()
221226
{
222-
for (OffsetT tile_offset = blockIdx.x * TILE_ITEMS; tile_offset < num_items;
223-
tile_offset += TILE_ITEMS * gridDim.x)
227+
// Within a portion, avoid overflowing (u)int32 counters.
228+
// Between portions, accumulate results in global memory.
229+
const OffsetT MAX_PORTION_SIZE = 1 << 30;
230+
OffsetT num_portions = cub::DivideAndRoundUp(num_items, MAX_PORTION_SIZE);
231+
for (OffsetT portion = 0; portion < num_portions; ++portion)
224232
{
225-
UnsignedBits keys[ITEMS_PER_THREAD];
226-
LoadTileKeys(tile_offset, keys);
227-
AccumulateSharedHistograms(tile_offset, keys);
228-
}
229-
CTA_SYNC();
233+
// Reset the counters.
234+
Init();
235+
CTA_SYNC();
230236

231-
// accumulate in global memory
232-
AccumulateGlobalHistograms();
237+
// Process the tiles.
238+
OffsetT portion_offset = portion * MAX_PORTION_SIZE;
239+
OffsetT portion_end =
240+
portion_offset + CUB_MIN(MAX_PORTION_SIZE, num_items - portion_offset);
241+
for (OffsetT tile_offset = portion_offset + blockIdx.x * TILE_ITEMS;
242+
tile_offset < portion_end; tile_offset += TILE_ITEMS * gridDim.x)
243+
{
244+
UnsignedBits keys[ITEMS_PER_THREAD];
245+
LoadTileKeys(tile_offset, keys);
246+
AccumulateSharedHistograms(tile_offset, keys);
247+
}
248+
CTA_SYNC();
249+
250+
// Accumulate the result in global memory.
251+
AccumulateGlobalHistograms();
252+
CTA_SYNC();
253+
}
233254
}
234255
};
235256

cub/agent/agent_radix_sort_onesweep.cuh

+15-14
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ template <
9292
bool IS_DESCENDING,
9393
typename KeyT,
9494
typename ValueT,
95-
typename OffsetT>
95+
typename OffsetT,
96+
typename PortionOffsetT>
9697
struct AgentRadixSortOnesweep
9798
{
9899
// constants
@@ -110,14 +111,14 @@ struct AgentRadixSortOnesweep
110111
WARP_THREADS = CUB_PTX_WARP_THREADS,
111112
BLOCK_WARPS = BLOCK_THREADS / WARP_THREADS,
112113
WARP_MASK = ~0,
113-
LOOKBACK_PARTIAL_MASK = 1 << (OffsetT(sizeof(OffsetT)) * 8 - 2),
114-
LOOKBACK_GLOBAL_MASK = 1 << (OffsetT(sizeof(OffsetT)) * 8 - 1),
114+
LOOKBACK_PARTIAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 2),
115+
LOOKBACK_GLOBAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 1),
115116
LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK,
116117
LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK,
117118
};
118119

119120
typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;
120-
typedef OffsetT AtomicOffsetT;
121+
typedef PortionOffsetT AtomicOffsetT;
121122

122123
static const RadixRankAlgorithm RANK_ALGORITHM =
123124
AgentRadixSortOnesweepPolicy::RANK_ALGORITHM;
@@ -165,7 +166,7 @@ struct AgentRadixSortOnesweep
165166
union
166167
{
167168
OffsetT global_offsets[RADIX_DIGITS];
168-
OffsetT block_idx;
169+
PortionOffsetT block_idx;
169170
};
170171
};
171172

@@ -183,13 +184,13 @@ struct AgentRadixSortOnesweep
183184
const UnsignedBits* d_keys_in;
184185
ValueT* d_values_out;
185186
const ValueT* d_values_in;
186-
OffsetT num_items;
187+
PortionOffsetT num_items;
187188
ShiftDigitExtractor<KeyT> digit_extractor;
188189

189190
// other thread variables
190191
int warp;
191192
int lane;
192-
OffsetT block_idx;
193+
PortionOffsetT block_idx;
193194
bool full_block;
194195

195196
// helper methods
@@ -213,7 +214,7 @@ struct AgentRadixSortOnesweep
213214
{
214215
// write the local sum into the bin
215216
AtomicOffsetT& loc = d_lookback[block_idx * RADIX_DIGITS + bin];
216-
OffsetT value = bins[u] | LOOKBACK_PARTIAL_MASK;
217+
PortionOffsetT value = bins[u] | LOOKBACK_PARTIAL_MASK;
217218
ThreadStore<STORE_VOLATILE>(&loc, value);
218219
}
219220
}
@@ -222,7 +223,7 @@ struct AgentRadixSortOnesweep
222223
struct CountsCallback
223224
{
224225
typedef AgentRadixSortOnesweep<AgentRadixSortOnesweepPolicy, IS_DESCENDING, KeyT,
225-
ValueT, OffsetT> AgentT;
226+
ValueT, OffsetT, PortionOffsetT> AgentT;
226227
AgentT& agent;
227228
int (&bins)[BINS_PER_THREAD];
228229
UnsignedBits (&keys)[ITEMS_PER_THREAD];
@@ -251,13 +252,13 @@ struct AgentRadixSortOnesweep
251252
int bin = ThreadBin(u);
252253
if (FULL_BINS || bin < RADIX_DIGITS)
253254
{
254-
OffsetT inc_sum = bins[u];
255+
PortionOffsetT inc_sum = bins[u];
255256
int want_mask = ~0;
256257
// backtrack as long as necessary
257-
for (OffsetT block_jdx = block_idx - 1; block_jdx >= 0; --block_jdx)
258+
for (PortionOffsetT block_jdx = block_idx - 1; block_jdx >= 0; --block_jdx)
258259
{
259260
// wait for some value to appear
260-
OffsetT value_j = 0;
261+
PortionOffsetT value_j = 0;
261262
AtomicOffsetT& loc_j = d_lookback[block_jdx * RADIX_DIGITS + bin];
262263
do {
263264
__threadfence_block(); // prevent hoisting loads from loop
@@ -269,7 +270,7 @@ struct AgentRadixSortOnesweep
269270
if (value_j & LOOKBACK_GLOBAL_MASK) break;
270271
}
271272
AtomicOffsetT& loc_i = d_lookback[block_idx * RADIX_DIGITS + bin];
272-
OffsetT value_i = inc_sum | LOOKBACK_GLOBAL_MASK;
273+
PortionOffsetT value_i = inc_sum | LOOKBACK_GLOBAL_MASK;
273274
ThreadStore<STORE_VOLATILE>(&loc_i, value_i);
274275
s.global_offsets[bin] += inc_sum - bins[u];
275276
}
@@ -638,7 +639,7 @@ struct AgentRadixSortOnesweep
638639
const KeyT *d_keys_in,
639640
ValueT *d_values_out,
640641
const ValueT *d_values_in,
641-
OffsetT num_items,
642+
PortionOffsetT num_items,
642643
int current_bit,
643644
int num_bits)
644645
: s(temp_storage.Alias())

0 commit comments

Comments
 (0)