Skip to content

Commit 26c36ac

Browse files
TL/CUDA: check devices
1 parent 915c5ac commit 26c36ac

File tree

6 files changed

+34
-8
lines changed

6 files changed

+34
-8
lines changed

src/components/tl/cuda/tl_cuda.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ typedef struct ucc_tl_cuda_sync {
142142
ucc_tl_cuda_mem_info_t mem_info_src;
143143
ucc_tl_cuda_mem_info_t mem_info_dst;
144144
cudaEvent_t ipc_event_local;
145-
cudaIpcEventHandle_t ev_handle;
146145
union {
147146
struct {
148147
size_t sbytes[UCC_TL_CUDA_MAX_PEERS];

src/components/tl/cuda/tl_cuda_coll.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ extern const char
2020
#define UCC_TL_CUDA_CHECK_DEVICE_MATCH(_team) do { \
2121
int _dev; \
2222
CUDA_CHECK(cudaGetDevice(&_dev)); \
23-
if (_dev != (_team)->device) { \
23+
if (((_team)->device != -1) && _dev != (_team)->device) { \
2424
tl_error(UCC_TL_TEAM_LIB(_team), "CUDA device mismatch, " \
2525
"current device %d, team device %d\n", _dev, \
2626
(_team)->device); \

src/components/tl/cuda/tl_cuda_team.c

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@ ucc_status_t ucc_tl_cuda_comm_init_post(ucc_tl_cuda_team_t *team)
3838
if (cu_ctx == NULL || cu_st != CUDA_SUCCESS) {
3939
tl_debug(tl_lib,
4040
"cannot create CUDA TL team without active CUDA context");
41-
return UCC_ERR_NO_RESOURCE;
41+
team->device_id = TL_CUDA_DEVICE_INVALID;
42+
team->state = TL_CUDA_STATE_ERROR;
43+
goto exchnage_rank_ids;
4244
}
4345

4446
status = CUDA_FUNC(cudaGetDevice(&team->device));
4547
if (status != UCC_OK) {
4648
tl_debug(tl_lib, "failed to get current device id");
47-
return status;
49+
team->device_id = TL_CUDA_DEVICE_INVALID;
50+
team->state = TL_CUDA_STATE_ERROR;
51+
goto exchnage_rank_ids;
4852
}
4953

5054
status = ucc_tl_cuda_topo_get_pci_id(team->device, &team->device_id);
@@ -88,6 +92,7 @@ ucc_status_t ucc_tl_cuda_comm_init_post(ucc_tl_cuda_team_t *team)
8892
goto free_scratch;
8993
}
9094

95+
exchnage_rank_ids:
9196
rank_id->pci_id = team->device_id;
9297
status = team->oob.allgather(rank_id, team->ids, rank_id_size,
9398
team->oob.coll_info, &team->oob_req);
@@ -127,6 +132,17 @@ ucc_status_t ucc_tl_cuda_comm_init_test(ucc_tl_cuda_team_t *team)
127132
return status;
128133
}
129134
team->oob.req_free(team->oob_req);
135+
/* check all ranks have valid CUDA device set */
136+
for (r = 0; r < tsize; r++) {
137+
rank_id = GET_RANK_ID(team->ids, r, max_concurrent);
138+
if (ucc_tl_cuda_topo_device_id_equal(&rank_id->pci_id,
139+
&TL_CUDA_DEVICE_INVALID)) {
140+
tl_debug(tl_lib, "rank %d device is invalid, team can't be created",
141+
r);
142+
team->state = TL_CUDA_STATE_ERROR;
143+
return UCC_ERR_NO_RESOURCE;
144+
}
145+
}
130146

131147
status = ucc_tl_cuda_team_topo_create(&team->super, &team->topo);
132148
if (status != UCC_OK) {
@@ -230,10 +246,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_team_t, ucc_base_context_t *tl_context,
230246
ucc_tl_cuda_rank_id_t *rank_id;
231247

232248
UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);
233-
self->oob = params->params.oob;
234-
self->stream = NULL;
235-
self->topo = NULL;
236-
self->scratch.loc = NULL;
249+
self->oob = params->params.oob;
250+
self->stream = NULL;
251+
self->topo = NULL;
252+
self->device = -1;
253+
memset(&self->scratch, 0, sizeof(ucc_tl_cuda_scratch_t));
237254

238255
if (!ucc_team_map_is_single_node(params->team, params->map)) {
239256
tl_debug(tl_context->lib, "multinode team is not supported");

src/components/tl/cuda/tl_cuda_topo.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ pthread_mutex_t nvml_lock = PTHREAD_MUTEX_INITIALIZER;
2525
} \
2626
} while (0)
2727

28+
const ucc_tl_cuda_device_pci_id_t TL_CUDA_DEVICE_INVALID = {
29+
.domain = 0xFFFF,
30+
.bus = 0xFF,
31+
.device = 0xFF,
32+
.function = 0xFF,
33+
};
34+
2835
static ucc_status_t
2936
ucc_tl_cuda_topo_pci_id_from_str(const char * bus_id_str,
3037
ucc_tl_cuda_device_pci_id_t *pci_id)

src/components/tl/cuda/tl_cuda_topo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ typedef struct ucc_tl_cuda_device_id {
1919
uint8_t function; /* range: 0 to 7 */
2020
} ucc_tl_cuda_device_pci_id_t;
2121

22+
extern const ucc_tl_cuda_device_pci_id_t TL_CUDA_DEVICE_INVALID;
23+
2224
typedef enum ucc_tl_cuda_topo_dev_type {
2325
UCC_TL_CUDA_TOPO_DEV_TYPE_GPU,
2426
UCC_TL_CUDA_TOPO_DEV_TYPE_SWITCH,

src/components/tl/nccl/tl_nccl_team.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
6060
UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);
6161

6262
size = UCC_TL_TEAM_SIZE(self);
63+
self->nccl_comm = NULL;
6364
self->comm_state = TL_NCCL_COMM_STATE_INIT;
6465
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
6566
"tl_nccl_unique_id");

0 commit comments

Comments
 (0)