Skip to content

Commit 8ee678d

Browse files
Yuri Putivskyfacebook-github-bot
authored andcommitted
Glow model onnx exporters (pytorch#3167)
Summary: Glow exporter (ONNX) serializes the Glow Function graph from memory into the file according to ONNX standard (extended). Documentation: Three layers of classes (similar to Glow loaders architecture). Protobuf writer, CommonOperatorWriter, and ONNXModelWriter. [Optional Fixes #issue] Pull Request resolved: pytorch#3167 Test Plan: Please see a detailed explanation of how to fill out the fields in the relevant sections in PULL_REQUEST.md. Differential Revision: D15998903 fbshipit-source-id: 302b3c8884d501e16d138bce9f6edc68bb4cd677
1 parent abc6b82 commit 8ee678d

File tree

10 files changed

+1424
-15
lines changed

10 files changed

+1424
-15
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License.");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef GLOW_EXPORTER_COMMONOPERATORWRITER_H
18+
#define GLOW_EXPORTER_COMMONOPERATORWRITER_H
19+
20+
#include "glow/Exporter/ProtobufWriter.h"
21+
22+
namespace glow {
23+
/// Declares writer methods for all operators. Every writer method serializes
24+
/// Glow node into the provided protobuf.
25+
template <typename Traits> class CommonOperatorWriter : public ProtobufWriter {
26+
protected:
27+
virtual ~CommonOperatorWriter() = default;
28+
29+
/// Declare pure virtual methods, one per each node kind.
30+
/// Derived class must to implement all of it.
31+
#define DEF_NODE(CLASS, NAME) \
32+
virtual llvm::Error write##NAME(const CLASS *node, \
33+
typename Traits::GraphProto &graph) = 0;
34+
#include "glow/AutoGenNodes.def"
35+
36+
/// Function invokes the correspondent virtual method according to \p node
37+
/// type to serialize node information into \p graph (protobuf), reports
38+
/// visited intermediate nodes through \p reporter, \returns llvm::Error.
39+
llvm::Error writeOperator(const Node *node,
40+
typename Traits::GraphProto &graph) {
41+
switch (node->getKind()) {
42+
#define DEF_NODE(CLASS, NAME) \
43+
case glow::Kinded::Kind::CLASS##Kind: \
44+
return write##NAME(llvm::cast<CLASS>(node), graph);
45+
#include "glow/AutoGenNodes.def"
46+
default:
47+
llvm_unreachable(
48+
"Not reachable, values and instructions are not handled here");
49+
return llvm::Error::success();
50+
}
51+
}
52+
53+
using ProtobufWriter::ProtobufWriter;
54+
};
55+
} // namespace glow
56+
57+
#endif // GLOW_EXPORTER_COMMONOPERATORWRITER_H
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef GLOW_EXPORTER_ONNXMODELWRITER_H
18+
#define GLOW_EXPORTER_ONNXMODELWRITER_H
19+
20+
#include "glow/Exporter/CommonOperatorWriter.h"
21+
#include "glow/Graph/Graph.h"
22+
23+
#include "onnx/onnx_pb.h"
24+
25+
#include "llvm/ADT/ArrayRef.h"
26+
#include "llvm/ADT/StringRef.h"
27+
28+
#include <string>
29+
30+
/// ONNX traits for protobuf types.
31+
struct ONNX_TRAITS {
32+
using GraphProto = ONNX_NAMESPACE::GraphProto;
33+
};
34+
35+
namespace glow {
36+
37+
/// Unique set of visited nodes.
38+
using ReportedNodes = std::unordered_set<const Node *>;
39+
40+
/// Writes ONNX models.
41+
class ONNXModelWriter : public CommonOperatorWriter<ONNX_TRAITS> {
42+
// Declare shorter aliases.
43+
using GraphType = typename ONNX_TRAITS::GraphProto;
44+
using NodeType = ONNX_NAMESPACE::NodeProto;
45+
using TensorType = ONNX_NAMESPACE::TensorProto;
46+
using AttrType = ONNX_NAMESPACE::AttributeProto;
47+
using ValueInfoType = ONNX_NAMESPACE::ValueInfoProto;
48+
49+
/// Current version of ONNX standard.
50+
size_t opsetVersion_;
51+
/// Keeps the track of already visited or processed nodes.
52+
ReportedNodes reportedNodes_;
53+
/// Converts \p glowType to \p protoType.
54+
static typename TensorType::DataType convertType(const Type &glowType);
55+
/// Writes Glow tensor \p T to proto output \p out.
56+
static void writeTensor(const Tensor &T, TensorType *out);
57+
/// Writes tensor shape from placeholder \p PH into protpbuf \p valueProto.
58+
static void tensorShapeFromPlaceholder(const Placeholder *PH,
59+
ValueInfoType *valueProto);
60+
/// Writes all inputs and outputs with operator name \p opName from give Node
61+
/// \p node into protobuf \p proto.
62+
static llvm::Error writeAllWithNode(const std::string &opName,
63+
const Node *node, NodeType *proto);
64+
/// Writes all inputs and outputs with operator name \p opName from give Node
65+
/// \p node into created node protobuf using \p graph.
66+
static llvm::Error writeAll(const std::string &opName, const Node *node,
67+
GraphType &graph);
68+
// Finds if uses of \p node have node with the provided \p kind.
69+
static bool hasUsesOfKind(const Node *node, Kinded::Kind kind);
70+
71+
public:
72+
/// Creates an ONNX model writer to serialize \p F graph into file
73+
/// \p modelFilename, writing \p irVersion and \p opsetVersion.
74+
/// If \p errPtr is not null then if an error occurs it will get assigned
75+
/// there otherwise if an error occurs it will abort.
76+
ONNXModelWriter(const std::string &modelFilename, Function &F,
77+
size_t irVersion, size_t opsetVersion,
78+
llvm::Error *errPtr = nullptr, bool textMode = false);
79+
80+
private:
81+
/// \returns error for the unexpected node kind.
82+
static llvm::Error writeUnexpectedKind(const Node *node) {
83+
RETURN_ERR(strFormat("Glow can not export node %s, unsupported kind: %d.",
84+
node->getName().str().c_str(), node->getKind()));
85+
}
86+
87+
/// Declares the overriden all pure virtual methods, declared in base class.
88+
#define DEF_NODE(CLASS, NAME) \
89+
llvm::Error write##NAME(const CLASS *, GraphType &) override;
90+
#include "glow/AutoGenNodes.def"
91+
};
92+
93+
} // namespace glow
94+
95+
#endif // GLOW_EXPORTER_ONNXMODELWRITER_H
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef GLOW_EXPORTER_PROTOBUFWRITER_H
18+
#define GLOW_EXPORTER_PROTOBUFWRITER_H
19+
20+
#include "glow/Graph/Graph.h"
21+
#include "glow/Support/Error.h"
22+
23+
#include <fstream>
24+
#include <google/protobuf/text_format.h>
25+
26+
namespace glow {
27+
/// Writes model: graph and weights.
28+
class ProtobufWriter {
29+
protected:
30+
/// The graph that we are constructing.
31+
Function &G_;
32+
/// Output file stream.
33+
std::ofstream ff_;
34+
35+
llvm::Error writeModel(const ::google::protobuf::Message &modelProto,
36+
bool textMode = false);
37+
38+
public:
39+
/// Constructs new ProtobufWriter object. It will write protopuf messages into
40+
/// \p modelFilename using graph and constants from \p F.
41+
/// If \p errPtr is not null then if an error occurs it will get assigned
42+
/// there otherwise if an error occurs it will abort.
43+
ProtobufWriter(const std::string &modelFilename, Function &F,
44+
llvm::Error *errPtr = nullptr);
45+
};
46+
47+
} // namespace glow
48+
49+
#endif // GLOW_EXPORTER_PROTOBUFWRITER_H

include/glow/Support/Error.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ class GlowErr final : public llvm::ErrorInfo<GlowErr> {
8585
COMPILE_UNSUPPORTED_NODE_AFTER_OPTIMIZE,
8686
// Compilation error; Compilation context not correctly setup.
8787
COMPILE_CONTEXT_MALFORMED,
88+
// Model writer encountered an invalid file name.
89+
MODEL_WRITER_INVALID_FILENAME,
90+
// Model writer cannot serialize graph to the file.
91+
MODEL_WRITER_SERIALIZATION_ERROR,
8892
};
8993

9094
/// GlowErr is not convertable to std::error_code. This is included for
@@ -156,6 +160,10 @@ class GlowErr final : public llvm::ErrorInfo<GlowErr> {
156160
return "COMPILE_UNSUPPORTED_NODE_AFTER_OPTIMIZE";
157161
case ErrorCode::COMPILE_CONTEXT_MALFORMED:
158162
return "COMPILE_CONTEXT_MALFORMED";
163+
case ErrorCode::MODEL_WRITER_INVALID_FILENAME:
164+
return "MODEL_WRITER_INVALID_FILENAME";
165+
case ErrorCode::MODEL_WRITER_SERIALIZATION_ERROR:
166+
return "MODEL_WRITER_SERIALIZATION_ERROR";
159167
};
160168

161169
llvm_unreachable("unsupported ErrorCode");

lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_subdirectory(CodeGen)
77
add_subdirectory(Converter)
88
add_subdirectory(ExecutionContext)
99
add_subdirectory(ExecutionEngine)
10+
add_subdirectory(Exporter)
1011
add_subdirectory(Graph)
1112
add_subdirectory(IR)
1213
add_subdirectory(Importer)

lib/Exporter/CMakeLists.txt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
include_directories(${PROTOBUF_INCLUDE_DIRS})
2+
include_directories(${CMAKE_CURRENT_BINARY_DIR})
3+
4+
add_definitions(-DGOOGLE_PROTOBUF_NO_RTTI)
5+
6+
add_library(Exporter
7+
ProtobufWriter.cpp
8+
ONNXModelWriter.cpp)
9+
target_compile_definitions(Exporter
10+
INTERFACE
11+
-DGOOGLE_PROTOBUF_NO_RTTI)
12+
target_link_libraries(Exporter
13+
PRIVATE
14+
Base
15+
Graph
16+
Importer
17+
LLVMSupport
18+
Support)
19+
target_link_libraries(Exporter PUBLIC onnx_proto ${PROTOBUF_LIBRARY})
20+
21+
if (MSVC AND LINK_PROTOBUF_AS_DLL)
22+
# For protobuf warning when it is build as dll.
23+
# Suppresses a warning that is treated as error.
24+
# Basically one of the header files has interface class
25+
# containing STL string. Which might cause issues
26+
# if things are build with different compilers.
27+
#
28+
# Sets general warning level as 2 for this project.
29+
# There are few warnings that are treated as errors that
30+
# come from VS include headers
31+
target_compile_options(onnx_proto PUBLIC /wd4251)
32+
target_compile_options(onnx_proto PUBLIC /W2)
33+
target_compile_definitions(onnx_proto PUBLIC -DPROTOBUF_USE_DLLS)
34+
endif()

0 commit comments

Comments
 (0)