Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #340 from canonizer/onesweep-offset64
Browse files Browse the repository at this point in the history
64-bit Offsets in DeviceRadixSort
  • Loading branch information
alliepiper authored Jan 27, 2022
2 parents 8500ac0 + 178bbaa commit 93f26ab
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 227 deletions.
45 changes: 33 additions & 12 deletions cub/agent/agent_radix_sort_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "../block/radix_rank_sort_operations.cuh"
#include "../config.cuh"
#include "../thread/thread_reduce.cuh"
#include "../util_math.cuh"
#include "../util_type.cuh"


Expand Down Expand Up @@ -97,12 +98,13 @@ struct AgentRadixSortHistogram
};

typedef RadixSortTwiddle<IS_DESCENDING, KeyT> Twiddle;
typedef OffsetT ShmemAtomicOffsetT;
typedef std::uint32_t ShmemCounterT;
typedef ShmemCounterT ShmemAtomicCounterT;
typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;

struct _TempStorage
{
ShmemAtomicOffsetT bins[MAX_NUM_PASSES][RADIX_DIGITS][NUM_PARTS];
ShmemAtomicCounterT bins[MAX_NUM_PASSES][RADIX_DIGITS][NUM_PARTS];
};

struct TempStorage : Uninitialized<_TempStorage> {};
Expand Down Expand Up @@ -133,8 +135,11 @@ struct AgentRadixSortHistogram
d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)),
num_items(num_items), begin_bit(begin_bit), end_bit(end_bit),
num_passes((end_bit - begin_bit + RADIX_BITS - 1) / RADIX_BITS)
{}

__device__ __forceinline__ void Init()
{
// init bins
// Initialize bins to 0.
#pragma unroll
for (int bin = threadIdx.x; bin < RADIX_DIGITS; bin += BLOCK_THREADS)
{
Expand Down Expand Up @@ -219,17 +224,33 @@ struct AgentRadixSortHistogram

__device__ __forceinline__ void Process()
{
for (OffsetT tile_offset = blockIdx.x * TILE_ITEMS; tile_offset < num_items;
tile_offset += TILE_ITEMS * gridDim.x)
// Within a portion, avoid overflowing (u)int32 counters.
// Between portions, accumulate results in global memory.
const OffsetT MAX_PORTION_SIZE = 1 << 30;
OffsetT num_portions = cub::DivideAndRoundUp(num_items, MAX_PORTION_SIZE);
for (OffsetT portion = 0; portion < num_portions; ++portion)
{
UnsignedBits keys[ITEMS_PER_THREAD];
LoadTileKeys(tile_offset, keys);
AccumulateSharedHistograms(tile_offset, keys);
}
CTA_SYNC();
// Reset the counters.
Init();
CTA_SYNC();

// accumulate in global memory
AccumulateGlobalHistograms();
// Process the tiles.
OffsetT portion_offset = portion * MAX_PORTION_SIZE;
OffsetT portion_end =
portion_offset + CUB_MIN(MAX_PORTION_SIZE, num_items - portion_offset);
for (OffsetT tile_offset = portion_offset + blockIdx.x * TILE_ITEMS;
tile_offset < portion_end; tile_offset += TILE_ITEMS * gridDim.x)
{
UnsignedBits keys[ITEMS_PER_THREAD];
LoadTileKeys(tile_offset, keys);
AccumulateSharedHistograms(tile_offset, keys);
}
CTA_SYNC();

// Accumulate the result in global memory.
AccumulateGlobalHistograms();
CTA_SYNC();
}
}
};

Expand Down
29 changes: 15 additions & 14 deletions cub/agent/agent_radix_sort_onesweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ template <
bool IS_DESCENDING,
typename KeyT,
typename ValueT,
typename OffsetT>
typename OffsetT,
typename PortionOffsetT>
struct AgentRadixSortOnesweep
{
// constants
Expand All @@ -110,14 +111,14 @@ struct AgentRadixSortOnesweep
WARP_THREADS = CUB_PTX_WARP_THREADS,
BLOCK_WARPS = BLOCK_THREADS / WARP_THREADS,
WARP_MASK = ~0,
LOOKBACK_PARTIAL_MASK = 1 << (OffsetT(sizeof(OffsetT)) * 8 - 2),
LOOKBACK_GLOBAL_MASK = 1 << (OffsetT(sizeof(OffsetT)) * 8 - 1),
LOOKBACK_PARTIAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 2),
LOOKBACK_GLOBAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 1),
LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK,
LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK,
};

typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;
typedef OffsetT AtomicOffsetT;
typedef PortionOffsetT AtomicOffsetT;

static const RadixRankAlgorithm RANK_ALGORITHM =
AgentRadixSortOnesweepPolicy::RANK_ALGORITHM;
Expand Down Expand Up @@ -165,7 +166,7 @@ struct AgentRadixSortOnesweep
union
{
OffsetT global_offsets[RADIX_DIGITS];
OffsetT block_idx;
PortionOffsetT block_idx;
};
};

Expand All @@ -183,13 +184,13 @@ struct AgentRadixSortOnesweep
const UnsignedBits* d_keys_in;
ValueT* d_values_out;
const ValueT* d_values_in;
OffsetT num_items;
PortionOffsetT num_items;
ShiftDigitExtractor<KeyT> digit_extractor;

// other thread variables
int warp;
int lane;
OffsetT block_idx;
PortionOffsetT block_idx;
bool full_block;

// helper methods
Expand All @@ -213,7 +214,7 @@ struct AgentRadixSortOnesweep
{
// write the local sum into the bin
AtomicOffsetT& loc = d_lookback[block_idx * RADIX_DIGITS + bin];
OffsetT value = bins[u] | LOOKBACK_PARTIAL_MASK;
PortionOffsetT value = bins[u] | LOOKBACK_PARTIAL_MASK;
ThreadStore<STORE_VOLATILE>(&loc, value);
}
}
Expand All @@ -222,7 +223,7 @@ struct AgentRadixSortOnesweep
struct CountsCallback
{
typedef AgentRadixSortOnesweep<AgentRadixSortOnesweepPolicy, IS_DESCENDING, KeyT,
ValueT, OffsetT> AgentT;
ValueT, OffsetT, PortionOffsetT> AgentT;
AgentT& agent;
int (&bins)[BINS_PER_THREAD];
UnsignedBits (&keys)[ITEMS_PER_THREAD];
Expand Down Expand Up @@ -251,13 +252,13 @@ struct AgentRadixSortOnesweep
int bin = ThreadBin(u);
if (FULL_BINS || bin < RADIX_DIGITS)
{
OffsetT inc_sum = bins[u];
PortionOffsetT inc_sum = bins[u];
int want_mask = ~0;
// backtrack as long as necessary
for (OffsetT block_jdx = block_idx - 1; block_jdx >= 0; --block_jdx)
for (PortionOffsetT block_jdx = block_idx - 1; block_jdx >= 0; --block_jdx)
{
// wait for some value to appear
OffsetT value_j = 0;
PortionOffsetT value_j = 0;
AtomicOffsetT& loc_j = d_lookback[block_jdx * RADIX_DIGITS + bin];
do {
__threadfence_block(); // prevent hoisting loads from loop
Expand All @@ -269,7 +270,7 @@ struct AgentRadixSortOnesweep
if (value_j & LOOKBACK_GLOBAL_MASK) break;
}
AtomicOffsetT& loc_i = d_lookback[block_idx * RADIX_DIGITS + bin];
OffsetT value_i = inc_sum | LOOKBACK_GLOBAL_MASK;
PortionOffsetT value_i = inc_sum | LOOKBACK_GLOBAL_MASK;
ThreadStore<STORE_VOLATILE>(&loc_i, value_i);
s.global_offsets[bin] += inc_sum - bins[u];
}
Expand Down Expand Up @@ -638,7 +639,7 @@ struct AgentRadixSortOnesweep
const KeyT *d_keys_in,
ValueT *d_values_out,
const ValueT *d_values_in,
OffsetT num_items,
PortionOffsetT num_items,
int current_bit,
int num_bits)
: s(temp_storage.Alias())
Expand Down
Loading

0 comments on commit 93f26ab

Please sign in to comment.