Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enable swap(...) for temporaries and between different reference types #1646

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions testing/device_reference.cu
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ DECLARE_UNITTEST(TestDeviceReferenceManipulation);

void TestDeviceReferenceSwap(void)
{
using std::swap;
typedef int T;

thrust::device_vector<T> v(2);
Expand All @@ -218,14 +219,37 @@ void TestDeviceReferenceSwap(void)
ref2 = 13;

// test thrust::swap()
thrust::swap(ref1, ref2);
swap(ref1, ref2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the point of this code is to test thrust::swap, no?

The existence of thrust::swap is a bit unfortunate. If you try to swap a type that has both std:: and thrust:: as associated namespace, the compiler has no way to pick between std::swap and thrust::swap. We could consider defining thrust::swap as a global function object that dispatches to (unqualified) swap in a context that has brought std::swap in with a using declaration. Then a simple call to thrust::swap will do the ADL lookup internally, and the potential ambiguity between std::swap and thrust::swap goes away. I think that's a change that is unlikely to break anybody.

@allisonvacanti thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, your proposed fix sounds good to me so long as nothing is specializing thrust::swap.

The THRUST_INLINE_CONSTANT macro may be needed when defining the function object, see https://github.com/NVIDIA/thrust/blob/main/thrust/detail/config/cpp_compatibility.h#L47-L74

ASSERT_EQUAL(13, ref1);
ASSERT_EQUAL(7, ref2);

// test thrust::swap(device_reference<T>, device_reference<T>)
swap(v.front(), v.back());
ASSERT_EQUAL(7, v.front());
ASSERT_EQUAL(13, v.back());

// test .swap()
ref1.swap(ref2);
ASSERT_EQUAL(7, ref1);
ASSERT_EQUAL(13, ref2);
ASSERT_EQUAL(13, ref1);
ASSERT_EQUAL(7, ref2);

// test .swap(device_reference<T>)
v.front().swap(v.back());
ASSERT_EQUAL(7, v.front());
ASSERT_EQUAL(13, v.back());

// test thrust::swap(device_reference<T>, T&)
T val = 29;
swap(v.front(), val);
ASSERT_EQUAL(7, val);
ASSERT_EQUAL(29, v.front());
ASSERT_EQUAL(13, v.back());

// test thrust::swap(T&, device_reference<T>)
swap(val, v.back());
ASSERT_EQUAL(13, val);
ASSERT_EQUAL(29, v.front());
ASSERT_EQUAL(7, v.back());
}
DECLARE_UNITTEST(TestDeviceReferenceSwap);

34 changes: 31 additions & 3 deletions thrust/detail/reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class reference
* \param other The \p tagged_reference to swap with.
*/
__host__ __device__
void swap(derived_type& other)
void swap(derived_type other)
{
// Avoid default-constructing a system; instead, just use a null pointer
// for dispatch. This assumes that `get_value` will not access any system
Expand Down Expand Up @@ -372,7 +372,7 @@ class reference

template <typename System>
__host__ __device__
void swap(System* system, derived_type& other)
void swap(System* system, derived_type other)
{
using thrust::system::detail::generic::select_system;
using thrust::system::detail::generic::iter_swap;
Expand Down Expand Up @@ -509,10 +509,38 @@ class tagged_reference<void const, Tag> {};
*/
template <typename Element, typename Tag>
__host__ __device__
void swap(tagged_reference<Element, Tag>& x, tagged_reference<Element, Tag>& y)
void swap(tagged_reference<Element, Tag> x, tagged_reference<Element, Tag> y)
{
x.swap(y);
}

/*! Exchanges the values of two objects referred to by a \p tagged_reference and a regular reference.
*
* \param x The \p tagged_reference of interest.
* \param y The regular reference of interest.
*/
template <typename Element, typename Tag>
__host__ __device__
void swap(Element& x, tagged_reference<Element, Tag> y)
{
Element tmp = x;
x = y;
y = tmp;
}

/*! Exchanges the values of two objects referred to by a regular reference and a \p tagged_reference.
*
* \param x The regular reference of interest.
* \param y The \p tagged_reference of interest.
*/
template <typename Element, typename Tag>
__host__ __device__
void swap(tagged_reference<Element, Tag> x, Element& y)
{
Element tmp = x;
x = y;
y = tmp;
}

THRUST_NAMESPACE_END

30 changes: 28 additions & 2 deletions thrust/device_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ template<typename T>
* \p other The other \p device_reference with which to swap.
*/
__host__ __device__
void swap(device_reference &other);
void swap(device_reference other);

/*! Prefix increment operator increments the object referenced by this
* \p device_reference.
Expand Down Expand Up @@ -962,11 +962,37 @@ template<typename T>
*/
template<typename T>
__host__ __device__
void swap(device_reference<T>& x, device_reference<T>& y)
void swap(device_reference<T> x, device_reference<T> y)
{
x.swap(y);
}

/*! swaps the value of a \p device_reference with a regular reference.
* \p x The \p device_reference of interest.
* \p y The regular reference of interest.
*/
template<typename T>
__host__ __device__
void swap(device_reference<T> x, T &y)
{
T tmp = x;
x = y;
y = tmp;
}

/*! swaps the value of a regular reference with a \p device_reference.
* \p x The regular reference of interest.
* \p y The \p device_reference of interest.
*/
template<typename T>
__host__ __device__
void swap(T &x, device_reference<T> y)
{
T tmp = x;
x = y;
y = tmp;
}

// declare these methods for the purpose of Doxygenating them
// they actually are defined for a derived-from class
#if THRUST_DOXYGEN
Expand Down