From 5d2be02e892292dc4cb85892400498b0cc655745 Mon Sep 17 00:00:00 2001 From: Devendar Bureddy Date: Fri, 22 Mar 2024 10:24:49 -0700 Subject: [PATCH] options to disable RS and AG --- src/sharp_plugin.c | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/sharp_plugin.c b/src/sharp_plugin.c index 13820c9d..d37e8b2e 100644 --- a/src/sharp_plugin.c +++ b/src/sharp_plugin.c @@ -25,11 +25,15 @@ extern ncclNet_v7_t ncclNetPlugin_v7; extern ncclNet_v6_t ncclNetPlugin_v6; extern ncclNet_v5_t ncclNetPlugin_v5; +extern ncclCollNet_v8_t ncclCollNetPlugin_v8; + int ncclNSharpDevs = -1; struct sharp_coll_caps sharp_caps; static int ncclSharpV3DatatypesSupported = 0; NCCL_PARAM(SharpGroupSizeThresh, "SHARP_GROUP_SIZE_THRESH", 2); NCCL_PARAM(SharpV3Datatypes, "SHARP_V3_DATATYPES", 2); +NCCL_PARAM(SharpDisableRS, "SHARP_DISABLE_REDUCE_SCATTER", 0); +NCCL_PARAM(SharpDisableAG, "SHARP_DISABLE_ALLGATHER", 0); enum ncclSharpRequestType { NCCL_SHARP_REQ_SHARP_COLL, @@ -204,7 +208,16 @@ ncclResult_t ncclSharpInit(ncclDebugLogger_t logFunction) { setenv("SHARP_COLL_LOCK_ON_COMM_INIT", "1", 0); setenv("SHARP_COLL_LOG_LEVEL", "3", 0); - return ncclNetPlugin_v7.init(logFunction); + if (ncclParamSharpDisableAG()) { + INFO(NCCL_NET, "Disabled SHARP Allgather"); + ncclCollNetPlugin_v8.iallgather = NULL; + } + if (ncclParamSharpDisableRS()) { + INFO(NCCL_NET, "Disabled SHARP reduce-scatter"); + ncclCollNetPlugin_v8.ireducescatter = NULL; + } + + return ncclNetPlugin_v8.init(logFunction); } ncclResult_t ncclSharpDevices(int* ndev) {