Skip to content

Commit

Permalink
fix regression due to code generation (#474)
Browse files Browse the repository at this point in the history
  • Loading branch information
nolmoonen authored Apr 15, 2024
1 parent fd43a96 commit ef49284
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions library/src/rng/sobol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@
namespace rocrand_impl::host
{

template<bool Scrambled, class Engine, class Constant>
__host__ __device__ Engine create_engine(const Constant* vectors,
[[maybe_unused]] Constant scramble_constant,
const unsigned int offset)
{
if constexpr(Scrambled)
{
return Engine(vectors, scramble_constant, offset);
}
else
{
return Engine(vectors, offset);
}
};

template<unsigned int OutputPerThread,
bool Scrambled,
class Engine,
Expand Down Expand Up @@ -110,19 +125,6 @@ __host__ __device__ void generate_sobol(dim3 block_idx,
}();

const Constant scramble_constant = Scrambled ? scramble_constants[dimension] : 0;
const auto create_engine
= [scramble_constant](const Constant* vectors, const unsigned int offset)
{
if constexpr(Scrambled)
{
return Engine(vectors, scramble_constant, offset);
}
else
{
(void)scramble_constant;
return Engine(vectors, offset);
}
};

data += dimension * n;

Expand All @@ -135,7 +137,9 @@ __host__ __device__ void generate_sobol(dim3 block_idx,
if constexpr(output_per_thread == 1)
{
const unsigned int engine_offset = engine_id * output_per_thread;
Engine engine = create_engine(vectors_ptr, offset + engine_offset);
Engine engine = create_engine<Scrambled, Engine>(vectors_ptr,
scramble_constant,
offset + engine_offset);

while(index < n)
{
Expand All @@ -155,7 +159,9 @@ __host__ __device__ void generate_sobol(dim3 block_idx,
const unsigned int engine_offset
= engine_id * output_per_thread
+ (engine_id == 0 ? 0 : head_size); // The first engine writes head_size values
Engine engine = create_engine(vectors_ptr, offset + engine_offset);
Engine engine = create_engine<Scrambled, Engine>(vectors_ptr,
scramble_constant,
offset + engine_offset);

if(engine_id == 0)
{
Expand Down

0 comments on commit ef49284

Please sign in to comment.