Skip to content

Commit

Permalink
Add type validation helpers for load_halo to ensure valid tuple types
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitrivlachos committed Dec 6, 2024
1 parent e2f2f16 commit 703db3b
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions include/device_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cuda_runtime.h>

#include <cuda/std/tuple>
#include <type_traits>

/*
* Kernel radii 🔳
Expand Down Expand Up @@ -94,6 +95,53 @@ struct PitchedArray2D {
}
};

/*
* Type validation helpers for load_halo
*
* These helpers ensure that only valid types are passed to the variadic
* parameter pack in `load_halo`. A valid type is defined as
* `cuda::std::tuple<PitchedArray2D<T>, PitchedArray2D<T>>`.
*
* `is_valid_pitched_array_tuple` checks a single type for validity.
* `are_valid_pitched_array_tuples` recursively validates all types
* in a parameter pack.
*
* These checks enforce compile-time correctness and prevent runtime
* errors caused by invalid types.
*
* The helpers use a technique called specialization to provide a custom
* implementation for specific cases. The default implementation of
* `is_valid_pitched_array_tuple` returns `false` for all types, but a
* specialized version explicitly recognizes and validates
* `cuda::std::tuple<PitchedArray2D<T>, PitchedArray2D<T>>`, returning `true`.
* This ensures only the intended types are considered valid.
*/

// Default case: a type is not a valid tuple
template <typename T>
struct is_valid_pitched_array_tuple : std::false_type {};

// Specialization for a valid tuple
// Checks if a type is a `cuda::std::tuple` of two `PitchedArray2D` objects.
template <typename T1, typename T2>
struct is_valid_pitched_array_tuple<
cuda::std::tuple<PitchedArray2D<T1>, PitchedArray2D<T2>>> : std::true_type {};

// Validate all types in a variadic parameter pack
template <typename... Args>
struct are_valid_pitched_array_tuples;

// Base case for parameter pack validation: an empty pack is always valid.
template <>
struct are_valid_pitched_array_tuples<> : std::true_type {};

// Recursive case: checks the first type and continues with the rest.
// If any type is invalid, the entire pack is considered invalid.
template <typename First, typename... Rest>
struct are_valid_pitched_array_tuples<First, Rest...>
: std::conditional_t<is_valid_pitched_array_tuple<First>::value,
are_valid_pitched_array_tuples<Rest...>,
std::false_type> {};
/**
* @brief Load the halo region of an image and mask into shared memory.
*
Expand Down Expand Up @@ -125,6 +173,11 @@ __device__ void load_halo(const cooperative_groups::thread_block block,
const uint8_t kernel_width,
const uint8_t kernel_height,
MappedPairs... mapped_pairs) {
// Validate the types in the parameter pack
static_assert(are_valid_pitched_array_tuples<MappedPairs...>::value,
"All mapped_pairs must be cuda::std::tuple<PitchedArray2D<T>, "
"PitchedArray2D<T>>");

// Compute local shared memory coordinates
int local_x = threadIdx.x + kernel_width;
int local_y = threadIdx.y + kernel_height;
Expand Down

0 comments on commit 703db3b

Please sign in to comment.