Skip to content

Commit f269fb8

Browse files
gmagogsfmfacebook-github-bot
authored andcommitted
Add Enum TorchScript serialization and deserialization support (pytorch#42963)
Summary: Pull Request resolved: pytorch#42963 * Adds code printing for enum type * Enhance enum type to include all contained enum names and values * Adds code parsing for enum type in deserialization * Enabled serialization/deserialization test in most TestCases. (With a few dangling issues to be addressed in later PRs to avoid this PR grows too large) Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D23223281 Pulled By: gmagogsfm fbshipit-source-id: 716d1866b7770dfb7bd8515548cfe7dc4c4585f7
1 parent aa53b2d commit f269fb8

File tree

11 files changed

+240
-180
lines changed

11 files changed

+240
-180
lines changed

aten/src/ATen/core/jit_type.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,18 +1131,19 @@ struct CAFFE2_API TupleType : public NamedType {
11311131

11321132
struct EnumType;
11331133
using EnumTypePtr = std::shared_ptr<EnumType>;
1134+
using EnumNameValue = std::pair<std::string, IValue>;
11341135
struct CAFFE2_API EnumType : public NamedType {
11351136
friend struct Type;
11361137
static const TypeKind Kind = TypeKind::EnumType;
11371138

11381139
static EnumTypePtr create(
11391140
const c10::QualifiedName& qualified_class_name,
1140-
TypePtr value, std::weak_ptr<::torch::jit::CompilationUnit> cu) {
1141+
TypePtr value, std::vector<EnumNameValue> enum_names_values, std::weak_ptr<::torch::jit::CompilationUnit> cu) {
11411142
switch (value->kind()) {
11421143
case TypeKind::IntType:
11431144
case TypeKind::FloatType:
11441145
case TypeKind::StringType:
1145-
return EnumTypePtr(new EnumType(qualified_class_name, std::move(value), std::move(cu)));
1146+
return EnumTypePtr(new EnumType(qualified_class_name, std::move(value), std::move(enum_names_values), std::move(cu)));
11461147
default:
11471148
AT_ERROR(
11481149
"Cannot create Enum with value type '",
@@ -1183,17 +1184,32 @@ struct CAFFE2_API EnumType : public NamedType {
11831184
return name().value();
11841185
}
11851186

1187+
at::ArrayRef<TypePtr> containedTypes() const override {
1188+
return value_type_;
1189+
}
1190+
1191+
const at::ArrayRef<EnumNameValue> enumNamesValues() const {
1192+
return enum_names_values_;
1193+
}
1194+
11861195
private:
1187-
EnumType(c10::QualifiedName qualified_class_name, TypePtr value_type, std::weak_ptr<torch::jit::CompilationUnit> cu)
1196+
EnumType(
1197+
c10::QualifiedName qualified_class_name,
1198+
TypePtr value_type,
1199+
std::vector<EnumNameValue> enum_names_values,
1200+
std::weak_ptr<torch::jit::CompilationUnit> cu)
11881201
: NamedType(TypeKind::EnumType, std::move(qualified_class_name)),
1189-
value_type_(std::move(value_type)), cu_(cu) {}
1202+
value_type_(std::move(value_type)),
1203+
enum_names_values_(std::move(enum_names_values)),
1204+
cu_(cu) {}
11901205

11911206
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
11921207
const auto& n = name().value();
11931208
return n.qualifiedName();
11941209
}
11951210

11961211
TypePtr value_type_;
1212+
std::vector<EnumNameValue> enum_names_values_;
11971213
std::weak_ptr<::torch::jit::CompilationUnit> cu_;
11981214
};
11991215

aten/src/ATen/test/ivalue_test.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -261,36 +261,48 @@ TEST(IValueTest, ListNestedEquality) {
261261

262262
TEST(IValueTest, EnumEquality) {
263263
auto cu = std::make_shared<CompilationUnit>();
264-
auto int_enum_type1 = EnumType::create("enum_class_1", IntType::get(), cu);
265-
auto int_enum_type2 = EnumType::create("enum_class_2", IntType::get(), cu);
266-
auto string_enum_type = EnumType::create("enum_class_3", StringType::get(), cu);
264+
IValue int_ivalue_1(1);
265+
IValue int_ivalue_2(2);
266+
IValue str_ivalue_1("1");
267+
auto int_enum_type1 = EnumType::create(
268+
"enum_class_1",
269+
IntType::get(),
270+
{{"enum_name_1", int_ivalue_1}, {"enum_name_2", int_ivalue_2}},
271+
cu);
272+
auto int_enum_type2 = EnumType::create(
273+
"enum_class_2",
274+
IntType::get(),
275+
{{"enum_name_1", int_ivalue_1}, {"enum_name_2", int_ivalue_2}},
276+
cu);
277+
auto string_enum_type = EnumType::create(
278+
"enum_class_3", StringType::get(), {{"enum_name_1", str_ivalue_1}}, cu);
267279

268280
EXPECT_EQ(
269281
IValue(c10::make_intrusive<ivalue::EnumHolder>(
270-
int_enum_type1, "enum_name_1", IValue(1))),
282+
int_enum_type1, "enum_name_1", int_ivalue_1)),
271283
IValue(c10::make_intrusive<ivalue::EnumHolder>(
272-
int_enum_type1, "enum_name_1", IValue(1)))
284+
int_enum_type1, "enum_name_1", int_ivalue_1))
273285
);
274286

275287
EXPECT_NE(
276288
IValue(c10::make_intrusive<ivalue::EnumHolder>(
277-
int_enum_type1, "enum_name_1", IValue(1))),
289+
int_enum_type1, "enum_name_1", int_ivalue_1)),
278290
IValue(c10::make_intrusive<ivalue::EnumHolder>(
279-
int_enum_type2, "enum_name_1", IValue(1)))
291+
int_enum_type2, "enum_name_1", int_ivalue_1))
280292
);
281293

282294
EXPECT_NE(
283295
IValue(c10::make_intrusive<ivalue::EnumHolder>(
284-
int_enum_type1, "enum_name_1", IValue(1))),
296+
int_enum_type1, "enum_name_1", int_ivalue_1)),
285297
IValue(c10::make_intrusive<ivalue::EnumHolder>(
286-
int_enum_type1, "enum_name_2", IValue(1)))
298+
int_enum_type1, "enum_name_2", int_ivalue_2))
287299
);
288300

289301
EXPECT_NE(
290302
IValue(c10::make_intrusive<ivalue::EnumHolder>(
291-
int_enum_type1, "enum_name_1", IValue(1))),
303+
int_enum_type1, "enum_name_1", int_ivalue_1)),
292304
IValue(c10::make_intrusive<ivalue::EnumHolder>(
293-
string_enum_type, "enum_name_1", IValue("1")))
305+
string_enum_type, "enum_name_1", str_ivalue_1))
294306
);
295307
}
296308

0 commit comments

Comments
 (0)