diff --git a/3rdparty/onednn b/3rdparty/onednn index f40443c41342..e2d45252ae9c 160000 --- a/3rdparty/onednn +++ b/3rdparty/onednn @@ -1 +1 @@ -Subproject commit f40443c413429c29570acd6cf5e3d1343cf647b4 +Subproject commit e2d45252ae9c3e91671339579e3c0f0061f81d49 diff --git a/cpp-package/include/mxnet-cpp/executor.h b/cpp-package/include/mxnet-cpp/executor.h index fff559b79df3..bb602fc282c9 100644 --- a/cpp-package/include/mxnet-cpp/executor.h +++ b/cpp-package/include/mxnet-cpp/executor.h @@ -132,7 +132,9 @@ class Executor { 1, nullptr, nullptr), - 0); + 0, + nullptr, + nullptr); } else { CHECK_EQ(MXAutogradBackwardEx(out_handles.size(), out_handles.data(), @@ -144,7 +146,10 @@ class Executor { 1, nullptr, nullptr), - 0); + 0, + 0, + nullptr, + nullptr); } grad_arrays.clear(); grad_arrays.reserve(arg_arrays.size()); diff --git a/example/extensions/lib_pass/pass_lib.cc b/example/extensions/lib_pass/pass_lib.cc index f441877fcad7..d219299ac9ad 100644 --- a/example/extensions/lib_pass/pass_lib.cc +++ b/example/extensions/lib_pass/pass_lib.cc @@ -49,4 +49,4 @@ MXReturnValue initialize(int version) { MX_ERROR_MSG << "MXNet version " << version << " not supported" << std::endl; return MX_FAIL; } -} +} \ No newline at end of file diff --git a/example/extensions/lib_pass/test_pass.py b/example/extensions/lib_pass/test_pass.py index ab89f9566ebe..3fa1fd22b891 100644 --- a/example/extensions/lib_pass/test_pass.py +++ b/example/extensions/lib_pass/test_pass.py @@ -67,4 +67,4 @@ def test_model(pass_name): sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend=pass_name) sym_block2.export('modified') -test_model('myPass') +test_model('myPass') \ No newline at end of file diff --git a/example/extensions/lib_reduce_gradient/Makefile b/example/extensions/lib_reduce_gradient/Makefile new file mode 100644 index 000000000000..95b2fdc4d9bc --- /dev/null +++ b/example/extensions/lib_reduce_gradient/Makefile @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +all: pass_lib + +pass_lib: + g++ -shared -fPIC -std=c++11 add_reduce_op.cc ../../../src/lib_api.cc -o add_reduce_op_lib.so -I ../../../include + +clean: + rm -rf libpass_lib.so diff --git a/example/extensions/lib_reduce_gradient/README.md b/example/extensions/lib_reduce_gradient/README.md new file mode 100644 index 000000000000..bcc3617cb8e2 --- /dev/null +++ b/example/extensions/lib_reduce_gradient/README.md @@ -0,0 +1,76 @@ + + +Add Reduce operation to computation Graph +======================================= + +## Introduction +This is the part of work of transferring [DeepSpeed's work](https://arxiv.org/abs/1910.02054) into MXNet. +Since the difference between symbolic and imperative, we divide the whole proecss into two phases: + +phase 1: Add reduce operation into graph. The reduce operation will do nothing +in forward but reduce the gradient to the right GPU(according to POS-trainer). + +phase2: In backward graph, delete the outputs in arrays so the memory planner can reuse such memory. + + ## Getting start + ### Prepare NCCL and horovod + Since we use horovod to communicate, please firstly install horovod. And we use NCCL reduce, please also install it. + + ### Complie the Graph Pass and load + Please firstly compile it like [lib pass](../lib_pass/). Run `make` and it will generate dynamic library + **add_reduce_op_lib.so** which is compiled from the `add_reduce_op.cc` file. Then load such file in your python code like +```python +import mxnet as mx +mx.library.load('add_reduce_op_lib.so') +``` + + ### Prepare options + Then we need know the correct partition of parameters and gradients about their GPUs. + So please use **POS_Trainer** from `pos_trainer.py` like normal trainer in MXNet. + ```python +from pos_trainer import POS_Trainer +trainer = POS_Trainer(params_dict, "adam", optimizer_params) +``` +Then trainer can generate corresponding options like: + ```python +options = trainer.generate_graph_pass_options() +backward_options = trainer.generate_backward_options() +``` +### modify graph +Before forward, we use + ```python +model.optimize_for(x, backend = "add_reduce_op", **options) +``` +to insert reduce operation into graphs. +![example add reduce](addreduce.png) + +Then we call backward option as + ```python +loss.backward(backward_option = backward_options) +``` +### Simple Example +Please see `test_reduce.py` + +### Current problem +1. The reduce operation will cause deadlock (it won't happen in NaiveEngine). Moreover, it will meet invalid address +problem in complex model like Bert-Base. +2. We do remove outputs from backward graph using backward option. But we need to verify whether it decrease the memory +consumption. diff --git a/example/extensions/lib_reduce_gradient/add_reduce_op.cc b/example/extensions/lib_reduce_gradient/add_reduce_op.cc new file mode 100644 index 000000000000..8faad796029a --- /dev/null +++ b/example/extensions/lib_reduce_gradient/add_reduce_op.cc @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file subgraph_lib.cc + * \brief subgraph operator implementation library file + */ + +#include +#include +#include +#include +#include "mxnet/lib_api.h" + +using namespace mxnet::ext; + + + +MXReturnValue add_reduce_op(mxnet::ext::Graph* g, + const std::unordered_map& options) { + std::string cur_rank = ""; + + std::string num_gpus = ""; + std::string nccl_unique_id = ""; + + for (auto kv : options) { + std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; + if (kv.first == "rank") + { + cur_rank = kv.second.c_str(); + } + if (kv.first == "nccl_unique_id") + nccl_unique_id = kv.second.c_str(); + if (kv.first == "num_gpus") + num_gpus = kv.second.c_str(); + } + size_t length = g->size(); + mxnet::ext::Node *tmp; + std::string root_rank; + mxnet::ext::Node *target_node; + int index = 0; + for (int i = 0;i < length; i += 1) + { + target_node = g->getNode(i); + //std::cout<<"deal with:" << target_node->name<name); + if (it == options.end()) {continue;} // req_grad == null + root_rank = it->second; + mxnet::ext::Node *new_reduce = g->addNode("ncclreduce_" + target_node->name,"_contrib_NCCLReduce"); + index += 1; + auto new_attrs = &new_reduce->attrs; + auto old_attrs = target_node->attrs; + for (auto it = old_attrs.begin(); it!=old_attrs.end(); it++) + { + if (it->first == "__ext_dtype__" || it->first == "__ext_shape__" || it->first == "__profiler_scope__") + { + new_attrs ->insert({{it->first, it->second}}); + } + } + new_attrs->insert({{"nccl_unique_id", nccl_unique_id}}); + new_attrs->insert({{"num_gpus", num_gpus}}); + new_attrs->insert({{"rank", cur_rank}}); + new_attrs->insert({{"root_rank", root_rank}}); + + for (int i=0;ioutputs.size(); i++) + { + new_reduce->outputs.push_back(target_node->outputs[i]); + mxnet::ext::Node *output_node = target_node->outputs[i].node; + int index = target_node->outputs[i].entry; + //std::cout<<"try change:"<name<<":"<inputs.size()<inputs[index].node = new_reduce; + } + for (int i=0;ioutputs.size(); i++) + { + target_node->outputs.pop_back(); + } + target_node->outputs.push_back({new_reduce, 0}); + new_reduce->inputs.push_back({target_node, 0}); + + } + g->print(); + + + return MX_SUCCESS; +} + + + +REGISTER_PASS(add_reduce_op).setBody(add_reduce_op); + +MXReturnValue initialize(int version) { + if (version >= 10700) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return MX_SUCCESS; + } else { + MX_ERROR_MSG << "MXNet version " << version << " not supported" << std::endl; + return MX_FAIL; + } +} diff --git a/example/extensions/lib_reduce_gradient/addreduce.png b/example/extensions/lib_reduce_gradient/addreduce.png new file mode 100644 index 000000000000..40338b972d45 Binary files /dev/null and b/example/extensions/lib_reduce_gradient/addreduce.png differ diff --git a/example/extensions/lib_reduce_gradient/pos_trainer.py b/example/extensions/lib_reduce_gradient/pos_trainer.py new file mode 100644 index 000000000000..eb22e2948973 --- /dev/null +++ b/example/extensions/lib_reduce_gradient/pos_trainer.py @@ -0,0 +1,196 @@ +# POS_Trainer is the stage one : partition optmizer status in DeepSpeed's work +# It can reduce memory consumption in distributed Trainer but slower +# since we can not solve overlapping problem when calling broadcast and optimize parameters. +# The usage of this trainer is totally same with original one +# I test some benchmark Here: +# For 4 V100 Gpu with 16GB memory, the maximum batch size for bert-large and bert-base: +# bert-large: Original: 16 Pos: 24 +# bert-base: Original: 64 Pos: 80 +# The ideal average saving memory for each GPU is: (N-1)/N * P * K +# where N is the GPU number, P is the parameter number and K is the memory +# multiplier of optimizer states(E.g. for Adam, K = 12) +#TODO add group_num +from horovod.mxnet.mpi_ops import allreduce, allreduce_ +from horovod.mxnet.mpi_ops import broadcast, broadcast_ +from horovod.mxnet.mpi_ops import init, shutdown +from horovod.mxnet.mpi_ops import size, local_size, rank, local_rank +from mxnet.base import _LIB, check_call, mx_uint, c_str, c_str_array, SymbolHandle + +import mxnet as mx +from collections import OrderedDict, defaultdict +import types +import time +import warnings +from mxnet.gluon.parameter import Parameter +from horovod.mxnet.mpi_ops import ProcessSet, global_process_set, add_process_set, remove_process_set + +class _NCCLReduceHelper(object): + _init = False + nccl_id = None + num_gpus = None + rank = None + + @staticmethod + def init(num_gpus, root_rank): + """Communicate the NCCL unique id""" + cls = _NCCLReduceHelper + if not cls._init: + cls._init = True + import ctypes + try: + from mpi4py import MPI + except: + raise ImportError("Spatial parallel modules require mpi4py package.") + import numpy as np + nccl_id_size = ctypes.c_int() + check_call(_LIB.MXNCCLGetUniqueIdSize(ctypes.byref(nccl_id_size))) + nccl_id_size = nccl_id_size.value + cls.nccl_id = np.zeros(nccl_id_size, np.byte) + check_call(_LIB.MXNCCLGetUniqueId( + cls.nccl_id.ctypes.data_as(ctypes.c_void_p))) + global_comm = MPI.COMM_WORLD + rank = global_comm.rank + color = rank / num_gpus + comm = global_comm.Split(color, rank) + comm.Bcast([cls.nccl_id, nccl_id_size, MPI.BYTE], root=0) + cls.num_gpus = num_gpus + cls.rank = rank % num_gpus + cls.root_rank = root_rank % num_gpus + assert num_gpus == cls.num_gpus + + +class POS_Trainer(mx.gluon.Trainer): + def __init__(self, params, optimizer, optimizer_params=None, + gradient_predivide_factor=1.0, prefix=None, partition_gradients = False): + + self._world_size = size() + self._world_rank = rank() + + self._partition_gradients = partition_gradients + + self._all_params = [] + self._all_param2idx = {} + self._all_params_with_names = params + param_list = [] + if isinstance(params, (dict, OrderedDict)): + for key in sorted(list(params.keys())): + param_list.append(params[key]) + params = param_list + if not isinstance(params, (list, tuple)): + raise ValueError( + "First argument must be a list or dict of Parameters, " \ + "got %s." % (type(params))) + for i, param in enumerate(params): + if not isinstance(param, Parameter): + raise ValueError( + "First argument must be a list or dict of Parameters, " \ + "got list of %s." % (type(param))) + if param._uuid in self._all_param2idx: + # Shared parameters have same uuid; only need to store one of the shared versions + continue + self._all_param2idx[param._uuid] = i + self._all_params.append(param) + self._partition_params, self._param2rank = self._partition_parameters(self._all_params) + self._own_part = self._partition_params[self._world_rank] + super(POS_Trainer, self).__init__( + self._own_part, optimizer, optimizer_params=optimizer_params, kvstore=None) + self._prefix = prefix if prefix else "" + self._scale = gradient_predivide_factor / size() + self._gradient_predivide_factor = gradient_predivide_factor + + + + def _partition_parameters(self, params): + """ + partition all the parameters by their size and try to average them. + """ + world_size = self._world_size + ## list for rank each would be + partition_params = [[] for _ in range(world_size)] + param2rank = {} + sizes = [0 for _ in range(world_size)] + for param in params: + if param.grad_req != 'null': + current_rank = sizes.index(min(sizes)) + partition_params[current_rank].append(param) + num = 1 + param2rank[param._uuid] = current_rank + for p in param.shape: + num *= p + sizes[current_rank] += num + return partition_params, param2rank + + def _allreduce_grads(self): + """ + rewrite allreduce here because we need to communicate using horovod. + Actually we should use reduce here, but since it is not available yet, + I use allreduce instead. + """ + if not self._partition_gradients: + for i, param in enumerate(self._all_params): + if param.grad_req != 'null': + allreduce_(param.list_grad()[0], average=False, + name=self._prefix + str(i), priority=-i, + prescale_factor=1.0 / self._gradient_predivide_factor) + + + + + + def step(self, batch_size, ignore_stale_grad=False): + """ + inherit from trainer, only call boardcast to make sure all parameter are consistent + Makes one step of parameter update. + Since each process main their own part, we need to brodcast after calculation + """ + super(POS_Trainer, self).step(batch_size, ignore_stale_grad) + self._broadcast_partition_params() + + if not self._kv_initialized: + self._init_kvstore() + if self._params_to_init: + self._init_params() + + def update(self, batch_size, ignore_stale_grad=False): + ''' + assert not (self._kvstore and self._update_on_kvstore), \ + 'update() when parameters are updated on kvstore ' \ + 'is not supported. Try setting `update_on_kvstore` ' \ + 'to False when creating trainer.' + Since each process main their own part, we need to brodcast after calculation + ''' + + + super(POS_Trainer, self).update(batch_size, ignore_stale_grad) + self._broadcast_partition_params() + + def _broadcast_partition_params(self): + """ + This function is to broadcast parameter since each process will maintain their own part + """ + for param in self._all_params: + broadcast_(param.data(), self._param2rank[param._uuid], name=str(self._all_param2idx[param._uuid])) + + def generate_graph_pass_options(self): + #Generate options for graph pass, key is parameter name and value is its rank + options = {} + for name in self._all_params_with_names: + type = name.split('.')[-1] + index = self._param2rank[self._all_params_with_names[name]._uuid] + new_name = self._all_params_with_names[name]._uuid.replace('-', '_') + '_' + type + options[new_name] = index + + helper = _NCCLReduceHelper + helper.init(size(), 0) + options.update({"num_gpus": size(), "rank": rank(), "nccl_unique_id":helper.nccl_id.ctypes.data}) + return options + + def generate_backward_options(self): + #generate backward option for deleting, key is the node name and value is its corresponding rank + backward_option = {"partition_grad":True, "current_rank":rank()} + for name in self._all_params_with_names: + type = name.split('.')[-1] + index = self._param2rank[self._all_params_with_names[name]._uuid] + new_name = 'ncclreduce_' + self._all_params_with_names[name]._uuid.replace('-', '_') + '_' + type + "_backward" + backward_option[new_name] = index + return backward_option diff --git a/example/extensions/lib_reduce_gradient/test_reduce.py b/example/extensions/lib_reduce_gradient/test_reduce.py new file mode 100644 index 000000000000..b8e40580f0f0 --- /dev/null +++ b/example/extensions/lib_reduce_gradient/test_reduce.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=arguments-differ + +# This test checks if dynamic loading of library into MXNet is successful +# and checks the end of end computation of custom operator + +import os, ctypes +import mxnet as mx +import time +from mxnet.gluon import nn +from mxnet import nd +import numpy as np +from mxnet.lr_scheduler import PolyScheduler +from mxnet import np, npx +from pos_trainer import POS_Trainer +try: + import horovod.mxnet as hvd +except ImportError: + pass +from mxnet.base import _LIB, check_call, mx_uint, c_str, c_str_array, SymbolHandle + +# load library +if (os.name=='posix'): + path = os.path.abspath('add_reduce_op_lib.so') + mx.library.load(path) +elif (os.name=='nt'): + path = os.path.abspath('add_reduce_op_lib.dll') + mx.library.load(path) + + +class Easynet(nn.HybridBlock): + def __init__(self, n): + super().__init__() + self.ls = nn.HybridSequential() + for i in range(n): + self.ls.add(nn.Dense(in_units=2, units=2, flatten=False)) + + + + def forward(self, input): + input = self.ls(input) + return input + + +def test_model(): + from mxnet import gluon + from mxnet.gluon import Block, nn, HybridBlock + from mxnet import init + + + + hvd.init() + rank = hvd.rank() + size = hvd.size() + ctx = mx.gpu(rank) + + np.random.seed(1234 + 10 * rank) + mx.random.seed(1234 + 10 * rank) + + + number = 2 + model = Easynet(number) + + if rank == 0: + for i in range(number): + model.ls[i].weight.initialize(init=init.One(), ctx=ctx) + model.ls[i].bias.initialize(init=init.One(), ctx=ctx) + else: + for i in range(number): + model.ls[i].weight.initialize(init=init.Zero(), ctx=ctx) + model.ls[i].bias.initialize(init=init.Zero(), ctx=ctx) + + model.hybridize() + + params = model.collect_params() + lr_scheduler = PolyScheduler(max_update=1, + base_lr=1e-3, + warmup_begin_lr=0.0, + pwr=1, + final_lr=0.0, + warmup_steps=0, + warmup_mode='linear') + optimizer_params = {'learning_rate': 1e-3, + 'wd': 1e-2, + 'lr_scheduler': lr_scheduler} + trainer = POS_Trainer(params, "adam", optimizer_params) + + options = trainer.generate_graph_pass_options() + backward_options = trainer.generate_backward_options() + x = np.ones((1,2), ctx = ctx) + label = np.ones((2, ), ctx = ctx) * rank + #print(options) + + loss_function = gluon.loss.L2Loss() + + model.optimize_for(x, backend = "add_reduce_op", **options) + for i in range(1): + with mx.autograd.record(): + out = model(x) + loss = loss_function(out, label).mean() / size + loss.backward(backward_option = backward_options) + mx.npx.waitall() + mx.nd.waitall() + + for name in params: + print(name, params[name].list_grad()[0]) + print('finish') + + +test_model() diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4210a9fa63d4..7b38d83c9b31 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -515,6 +515,15 @@ MXNET_DLL int MXSetNumOMPThreads(int thread_num); * \param bulk_size new bulk_size * \param prev_bulk_size previous bulk_size */ + + /*! + * \brief Get the compute capability of a given GPU. + * \param dev the GPU number to query + * \param out pointer to integer that will hold the compute capability of the queried GPU. + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXGetGPUSMArch(int dev, int* out); + MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size); /*! @@ -524,6 +533,14 @@ MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size); */ MXNET_DLL int MXGetGPUCount(int* out); +/*! + * \brief Get the compute capability of a given GPU. + * \param dev the GPU number to query + * \param out pointer to integer that will hold the compute capability of the queried GPU. + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXGetGPUSMArch(int dev, int* out); + /*! * \brief get the free and total available memory on a GPU * Note: Deprecated, use MXGetGPUMemoryInformation64 instead. @@ -534,6 +551,22 @@ MXNET_DLL int MXGetGPUCount(int* out); */ MXNET_DLL int MXGetGPUMemoryInformation(int dev, int* free_mem, int* total_mem); +/*! + * \brief Get the size of the NCCL unique id (in bytes). + * \param size pointer to integer that will hold the NCCL unique id size. + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXNCCLGetUniqueIdSize(int* size); + +/*! + * \brief Get the NCCL unique id. + * \param out pointer to an array that will hold the NCCL unique id. It has to be at least of the + * size returned by MXNCCLGetUniqueIdSize. + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXNCCLGetUniqueId(void* out); + + /*! * \brief get the free and total available memory on a GPU * \param dev the GPU number to query @@ -1309,7 +1342,10 @@ MXNET_DLL int MXAutogradBackwardEx(uint32_t num_output, int create_graph, int is_train, NDArrayHandle** grad_handles, - int** grad_stypes); + int** grad_stypes, + const mx_uint num_options, + const char** keys, + const char** vals); /* * \brief get the graph constructed by autograd. * \param handle ndarray handle diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index e4e3f6a938d0..f49637dbf263 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -279,7 +279,8 @@ class Imperative { const std::vector& variables, bool is_train, bool retain_graph, - bool create_graph); + bool create_graph, + const std::unordered_map backward_options_map = {}); /*! \brief Return the marked nonleaf nodes. */ std::vector ListNonleafVariables(const nnvm::Symbol& sym) const; /*! \return AutogradRuntime singleton */ diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h index b856002cb76f..038d7c20fcb2 100644 --- a/include/mxnet/resource.h +++ b/include/mxnet/resource.h @@ -39,16 +39,22 @@ struct ResourceRequest { /*! \brief Resource type, indicating what the pointer type is */ enum Type { /*! \brief mshadow::Random object */ - kRandom, + kRandom = 0, /*! \brief A dynamic temp space that can be arbitrary size */ - kTempSpace, + kTempSpace = 1, /*! \brief common::RandGenerator object, which can be used in GPU kernel functions */ - kParallelRandom + kParallelRandom = 2 #if MXNET_USE_CUDNN == 1 , /*! \brief cudnnDropoutDescriptor_t object for GPU dropout kernel functions */ - kCuDNNDropoutDesc + kCuDNNDropoutDesc = 3 #endif // MXNET_USE_CUDNN == 1 +#if MXNET_USE_CUDA + , + /*! \brief Resource indicating the usage of multi GPU communication, used to prevent + * multiple ops of doing it at the same time */ + kMultiGPUComm = 4 +#endif // MXNET_USE_CUDA == 1 }; /*! \brief type of resources */ Type type; diff --git a/python/mxnet/gluon/contrib/__init__.py b/python/mxnet/gluon/contrib/__init__.py index 9b9f0ce8d696..8d421a137336 100644 --- a/python/mxnet/gluon/contrib/__init__.py +++ b/python/mxnet/gluon/contrib/__init__.py @@ -21,3 +21,5 @@ from . import data from . import estimator + +from . import nn \ No newline at end of file diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index fd76789918a1..e56ff4e34c9e 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -34,7 +34,7 @@ from functools import reduce # pylint: disable=redefined-builtin import numpy as np from ..base import _LIB, numeric_types, integer_types -from ..base import c_array, c_array_buf, c_handle_array, mx_real_t +from ..base import c_array, c_array_buf, c_handle_array, mx_real_t, c_str_array from ..base import mx_uint, NDArrayHandle, check_call, mx_int, mx_int64 from ..base import ctypes2buffer from ..dlpack import ndarray_to_dlpack_for_read, ndarray_to_dlpack_for_write @@ -2924,7 +2924,7 @@ def detach(self): check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl))) return _ndarray_cls(hdl) - def backward(self, out_grad=None, retain_graph=False, train_mode=True): + def backward(self, out_grad=None, retain_graph=False, train_mode=True, backward_option = {}): """Compute the gradients of this NDArray w.r.t variables. Parameters @@ -2942,7 +2942,11 @@ def backward(self, out_grad=None, retain_graph=False, train_mode=True): ograd_handles = [NDArrayHandle(0)] else: ograd_handles = [out_grad.handle] - + key_list = [] + val_list = [] + for key, val in backward_option.items(): + key_list.append(key) + val_list.append(str(val)) check_call(_LIB.MXAutogradBackwardEx( 1, c_handle_array([self]), c_array(NDArrayHandle, ograd_handles), @@ -2952,7 +2956,10 @@ def backward(self, out_grad=None, retain_graph=False, train_mode=True): ctypes.c_int(0), ctypes.c_int(train_mode), ctypes.c_void_p(0), - ctypes.c_void_p(0))) + ctypes.c_void_p(0), + mx_uint(len(key_list)), + c_str_array(key_list), + c_str_array(val_list))) def tostype(self, stype): """Return a copy of the array with chosen storage type. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 2fb883f00997..ea8cec633171 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -62,6 +62,13 @@ #include "miniz.h" #include "nnvm/pass_functions.h" +#if MXNET_USE_CUDA +#include "../common/cuda/utils.h" +#endif +#if MXNET_USE_NCCL +#include +#endif + // FTZ only applies to SSE and AVX instructions. #if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ (defined(_M_IX86_FP) && _M_IX86_FP >= 1) @@ -1936,6 +1943,16 @@ int MXGetGPUCount(int* out) { API_END(); } +int MXGetGPUSMArch(int dev, int* out) { + API_BEGIN(); +#if MXNET_USE_CUDA == 1 + *out = SMArch(dev); +#else + LOG(FATAL) << "Compile with USE_CUDA=1 to query CUDA device properties."; +#endif + API_END(); +} + // Deprecated: use MXGetGPUMemoryInformation64() instead. int MXGetGPUMemoryInformation(int dev, int* free_mem, int* total_mem) { API_BEGIN(); @@ -1953,6 +1970,30 @@ int MXGetGPUMemoryInformation64(int dev, uint64_t* free_mem, uint64_t* total_mem API_END(); } +int MXNCCLGetUniqueIdSize(int* size) { + API_BEGIN(); +#if MXNET_USE_CUDA && MXNET_USE_NCCL + *size = sizeof(ncclUniqueId); +#else + LOG(FATAL) << "Compile with USE_CUDA=1 and USE_NCCL=1 to have NCCL support."; +#endif + API_END(); +} + +int MXNCCLGetUniqueId(void* out) { + API_BEGIN(); +#if MXNET_USE_CUDA && MXNET_USE_NCCL + auto ret = ncclGetUniqueId(reinterpret_cast(out)); + if (ret != ncclSuccess) { + LOG(FATAL) << "Failed to get the NCCL unique id"; + } +#else + LOG(FATAL) << "Compile with USE_CUDA=1 and USE_NCCL=1 to have NCCL support."; +#endif + API_END(); +} + + int MXGetVersion(int* out) { API_BEGIN(); *out = static_cast(MXNET_VERSION); diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 2e9c0a373621..0b30373f68b2 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -362,6 +362,9 @@ int MXAutogradBackward(uint32_t num_output, false, true, nullptr, + nullptr, + 0, + nullptr, nullptr); } @@ -374,10 +377,15 @@ int MXAutogradBackwardEx(uint32_t num_output, int create_graph, int is_train, NDArrayHandle** grad_handles, - int** grad_stypes) { + int** grad_stypes, + const mx_uint num_options, + const char** keys, + const char** vals) { MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); - + std::unordered_map backward_options_map; + for (mx_uint i = 0; i < num_options; ++i) + backward_options_map.emplace(keys[i], vals[i]); std::vector outputs, ograds, variables; outputs.reserve(num_output); for (uint32_t i = 0; i < num_output; ++i) { @@ -397,9 +405,8 @@ int MXAutogradBackwardEx(uint32_t num_output, for (uint32_t i = 0; i < num_variables; ++i) { variables.emplace_back(reinterpret_cast(var_handles[i])); } - auto grads = - Imperative::Get()->Backward(outputs, ograds, variables, is_train, retain_graph, create_graph); + Imperative::Get()->Backward(outputs, ograds, variables, is_train, retain_graph, create_graph, backward_options_map); if (num_variables != 0) { ret->ret_handles.clear(); ret->out_types.clear(); diff --git a/src/imperative/attach_op_resource_pass.cc b/src/imperative/attach_op_resource_pass.cc index 17d6d7a41dc3..5b8e3ebf4649 100644 --- a/src/imperative/attach_op_resource_pass.cc +++ b/src/imperative/attach_op_resource_pass.cc @@ -83,6 +83,12 @@ void AttachOpResources(const Graph& g, break; } #endif // MXNET_USE_CUDNN == 1 +#if MXNET_USE_CUDA == 1 + case ResourceRequest::kMultiGPUComm: { + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + break; + } +#endif // MXNET_USE_NCCL == 1 default: LOG(FATAL) << "resource type " << req.type << " is not yet supported"; } diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 894ef09a1d16..c29a7ab96c80 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -259,7 +259,8 @@ void SetBackwardInputEid(const std::vector& bwd_in_dep, bool CachedOp::SetBackwardGraph(GraphInfo* info, const std::vector& reqs, const std::vector& inputs, - bool detect_inplace_addto) { + bool detect_inplace_addto, + const std::unordered_map backward_options_map) { using namespace nnvm; using namespace imperative; std::lock_guard lock(mutex_); @@ -279,7 +280,26 @@ bool CachedOp::SetBackwardGraph(GraphInfo* info, g.attrs["context"] = std::make_shared( std::vector(g.indexed_graph().num_nodes(), default_ctx)); } + //deal with delete + auto pg_option = backward_options_map.find("partition_grad"); + if (pg_option != backward_options_map.end()) + { + auto it = backward_options_map.find("current_rank"); + int cur_rank = std::atoi(it->second.c_str()); + auto output_it = g.outputs.begin(); + while (output_it!=g.outputs.end()) + { + it = backward_options_map.find(output_it->node ->attrs.name); + if (it!=backward_options_map.end() && atoi(it->second.c_str()) != cur_rank) + { + g.outputs.erase(output_it); + } + else{ + output_it++; + } + } + } const auto& idx = g.indexed_graph(); if (info->bwd_input_eid.size() != inputs.size()) { @@ -902,7 +922,8 @@ void CachedOp::DynamicBackward(const bool retain_graph, const OpStatePtr& op_state, const std::vector& inputs, const std::vector& reqs, - const std::vector& outputs) { + const std::vector& outputs, + const std::unordered_map backward_options_map) { using namespace nnvm; using namespace imperative; @@ -915,7 +936,7 @@ void CachedOp::DynamicBackward(const bool retain_graph, std::lock_guard lock(state.mutex); state.info.fwd_graph = runtime.info.fwd_graph; state.info.input_map = runtime.info.input_map; - SetBackwardGraph(&state.info, reqs, inputs); + SetBackwardGraph(&state.info, reqs, inputs, false, backward_options_map); runtime.info.full_graph = state.info.full_graph; runtime.info.bwd_input_eid = state.info.bwd_input_eid; } @@ -1005,7 +1026,8 @@ void CachedOp::StaticBackward(const bool retain_graph, const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& reqs, - const std::vector& outputs) { + const std::vector& outputs, + const std::unordered_map backward_options_map) { using namespace nnvm; using namespace imperative; @@ -1014,9 +1036,10 @@ void CachedOp::StaticBackward(const bool retain_graph, auto& state = state_ptr.get_state(); std::lock_guard lock(state.mutex); - bool match = SetBackwardGraph(&state.info, reqs, inputs, true); + bool match = SetBackwardGraph(&state.info, reqs, inputs, true, backward_options_map); nnvm::Graph& g = state.info.full_graph; + const auto& idx = g.indexed_graph(); auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes(); @@ -1090,7 +1113,8 @@ void CachedOp::Backward(const bool retain_graph, const OpStatePtr& state, const std::vector& inputs, const std::vector& reqs, - const std::vector& outputs) { + const std::vector& outputs, + const std::unordered_map backward_options_map) { const auto& fwd_idx = fwd_graph_.indexed_graph(); const auto& full_idx = full_graph_.indexed_graph(); const auto& mutable_input_nodes = fwd_idx.mutable_input_nodes(); @@ -1121,9 +1145,9 @@ void CachedOp::Backward(const bool retain_graph, try { if (config_.static_alloc) { - StaticBackward(retain_graph, state, inputs, reqs, outputs); + StaticBackward(retain_graph, state, inputs, reqs, outputs, backward_options_map); } else { - DynamicBackward(retain_graph, state, inputs, reqs, outputs); + DynamicBackward(retain_graph, state, inputs, reqs, outputs, backward_options_map); } } catch (const dmlc::Error& e) { Engine::Get()->set_bulk_size(prev_bulk_size); @@ -1261,7 +1285,7 @@ void CachedOpBackward(const OpStatePtr& state_ptr, // pass a flag to determine whether to record computation inside an operator. // Let's use false here for now and design a solution when the second-order // differentiation is supported. - s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs); + s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs, {}); Imperative::Get()->set_is_training(orig_is_train); // Clean up what we recorded. diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 079a56e20a12..e8d441114a0c 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -501,7 +501,8 @@ class CachedOp { const OpStatePtr& state, const std::vector& inputs, const std::vector& reqs, - const std::vector& outputs); + const std::vector& outputs, + const std::unordered_map backward_options_map); // backward storage type inference virtual bool BackwardStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, @@ -606,7 +607,8 @@ class CachedOp { bool SetBackwardGraph(GraphInfo* info, const std::vector& reqs, const std::vector& inputs, - bool detect_inplace_addto = false); + bool detect_inplace_addto = false, + const std::unordered_map backward_options_map = {}); bool CheckDynamicShapeExists(const Context& default_ctx, const std::vector& inputs, bool erase_result); @@ -632,12 +634,14 @@ class CachedOp { const OpStatePtr& op_state, const std::vector& inputs, const std::vector& reqs, - const std::vector& outputs); + const std::vector& outputs, + const std::unordered_map backward_options_map); void StaticBackward(const bool retain_graph, const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& reqs, - const std::vector& outputs); + const std::vector& outputs, + const std::unordered_map backward_options_map); size_t BwdOriginalInput(const std::vector& input_map, size_t new_i); CachedOpConfig config_; diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index b9bdaac9476f..173bb33f744a 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -440,7 +440,8 @@ std::vector Imperative::Backward(const std::vector& outputs, const std::vector& variables, bool is_train, bool retain_graph, - bool create_graph) { + bool create_graph, + const std::unordered_map backward_options_map) { using namespace nnvm; using namespace imperative; static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; @@ -519,7 +520,6 @@ std::vector Imperative::Backward(const std::vector& outputs, for (const auto& i : nleaf_vars) { us.emplace_back(NodeEntry{i, 0, 0}); } - Graph g_graph = pass::MXGradient(graph, graph.outputs, xs, @@ -723,7 +723,11 @@ std::vector Imperative::Backward(const std::vector& outputs, std::move(ref_count), &states, dispatch_modes, - is_recording()); + is_recording(), + nullptr, + nullptr, + false, + backward_options_map); } catch (const dmlc::Error& e) { Engine::Get()->set_bulk_size(prev_bulk_size); set_is_recording(prev_recording); diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index e3a58804d8ac..c6e9b9124234 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -75,7 +75,8 @@ void InvokeOperator(const nnvm::IndexedGraph& idx, const std::vector& ndoutputs, std::vector* p_req, std::vector* p_ref_count, - std::function invoke) { + std::function invoke, + const std::unordered_map backward_options_map) { static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); @@ -88,7 +89,7 @@ void InvokeOperator(const nnvm::IndexedGraph& idx, const auto& cached_op = dmlc::get(node.source->attrs.parsed); nnvm::Node* fwd_node = node.source->control_deps[0].get(); auto fwd_node_id = idx.node_id(fwd_node); - cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs); + cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs, backward_options_map); } else if (createop.count(node.source->op())) { mxnet::ShapeVector arg_shapes; nnvm::DTypeVector arg_dtypes; @@ -138,7 +139,8 @@ void RunGraph(const bool retain_graph, bool recording, mxnet::ShapeVector* shapes, const imperative::CachedOpMonCallback& callback, - const bool monitor_all) { + const bool monitor_all, + const std::unordered_map backward_options_map) { CHECK(shapes == nullptr); for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; @@ -161,8 +163,9 @@ void RunGraph(const bool retain_graph, Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state); } }; + InvokeOperator( - idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke); + idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke, backward_options_map); if (callback) { mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); } @@ -224,7 +227,7 @@ void NaiveRunGraph(const bool retain_graph, } }; InvokeOperator( - idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke); + idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke, {}); if (callback) { mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index ce1a60fb2b20..7f8e7d7f687b 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -1369,7 +1369,8 @@ void RunGraph(const bool retain_graph, bool recording, mxnet::ShapeVector* shapes = nullptr, const CachedOpMonCallback& callback = nullptr, - const bool monitor_all_ = false); + const bool monitor_all_ = false, + const std::unordered_map backward_options_map = {}); void NaiveRunGraph(const bool retain_graph, const Context& default_ctx, diff --git a/src/imperative/naive_cached_op.h b/src/imperative/naive_cached_op.h index cd1365508f97..037d4d489448 100644 --- a/src/imperative/naive_cached_op.h +++ b/src/imperative/naive_cached_op.h @@ -48,7 +48,8 @@ class NaiveCachedOp : public CachedOp { const OpStatePtr& state, const std::vector& inputs, const std::vector& reqs, - const std::vector& outputs) override { + const std::vector& outputs, + const std::unordered_map backward_options_map) override { LOG(FATAL) << "Backward is not supported in NaiveCachedOp."; } // backward storage type inference diff --git a/src/operator/contrib/nn/reduce_op-inl.h b/src/operator/contrib/nn/reduce_op-inl.h new file mode 100644 index 000000000000..8b08d7d6994f --- /dev/null +++ b/src/operator/contrib/nn/reduce_op-inl.h @@ -0,0 +1,97 @@ +#ifndef MXNET_OPERATOR_CONTRIB_NCCLREDUCE_H_ +#define MXNET_OPERATOR_CONTRIB_NCCLREDUCE_H_ + +#include +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../elemwise_op_common.h" +#include "../../tensor/init_op.h" + +#if MXNET_USE_NCCL +#include +#include +#include + +namespace mxnet { +namespace op { + +struct NCCLReduceParam : public dmlc::Parameter { + int32_t num_gpus; + int32_t root_rank; + int32_t rank; + uintptr_t nccl_unique_id; + + DMLC_DECLARE_PARAMETER(NCCLReduceParam) { + DMLC_DECLARE_FIELD(num_gpus) + .set_default(1) + .describe("Number of all gpus."); + DMLC_DECLARE_FIELD(root_rank) + .set_default(0) + .describe("root rank of reduce operation"); + DMLC_DECLARE_FIELD(rank) + .set_default(0) + .describe("rank of current process"); + DMLC_DECLARE_FIELD(nccl_unique_id) + .describe("NCCL unique ID"); + } +}; + +template +struct ncclreduce_compute { + template + MSHADOW_XINLINE static void Map(int i, + DType* in_grad, + const DType* out_grad) { + KERNEL_ASSIGN(in_grad[i], req, out_grad[i] * 1); + } +}; + +template +void NCCLReduceCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + mshadow::Stream* s = ctx.get_stream(); + const TBlob& out_grad = inputs[0]; + const TBlob& in_grad = outputs[0]; + const NCCLReduceParam& param = nnvm::get(attrs.parsed); + using namespace mxnet_op; + MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch(s, + in_grad.Size(), + in_grad.dptr(), + out_grad.dptr()); + }); + }); +} + + + + + +class NCCLCommContainer { + public: + struct Param { + int num_gpus; + int rank; + uintptr_t nccl_unique_id; + }; + static inline std::unordered_map> comm_map; + + static void Init(const Param& param); +}; + +} // namespace op +} // namespace mxnet + +#else +static_assert(false, "You need to compile with NCCL support to use reduce operation!"); +#endif // MXNET_USE_NCCL + +#endif // MXNET_OPERATOR_CONTRIB_SPATIAL_PARALLEL_SUPPORT_H_ \ No newline at end of file diff --git a/src/operator/contrib/nn/reduce_op.cc b/src/operator/contrib/nn/reduce_op.cc new file mode 100644 index 000000000000..e00245b06ed1 --- /dev/null +++ b/src/operator/contrib/nn/reduce_op.cc @@ -0,0 +1,89 @@ +#include "reduce_op-inl.h" +#include +#include +#include +#include +#include +#include +#include "../../operator_common.h" +#include "../../elemwise_op_common.h" + + +namespace mxnet { +namespace op { + +void NCCLCommContainer::Init(const NCCLCommContainer::Param& param) { + std::lock_guard l(Storage::Get()->GetMutex(Context::kGPU)); + if (NCCLCommContainer::comm_map.count(param.num_gpus) == 0) { + auto [it, inserted] = NCCLCommContainer::comm_map.emplace(param.num_gpus, // NOLINT(*) + std::make_unique()); + CHECK(inserted) << "Could not insert new NCCL communicator!"; + ncclComm_t* comm = it->second.get(); + ncclUniqueId id = *(reinterpret_cast( + reinterpret_cast(param.nccl_unique_id))); + auto result = ncclCommInitRank(comm, param.num_gpus, id, param.rank); + CHECK_EQ(result, ncclSuccess) << "ncclCommInitRank failed!"; + } +} + +bool NCCLReduceShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U; +} + +inline bool NCCLReduceType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return out_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(NCCLReduceParam); + +NNVM_REGISTER_OP(_contrib_NCCLReduce) +.describe(R"code(Reduce operation +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FCompute", NCCLReduceCompute) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", NCCLReduceShape) +.set_attr("FInferType", NCCLReduceType) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_NCCLReduce"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.set_attr("FInplaceIdentity", + [](const NodeAttrs& attrs){ + const NCCLReduceParam& param = nnvm::get(attrs.parsed); + if (param.num_gpus == 1) { + return std::vector{true}; + } else { + return std::vector{false}; + } + }) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NCCLReduceParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_NCCLReduce) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true); +} +} \ No newline at end of file diff --git a/src/operator/contrib/nn/reduce_op.cu b/src/operator/contrib/nn/reduce_op.cu new file mode 100644 index 000000000000..655023c6c964 --- /dev/null +++ b/src/operator/contrib/nn/reduce_op.cu @@ -0,0 +1,59 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file spatial_parallel_support.cu + * \brief Support operators for spatial parallelism + * \author Przemyslaw Tredak +*/ + +#include "reduce_op-inl.h" +#include +#include +#include +#include +#include "../../operator_common.h" +#include "../../../common/utils.h" +#include "../../tensor/elemwise_binary_op.h" + +namespace mxnet { +namespace op { + + +void NCCLReduceBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const NCCLReduceParam& param = nnvm::get(attrs.parsed); + if (req[0] == OpReqType::kNullOp) return; + if (param.num_gpus == 1 && req[0] == OpReqType::kWriteInplace) return; + NCCLCommContainer::Param p = {param.num_gpus, + param.rank, + param.nccl_unique_id}; + NCCLCommContainer::Init(p); + + std::lock_guard l(Storage::Get()->GetMutex(Context::kGPU)); + ncclComm_t comm = *(NCCLCommContainer::comm_map.at(param.num_gpus)); + const index_t size = inputs[0].shape_.Size() * + common::mshadow_type_info(inputs[0].type_flag_).size; + if (req[0] != OpReqType::kAddTo) { + ncclResult_t result = ncclReduce(inputs[0].dptr_, + outputs[0].dptr_, + size, ncclFloat32, ncclAvg, param.root_rank, + comm, + mshadow::Stream::GetStream(ctx.get_stream())); + + + CHECK_EQ(result, ncclSuccess) << "NCCL Reduce failed!"; + } else { + LOG(FATAL) << "kAddTo not supported yet!"; + } +} + +NNVM_REGISTER_OP(_contrib_NCCLReduce) +.set_attr("FCompute", NCCLReduceCompute); + +NNVM_REGISTER_OP(_backward_NCCLReduce) +.set_attr("FCompute", NCCLReduceBackward); + +} // namespace op +} // namespace mxnet \ No newline at end of file diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc index ad12d3ded5d2..a5647bcad859 100644 --- a/src/operator/subgraph_op_common.cc +++ b/src/operator/subgraph_op_common.cc @@ -274,7 +274,7 @@ void LoopState::Backward(int iter_no, outputs.push_back(&igrad_bufs[i]); CHECK_EQ(outputs.size(), op->num_inputs()); auto state = all_states[iter_no]; - op->Backward(false, state, inputs, req, outputs); + op->Backward(false, state, inputs, req, outputs, {}); // If an input and an output share the array, the output array will be changed // by CachedOp. We need to copy data to the real output. for (size_t i = 0; i < igrads.size(); i++)