Skip to content

Commit 8bc7db1

Browse files
committed
UCT plugin: Separate common code in library file
1 parent 3d2d04b commit 8bc7db1

File tree

4 files changed

+1234
-1123
lines changed

4 files changed

+1234
-1123
lines changed

include/ucx_uct_lib.h

+291
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/*************************************************************************
2+
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE.txt for license information
5+
************************************************************************/
6+
7+
#ifndef NCCL_UCX_UCT_LIB_H_
8+
#define NCCL_UCX_UCT_LIB_H_
9+
10+
#include <assert.h>
11+
#include <stdint.h>
12+
#include <unistd.h>
13+
14+
#include "p2p_plugin.h"
15+
#include "socket.h"
16+
17+
#include <uct/api/uct.h>
18+
19+
#define NCCL_UCX_UCT_MAX_RECVS NCCL_NET_IB_MAX_RECVS
20+
#define NCCL_UCT_LISTEN_HANDLE_MAGIC 0x43cf19ed91abdb85
21+
#define NCCL_UCT_REG_ALIGN 4096
22+
23+
typedef enum {
24+
NCCL_UCT_START = 0,
25+
NCCL_UCT_CONNECT,
26+
NCCL_UCT_ACCEPT,
27+
NCCL_UCT_RECEIVE_REMOTE, /* Acceptor receives ep addr/remote communicator */
28+
NCCL_UCT_RECEIVE_ADDR,
29+
NCCL_UCT_RX_READY,
30+
NCCL_UCT_DONE
31+
} nccl_uct_state_t;
32+
33+
/* UCT EP address to exchange and connect to */
34+
typedef struct {
35+
uint8_t dev_addr_size;
36+
uint8_t ep_addr_size;
37+
uint8_t data[64];
38+
} nccl_uct_ep_addr_t;
39+
40+
typedef struct {
41+
uct_iface_h iface;
42+
uct_md_h md;
43+
uct_component_h comp;
44+
void *addr;
45+
size_t addr_size;
46+
void *dev_addr;
47+
size_t dev_addr_size;
48+
size_t ep_addr_size;
49+
size_t rkey_packed_size;
50+
51+
size_t am_max_short;
52+
size_t min_get_zcopy;
53+
} nccl_uct_iface_t;
54+
55+
struct nccl_uct_context;
56+
57+
typedef struct nccl_uct_worker {
58+
struct nccl_uct_worker *next;
59+
struct {
60+
pthread_t thread;
61+
int dev;
62+
} id;
63+
64+
int count;
65+
ucs_async_context_t *async;
66+
uct_worker_h worker;
67+
nccl_uct_iface_t *uct_iface;
68+
struct nccl_uct_context *context;
69+
} nccl_uct_worker_t;
70+
71+
typedef struct {
72+
uct_ep_h ep;
73+
uct_ep_addr_t *addr;
74+
size_t addr_size;
75+
nccl_uct_iface_t *uct_iface;
76+
uint8_t data[];
77+
} nccl_uct_ep_t;
78+
79+
/* All the remote addresses for the communicator */
80+
typedef struct nccl_uct_comm_addr {
81+
nccl_uct_ep_addr_t rma;
82+
/* TODO: Add multi-QP here */
83+
} nccl_uct_comm_addr_t;
84+
85+
/* Either Receiver or Sender communicator, connected to one peer */
86+
typedef struct nccl_uct_comm {
87+
struct ncclSocket sock;
88+
struct nccl_uct_context *context;
89+
int dev;
90+
91+
nccl_uct_worker_t *uct_worker;
92+
nccl_uct_iface_t *uct_iface;
93+
nccl_uct_ep_t *uct_ep;
94+
95+
struct nccl_uct_comm_remote {
96+
nccl_uct_comm_addr_t addr; /* Remote addresses */
97+
const struct nccl_uct_comm *comm; /* Cookie received in connect */
98+
} remote;
99+
100+
/* Local GET on current device */
101+
struct {
102+
int enabled;
103+
nccl_uct_ep_t *uct_ep; /* Locally read from HCA */
104+
nccl_uct_ep_addr_t addr;
105+
106+
uint8_t *mem; /* Dummy memory to read into */
107+
uct_mem_h memh;
108+
} gpu_flush;
109+
} nccl_uct_comm_t;
110+
111+
/* State tracking used while connecting/accepting only */
112+
typedef struct {
113+
nccl_uct_state_t state;
114+
nccl_uct_comm_t *comm; /* current communicator being created */
115+
int offset; /* for Socket reading */
116+
int ready; /* accept must complete after connect */
117+
} nccl_uct_stage_t;
118+
119+
/* Memory registration handle in NCCL UCT plugin returned by ->regMR() */
120+
typedef struct {
121+
uct_mem_h memh;
122+
nccl_uct_comm_t *comm;
123+
uct_rkey_bundle_t bundle;
124+
uint8_t rkey[];
125+
} nccl_uct_memh_t;
126+
127+
/* On-the-wire handle passed OOB by NCCL from listener to connector */
128+
typedef struct {
129+
uint64_t magic;
130+
struct {
131+
union ncclSocketAddress addr;
132+
uint32_t id;
133+
} listener;
134+
nccl_uct_comm_t *comm; /* Created communicator in accept */
135+
nccl_uct_stage_t stage; /* Used by connector */
136+
} nccl_uct_listen_handle_t;
137+
138+
/* Communicator while listening to remote ranks */
139+
typedef struct {
140+
struct ncclSocket sock;
141+
struct nccl_uct_context *context;
142+
int dev;
143+
uint32_t id;
144+
nccl_uct_worker_t *uct_worker;
145+
nccl_uct_comm_t *comm;
146+
147+
/* Used by acceptor */
148+
nccl_uct_stage_t stage;
149+
} nccl_uct_listen_comm_t;
150+
151+
/* Global state of the plugin */
152+
typedef struct nccl_uct_context {
153+
/* Transport to use */
154+
const char *tl_name;
155+
156+
/* IB devices available */
157+
int dev_count;
158+
159+
/* Use by common code to setup communicators */
160+
struct nccl_uct_ops {
161+
ncclResult_t (*comm_alloc)(nccl_uct_comm_t **comm);
162+
ncclResult_t (*comm_init)(nccl_uct_comm_t *comm,
163+
struct nccl_uct_context *context,
164+
nccl_uct_worker_t *worker, int dev,
165+
const nccl_uct_comm_t *remote_comm);
166+
ncclResult_t (*iface_set)(nccl_uct_iface_t *uct_iface);
167+
} ops;
168+
169+
/* Max sizes needed */
170+
size_t am_short_size;
171+
size_t rkey_size;
172+
173+
/* OOB socket for accepting/connecting */
174+
char if_name[MAX_IF_NAME_SIZE];
175+
union ncclSocketAddress if_addr;
176+
177+
/* Number of listener created */
178+
uint32_t listener_count;
179+
180+
/* List of created workers */
181+
nccl_uct_worker_t *worker_list;
182+
} nccl_uct_context_t;
183+
184+
#define UCXCHECK(statement, failure_action, message, ...) \
185+
do { \
186+
ucs_status_t _status = statement; \
187+
if (_status != UCS_OK) { \
188+
WARN("Failed: " message ": %s", ##__VA_ARGS__, \
189+
ucs_status_string(_status)); \
190+
failure_action; \
191+
} \
192+
} while (0)
193+
194+
extern nccl_uct_context_t context;
195+
196+
/* Library functions */
197+
ncclResult_t nccl_uct_iface_set_handler(nccl_uct_iface_t *uct_iface, int id,
198+
uct_am_callback_t callback);
199+
ncclResult_t nccl_uct_devices(int *ndev);
200+
ncclResult_t nccl_uct_comm_init(nccl_uct_comm_t *comm,
201+
nccl_uct_context_t *context,
202+
nccl_uct_worker_t *worker, int dev,
203+
const nccl_uct_comm_t *remote_comm);
204+
void nccl_uct_comm_deinit(nccl_uct_comm_t *comm);
205+
int nccl_uct_flush_index(nccl_uct_comm_t *base, int *sizes, int n);
206+
ncclResult_t nccl_uct_flush(nccl_uct_comm_t *base_comm, void *data, int size,
207+
nccl_uct_memh_t *uct_memh,
208+
uct_completion_t *completion, void **request);
209+
210+
/* NCCL common plugin callbacks */
211+
ncclResult_t nccl_uct_listen(int dev, void *listen_handle, void **listen_comm);
212+
ncclResult_t nccl_uct_accept(void *listen_comm, void **recv_comm,
213+
ncclNetDeviceHandle_v7_t **recvDevComm);
214+
ncclResult_t nccl_uct_connect(int dev, void *listen_handle, void **send_comm,
215+
ncclNetDeviceHandle_t **sendDevComm);
216+
ncclResult_t nccl_uct_close_listen(void *listen_comm);
217+
ncclResult_t nccl_uct_reg_mr_dmabuf(void *reg_comm, void *data, size_t size,
218+
int type, uint64_t offset, int fd,
219+
void **mhandle);
220+
ncclResult_t nccl_uct_reg_mr(void *reg_comm, void *data, size_t size, int type,
221+
void **mhandle);
222+
ncclResult_t nccl_uct_dereg_mr(void *dereg_comm, void *mhandle);
223+
224+
/* Compatibility callback */
225+
ncclResult_t nccl_uct_get_properties_v7(int dev,
226+
ncclNetProperties_v7_t *props_v7);
227+
ncclResult_t nccl_uct_reg_mr_v7(void *comm, void *data, int size, int type,
228+
void **mhandle);
229+
ncclResult_t nccl_uct_get_properties_v6(int dev,
230+
ncclNetProperties_v6_t *props_v6);
231+
ncclResult_t nccl_uct_connect_v6(int dev, void *handle, void **send_comm);
232+
ncclResult_t nccl_uct_accept_v6(void *listen_comm, void **recv_comm);
233+
ncclResult_t nccl_uct_get_properties(int dev, ncclNetProperties_t *props);
234+
235+
236+
#define NCCL_UCT_PLUGIN_BASE(plugin_name, prefix, get_properties_func, \
237+
connect_func, accept_func, reg_mr_func) \
238+
{ \
239+
.name = plugin_name, \
240+
.init = prefix##_init, \
241+
.devices = nccl_uct_devices, \
242+
.getProperties = get_properties_func, \
243+
.listen = nccl_uct_listen, \
244+
.connect = connect_func, \
245+
.accept = accept_func, \
246+
.regMr = reg_mr_func, \
247+
.regMrDmaBuf = nccl_uct_reg_mr_dmabuf, \
248+
.deregMr = nccl_uct_dereg_mr, \
249+
.isend = prefix##_isend, \
250+
.irecv = prefix##_irecv, \
251+
.iflush = prefix##_iflush, \
252+
.test = prefix##_test, \
253+
.closeSend = prefix##_close, \
254+
.closeRecv = prefix##_close, \
255+
.closeListen = nccl_uct_close_listen \
256+
}
257+
258+
#define NCCL_UCT_PLUGIN_V8(plugin_name, prefix) \
259+
NCCL_UCT_PLUGIN_BASE(plugin_name, prefix, nccl_uct_get_properties, \
260+
nccl_uct_connect, nccl_uct_accept, nccl_uct_reg_mr)
261+
262+
#define NCCL_UCT_PLUGIN_V7(plugin_name, prefix) \
263+
NCCL_UCT_PLUGIN_BASE(plugin_name, prefix, nccl_uct_get_properties_v7, \
264+
nccl_uct_connect, nccl_uct_accept, nccl_uct_reg_mr_v7)
265+
266+
#define NCCL_UCT_PLUGIN_V6(plugin_name, prefix) \
267+
NCCL_UCT_PLUGIN_BASE(plugin_name, prefix, nccl_uct_get_properties_v6, \
268+
nccl_uct_connect_v6, nccl_uct_accept_v6, \
269+
nccl_uct_reg_mr_v7)
270+
271+
#define NCCL_UCT_PLUGIN_V5(plugin_name, prefix) \
272+
{ \
273+
.name = plugin_name, \
274+
.init = prefix##_init, \
275+
.devices = nccl_uct_devices, \
276+
.getProperties = nccl_uct_get_properties_v6, \
277+
.listen = nccl_uct_listen, \
278+
.connect = nccl_uct_connect_v6, \
279+
.accept = nccl_uct_accept_v6, \
280+
.regMr = nccl_uct_reg_mr_v7, \
281+
.deregMr = nccl_uct_dereg_mr, \
282+
.isend = prefix##_isend, \
283+
.irecv = prefix##_irecv, \
284+
.iflush = prefix##_iflush, \
285+
.test = prefix##_test, \
286+
.closeSend = prefix##_close, \
287+
.closeRecv = prefix##_close, \
288+
.closeListen = nccl_uct_close_listen \
289+
}
290+
291+
#endif /* NCCL_UCX_UCT_LIB_H_ */

src/Makefile.am

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ libnccl_net_la_LDFLAGS += $(UCX_LDFLAGS)
2525
libnccl_net_la_SOURCES += \
2626
ucx_plugin.c \
2727
ucx_rma_plugin.c \
28+
ucx_uct_lib.c \
2829
ucx_uct_plugin.c
2930
endif
3031

0 commit comments

Comments
 (0)