|
| 1 | +/** |
| 2 | + * Copyright (c) 2017-present, Facebook, Inc. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#ifndef GLOW_EXPORTER_ONNXMODELWRITER_H |
| 18 | +#define GLOW_EXPORTER_ONNXMODELWRITER_H |
| 19 | + |
| 20 | +#include "glow/Exporter/CommonOperatorWriter.h" |
| 21 | +#include "glow/Graph/Graph.h" |
| 22 | + |
| 23 | +#include "onnx/onnx_pb.h" |
| 24 | + |
| 25 | +#include "llvm/ADT/ArrayRef.h" |
| 26 | +#include "llvm/ADT/StringRef.h" |
| 27 | + |
| 28 | +#include <string> |
| 29 | + |
| 30 | +/// ONNX traits for protobuf types. |
| 31 | +struct ONNX_TRAITS { |
| 32 | + using GraphProto = ONNX_NAMESPACE::GraphProto; |
| 33 | +}; |
| 34 | + |
| 35 | +namespace glow { |
| 36 | + |
| 37 | +/// Unique set of visited nodes. |
| 38 | +using ReportedNodes = std::unordered_set<const Node *>; |
| 39 | + |
| 40 | +/// Writes ONNX models. |
| 41 | +class ONNXModelWriter : public CommonOperatorWriter<ONNX_TRAITS> { |
| 42 | + // Declare shorter aliases. |
| 43 | + using GraphType = typename ONNX_TRAITS::GraphProto; |
| 44 | + using NodeType = ONNX_NAMESPACE::NodeProto; |
| 45 | + using TensorType = ONNX_NAMESPACE::TensorProto; |
| 46 | + using AttrType = ONNX_NAMESPACE::AttributeProto; |
| 47 | + using ValueInfoType = ONNX_NAMESPACE::ValueInfoProto; |
| 48 | + |
| 49 | + /// Current version of ONNX standard. |
| 50 | + size_t opsetVersion_; |
| 51 | + /// Keeps the track of already visited or processed nodes. |
| 52 | + ReportedNodes reportedNodes_; |
| 53 | + /// Converts \p glowType to \p protoType. |
| 54 | + static typename TensorType::DataType convertType(const Type &glowType); |
| 55 | + /// Writes Glow tensor \p T to proto output \p out. |
| 56 | + static void writeTensor(const Tensor &T, TensorType *out); |
| 57 | + /// Writes tensor shape from placeholder \p PH into protpbuf \p valueProto. |
| 58 | + static void tensorShapeFromPlaceholder(const Placeholder *PH, |
| 59 | + ValueInfoType *valueProto); |
| 60 | + /// Writes all inputs and outputs with operator name \p opName from give Node |
| 61 | + /// \p node into protobuf \p proto. |
| 62 | + static llvm::Error writeAllWithNode(const std::string &opName, |
| 63 | + const Node *node, NodeType *proto); |
| 64 | + /// Writes all inputs and outputs with operator name \p opName from give Node |
| 65 | + /// \p node into created node protobuf using \p graph. |
| 66 | + static llvm::Error writeAll(const std::string &opName, const Node *node, |
| 67 | + GraphType &graph); |
| 68 | + // Finds if uses of \p node have node with the provided \p kind. |
| 69 | + static bool hasUsesOfKind(const Node *node, Kinded::Kind kind); |
| 70 | + |
| 71 | +public: |
| 72 | + /// Creates an ONNX model writer to serialize \p F graph into file |
| 73 | + /// \p modelFilename, writing \p irVersion and \p opsetVersion. |
| 74 | + /// If \p errPtr is not null then if an error occurs it will get assigned |
| 75 | + /// there otherwise if an error occurs it will abort. |
| 76 | + ONNXModelWriter(const std::string &modelFilename, Function &F, |
| 77 | + size_t irVersion, size_t opsetVersion, |
| 78 | + llvm::Error *errPtr = nullptr, bool textMode = false); |
| 79 | + |
| 80 | +private: |
| 81 | + /// \returns error for the unexpected node kind. |
| 82 | + static llvm::Error writeUnexpectedKind(const Node *node) { |
| 83 | + RETURN_ERR(strFormat("Glow can not export node %s, unsupported kind: %d.", |
| 84 | + node->getName().str().c_str(), node->getKind())); |
| 85 | + } |
| 86 | + |
| 87 | + /// Declares the overriden all pure virtual methods, declared in base class. |
| 88 | +#define DEF_NODE(CLASS, NAME) \ |
| 89 | + llvm::Error write##NAME(const CLASS *, GraphType &) override; |
| 90 | +#include "glow/AutoGenNodes.def" |
| 91 | +}; |
| 92 | + |
| 93 | +} // namespace glow |
| 94 | + |
| 95 | +#endif // GLOW_EXPORTER_ONNXMODELWRITER_H |
0 commit comments