Skip to content

Commit fd77927

Browse files
authored
Merge pull request #359 from drnikolaev/caffe-0.16.2-pr
June 2017 release
2 parents 5a06f0e + 2a4581e commit fd77927

File tree

110 files changed

+4422
-887
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+4422
-887
lines changed

3rdparty/half_float/half.hpp

+22-23
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,16 @@
197197
#endif
198198

199199
#ifdef __CUDA_ARCH__
200-
#include "caffe/util/gpu_math_functions.cuh"
200+
#include "caffe/util/half.cuh"
201+
#include "caffe/util/gpu_math_functions.cuh"
202+
#endif
203+
204+
#if !defined(CPU_ONLY) && defined(__CUDA_ARCH__)
205+
#define CAFFE_UTIL_HD __host__ __device__
206+
#define CAFFE_UTIL_IHD __inline__ __host__ __device__
207+
#else
208+
#define CAFFE_UTIL_HD
209+
#define CAFFE_UTIL_IHD inline
201210
#endif
202211

203212
/// Default rounding mode.
@@ -956,29 +965,24 @@ namespace half_float
956965
friend struct std::hash<half>;
957966
#endif
958967

959-
public:
968+
public:
960969
/// Default constructor.
961970
/// This initializes the half to 0. Although this does not match the builtin types' default-initialization semantics
962971
/// and may be less efficient than no initialization, it is needed to provide proper value-initialization semantics.
963972
HALF_CONSTEXPR
964973
CAFFE_UTIL_HD
965974
half() : data_() {}
966975

976+
template<typename H>
967977
CAFFE_UTIL_HD
968-
__half geth() const {
969-
__half h;
970-
h.x = data_;
971-
return h;
978+
const H* gethp() const {
979+
return reinterpret_cast<const H*>(&data_);
972980
}
973981

982+
template<typename H>
974983
CAFFE_UTIL_HD
975-
const __half* gethp() const {
976-
return reinterpret_cast<const __half*>(this);
977-
}
978-
979-
CAFFE_UTIL_HD
980-
__half* gethp() {
981-
return reinterpret_cast<__half*>(this);
984+
H* gethp() {
985+
return reinterpret_cast<H*>(&data_);
982986
}
983987

984988
CAFFE_UTIL_HD
@@ -995,18 +999,13 @@ namespace half_float
995999
/// Copy constructor.
9961000
/// \tparam T type of concrete half expression
9971001
/// \param rhs half expression to copy from
998-
// half(detail::expr rhs) : data_(detail::float2half<round_style>(rhs)) {}
999-
10001002
CAFFE_UTIL_HD
10011003
half(detail::expr rhs) {
10021004
assign(rhs);
10031005
}
10041006

10051007
/// Conversion constructor.
10061008
/// \param rhs float to convert
1007-
// template<typename T>
1008-
// half(const T& rhs) : data_(detail::float2half<round_style>((float)rhs)) {}
1009-
10101009
template<typename T>
10111010
CAFFE_UTIL_HD
10121011
half(const T& rhs) {
@@ -1030,8 +1029,8 @@ namespace half_float
10301029
// operator float() const { return detail::half2float(data_); }
10311030
CAFFE_UTIL_HD operator float() const {
10321031
#ifdef __CUDA_ARCH__
1033-
__half h;
1034-
h.x = data_;
1032+
::half h;
1033+
h.setx(data_);
10351034
return __half2float(h);
10361035
#else
10371036
return detail::half2float(data_);
@@ -1040,7 +1039,7 @@ namespace half_float
10401039

10411040
CAFFE_UTIL_HD void assign(float rhs) {
10421041
#ifdef __CUDA_ARCH__
1043-
data_ = float2half_clip(rhs).x;
1042+
data_ = float2half_clip(rhs).x();
10441043
#else
10451044
data_ = detail::float2half<round_style>(rhs);
10461045
#endif
@@ -1117,9 +1116,9 @@ namespace half_float
11171116
float after = static_cast<float>(*this);
11181117
if (before == after && before != 0.f && rhs != 0.f) {
11191118
#ifdef __CUDA_ARCH__
1120-
CUPRINTF("GPU PRECISION LOSS: %g -= %g\n", before, rhs);
1119+
printf("GPU PRECISION LOSS: %g -= %g\n", before, rhs);
11211120
#else
1122-
CUPRINTF("CPU PRECISION LOSS: %g -= %g\n", before, rhs);
1121+
printf("CPU PRECISION LOSS: %g -= %g\n", before, rhs);
11231122
#endif
11241123
}
11251124
#else

CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ set(CAFFE_TARGET_VERSION "0.16.1")
1414
set(CAFFE_TARGET_SOVERSION "0.16")
1515
add_definitions(-DCAFFE_VERSION=${CAFFE_TARGET_VERSION})
1616

17+
# Skip `typedef __half half;`
18+
add_definitions(-DCUDA_NO_HALF=1)
19+
1720
# ---[ Using cmake scripts and modules
1821
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
1922

Makefile

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ DYNAMIC_SONAME_SHORT := $(DYNAMIC_NAME_SHORT).$(DYNAMIC_VERSION_MAJOR).$(DYNAMIC
4242
DYNAMIC_VERSIONED_NAME_SHORT := $(DYNAMIC_SONAME_SHORT).$(DYNAMIC_VERSION_REVISION)
4343
DYNAMIC_NAME := $(LIB_BUILD_DIR)/$(DYNAMIC_VERSIONED_NAME_SHORT)
4444
COMMON_FLAGS += -DCAFFE_VERSION=$(DYNAMIC_VERSION_MAJOR).$(DYNAMIC_VERSION_MINOR).$(DYNAMIC_VERSION_REVISION)
45-
# FP16 Caffe requires C++ 11
45+
# NVCaffe requires C++ 11
4646
COMMON_FLAGS += -std=c++11
47+
COMMON_FLAGS += -DCUDA_NO_HALF
4748

4849
##############################
4950
# Get all source files

cmake/Cuda.cmake

+3-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ endfunction()
5656
# caffe_select_nvcc_arch_flags(out_variable)
5757
function(caffe_select_nvcc_arch_flags out_variable)
5858
# List of arch names
59-
set(__archs_names "Fermi" "Kepler" "Maxwell" "All" "Manual")
59+
set(__archs_names "Fermi" "Kepler" "Maxwell" "Pascal" "Volta" "All" "Manual")
6060
set(__archs_name_default "All")
6161
if(NOT CMAKE_CROSSCOMPILING)
6262
list(APPEND __archs_names "Auto")
@@ -91,6 +91,8 @@ function(caffe_select_nvcc_arch_flags out_variable)
9191
set(__cuda_arch_bin "50")
9292
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
9393
set(__cuda_arch_bin "60 61 62")
94+
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
95+
set(__cuda_arch_bin "70")
9496
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
9597
set(__cuda_arch_bin ${Caffe_known_gpu_archs})
9698
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")

cmake/Modules/FindNVML.cmake

+2-6
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,8 @@
99
# NVML_LIBRARY
1010

1111
file (GLOB MLPATH /usr/lib/nvidia-???)
12-
13-
find_path(NVML_INCLUDE_DIR NAMES nvml.h
14-
PATHS ${CUDA_INCLUDE_DIRS} ${NVML_ROOT_DIR}/include
15-
)
16-
17-
find_library(NVML_LIBRARY nvidia-ml PATHS ${MLPATH} ${NVML_ROOT_DIR}/lib ${NVML_ROOT_DIR}/lib64)
12+
find_path(NVML_INCLUDE_DIR NAMES nvml.h PATHS ${CUDA_INCLUDE_DIRS} ${NVML_ROOT_DIR}/include)
13+
find_library(NVML_LIBRARY nvidia-ml PATHS ${MLPATH} /usr/local/cuda/lib64/stubs ${NVML_ROOT_DIR}/lib ${NVML_ROOT_DIR}/lib64)
1814

1915
include(FindPackageHandleStandardArgs)
2016
find_package_handle_standard_args(NVML DEFAULT_MSG NVML_INCLUDE_DIR NVML_LIBRARY)

cmake/lint.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
set(CMAKE_SOURCE_DIR ..)
33
set(LINT_COMMAND ${CMAKE_SOURCE_DIR}/scripts/cpp_lint.py)
4-
set(SRC_FILE_EXTENSIONS h hpp hu c cpp cu cc)
4+
set(SRC_FILE_EXTENSIONS h hpp hu c cpp cu cc cuh)
55
set(EXCLUDE_FILE_EXTENSTIONS pb.h pb.cc)
66
set(LINT_DIRS include src/caffe examples tools python matlab)
77

common_plot.py

+44
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@
99
def get_test_accuracy(log, top_k):
1010
iteration = re.findall(r'Iteration (\d*), Testing net \(#0\)', log)
1111
accuracy = re.findall(r'Test net output #\d: accuracy/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
12+
if len(accuracy)==0:
13+
accuracy = re.findall(r'Test net output #\d: top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
1214
if len(accuracy)==0:
1315
accuracy = re.findall(r'Test net output #\d: loss/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
16+
if len(accuracy)==0:
17+
accuracy = re.findall(r'Test net output #\d: accuracy/top{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
18+
if len(accuracy)==0:
19+
accuracy = re.findall(r'Test net output #\d: accuracy = (\d*.\d*)', log)
1420
iteration = [int(i) for i in iteration]
1521
accuracy = [float(i) for i in accuracy]
1622
return iteration, accuracy
@@ -25,6 +31,13 @@ def get_test_loss(log):
2531
loss = [float(i) for i in loss]
2632
return iteration, loss
2733

34+
def get_train_loss(log):
35+
iteration = re.findall(r'Iteration (\d*), lr = ', log)
36+
loss = re.findall(r'Train net output #\d: loss = (\d*.\d*)', log)
37+
iteration = [int(i) for i in iteration]
38+
loss = [float(i) for i in loss]
39+
return iteration, loss
40+
2841

2942
def get_net_name(log):
3043
return re.findall(r"Solving (.*)\n", log)[0]
@@ -44,13 +57,22 @@ def parse_files(files, top_k=1, separate=False):
4457
data[net_name]["loss"] = {}
4558
data[net_name]["loss"]["loss"] = []
4659
data[net_name]["loss"]["iteration"] = []
60+
data[net_name]["train_loss"] = {}
61+
data[net_name]["train_loss"]["loss"] = []
62+
data[net_name]["train_loss"]["iteration"] = []
63+
4764
iteration, accuracy = get_test_accuracy(log, top_k)
4865
data[net_name]["accuracy"]["iteration"].extend(iteration)
4966
data[net_name]["accuracy"]["accuracy"].extend(accuracy)
5067

5168
iteration, loss = get_test_loss(log)
5269
data[net_name]["loss"]["iteration"].extend(iteration)
5370
data[net_name]["loss"]["loss"].extend(loss)
71+
72+
iteration, loss = get_train_loss(log)
73+
data[net_name]["train_loss"]["iteration"].extend(iteration)
74+
data[net_name]["train_loss"]["loss"].extend(loss)
75+
5476
return data
5577

5678

@@ -172,3 +194,25 @@ def plot_loss(data, value_at_hover=False):
172194
plt.xlim(0)
173195
plt.grid()
174196
return plt
197+
198+
def plot_train_loss(data, value_at_hover=False):
199+
nets = data.keys()
200+
colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
201+
fig = plt.figure()
202+
ax = fig.add_subplot(111)
203+
for net in nets:
204+
iteration = data[net]["train_loss"]["iteration"]
205+
loss = data[net]["train_loss"]["loss"]
206+
iteration, loss = (list(t) for t in zip(*sorted(zip(iteration, loss))))
207+
ax.scatter(iteration, loss, color=next(colors))
208+
if value_at_hover:
209+
cursor = FollowDotCursor(ax, iteration, loss)
210+
211+
plt.legend(nets, loc='upper right')
212+
plt.title("Log Loss")
213+
plt.xlabel("Iteration")
214+
plt.ylabel("Log Loss")
215+
plt.xlim(0)
216+
plt.grid()
217+
return plt
218+

include/caffe/blob.hpp

+16-71
Original file line numberDiff line numberDiff line change
@@ -36,49 +36,6 @@ class TBlob;
3636
*/
3737
class Blob {
3838
public:
39-
// This proxy makes sure that we can't rely on cached values while pointer
40-
// to data is being used and data potentially might be changed.
41-
// When pointer is actually given, proxy flushes the cache.
42-
// There are use cases where we "preliminary convert data" but don't change it yet.
43-
// In such cases cache is still valid until we really change data.
44-
// For example, this line doesn't change blob's state:
45-
// Blob::PtrProxy<Ftype> top_data = top[i]->mutable_gpu_data<Ftype>();
46-
// The state will be changed at the moment of passing a raw pointer to,
47-
// let say, CuDNN routine.
48-
template<typename Ptype>
49-
class PtrProxy {
50-
public:
51-
PtrProxy() : tensor_(), is_gpu_(false), zero_new_mem_(true) {}
52-
53-
PtrProxy(shared_ptr<Tensor> tensor, bool is_gpu, bool zero_new_mem = true)
54-
: tensor_(tensor), is_gpu_(is_gpu), zero_new_mem_(zero_new_mem) {}
55-
56-
operator Ptype*() {
57-
CHECK(tensor_);
58-
return reinterpret_cast<Ptype*>(tensor_->mutable_memory(tp<Ptype>(), is_gpu_, zero_new_mem_));
59-
}
60-
61-
~PtrProxy() {}
62-
63-
PtrProxy(PtrProxy&& other) : tensor_(std::move(other.tensor_)), is_gpu_(other.is_gpu_),
64-
zero_new_mem_(other.zero_new_mem_) {}
65-
66-
PtrProxy& operator=(PtrProxy&& other) {
67-
tensor_ = std::move(other.tensor_);
68-
is_gpu_ = other.is_gpu_;
69-
zero_new_mem_ = other.zero_new_mem_;
70-
return *this;
71-
}
72-
73-
PtrProxy(const PtrProxy&) = delete;
74-
PtrProxy& operator=(const PtrProxy& other) = delete;
75-
76-
private:
77-
shared_ptr<Tensor> tensor_;
78-
bool is_gpu_;
79-
bool zero_new_mem_;
80-
};
81-
8239
void Swap(Blob& other) noexcept {
8340
std::swap(data_tensor_, other.data_tensor_);
8441
std::swap(diff_tensor_, other.diff_tensor_);
@@ -387,26 +344,15 @@ class Blob {
387344
}
388345

389346
template<typename Dtype>
390-
PtrProxy<Dtype> mutable_cpu_data(bool zero_new_mem = true) {
347+
Dtype* mutable_cpu_data() {
391348
convert_data(tp<Dtype>());
392-
return PtrProxy<Dtype>(data_tensor_, false, zero_new_mem);
349+
return static_cast<Dtype*>(data_tensor_->mutable_synced_mem()->mutable_cpu_data());
393350
}
394351

395352
template<typename Dtype>
396-
PtrProxy<Dtype> mutable_cpu_diff(bool zero_new_mem = true) {
353+
Dtype* mutable_cpu_diff() {
397354
convert_diff(tp<Dtype>());
398-
return PtrProxy<Dtype>(diff_tensor_, false, zero_new_mem);
399-
}
400-
401-
// pycaffe needs these two, do NOT use them anywhere else
402-
template<typename Dtype>
403-
Dtype* mutable_cpu_data_raw() {
404-
return (Dtype*) Blob::mutable_cpu_data<Dtype>();
405-
}
406-
407-
template<typename Dtype>
408-
Dtype* mutable_cpu_diff_raw() {
409-
return (Dtype*) Blob::mutable_cpu_diff<Dtype>();
355+
return static_cast<Dtype*>(diff_tensor_->mutable_synced_mem()->mutable_cpu_data());
410356
}
411357

412358
// Element-wise accessor. Might be slow due to syncing from GPU to CPU.
@@ -572,15 +518,15 @@ class Blob {
572518
}
573519

574520
template<typename Dtype>
575-
PtrProxy<Dtype> mutable_gpu_data(bool zero_new_mem = true) {
521+
Dtype* mutable_gpu_data() {
576522
convert_data(tp<Dtype>());
577-
return PtrProxy<Dtype>(data_tensor_, true, zero_new_mem);
523+
return static_cast<Dtype*>(data_tensor_->mutable_synced_mem()->mutable_gpu_data());
578524
}
579525

580526
template<typename Dtype>
581-
PtrProxy<Dtype> mutable_gpu_diff(bool zero_new_mem = true) {
527+
Dtype* mutable_gpu_diff() {
582528
convert_diff(tp<Dtype>());
583-
return PtrProxy<Dtype>(diff_tensor_, true, zero_new_mem);
529+
return static_cast<Dtype*>(diff_tensor_->mutable_synced_mem()->mutable_gpu_data());
584530
}
585531

586532
void async_gpu_push() {
@@ -701,19 +647,18 @@ class TBlob : public Blob {
701647
}
702648

703649
template<typename T = Dtype>
704-
PtrProxy <T> mutable_cpu_data(bool zero_new_mem = true) {
650+
T* mutable_cpu_data() {
705651
check_integrity(true, data_type(), tp<T>());
706-
return Blob::mutable_cpu_data<T>(zero_new_mem);
652+
return Blob::mutable_cpu_data<T>();
707653
}
708654

709655
template<typename T = Dtype>
710-
PtrProxy <T> mutable_cpu_diff(bool zero_new_mem = true) {
656+
T* mutable_cpu_diff() {
711657
check_integrity(false, diff_type(), tp<T>());
712-
return Blob::mutable_cpu_diff<T>(zero_new_mem);
658+
return Blob::mutable_cpu_diff<T>();
713659
}
714660

715661
#ifndef CPU_ONLY
716-
717662
template<typename T = Dtype>
718663
const T* gpu_data() const {
719664
check_integrity(true, data_type(), tp<T>());
@@ -727,15 +672,15 @@ class TBlob : public Blob {
727672
}
728673

729674
template<typename T = Dtype>
730-
PtrProxy <T> mutable_gpu_data(bool zero_new_mem = true) {
675+
T* mutable_gpu_data() {
731676
check_integrity(true, data_type(), tp<T>());
732-
return Blob::mutable_gpu_data<T>(zero_new_mem);
677+
return Blob::mutable_gpu_data<T>();
733678
}
734679

735680
template<typename T = Dtype>
736-
PtrProxy <T> mutable_gpu_diff(bool zero_new_mem = true) {
681+
T* mutable_gpu_diff() {
737682
check_integrity(false, diff_type(), tp<T>());
738-
return Blob::mutable_gpu_diff<T>(zero_new_mem);
683+
return Blob::mutable_gpu_diff<T>();
739684
}
740685
#endif
741686

0 commit comments

Comments
 (0)