Skip to content

Commit 4279b42

Browse files
jackm321facebook-github-bot
authored andcommitted
Create a union type GlowIValue for loading values from PyTorch (pytorch#3479)
Summary: * Add a type which maps closely to torch::jit::IValue but holds Glow tensors. This is a departure from the previous approach which is to just cast everything as Glow Constants. * Separate values loaded from PyTorch into two categories, Tensors for which only shapes are known at load time, and all other values (ints, lists, etc) for which values can be reasoned about at load time. For each of these categories, the information known at load-time will eventually be guarded so that, for example if a tensor's shape changes or any other values's contents change, the graph will not be valid and reloading will happen * This also adds most of the necessary ingredients to make loading quantized graphs more simple Documentation: doxygen Will add some information to pytorch.md about how this works once it's settled Pull Request resolved: pytorch#3479 Test Plan: CI Differential Revision: D17184105 Pulled By: jackm321 fbshipit-source-id: fad7fbd67e904d5f4ce40347fc7dfd377764acf8
1 parent 3b005bb commit 4279b42

File tree

15 files changed

+1435
-821
lines changed

15 files changed

+1435
-821
lines changed

include/glow/Graph/Nodes.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,20 @@ class Constant : public Storage {
9191
return k->getKind() == Kinded::Kind::ConstantKind;
9292
}
9393

94-
/// \returns a mutable reference to the payload tensor. If the payload tensor
95-
/// is unowned then it will be converted to an owned copy before returning.
96-
Tensor &getPayloadMutable() {
97-
// If payload is unowned, make an owned copy of the payload for
98-
// modification.
94+
/// If payload is unowned, make an owned copy of the payload for
95+
/// modification.
96+
void ensureIsOwned() {
9997
if (payload_.isUnowned()) {
10098
payload_ = payload_.clone();
10199
}
100+
}
101+
102+
/// \returns a mutable reference to the payload tensor. If the payload tensor
103+
/// is unowned then it will be converted to an owned copy before returning.
104+
Tensor &getPayloadMutable() {
105+
/// Make sure the payload is owned before handing out a mutable reference.
106+
ensureIsOwned();
107+
102108
assert(!payload_.isUnowned() &&
103109
"Can only modify Constants with owned payloads");
104110
return payload_;

torch_glow/setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ test=pytest
33

44
[tool:pytest]
55
testpaths = tests
6-
addopts = --verbose
6+
addopts = --verbose

torch_glow/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ link_directories(${PYTORCH_DIR}/lib)
1414
add_library(PyTorchModelLoader
1515
CachingGraphRunner.cpp
1616
GlowFuser.cpp
17+
GlowIValue.cpp
1718
PyTorchCommon.cpp
1819
FuseLinear.cpp
1920
PyTorchModelLoader.cpp)

torch_glow/src/CachingGraphRunner.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ llvm::Error CachingGraphRunner::runImpl(const PerGlowGraphInfo &info,
104104
for (auto size : ph->dims()) {
105105
sizes.push_back(static_cast<int64_t>(size));
106106
}
107-
auto ptT = at::empty(
108-
sizes, at::TensorOptions().dtype(
109-
PyTorchModelLoader::convertGlowType(ph->getType())));
107+
108+
auto ptT = glowTypeToEmptyPTTensor(*ph->getType());
109+
110110
glow::Tensor t(ptT.data_ptr(), ph->getType());
111111

112112
outputs.push_back(std::move(ptT));

torch_glow/src/GlowIValue.cpp

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "GlowIValue.h"
18+
19+
#include <ATen/core/ivalue.h>
20+
21+
#include "PyTorchCommon.h"
22+
23+
#include "glow/Support/Error.h"
24+
#include "glow/Support/Support.h"
25+
26+
namespace glow {
27+
28+
// static
29+
const char *GlowIValue::tagToStr(GlowIValue::Tag tag) {
30+
switch (tag) {
31+
case GlowIValue::Tag::None:
32+
return "None";
33+
case GlowIValue::Tag::Tensor:
34+
return "Tensor";
35+
case GlowIValue::Tag::Double:
36+
return "Double";
37+
case GlowIValue::Tag::Int:
38+
return "Int";
39+
case GlowIValue::Tag::Bool:
40+
return "Bool";
41+
case GlowIValue::Tag::IntList:
42+
return "IntList";
43+
case GlowIValue::Tag::DoubleList:
44+
return "DoubleList";
45+
case GlowIValue::Tag::BoolList:
46+
return "BoolList";
47+
case GlowIValue::Tag::Tuple:
48+
return "Tuple";
49+
}
50+
}
51+
52+
void GlowIValue::reset() {
53+
switch (tag_) {
54+
case Tag::Tensor:
55+
delete payload_.asTensor;
56+
break;
57+
case Tag::IntList:
58+
delete payload_.asIntList;
59+
break;
60+
case Tag::DoubleList:
61+
delete payload_.asDoubleList;
62+
break;
63+
case Tag::BoolList:
64+
delete payload_.asBoolList;
65+
break;
66+
case Tag::Tuple:
67+
delete payload_.asTuple;
68+
break;
69+
case Tag::None:
70+
case Tag::Double:
71+
case Tag::Int:
72+
case Tag::Bool:
73+
// Nothing to free.
74+
break;
75+
}
76+
tag_ = Tag::None;
77+
}
78+
79+
GlowIValue::~GlowIValue() { reset(); }
80+
81+
GlowIValue::GlowIValue(GlowIValue &&other) {
82+
std::swap(tag_, other.tag_);
83+
std::swap(payload_, other.payload_);
84+
}
85+
86+
GlowIValue &GlowIValue::operator=(GlowIValue &&other) {
87+
reset();
88+
std::swap(tag_, other.tag_);
89+
std::swap(payload_, other.payload_);
90+
return *this;
91+
}
92+
93+
GlowIValue::Tag GlowIValue::getTag() const { return tag_; }
94+
95+
const char *GlowIValue::getTagString() const { return tagToStr(tag_); }
96+
97+
bool GlowIValue::isNone() const { return Tag::None == tag_; }
98+
bool GlowIValue::isTensor() const { return Tag::Tensor == tag_; }
99+
bool GlowIValue::isDouble() const { return Tag::Double == tag_; }
100+
bool GlowIValue::isInt() const { return Tag::Int == tag_; }
101+
bool GlowIValue::isBool() const { return Tag::Bool == tag_; }
102+
bool GlowIValue::isIntList() const { return Tag::IntList == tag_; }
103+
bool GlowIValue::isDoubleList() const { return Tag::DoubleList == tag_; }
104+
bool GlowIValue::isBoolList() const { return Tag::BoolList == tag_; }
105+
bool GlowIValue::isTuple() const { return Tag::Tuple == tag_; }
106+
107+
#define ExpectTag(EXPECTED_TAG) \
108+
RETURN_ERR_IF_NOT(tag_ == (EXPECTED_TAG), \
109+
strFormat("Expected GlowIValue with tag %s but found %s", \
110+
tagToStr((EXPECTED_TAG)), tagToStr(tag_)))
111+
112+
llvm::Expected<Tensor *> GlowIValue::toTensor() {
113+
ExpectTag(Tag::Tensor);
114+
return payload_.asTensor;
115+
}
116+
117+
llvm::Expected<const Tensor *> GlowIValue::toTensor() const {
118+
ExpectTag(Tag::Tensor);
119+
return payload_.asTensor;
120+
}
121+
122+
llvm::Expected<double> GlowIValue::toDouble() const {
123+
ExpectTag(Tag::Double);
124+
return payload_.asDouble;
125+
}
126+
127+
llvm::Expected<int64_t> GlowIValue::toInt() const {
128+
ExpectTag(Tag::Int);
129+
return payload_.asInt;
130+
}
131+
132+
llvm::Expected<bool> GlowIValue::toBool() const {
133+
ExpectTag(Tag::Bool);
134+
return payload_.asBool;
135+
}
136+
137+
llvm::Expected<std::vector<int64_t> *> GlowIValue::toIntList() {
138+
ExpectTag(Tag::IntList);
139+
return payload_.asIntList;
140+
}
141+
142+
llvm::Expected<const std::vector<int64_t> *> GlowIValue::toIntList() const {
143+
ExpectTag(Tag::IntList);
144+
return payload_.asIntList;
145+
}
146+
147+
llvm::Expected<std::vector<double> *> GlowIValue::toDoubleList() {
148+
ExpectTag(Tag::DoubleList);
149+
return payload_.asDoubleList;
150+
}
151+
152+
llvm::Expected<const std::vector<double> *> GlowIValue::toDoubleList() const {
153+
ExpectTag(Tag::DoubleList);
154+
return payload_.asDoubleList;
155+
}
156+
157+
llvm::Expected<std::vector<bool> *> GlowIValue::toBoolList() {
158+
ExpectTag(Tag::BoolList);
159+
return payload_.asBoolList;
160+
}
161+
162+
llvm::Expected<const std::vector<bool> *> GlowIValue::toBoolList() const {
163+
ExpectTag(Tag::BoolList);
164+
return payload_.asBoolList;
165+
}
166+
167+
llvm::Expected<std::vector<GlowIValue> *> GlowIValue::toTuple() {
168+
ExpectTag(Tag::Tuple);
169+
return payload_.asTuple;
170+
}
171+
172+
llvm::Expected<const std::vector<GlowIValue> *> GlowIValue::toTuple() const {
173+
ExpectTag(Tag::Tuple);
174+
return payload_.asTuple;
175+
}
176+
177+
#undef ExpectTag
178+
179+
void GlowIValue::fromNone() {
180+
reset();
181+
tag_ = Tag::None;
182+
}
183+
184+
void GlowIValue::fromTensor(Tensor tensor) {
185+
reset();
186+
tag_ = Tag::Tensor;
187+
payload_.asTensor = new glow::Tensor(std::move(tensor));
188+
}
189+
190+
void GlowIValue::fromDouble(double d) {
191+
reset();
192+
tag_ = Tag::Double;
193+
payload_.asDouble = d;
194+
}
195+
196+
void GlowIValue::fromInt(int64_t i) {
197+
reset();
198+
tag_ = Tag::Int;
199+
payload_.asInt = i;
200+
}
201+
202+
void GlowIValue::fromBool(bool b) {
203+
reset();
204+
tag_ = Tag::Bool;
205+
payload_.asBool = b;
206+
}
207+
208+
void GlowIValue::fromIntList(std::vector<int64_t> intList) {
209+
reset();
210+
tag_ = Tag::IntList;
211+
payload_.asIntList = new std::vector<int64_t>;
212+
std::swap(intList, *payload_.asIntList);
213+
}
214+
215+
void GlowIValue::fromDoubleList(std::vector<double> doubleList) {
216+
reset();
217+
tag_ = Tag::DoubleList;
218+
payload_.asDoubleList = new std::vector<double>;
219+
std::swap(doubleList, *payload_.asDoubleList);
220+
}
221+
222+
void GlowIValue::fromBoolList(std::vector<bool> boolList) {
223+
reset();
224+
tag_ = Tag::BoolList;
225+
payload_.asBoolList = new std::vector<bool>;
226+
std::swap(boolList, *payload_.asBoolList);
227+
}
228+
229+
void GlowIValue::fromTuple(std::vector<GlowIValue> glowIValList) {
230+
reset();
231+
tag_ = Tag::Tuple;
232+
payload_.asTuple = new std::vector<GlowIValue>;
233+
std::swap(glowIValList, *payload_.asTuple);
234+
}
235+
236+
llvm::Error GlowIValue::fromIValue(const at::IValue &ival) {
237+
reset();
238+
if (ival.isNone()) {
239+
fromNone();
240+
} else if (ival.isTensor()) {
241+
glow::Tensor t = ptTensorToGlowTensor(ival.toTensor());
242+
fromTensor(std::move(t));
243+
} else if (ival.isDouble()) {
244+
fromDouble(ival.toDouble());
245+
} else if (ival.isInt()) {
246+
fromInt(ival.toInt());
247+
} else if (ival.isBool()) {
248+
fromBool(ival.toBool());
249+
} else if (ival.isDoubleList()) {
250+
const auto ivalDoubles = ival.toDoubleList();
251+
std::vector<double> doubles(ivalDoubles.begin(), ivalDoubles.end());
252+
fromDoubleList(std::move(doubles));
253+
} else if (ival.isIntList()) {
254+
const auto ivalInts = ival.toIntList();
255+
std::vector<int64_t> ints(ivalInts.begin(), ivalInts.end());
256+
fromIntList(std::move(ints));
257+
} else if (ival.isBoolList()) {
258+
const auto ivalBools = ival.toBoolList();
259+
std::vector<bool> bools(ivalBools.begin(), ivalBools.end());
260+
fromBoolList(std::move(bools));
261+
} else if (ival.isTuple()) {
262+
const auto ivalTuple = ival.toTuple();
263+
const auto &elems = ivalTuple->elements();
264+
std::vector<GlowIValue> tuple;
265+
for (const auto &elem : elems) {
266+
GlowIValue glowIVal;
267+
RETURN_IF_ERR(glowIVal.fromIValue(elem));
268+
tuple.push_back(std::move(glowIVal));
269+
}
270+
fromTuple(std::move(tuple));
271+
} else {
272+
RETURN_ERR("Encountered unhandled IValue type");
273+
}
274+
return llvm::Error::success();
275+
}
276+
277+
} // namespace glow

0 commit comments

Comments
 (0)