From 65b29ef39b24b016f0c125e9827c65f097565da9 Mon Sep 17 00:00:00 2001 From: fzou1 Date: Sat, 22 Jul 2017 23:57:25 +0800 Subject: [PATCH] fix hang issue if resuming training and compilation issue as macro is called in another function than ForwardBackwardImpl --- include/caffe/multinode/multi_solver.hpp | 2 +- include/caffe/multinode/multi_sync.hpp | 20 +++++----- src/caffe/multinode/multi_solver.cpp | 50 ++++++------------------ 3 files changed, 23 insertions(+), 49 deletions(-) diff --git a/include/caffe/multinode/multi_solver.hpp b/include/caffe/multinode/multi_solver.hpp index 1b5664d5f..5d2082821 100644 --- a/include/caffe/multinode/multi_solver.hpp +++ b/include/caffe/multinode/multi_solver.hpp @@ -64,7 +64,7 @@ class MultiSolver { Net& net = *root_solver_->net(); const std::vector>> & layers{ net.layers() }; layer_finished_flags_.resize(layers.size()); - std::fill(layer_finished_flags_.begin(), layer_finished_flags_.end(), false); + std::fill(layer_finished_flags_.begin(), layer_finished_flags_.end(), true); #endif } diff --git a/include/caffe/multinode/multi_sync.hpp b/include/caffe/multinode/multi_sync.hpp index d08f7f13c..b979e89fe 100644 --- a/include/caffe/multinode/multi_sync.hpp +++ b/include/caffe/multinode/multi_sync.hpp @@ -182,15 +182,15 @@ namespace caffe { mn::train::commit(); #ifdef PERFORMANCE_MONITORING - statsIterResult.resize(caffe::mn::train::get_session().get_operation_count()); - caffe::mn::train::stats::start(); + statsIterResult.resize(caffe::mn::train::get_session().get_operation_count()); + caffe::mn::train::stats::start(); #endif solver->add_callback(this); solver->Solve(); #ifdef PERFORMANCE_MONITORING - dump_stats_to_file(); + dump_stats_to_file(); #endif } @@ -206,6 +206,10 @@ namespace caffe { } void on_iter_finished(int layer_id) { +#ifdef FW_OVERLAP_OPT + solver->set_layer_finished_flag(layer_id, false); +#endif + boost::shared_ptr> &layer = layers[layer_id]; if (layer->layerOp == nullptr) { return; @@ -238,16 +242,11 @@ namespace caffe { } std::vector ¶m_ids = layer_param_ids[layer_id]; - -#ifdef FW_OVERLAP_OPT - int finished_count = 0; -#endif - for (int i=0; iParamNeedReduce(i) #ifdef FW_OVERLAP_OPT || (param_ids_finished_flags[layer_id][i] == true)) { - finished_count++; + param_ids_finished_flags[layer_id][i] = true; #else ) { #endif @@ -264,7 +263,6 @@ namespace caffe { #ifdef FW_OVERLAP_OPT assert(is_completed); param_ids_finished_flags[layer_id][i] = true; - finished_count++; #endif if (CAN_USE_PRV(net_params[param_ids[i]])) { if (delwt_buf != net_params[param_ids[i]]->prv_diff()) @@ -279,6 +277,8 @@ namespace caffe { } #ifdef FW_OVERLAP_OPT + int finished_count = std::count(param_ids_finished_flags[layer_id].begin(), + param_ids_finished_flags[layer_id].end(), true); if (finished_count == param_ids.size()) { solver->set_layer_finished_flag(layer_id, true); } diff --git a/src/caffe/multinode/multi_solver.cpp b/src/caffe/multinode/multi_solver.cpp index 0516b27b5..13ad8da2b 100644 --- a/src/caffe/multinode/multi_solver.cpp +++ b/src/caffe/multinode/multi_solver.cpp @@ -46,16 +46,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. namespace caffe { -#define START_ITER 1 - - #ifdef CAFFE_PER_LAYER_TIMINGS #define LAYER_TIMING_START() do { \ - timer.Start(); \ + root_solver_->timer.Start(); \ }while(0) #define LAYER_TIMING_STOP(name, index) do { \ - name##_time_per_layer[index] += timer.MicroSeconds(); \ + root_solver_->name##_time_per_layer[index] += root_solver_->timer.MicroSeconds(); \ }while(0) #else #define LAYER_TIMING_START() @@ -101,50 +98,29 @@ inline void MultiSolver::WaitAndUpdateGradient(int layer_id) { template Dtype MultiSolver::ForwardBackwardImpl(bool first, bool last) { - Dtype loss = 0; Net& net = *root_solver_->net(); const std::vector>>& layers{ net.layers() }; const std::vector& layer_need_backward{ net.layer_need_backward() }; -#ifdef FW_OVERLAP_OPT - int iter = root_solver_->iter(); -#endif - -#ifdef CAFFE_PER_LAYER_TIMINGS - Timer& timer = root_solver_->timer; - std::vector& forward_time_per_layer = root_solver_->forward_time_per_layer; - std::vector& backward_time_per_layer = root_solver_->backward_time_per_layer; - std::vector& update_time_per_layer = root_solver_->update_time_per_layer; - std::vector& startcomm_time_per_layer = root_solver_->startcomm_time_per_layer; - std::vector& waitcomm_time_per_layer = root_solver_->waitcomm_time_per_layer; -#endif /* CAFFE_PER_LAYER_TIMINGS */ - for (int i = 0; i < layers.size(); ++i) { #ifdef FW_OVERLAP_OPT - if (first && iter >= START_ITER + 1) { + if (first && IsSkipWaitGradient(i) == false) { while (layer_finished_flags_[i] == false) { - if (IsSkipWaitGradient(i)) { - break; - } - WaitAndUpdateGradient(i); - if (layer_finished_flags_[i]) { + if (layer_finished_flags_[i]) break; - } for (int k=i+1; k::ForwardBackwardImpl(bool first, bool last) { } LAYER_TIMING_START(); - net.BackwardFromTo(i, i); - LAYER_TIMING_STOP(backward, i); - if (last && (layers[i]->layerOp != nullptr) && layers[i]->layerOp->HasParameterSets()) { + if (last && (layers[i]->layerOp != nullptr) + && layers[i]->layerOp->HasParameterSets()) { LAYER_TIMING_START(); for (int j = 0; j < callbacks_.size(); ++j) { callbacks_[j]->on_iter_finished(i); @@ -174,6 +149,7 @@ Dtype MultiSolver::ForwardBackwardImpl(bool first, bool last) { } #ifdef FW_OVERLAP_OPT + int iter = root_solver_->iter(); int max_iter = root_solver_->param().max_iter(); bool test = (root_solver_->param().test_interval() && ((iter + 1) % root_solver_->param().test_interval() == 0)); @@ -183,12 +159,7 @@ Dtype MultiSolver::ForwardBackwardImpl(bool first, bool last) { #else if (last) { #endif - - for (int i = 0; i < layers.size(); ++i) { -#ifdef FW_OVERLAP_OPT - if (layer_finished_flags_[i]) - continue; -#endif + for (int i = 0; i < layers.size(); ++i) { if (IsSkipWaitGradient(i)) { #ifdef FW_OVERLAP_OPT finished_count++; @@ -196,9 +167,12 @@ Dtype MultiSolver::ForwardBackwardImpl(bool first, bool last) { #endif continue; } +#ifdef FW_OVERLAP_OPT + if (layer_finished_flags_[i]) + continue; +#endif WaitAndUpdateGradient(i); - #ifdef FW_OVERLAP_OPT if (layer_finished_flags_[i]) finished_count++;