Skip to content

Commit 87989a6

Browse files
simpkinsfacebook-github-bot
authored andcommitted
[caffe2] support serializing float data as bfloat16 (pytorch#53735)
Summary: Pull Request resolved: pytorch#53735 Add an option to BlobSerializationOptions to request that float data be serialized as bfloat16. This reduces the serialized data size at the expense of some loss in precision. ghstack-source-id: 124317910 Test Plan: Included a new unit test. Reviewed By: mraway Differential Revision: D26658205 fbshipit-source-id: 74521ed161059066355a3f208488ed01a344dbb5
1 parent b032316 commit 87989a6

File tree

4 files changed

+209
-6
lines changed

4 files changed

+209
-6
lines changed

caffe2/core/blob_serialization.cc

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
#include "caffe2/core/blob.h"
1111
#include "caffe2/core/common.h"
1212
#include "caffe2/utils/proto_utils.h"
13+
#ifdef USE_FBGEMM
14+
#include "fbgemm/FbgemmConvert.h"
15+
#endif
1316

1417
C10_DEFINE_int(
1518
caffe2_tensor_chunk_size,
@@ -388,7 +391,75 @@ void SerializeTensorData(const SerializeParams<at::Half>& params) {
388391
params.tensor_proto);
389392
}
390393

394+
#ifdef USE_FBGEMM
395+
namespace {
396+
// Unfortunately we can't include folly/lang/Bits.h here,
397+
// so provide our own byte-swapping code.
398+
fbgemm::bfloat16 ByteSwap(fbgemm::bfloat16 n) {
399+
#ifdef _MSC_VER
400+
return _byteswap_ushort(n);
401+
#else
402+
return __builtin_bswap16(n);
403+
#endif
404+
}
405+
406+
void ByteSwapArray(
407+
const fbgemm::bfloat16* src,
408+
fbgemm::bfloat16* dest,
409+
size_t num_elements) {
410+
// Note that we support src and dest pointing to the same location.
411+
// We currently only use this function on big-endian machines, so it isn't
412+
// worth trying to build a fancier SIMD version.
413+
for (size_t n = 0; n < num_elements; ++n) {
414+
dest[n] = ByteSwap(src[n]);
415+
}
416+
}
417+
} // namespace
418+
#endif // USE_FBGEMM
419+
391420
void SerializeTensorData(const SerializeParams<float>& params) {
421+
// The FLOAT_BFLOAT16 option requests doing a conversion to bfloat16. This
422+
// reduces the serialized data size at the cost of some lost precision.
423+
// We currently only support doing this when compiled with fbgemm.
424+
#ifdef USE_FBGEMM
425+
if (params.options.float_format() ==
426+
BlobSerializationOptions_FloatFormat_FLOAT_BFLOAT16) {
427+
std::unique_ptr<float[]> tmp_buffer;
428+
const float* src;
429+
if (params.context.device() == CPU) {
430+
src = params.input.data();
431+
} else {
432+
tmp_buffer.reset(new float[params.input.size()]);
433+
params.context.CopyToCPU(
434+
params.input.size(), params.input.data(), tmp_buffer.get());
435+
}
436+
437+
params.SetDataFormat(TensorProto_SerializationFormat_FMT_BFLOAT16);
438+
// TODO: it would be nice if we could use
439+
// folly::resizeWithoutInitialization() here
440+
params.tensor_proto.mutable_raw_data()->resize(
441+
params.input.size() * sizeof(fbgemm::bfloat16));
442+
443+
Range<fbgemm::bfloat16*> dest(
444+
reinterpret_cast<fbgemm::bfloat16*>(
445+
&(*params.tensor_proto.mutable_raw_data())[0]),
446+
params.input.size());
447+
448+
fbgemm::FloatToBfloat16_simd(src, dest.data(), params.input.size());
449+
450+
// Note: technically a platform can have different integer from floating
451+
// point endianness, and we ideally should check floating point endianness
452+
// here. However, the fbgemm code doesn't appear to make this distinction,
453+
// and at least in the Bfloat16ToFloat_ref() code it appears to assume that
454+
// floating point and integer endianness are the same.
455+
if (!kIsLittleEndian) {
456+
ByteSwapArray(dest.data(), dest.data(), dest.size());
457+
}
458+
return;
459+
}
460+
#endif
461+
462+
params.SetDataFormat(TensorProto_SerializationFormat_FMT_PROTOBUF);
392463
params.CopyToRepeatedField(params.tensor_proto.mutable_float_data());
393464
}
394465

@@ -792,6 +863,48 @@ DESERIALIZE_IMPL(float, FMT_PROTOBUF) {
792863
params.CopyFromRepeatedField(params.tensor_proto.float_data());
793864
}
794865

866+
DESERIALIZE_IMPL(float, FMT_BFLOAT16) {
867+
#ifdef USE_FBGEMM
868+
CAFFE_ENFORCE_EQ(
869+
params.dest.size() * sizeof(fbgemm::bfloat16),
870+
params.tensor_proto.raw_data().size(),
871+
"incorrect data size in serialized bfloat16 data");
872+
auto raw_src = reinterpret_cast<const fbgemm::bfloat16*>(
873+
params.tensor_proto.raw_data().data());
874+
875+
// If we are on a big-endian machine, byte-swap the serialized data.
876+
const fbgemm::bfloat16* src;
877+
std::unique_ptr<fbgemm::bfloat16[]> bswap_buffer;
878+
if (kIsLittleEndian) {
879+
src = raw_src;
880+
} else {
881+
bswap_buffer.reset(new fbgemm::bfloat16[params.dest.size()]);
882+
ByteSwapArray(raw_src, bswap_buffer.get(), params.dest.size());
883+
src = bswap_buffer.get();
884+
}
885+
886+
// If we are on a non-CPU device, we need an intermediate CPU buffer for the
887+
// bfloat16 to float conversion.
888+
std::unique_ptr<float[]> tmp_buffer;
889+
float* dest;
890+
if (params.context.device() == CPU) {
891+
dest = params.dest.data();
892+
} else {
893+
tmp_buffer.reset(new float[params.dest.size()]);
894+
dest = tmp_buffer.get();
895+
}
896+
897+
fbgemm::Bfloat16ToFloat_simd(src, dest, params.dest.size());
898+
if (params.context.device() != CPU) {
899+
params.context.CopyFromCPU(params.dest.size(), dest, params.dest.data());
900+
}
901+
#else
902+
// We cannot load serialized bfloat16 data without fbgemm.
903+
CAFFE_ENFORCE(
904+
false, "cannot perform bfloat16 to float conversion without fbgemm");
905+
#endif
906+
}
907+
795908
DESERIALIZE_IMPL(double, FMT_PROTOBUF) {
796909
params.CopyFromRepeatedField(params.tensor_proto.double_data());
797910
}
@@ -825,6 +938,7 @@ void DeserializeTensorBody(
825938
DeserializeParams<T> params(dest, tensor_proto, context);
826939
switch (format) {
827940
DESERIALIZE_FORMAT_CASE(FMT_PROTOBUF);
941+
DESERIALIZE_FORMAT_CASE(FMT_BFLOAT16);
828942
}
829943

830944
// This can happen if the blob was serialized by a newer version of the code

caffe2/proto/caffe2.proto

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ message TensorProto {
4949
// the protobuf typed fields, although in some cases raw little endian data
5050
// is stored in the byte_data field instead.
5151
FMT_PROTOBUF = 0;
52+
// bfloat16 data stored in the raw_data field.
53+
FMT_BFLOAT16 = 1;
5254
}
5355
// data_format is a SerializationFormat enum value.
5456
// However, we intentionally store it as an integer value so we can
@@ -504,6 +506,19 @@ message BlobSerializationOptions {
504506
// - a chunk size of -1 means to disable chunking, and serialize the blob in
505507
// a single chunk.
506508
optional int64 chunk_size = 2;
509+
510+
enum FloatFormat {
511+
// Use the current default serialization format, as chosen by the
512+
// current version of the code. (At the time of writing this is PROTOBUF)
513+
FLOAT_DEFAULT = 0;
514+
// Store the data in the TensorProto's float_data field
515+
FLOAT_PROTOBUF = 1;
516+
// Serialize float values as bfloat16. Note that this conversion is lossy.
517+
FLOAT_BFLOAT16 = 2;
518+
}
519+
520+
// Settings for how to serialize tensors containing float values
521+
optional FloatFormat float_format = 3;
507522
}
508523

509524
message SerializationOptions {

caffe2/proto/caffe2_pb2.pyi

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,11 @@ class TensorProto(google.protobuf.message.Message):
8080
class _SerializationFormat(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[SerializationFormat.V], builtins.type):
8181
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ...
8282
FMT_PROTOBUF = TensorProto.SerializationFormat.V(0)
83+
FMT_BFLOAT16 = TensorProto.SerializationFormat.V(1)
8384
class SerializationFormat(metaclass=_SerializationFormat):
8485
V = typing.NewType('V', builtins.int)
8586
FMT_PROTOBUF = TensorProto.SerializationFormat.V(0)
87+
FMT_BFLOAT16 = TensorProto.SerializationFormat.V(1)
8688

8789
class Segment(google.protobuf.message.Message):
8890
DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
@@ -708,18 +710,32 @@ global___DBReaderProto = DBReaderProto
708710

709711
class BlobSerializationOptions(google.protobuf.message.Message):
710712
DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
713+
class _FloatFormat(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[FloatFormat.V], builtins.type):
714+
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ...
715+
FLOAT_DEFAULT = BlobSerializationOptions.FloatFormat.V(0)
716+
FLOAT_PROTOBUF = BlobSerializationOptions.FloatFormat.V(1)
717+
FLOAT_BFLOAT16 = BlobSerializationOptions.FloatFormat.V(2)
718+
class FloatFormat(metaclass=_FloatFormat):
719+
V = typing.NewType('V', builtins.int)
720+
FLOAT_DEFAULT = BlobSerializationOptions.FloatFormat.V(0)
721+
FLOAT_PROTOBUF = BlobSerializationOptions.FloatFormat.V(1)
722+
FLOAT_BFLOAT16 = BlobSerializationOptions.FloatFormat.V(2)
723+
711724
BLOB_NAME_REGEX_FIELD_NUMBER: builtins.int
712725
CHUNK_SIZE_FIELD_NUMBER: builtins.int
726+
FLOAT_FORMAT_FIELD_NUMBER: builtins.int
713727
blob_name_regex: typing.Text = ...
714728
chunk_size: builtins.int = ...
729+
float_format: global___BlobSerializationOptions.FloatFormat.V = ...
715730

716731
def __init__(self,
717732
*,
718733
blob_name_regex : typing.Optional[typing.Text] = ...,
719734
chunk_size : typing.Optional[builtins.int] = ...,
735+
float_format : typing.Optional[global___BlobSerializationOptions.FloatFormat.V] = ...,
720736
) -> None: ...
721-
def HasField(self, field_name: typing_extensions.Literal[u"blob_name_regex",b"blob_name_regex",u"chunk_size",b"chunk_size"]) -> builtins.bool: ...
722-
def ClearField(self, field_name: typing_extensions.Literal[u"blob_name_regex",b"blob_name_regex",u"chunk_size",b"chunk_size"]) -> None: ...
737+
def HasField(self, field_name: typing_extensions.Literal[u"blob_name_regex",b"blob_name_regex",u"chunk_size",b"chunk_size",u"float_format",b"float_format"]) -> builtins.bool: ...
738+
def ClearField(self, field_name: typing_extensions.Literal[u"blob_name_regex",b"blob_name_regex",u"chunk_size",b"chunk_size",u"float_format",b"float_format"]) -> None: ...
723739
global___BlobSerializationOptions = BlobSerializationOptions
724740

725741
class SerializationOptions(google.protobuf.message.Message):

caffe2/python/operator_test/load_save_test.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,9 @@ def float_array(dtype: Type[np.floating], size: int) -> np.ndarray:
461461

462462
return blobs
463463

464-
def load_and_check_blobs(
464+
def load_blobs(
465465
self,
466-
blobs: List[Tuple[str, np.ndarray]],
466+
blob_names: List[str],
467467
dbs: List[str],
468468
db_type: Optional[str] = None
469469
) -> None:
@@ -472,13 +472,21 @@ def load_and_check_blobs(
472472
load_op = core.CreateOperator(
473473
"Load",
474474
[],
475-
[name for name, data in blobs],
475+
blob_names,
476476
absolute_path=1,
477477
dbs=dbs,
478478
db_type=db_type or self._db_type,
479479
)
480480
self.assertTrue(workspace.RunOperatorOnce(load_op))
481-
self.assertEqual(len(workspace.Blobs()), len(blobs))
481+
self.assertEqual(len(workspace.Blobs()), len(blob_names))
482+
483+
def load_and_check_blobs(
484+
self,
485+
blobs: List[Tuple[str, np.ndarray]],
486+
dbs: List[str],
487+
db_type: Optional[str] = None
488+
) -> None:
489+
self.load_blobs([name for name, data in blobs], dbs, db_type)
482490
for name, data in blobs:
483491
np.testing.assert_array_equal(workspace.FetchBlob(name), data)
484492

@@ -636,5 +644,55 @@ def testSaveWithOptions(self) -> None:
636644
)
637645

638646

647+
def testSaveFloatToBfloat16(self) -> None:
648+
tmp_folder = self.make_tempdir()
649+
tmp_file = str(tmp_folder / "save.output")
650+
651+
# Create 2 blobs with the same float data
652+
float_data = np.random.random_sample(4000).astype(np.float32)
653+
workspace.FeedBlob("float1", float_data)
654+
workspace.FeedBlob("float2", float_data)
655+
blob_names = ["float1", "float2"]
656+
657+
# Serialize the data, using bfloat16 serialization for one of the blobs
658+
save_op = core.CreateOperator(
659+
"Save",
660+
blob_names,
661+
[],
662+
absolute_path=1,
663+
db=tmp_file,
664+
db_type=self._db_type,
665+
options=caffe2_pb2.SerializationOptions(
666+
options=[
667+
BlobSerializationOptions(
668+
blob_name_regex="float1",
669+
float_format=BlobSerializationOptions.FLOAT_BFLOAT16,
670+
),
671+
],
672+
),
673+
)
674+
self.assertTrue(workspace.RunOperatorOnce(save_op))
675+
676+
# As long as fbgemm was available for us to perform bfloat16 conversion,
677+
# the serialized data for float1 should be almost half the size of float2
678+
if workspace.has_fbgemm:
679+
blob_chunks = self._read_chunk_info(Path(tmp_file))
680+
self.assertEqual(len(blob_chunks["float1"]), 1, blob_chunks["float1"])
681+
self.assertEqual(len(blob_chunks["float2"]), 1, blob_chunks["float2"])
682+
self.assertLess(
683+
blob_chunks["float1"][0].value_size,
684+
0.6 * blob_chunks["float2"][0].value_size
685+
)
686+
687+
self.load_blobs(blob_names, [tmp_file])
688+
689+
# float2 should be exactly the same as the input data
690+
np.testing.assert_array_equal(workspace.FetchBlob("float2"), float_data)
691+
# float2 should be close-ish to the input data
692+
np.testing.assert_array_almost_equal(
693+
workspace.FetchBlob("float1"), float_data, decimal=2
694+
)
695+
696+
639697
if __name__ == '__main__':
640698
unittest.main()

0 commit comments

Comments
 (0)