Skip to content

Commit 634b051

Browse files
Update resource adaptor for rmm
Signed-off-by: Paul Mattione <[email protected]>
1 parent 79253a9 commit 634b051

File tree

1 file changed

+50
-11
lines changed

1 file changed

+50
-11
lines changed

src/main/cpp/src/SparkResourceAdaptorJni.cpp

+50-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
#include <rmm/mr/device/device_memory_resource.hpp>
17+
#include <rmm/resource_ref.hpp>
1818

1919
#include <cudf_jni_apis.hpp>
2020
#include <pthread.h>
@@ -384,10 +384,10 @@ class full_thread_state {
384384
* mitigation we might want to do to avoid killing a task with an out of
385385
* memory error.
386386
*/
387-
class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
387+
class spark_resource_adaptor final {
388388
public:
389389
spark_resource_adaptor(JNIEnv* env,
390-
rmm::mr::device_memory_resource* mr,
390+
rmm::device_async_resource_ref mr,
391391
std::shared_ptr<spdlog::logger>& logger,
392392
bool const is_log_enabled)
393393
: resource{mr}, logger{logger}, is_log_enabled{is_log_enabled}
@@ -399,7 +399,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
399399
logger->set_pattern("%H:%M:%S.%f,%v");
400400
}
401401

402-
rmm::mr::device_memory_resource* get_wrapped_resource() { return resource; }
402+
rmm::device_async_resource_ref get_wrapped_resource() { return resource; }
403403

404404
/**
405405
* Update the internal state so that a specific thread is dedicated to a task.
@@ -870,7 +870,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
870870
}
871871

872872
private:
873-
rmm::mr::device_memory_resource* const resource;
873+
rmm::device_async_resource_ref resource;
874874
std::shared_ptr<spdlog::logger> logger; ///< spdlog logger object
875875
bool const is_log_enabled;
876876

@@ -1728,13 +1728,46 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
17281728
return ret;
17291729
}
17301730

1731-
void* do_allocate(std::size_t const num_bytes, rmm::cuda_stream_view stream) override
1731+
/**
1732+
* Sync allocation method required to satisfy cuda::mr::resource concept
1733+
* Synchronous memory allocations are not supported
1734+
*/
1735+
void* allocate(std::size_t, std::size_t) { return nullptr; }
1736+
1737+
/**
1738+
* Sync deallocation method required to satisfy cuda::mr::resource concept
1739+
* Asynchronous memory allocations are not supported
1740+
*/
1741+
void deallocate(void*, std::size_t, std::size_t) {}
1742+
1743+
/**
1744+
* Equality comparison method required to satisfy cuda::mr::resource concept
1745+
*/
1746+
friend bool operator==(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs)
1747+
{
1748+
return (lhs.resource == rhs.resource) && (lhs.jvm == rhs.jvm);
1749+
}
1750+
1751+
/**
1752+
* Equality comparison method required to satisfy cuda::mr::resource concept
1753+
*/
1754+
friend bool operator!=(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs)
1755+
{
1756+
return !(lhs == rhs);
1757+
}
1758+
1759+
/**
1760+
* Async allocation method required to satisfy cuda::mr::async_resource concept
1761+
*/
1762+
void* allocate_async(std::size_t const num_bytes,
1763+
std::size_t const alignment,
1764+
rmm::cuda_stream_view stream)
17321765
{
17331766
auto const tid = static_cast<long>(pthread_self());
17341767
while (true) {
17351768
bool const likely_spill = pre_alloc(tid);
17361769
try {
1737-
void* ret = resource->allocate(num_bytes, stream);
1770+
void* ret = resource.allocate_async(num_bytes, alignment, stream);
17381771
post_alloc_success(tid, likely_spill);
17391772
return ret;
17401773
} catch (rmm::out_of_memory const& e) {
@@ -1787,9 +1820,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
17871820
wake_next_highest_priority_blocked(lock, true, is_for_cpu);
17881821
}
17891822

1790-
void do_deallocate(void* p, std::size_t size, rmm::cuda_stream_view stream) override
1823+
/**
1824+
* Async deallocation method required to satisfy cuda::mr::async_resource concept
1825+
*/
1826+
void deallocate_async(void* p,
1827+
std::size_t size,
1828+
std::size_t const alignment,
1829+
rmm::cuda_stream_view stream)
17911830
{
1792-
resource->deallocate(p, size, stream);
1831+
resource.deallocate_async(p, size, alignment, stream);
17931832
// deallocate success
17941833
if (size > 0) {
17951834
std::unique_lock<std::mutex> lock(state_mutex);
@@ -1818,7 +1857,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr
18181857
JNI_NULL_CHECK(env, child, "child is null", 0);
18191858
try {
18201859
cudf::jni::auto_set_device(env);
1821-
auto wrapped = reinterpret_cast<rmm::mr::device_memory_resource*>(child);
1860+
auto wrapped = reinterpret_cast<rmm::device_async_resource_ref*>(child);
18221861
cudf::jni::native_jstring nlogloc(env, log_loc);
18231862
std::shared_ptr<spdlog::logger> logger;
18241863
bool is_log_enabled;
@@ -1837,7 +1876,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr
18371876
}
18381877
}
18391878

1840-
auto ret = new spark_resource_adaptor(env, wrapped, logger, is_log_enabled);
1879+
auto ret = new spark_resource_adaptor(env, *wrapped, logger, is_log_enabled);
18411880
return cudf::jni::ptr_as_jlong(ret);
18421881
}
18431882
CATCH_STD(env, 0)

0 commit comments

Comments
 (0)