Skip to content

Commit 8347658

Browse files
committed
[Cuda] Save the Cuda native error code on adapter-specific errors
1 parent 290bb93 commit 8347658

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

source/adapters/cuda/adapter.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
9191

9292
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError(
9393
ur_adapter_handle_t, const char **ppMessage, int32_t *pError) {
94-
std::ignore = pError;
94+
*pError = ErrorAdapterNativeCode;
9595
*ppMessage = ErrorMessage;
9696
return ErrorMessageCode;
9797
}

source/adapters/cuda/common.cpp

+17-6
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
#include "common.hpp"
1212
#include "logger/ur_logger.hpp"
13+
#include "ur_api.h"
1314

1415
#include <cuda.h>
1516

1617
#include <sstream>
18+
#include <string.h>
1719

1820
ur_result_t mapErrorUR(CUresult Result) {
1921
switch (Result) {
@@ -105,6 +107,7 @@ void detail::ur::assertion(bool Condition, const char *Message) {
105107
// Global variables for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
106108
thread_local ur_result_t ErrorMessageCode = UR_RESULT_SUCCESS;
107109
thread_local char ErrorMessage[MaxMessageSize];
110+
thread_local int32_t ErrorAdapterNativeCode = 0;
108111

109112
// Utility function for setting a message and warning
110113
[[maybe_unused]] void setErrorMessage(const char *pMessage,
@@ -114,16 +117,24 @@ thread_local char ErrorMessage[MaxMessageSize];
114117
ErrorMessageCode = ErrorCode;
115118
}
116119

117-
void setPluginSpecificMessage(CUresult cu_res) {
120+
[[maybe_unused]] void setAdapterSpecificMessage(CUresult cu_res) {
121+
ErrorAdapterNativeCode = static_cast<int32_t>(cu_res);
122+
// according to the documentation of the cuGetErrorName and cuGetErrorString
123+
// CUDA driver APIs, both error_name and error_string are null-terminated.
118124
const char *error_string;
119125
const char *error_name;
120126
cuGetErrorName(cu_res, &error_name);
121127
cuGetErrorString(cu_res, &error_string);
122-
char *message = (char *)malloc(strlen(error_string) + strlen(error_name) + 2);
123-
strcpy(message, error_name);
124-
strcat(message, "\n");
125-
strcat(message, error_string);
128+
static constexpr char new_line[] = "\n";
129+
// non-null-terminated sizes
130+
const size_t error_string_size = std::strlen(error_string);
131+
const size_t error_name_size = std::strlen(error_name);
132+
char *message = reinterpret_cast<char *>(
133+
std::malloc(error_string_size + error_name_size + sizeof(new_line)));
134+
std::strcpy(message, error_name);
135+
std::strcat(message, new_line);
136+
std::strncat(message, error_string, error_string_size);
126137

127138
setErrorMessage(message, UR_RESULT_ERROR_ADAPTER_SPECIFIC);
128-
free(message);
139+
std::free(message);
129140
}

source/adapters/cuda/common.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ std::string getCudaVersionString();
3535
constexpr size_t MaxMessageSize = 256;
3636
extern thread_local ur_result_t ErrorMessageCode;
3737
extern thread_local char ErrorMessage[MaxMessageSize];
38+
extern thread_local int32_t ErrorAdapterNativeCode;
3839

3940
// Utility function for setting a message and warning
4041
[[maybe_unused]] void setErrorMessage(const char *pMessage,
4142
ur_result_t ErrorCode);
4243

43-
void setPluginSpecificMessage(CUresult cu_res);
44+
void setAdapterSpecificMessage(CUresult cu_res);
4445

4546
/// ------ Error handling, matching OpenCL plugin semantics.
4647
namespace detail {

0 commit comments

Comments
 (0)