Skip to content

Commit 9d1ea51

Browse files
npmillerfabiomestre
authored andcommitted
[SYCL][CUDA] Remove size checks from USM allocations (#10034)
These checks are causing issues for very large USM allocations because the `MAX_MEM_ALLOC_SIZE` reported is lower than what CUDA actually supports. We will follow up with an update on the reported `MAX_MEM_ALLOC_SIZE`, but it makes sense to remove the checks either way, as the CUDA allocation functions will return an error if they can't allocate the memory.
1 parent 16d4f24 commit 9d1ea51

File tree

1 file changed

+15
-40
lines changed

1 file changed

+15
-40
lines changed

usm.cpp

+15-40
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
2424
[[maybe_unused]] ur_usm_pool_handle_t pool, size_t size, void **ppMem) {
2525
UR_ASSERT(ppMem, UR_RESULT_ERROR_INVALID_NULL_POINTER);
2626
UR_ASSERT(hContext, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
27-
28-
size_t DeviceMaxMemAllocSize = 0;
29-
UR_ASSERT(urDeviceGetInfo(hContext->getDevice(),
30-
UR_DEVICE_INFO_MAX_MEM_ALLOC_SIZE, sizeof(size_t),
31-
static_cast<void *>(&DeviceMaxMemAllocSize),
32-
nullptr) == UR_RESULT_SUCCESS,
33-
UR_RESULT_ERROR_INVALID_DEVICE);
34-
UR_ASSERT(size > 0 && size <= DeviceMaxMemAllocSize,
35-
UR_RESULT_ERROR_INVALID_USM_SIZE);
27+
UR_ASSERT(!pUSMDesc || (pUSMDesc->align == 0 ||
28+
((pUSMDesc->align & (pUSMDesc->align - 1)) == 0)),
29+
UR_RESULT_ERROR_INVALID_VALUE);
3630

3731
ur_result_t Result = UR_RESULT_SUCCESS;
3832
try {
@@ -42,13 +36,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
4236
Result = Err;
4337
}
4438

45-
UR_ASSERT(!pUSMDesc || (pUSMDesc->align == 0 ||
46-
((pUSMDesc->align & (pUSMDesc->align - 1)) == 0)),
47-
UR_RESULT_ERROR_INVALID_VALUE);
48-
49-
assert(Result == UR_RESULT_SUCCESS &&
50-
(!pUSMDesc || pUSMDesc->align == 0 ||
51-
reinterpret_cast<std::uintptr_t>(*ppMem) % pUSMDesc->align == 0));
39+
if (Result == UR_RESULT_SUCCESS) {
40+
assert((!pUSMDesc || pUSMDesc->align == 0 ||
41+
reinterpret_cast<std::uintptr_t>(*ppMem) % pUSMDesc->align == 0));
42+
}
5243

5344
return Result;
5445
}
@@ -66,15 +57,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
6657
((pUSMDesc->align & (pUSMDesc->align - 1)) == 0)),
6758
UR_RESULT_ERROR_INVALID_VALUE);
6859

69-
size_t DeviceMaxMemAllocSize = 0;
70-
UR_ASSERT(urDeviceGetInfo(hDevice, UR_DEVICE_INFO_MAX_MEM_ALLOC_SIZE,
71-
sizeof(size_t),
72-
static_cast<void *>(&DeviceMaxMemAllocSize),
73-
nullptr) == UR_RESULT_SUCCESS,
74-
UR_RESULT_ERROR_INVALID_DEVICE);
75-
UR_ASSERT(size > 0 && size <= DeviceMaxMemAllocSize,
76-
UR_RESULT_ERROR_INVALID_USM_SIZE);
77-
7860
ur_result_t Result = UR_RESULT_SUCCESS;
7961
try {
8062
ScopedContext Active(hContext);
@@ -83,9 +65,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
8365
return Err;
8466
}
8567

86-
assert(Result == UR_RESULT_SUCCESS &&
87-
(!pUSMDesc || pUSMDesc->align == 0 ||
88-
reinterpret_cast<std::uintptr_t>(*ppMem) % pUSMDesc->align == 0));
68+
if (Result == UR_RESULT_SUCCESS) {
69+
assert((!pUSMDesc || pUSMDesc->align == 0 ||
70+
reinterpret_cast<std::uintptr_t>(*ppMem) % pUSMDesc->align == 0));
71+
}
8972

9073
return Result;
9174
}
@@ -103,15 +86,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
10386
((pUSMDesc->align & (pUSMDesc->align - 1)) == 0)),
10487
UR_RESULT_ERROR_INVALID_VALUE);
10588

106-
size_t DeviceMaxMemAllocSize = 0;
107-
UR_ASSERT(urDeviceGetInfo(hDevice, UR_DEVICE_INFO_MAX_MEM_ALLOC_SIZE,
108-
sizeof(size_t),
109-
static_cast<void *>(&DeviceMaxMemAllocSize),
110-
nullptr) == UR_RESULT_SUCCESS,
111-
UR_RESULT_ERROR_INVALID_DEVICE);
112-
UR_ASSERT(size > 0 && size <= DeviceMaxMemAllocSize,
113-
UR_RESULT_ERROR_INVALID_USM_SIZE);
114-
11589
ur_result_t Result = UR_RESULT_SUCCESS;
11690
try {
11791
ScopedContext Active(hContext);
@@ -121,9 +95,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
12195
return Err;
12296
}
12397

124-
assert(Result == UR_RESULT_SUCCESS &&
125-
(!pUSMDesc || pUSMDesc->align == 0 ||
126-
reinterpret_cast<std::uintptr_t>(*ppMem) % pUSMDesc->align == 0));
98+
if (Result == UR_RESULT_SUCCESS) {
99+
assert((!pUSMDesc || pUSMDesc->align == 0 ||
100+
reinterpret_cast<std::uintptr_t>(*ppMem) % pUSMDesc->align == 0));
101+
}
127102

128103
return Result;
129104
}

0 commit comments

Comments
 (0)