forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_custom_class.cpp
140 lines (116 loc) · 4.3 KB
/
test_custom_class.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <gtest/gtest.h>
#include <test/cpp/jit/test_custom_class_registrations.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/custom_class.h>
#include <torch/script.h>
#include <iostream>
#include <string>
#include <vector>
namespace torch {
namespace jit {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(CustomClassTest, TorchbindIValueAPI) {
script::Module m("m");
// test make_custom_class API
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
std::vector<std::string>{"foo", "bar"});
m.define(R"(
def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString):
return s.pop(), s
)");
auto test_with_obj = [&m](IValue obj, std::string expected) {
auto res = m.run_method("forward", obj);
auto tup = res.toTuple();
AT_ASSERT(tup->elements().size() == 2);
auto str = tup->elements()[0].toStringRef();
auto other_obj =
tup->elements()[1].toCustomClass<MyStackClass<std::string>>();
AT_ASSERT(str == expected);
auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>();
AT_ASSERT(other_obj.get() == ref_obj.get());
};
test_with_obj(custom_class_obj, "bar");
// test IValue() API
auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>(
std::vector<std::string>{"baz", "boo"});
auto new_stack_ivalue = c10::IValue(my_new_stack);
test_with_obj(new_stack_ivalue, "boo");
}
class TorchBindTestClass : public torch::jit::CustomClassHolder {
public:
std::string get() {
return "Hello, I am your test custom class";
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr char class_doc_string[] = R"(
I am docstring for TorchBindTestClass
Args:
What is an argument? Oh never mind, I don't take any.
Return:
How would I know? I am just a holder of some meaningless test methods.
)";
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr char method_doc_string[] =
"I am docstring for TorchBindTestClass get_with_docstring method";
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static auto reg =
torch::class_<TorchBindTestClass>(
"_TorchBindTest",
"_TorchBindTestClass",
class_doc_string)
.def("get", &TorchBindTestClass::get)
.def("get_with_docstring", &TorchBindTestClass::get, method_doc_string);
} // namespace
// Tests DocString is properly propagated when defining CustomClasses.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(CustomClassTest, TestDocString) {
auto class_type = getCustomClass(
"__torch__.torch.classes._TorchBindTest._TorchBindTestClass");
AT_ASSERT(class_type);
AT_ASSERT(class_type->doc_string() == class_doc_string);
AT_ASSERT(class_type->getMethod("get").doc_string().empty());
AT_ASSERT(
class_type->getMethod("get_with_docstring").doc_string() ==
method_doc_string);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(CustomClassTest, Serialization) {
script::Module m("m");
// test make_custom_class API
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
std::vector<std::string>{"foo", "bar"});
m.register_attribute(
"s",
custom_class_obj.type(),
custom_class_obj,
// NOLINTNEXTLINE(bugprone-argument-comment)
/*is_parameter=*/false);
m.define(R"(
def forward(self):
return self.s.return_a_tuple()
)");
auto test_with_obj = [](script::Module& mod) {
auto res = mod.run_method("forward");
auto tup = res.toTuple();
AT_ASSERT(tup->elements().size() == 2);
auto i = tup->elements()[1].toInt();
AT_ASSERT(i == 123);
};
auto frozen_m = torch::jit::freeze_module(m.clone());
test_with_obj(m);
test_with_obj(frozen_m);
std::ostringstream oss;
m.save(oss);
std::istringstream iss(oss.str());
caffe2::serialize::IStreamAdapter adapter{&iss};
auto loaded_module = torch::jit::load(iss, torch::kCPU);
std::ostringstream oss_frozen;
frozen_m.save(oss_frozen);
std::istringstream iss_frozen(oss_frozen.str());
caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen};
auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU);
}
} // namespace jit
} // namespace torch