14
14
* limitations under the License.
15
15
*/
16
16
17
- #include < rmm/mr/device/device_memory_resource .hpp>
17
+ #include < rmm/resource_ref .hpp>
18
18
19
19
#include < cudf_jni_apis.hpp>
20
20
#include < pthread.h>
@@ -384,10 +384,10 @@ class full_thread_state {
384
384
* mitigation we might want to do to avoid killing a task with an out of
385
385
* memory error.
386
386
*/
387
- class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
387
+ class spark_resource_adaptor final {
388
388
public:
389
389
spark_resource_adaptor (JNIEnv* env,
390
- rmm::mr::device_memory_resource* mr,
390
+ rmm::device_async_resource_ref mr,
391
391
std::shared_ptr<spdlog::logger>& logger,
392
392
bool const is_log_enabled)
393
393
: resource{mr}, logger{logger}, is_log_enabled{is_log_enabled}
@@ -399,7 +399,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
399
399
logger->set_pattern (" %H:%M:%S.%f,%v" );
400
400
}
401
401
402
- rmm::mr::device_memory_resource* get_wrapped_resource () { return resource; }
402
+ rmm::device_async_resource_ref get_wrapped_resource () { return resource; }
403
403
404
404
/* *
405
405
* Update the internal state so that a specific thread is dedicated to a task.
@@ -870,7 +870,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
870
870
}
871
871
872
872
private:
873
- rmm::mr::device_memory_resource* const resource;
873
+ rmm::device_async_resource_ref resource;
874
874
std::shared_ptr<spdlog::logger> logger; // /< spdlog logger object
875
875
bool const is_log_enabled;
876
876
@@ -1728,13 +1728,46 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
1728
1728
return ret;
1729
1729
}
1730
1730
1731
- void * do_allocate (std::size_t const num_bytes, rmm::cuda_stream_view stream) override
1731
+ /* *
1732
+ * Sync allocation method required to satisfy cuda::mr::resource concept
1733
+ * Synchronous memory allocations are not supported
1734
+ */
1735
+ void * allocate (std::size_t , std::size_t ) { return nullptr ; }
1736
+
1737
+ /* *
1738
+ * Sync deallocation method required to satisfy cuda::mr::resource concept
1739
+ * Asynchronous memory allocations are not supported
1740
+ */
1741
+ void deallocate (void *, std::size_t , std::size_t ) {}
1742
+
1743
+ /* *
1744
+ * Equality comparison method required to satisfy cuda::mr::resource concept
1745
+ */
1746
+ friend bool operator ==(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs)
1747
+ {
1748
+ return (lhs.resource == rhs.resource ) && (lhs.jvm == rhs.jvm );
1749
+ }
1750
+
1751
+ /* *
1752
+ * Equality comparison method required to satisfy cuda::mr::resource concept
1753
+ */
1754
+ friend bool operator !=(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs)
1755
+ {
1756
+ return !(lhs == rhs);
1757
+ }
1758
+
1759
+ /* *
1760
+ * Async allocation method required to satisfy cuda::mr::async_resource concept
1761
+ */
1762
+ void * allocate_async (std::size_t const num_bytes,
1763
+ std::size_t const alignment,
1764
+ rmm::cuda_stream_view stream)
1732
1765
{
1733
1766
auto const tid = static_cast <long >(pthread_self ());
1734
1767
while (true ) {
1735
1768
bool const likely_spill = pre_alloc (tid);
1736
1769
try {
1737
- void * ret = resource-> allocate (num_bytes, stream);
1770
+ void * ret = resource. allocate_async (num_bytes, alignment , stream);
1738
1771
post_alloc_success (tid, likely_spill);
1739
1772
return ret;
1740
1773
} catch (rmm::out_of_memory const & e) {
@@ -1787,9 +1820,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
1787
1820
wake_next_highest_priority_blocked (lock, true , is_for_cpu);
1788
1821
}
1789
1822
1790
- void do_deallocate (void * p, std::size_t size, rmm::cuda_stream_view stream) override
1823
+ /* *
1824
+ * Async deallocation method required to satisfy cuda::mr::async_resource concept
1825
+ */
1826
+ void deallocate_async (void * p,
1827
+ std::size_t size,
1828
+ std::size_t const alignment,
1829
+ rmm::cuda_stream_view stream)
1791
1830
{
1792
- resource-> deallocate (p, size, stream);
1831
+ resource. deallocate_async (p, size, alignment , stream);
1793
1832
// deallocate success
1794
1833
if (size > 0 ) {
1795
1834
std::unique_lock<std::mutex> lock (state_mutex);
@@ -1818,7 +1857,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr
1818
1857
JNI_NULL_CHECK (env, child, " child is null" , 0 );
1819
1858
try {
1820
1859
cudf::jni::auto_set_device (env);
1821
- auto wrapped = reinterpret_cast <rmm::mr::device_memory_resource *>(child);
1860
+ auto wrapped = reinterpret_cast <rmm::device_async_resource_ref *>(child);
1822
1861
cudf::jni::native_jstring nlogloc (env, log_loc);
1823
1862
std::shared_ptr<spdlog::logger> logger;
1824
1863
bool is_log_enabled;
@@ -1837,7 +1876,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr
1837
1876
}
1838
1877
}
1839
1878
1840
- auto ret = new spark_resource_adaptor (env, wrapped, logger, is_log_enabled);
1879
+ auto ret = new spark_resource_adaptor (env, * wrapped, logger, is_log_enabled);
1841
1880
return cudf::jni::ptr_as_jlong (ret);
1842
1881
}
1843
1882
CATCH_STD (env, 0 )
0 commit comments