Skip to content
This repository has been archived by the owner on Aug 5, 2022. It is now read-only.

Commit

Permalink
Merge branch 'release_1.0.0_rc2' of ssh://git-ccr-1.devtools.intel.co…
Browse files Browse the repository at this point in the history
…m:29418/dl_framework-intel_caffe
  • Loading branch information
daisyden committed Jun 6, 2017
2 parents b0ef323 + cef51e9 commit b8cb4f5
Show file tree
Hide file tree
Showing 36 changed files with 3,036 additions and 2,132 deletions.
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

COPYRIGHT

All modification made by Intel Corporation: © 2017 Intel Corporation.
Expand Down
6 changes: 5 additions & 1 deletion Makefile.mkldnn
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ MKLDNN_COMMIT := `cat ${CAFFE_ROOTDIR}/mkldnn.commit`
MKLDNN_CXX := $(CXX)
MKLDNN_CC := $(CC)

RETURN_STRING=$(shell ./external/mkl/prepare_mkl.sh)
MKLROOT=$(firstword $(RETURN_STRING))
MKL_ROOTDIR := $(MKLROOT)

# We do this because earlier versions of CMake have problems with ccache
ifneq (,$(findstring ccache,$(CXX)))
MKLDNN_CXX := $(lastword $(CXX))
Expand All @@ -18,7 +22,7 @@ ifneq (,$(findstring ccache,$(CC)))
endif

MKLDNN_GITHUB := https://github.com/01org/mkl-dnn.git
MKLDNN_CMAKE_FLAGS += $(MKLDNN_SRCDIR) -DCMAKE_INSTALL_PREFIX=$(CAFFE_ROOTDIR)/$(MKLDNN_INSTALLDIR) -B$(CAFFE_ROOTDIR)/$(MKLDNN_BUILDDIR) -DCMAKE_CXX_COMPILER="$(MKLDNN_CXX)" -DCMAKE_C_COMPILER="$(MKLDNN_CC)"
MKLDNN_CMAKE_FLAGS += $(MKLDNN_SRCDIR) -DCMAKE_INSTALL_PREFIX=$(CAFFE_ROOTDIR)/$(MKLDNN_INSTALLDIR) -DMKLROOT=${MKL_ROOTDIR} -B$(CAFFE_ROOTDIR)/$(MKLDNN_BUILDDIR) -DCMAKE_CXX_COMPILER="$(MKLDNN_CXX)" -DCMAKE_C_COMPILER="$(MKLDNN_CC)"

ifeq ("$(wildcard $(MKLDNN_INSTALLDIR)/include/mkldnn.hpp)", "")
mkldnn_download:
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ Please cite Caffe in your publications if it helps your research:

***
*Other names and brands may be claimed as the property of others



7 changes: 4 additions & 3 deletions external/mkl/prepare_mkl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ echo $VERSION_LINE # Return Version Line
# MKL
DST=`dirname $0`
OMP=0
VERSION_MATCH=20170101
ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170110.tgz
VERSION_MATCH=20170425
ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz
MKL_CONTENT_DIR=`echo $ARCHIVE_BASENAME | rev | cut -d "." -f 2- | rev`
GITHUB_RELEASE_TAG=self_containted_MKLGOLD_u2
GITHUB_RELEASE_TAG=1.0.0

MKLURL="https://github.com/intel/caffe/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME"
# there are diffrent MKL lib to be used for GCC and for ICC
reg='^[0-9]+$'
Expand Down
2 changes: 1 addition & 1 deletion include/caffe/layers/accuracy_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class AccuracyLayer : public Layer<Dtype> {
// If there are two top blobs, then the second blob will contain
// accuracies per class.
virtual inline int MinTopBlobs() const { return 1; }
virtual inline int MaxTopBlos() const { return 2; }
virtual inline int MaxTopBlobs() const { return 2; }

protected:
/**
Expand Down
14 changes: 14 additions & 0 deletions include/caffe/layers/base_conv_layer.hpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ class BaseConvolutionLayer : public Layer<Dtype> {
pad_.cpu_data()[0], pad_.cpu_data()[1],
stride_.cpu_data()[0], stride_.cpu_data()[1],
dilation_.cpu_data()[0], dilation_.cpu_data()[1], col_buff);
} else if (!force_nd_im2col_ && num_spatial_axes_ == 3) {
im3d2col_cpu(data, conv_in_channels_,
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], conv_input_shape_.cpu_data()[3],
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], kernel_shape_.cpu_data()[2],
pad_.cpu_data()[0], pad_.cpu_data()[1], pad_.cpu_data()[2],
stride_.cpu_data()[0], stride_.cpu_data()[1], stride_.cpu_data()[2],
dilation_.cpu_data()[0], dilation_.cpu_data()[1], dilation_.cpu_data()[2], col_buff);
} else {
im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(),
col_buffer_shape_.data(), kernel_shape_.cpu_data(),
Expand All @@ -167,6 +174,13 @@ class BaseConvolutionLayer : public Layer<Dtype> {
pad_.cpu_data()[0], pad_.cpu_data()[1],
stride_.cpu_data()[0], stride_.cpu_data()[1],
dilation_.cpu_data()[0], dilation_.cpu_data()[1], data);
} else if (!force_nd_im2col_ && num_spatial_axes_ == 3) {
col2im3d_cpu(col_buff, conv_in_channels_,
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], conv_input_shape_.cpu_data()[3],
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], kernel_shape_.cpu_data()[2],
pad_.cpu_data()[0], pad_.cpu_data()[1], pad_.cpu_data()[2],
stride_.cpu_data()[0], stride_.cpu_data()[1], stride_.cpu_data()[2],
dilation_.cpu_data()[0], dilation_.cpu_data()[1], dilation_.cpu_data()[2], data);
} else {
col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(),
col_buffer_shape_.data(), kernel_shape_.cpu_data(),
Expand Down
1,029 changes: 519 additions & 510 deletions include/caffe/layers/mkldnn_layers.hpp

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions include/caffe/layers/softmax_loss_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> {
virtual inline int MinTopBlobs() const { return 1; }
virtual inline int MaxTopBlobs() const { return 2; }

virtual inline int ExactNumBottomBlobs() const { return -1; }
virtual inline int MinBottomBlobs() const { return 2; }
virtual inline int MaxBottomBlobs() const { return 3; }
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
Expand Down
4 changes: 4 additions & 0 deletions include/caffe/mkldnn_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ template <typename Dtype>
class MKLDNNPrimitive {
public:
explicit MKLDNNPrimitive():aprimitive(), mkldnn_stream() {}

//API for initializing with shared_ptr<primitive>
MKLDNNPrimitive(shared_ptr<primitive> aprimitive_input) {this->aprimitive = aprimitive_input;}

virtual ~MKLDNNPrimitive() {}
void reset(primitive* pprimitive) { this->aprimitive.reset(pprimitive);}
shared_ptr<primitive> aprimitive;
Expand Down
17 changes: 10 additions & 7 deletions include/caffe/util/compareToolUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,13 +372,16 @@ int collectAndCheckLayerData(bool collect_step,
}

if (bottom_need_backward[i].size() > 0 && bottom_need_backward[i][0]) {
getFileName(file_name, false, "FwrdBtmDat", i);
checkData(file_name, bottom_vecs[i][0]->cpu_data(),
layers[i]->type(), output_dir,
&erronous_layers);
checkAllNans(bottom_vecs[i][0]->cpu_diff(),
bottom_vecs[i][0]->count(), "bottom.diff",
layers[i]->type(), &erronous_layers);
// We check data only for out-of-place computations
if (bottom_vecs[i][0] != top_vecs[i][0]) {
getFileName(file_name, false, "FwrdBtmDat", i);
checkData(file_name, bottom_vecs[i][0]->cpu_data(),
layers[i]->type(), output_dir,
&erronous_layers);
}
checkAllNans(bottom_vecs[i][0]->cpu_diff(),
bottom_vecs[i][0]->count(), "bottom.diff",
layers[i]->type(), &erronous_layers);
}

checkAllNans(top_vecs[i][0]->cpu_diff(),
Expand Down
14 changes: 14 additions & 0 deletions include/caffe/util/im2col.hpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ void im2col_cpu(const Dtype* data_im, const int channels,
const int stride_w, const int dilation_h, const int dilation_w,
Dtype* data_col);

template <typename Dtype>
void im3d2col_cpu(const Dtype* data_im, const int channels,
const int depth, const int height, const int width, const int kernel_d, const int kernel_h, const int kernel_w,
const int pad_d, const int pad_h, const int pad_w, const int stride_d, const int stride_h,
const int stride_w, const int dilation_d, const int dilation_h, const int dilation_w,
Dtype* data_col);

template <typename Dtype>
void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
const int* im_shape, const int* col_shape,
Expand All @@ -66,6 +73,13 @@ void col2im_cpu(const Dtype* data_col, const int channels,
const int stride_w, const int dilation_h, const int dilation_w,
Dtype* data_im);

template <typename Dtype>
void col2im3d_cpu(const Dtype* data_col, const int channels,
const int depth, const int height, const int width, const int kernel_d, const int kernel_h, const int kernel_w,
const int pad_d, const int pad_h, const int pad_w, const int stride_d, const int stride_h,
const int stride_w, const int dilation_d, const int dilation_h, const int dilation_w,
Dtype* data_im);

template <typename Dtype>
void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes,
const int col_size, const int* im_shape, const int* col_shape,
Expand Down
5 changes: 3 additions & 2 deletions models/intel_optimized_models/resnet_50/solver.prototxt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#This solver is described by Computer Vision Group Jena (CVGJ) in [ImageNet pre-trained models with batch normalization] (https://arxiv.org/pdf/1612.01452.pdf)
net: "train_val.prototxt"
net: "models/intel_optimized_models/resnet_50/train_val.prototxt"
test_iter: 5000
test_interval: 15000
base_lr: 0.1
Expand All @@ -11,6 +11,7 @@ power: 1
momentum: 0.9
weight_decay: 0.0001
snapshot: 30000
snapshot_prefix: "caffe-resnet50"
snapshot_prefix: "models/intel_optimized_models/resnet_50/caffe-resnet50"
test_initialization: false
solver_mode: CPU

4 changes: 2 additions & 2 deletions models/intel_optimized_models/resnet_50/train_val.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ transform_param {
mean_value: 123
}
data_param {
source: "/data/compressed_lmdb/ilsvrc12_train_lmdb"
source: "examples/imagenet/ilsvrc12_train_lmdb"
batch_size: 128
backend: LMDB
shuffle: true
Expand All @@ -46,7 +46,7 @@ transform_param {
mean_value: 123
}
data_param {
source: "/data/compressed_lmdb/ilsvrc12_val_lmdb/"
source: "examples/imagenet/ilsvrc12_val_lmdb/"
batch_size: 10
backend: LMDB
}
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/concat_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
if (propagate_down[i]) {
Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
#ifdef _OPENMP
#pragma omp parallel for
#pragma omp parallel for if(num_concats_ > 1)
#endif
for (int n = 0; n < num_concats_; ++n) {
caffe_copy(bottom_concat_axis * concat_input_size_, top_diff +
Expand Down
24 changes: 19 additions & 5 deletions src/caffe/layers/conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[i]->cpu_data();
Dtype* top_data = top[i]->mutable_cpu_data();
#ifdef _OPENMP
#pragma omp parallel for num_threads(this->num_of_threads_)
#pragma omp parallel if(this->num_of_threads_ > 1) num_threads(this->num_of_threads_)
{
#pragma omp for
#endif
for (int n = 0; n < this->num_; ++n) {
this->forward_cpu_gemm(bottom_data + n*this->bottom_dim_,
Expand All @@ -84,6 +86,9 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
this->forward_cpu_bias(top_data + n * this->top_dim_, bias);
}
}
#ifdef _OPENMP
}
#endif
}
}

Expand Down Expand Up @@ -111,8 +116,10 @@ void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,

if (this->param_propagate_down_[0]) {
#ifdef _OPENMP
this->clear_weight_mt();
#pragma omp parallel num_threads(this->num_of_threads_)
if (this->num_of_threads_ > 1) {
this->clear_weight_mt();
}
#pragma omp parallel if(this->num_of_threads_ > 1) num_threads(this->num_of_threads_)
#endif
{
#ifdef _OPENMP
Expand All @@ -125,20 +132,27 @@ void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
}

#ifdef _OPENMP
this->sum_weight_mt(weight_diff);
if (this->num_of_threads_ > 1) {
this->sum_weight_mt(weight_diff);
}
#endif
}
}

if (propagate_down[i]) {
#ifdef _OPENMP
#pragma omp parallel for num_threads(this->num_of_threads_)
#pragma omp parallel if(this->num_of_threads_ > 1) num_threads(this->num_of_threads_)
{
#pragma omp for
#endif
for (int n = 0; n < this->num_; ++n) {
// gradient w.r.t. bottom data, if necessary.
this->backward_cpu_gemm(top_diff + n * this->top_dim_, weight,
bottom_diff + n * this->bottom_dim_);
}
#ifdef _OPENMP
}
#endif
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions src/caffe/layers/dropout_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
// below line designated to set correspondent SyncedMemory->_head to HEAD_AT_CPU
// Fix the issue of "Check failed: this->_cpu_ptr == cpu_ptr (0 vs. 0x5587dfc87ec0)" (GoogleNet V1)
// The reason is after pooling layer: MKLDNNPoolingLayer<Dtype>::Forward_cpu: pool5/7x7_s1, the top[0]->prv_data() has value
// It will convert to cpu data in the dropout layer, and set the _head to HEAD_AT_CPU after executing top[0]->mutable_cpu_data()
// Howerver, I found top[0]->cpu_data() and top[0]->prv_data() both has value
// So in the inner product layer: loss3/classifier, the data will convert from bottom prv data
// and the reorder will change from this->_reorder_usr2prv to this->_reorder_extprv2prv_pd
// So eventually trigger the assertion.
top[0]->set_prv_data_descriptor(NULL);
unsigned int* mask = rand_vec_.mutable_cpu_data();
const int count = bottom[0]->count();
if (this->phase_ == TRAIN) {
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/hdf5_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ void HDF5DataLayer<Dtype>::LoadHDF5FileData(const char* filename) {
// Shuffle if needed.
if (this->layer_param_.hdf5_data_param().shuffle()) {
std::random_shuffle(data_permutation_.begin(), data_permutation_.end());
DLOG(INFO) << "Successully loaded " << hdf_blobs_[0]->shape(0)
DLOG(INFO) << "Successfully loaded " << hdf_blobs_[0]->shape(0)
<< " rows (shuffled)";
} else {
DLOG(INFO) << "Successully loaded " << hdf_blobs_[0]->shape(0) << " rows";
DLOG(INFO) << "Successfully loaded " << hdf_blobs_[0]->shape(0) << " rows";
}
}

Expand Down
30 changes: 22 additions & 8 deletions src/caffe/layers/mkl_batch_norm_layer.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ void MKLBatchNormLayer<Dtype>::Init(const vector<Blob<Dtype>*>& bottom,
eps_ = this->layer_param_.batch_norm_param().eps();
use_weight_bias_ = this->layer_param_.batch_norm_param().use_weight_bias();
bias_term_ = this->layer_param_.batch_norm_param().bias_term();
use_global_stats_ = this->layer_param_.batch_norm_param().use_global_stats();

CHECK(use_weight_bias_) << "BatchNorm without scaling have not supported yet";

Expand Down Expand Up @@ -111,6 +112,7 @@ void MKLBatchNormLayer<Dtype>::Init(const vector<Blob<Dtype>*>& bottom,
dnnReleaseBuffer<Dtype>(variance_buffer_);
dnnReleaseBuffer<Dtype>(scaleShift_buffer_);
dnnReleaseBuffer<Dtype>(diffScaleShift_buffer_);

// "Lazy" allocation because here we don't know
// what layout is used by neighbours.

Expand Down Expand Up @@ -271,9 +273,15 @@ void MKLBatchNormLayer<Dtype>::Forward_cpu(
bwd_top_diff ->create_internal_layout(batchNormFwd, dnnResourceDst);
bwd_bottom_diff->create_internal_layout(batchNormFwd, dnnResourceSrc);

e = dnnBatchNormalizationCreateBackward<Dtype>(
&batchNormBwd, NULL, mem_descr->layout_int, eps_, dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
if (!use_global_stats_) {
e = dnnBatchNormalizationCreateBackward<Dtype>(
&batchNormBwd, NULL, mem_descr->layout_int, eps_, dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
} else {
e = dnnBatchNormalizationCreateBackward<Dtype>(
&batchNormBwd, NULL, mem_descr->layout_int, eps_, dnnUseScaleShift | dnnUseInputMeanVariance);
CHECK_EQ(e, E_SUCCESS);
}
}
} else {
DLOG(INFO) << "Using cpu_data in MKLBatchNormLayer.";
Expand All @@ -290,9 +298,15 @@ void MKLBatchNormLayer<Dtype>::Forward_cpu(
dnnUseScaleShift | dnnUseInputMeanVariance);
CHECK_EQ(e, E_SUCCESS);

e = dnnBatchNormalizationCreateBackward<Dtype>(
&batchNormBwd, NULL, layout_usr_, eps_, dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
if (!use_global_stats_) {
e = dnnBatchNormalizationCreateBackward<Dtype>(
&batchNormBwd, NULL, layout_usr_, eps_, dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
} else {
e = dnnBatchNormalizationCreateBackward<Dtype>(
&batchNormBwd, NULL, layout_usr_, eps_, dnnUseScaleShift | dnnUseInputMeanVariance);
CHECK_EQ(e, E_SUCCESS);
}
}
bottom_data =
reinterpret_cast<void *>(const_cast<Dtype*>(bottom[0]->cpu_data()));
Expand Down Expand Up @@ -360,13 +374,13 @@ void MKLBatchNormLayer<Dtype>::Forward_cpu(
// doing Backward
// TODO: make a caffe_coppy working on blobs
caffe_copy(amount_to_copy, static_cast<Dtype*>(bottom_data),
temp_.mutable_cpu_data());
temp_.mutable_cpu_data());
}

if (use_global_stats_) {
// use the stored mean/variance estimates.
const Dtype scale_factor = this->blobs_[2]->cpu_data()[0] == 0 ?
0 : 1 / this->blobs_[2]->cpu_data()[0];
0 : 1 / this->blobs_[2]->cpu_data()[0];
caffe_cpu_scale(this->blobs_[0]->count(), scale_factor,
this->blobs_[0]->cpu_data(), mean_buffer_);
caffe_cpu_scale(this->blobs_[1]->count(), scale_factor,
Expand Down
Loading

0 comments on commit b8cb4f5

Please sign in to comment.