forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_save_load.cpp
158 lines (144 loc) · 5.33 KB
/
test_save_load.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>
#include <sstream>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/torch.h>
#include "caffe2/serialize/istream_adapter.h"
namespace torch {
namespace jit {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SerializationTest, ExtraFilesHookPreference) {
// Tests that an extra file written explicitly has precedence over
// extra files written by a hook
// TODO: test for the warning, too
const auto script = R"JIT(
def forward(self):
x = torch.rand(5, 5)
x = x.mm(x)
return x
)JIT";
auto module =
std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
module->define(script);
std::ostringstream oss;
std::unordered_map<std::string, std::string> extra_files;
extra_files["metadata.json"] = "abc";
SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
return {{"metadata.json", "def"}};
});
module->save(oss, extra_files);
SetExportModuleExtraFilesHook(nullptr);
std::istringstream iss(oss.str());
caffe2::serialize::IStreamAdapter adapter{&iss};
std::unordered_map<std::string, std::string> loaded_extra_files;
loaded_extra_files["metadata.json"] = "";
auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files);
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SerializationTest, ExtraFileHooksNoSecret) {
// no secrets
std::stringstream ss;
{
Module m("__torch__.m");
ExtraFilesMap extra;
extra["metadata.json"] = "abc";
m.save(ss, extra);
}
ss.seekg(0);
{
ExtraFilesMap extra;
extra["metadata.json"] = "";
extra["secret.json"] = "";
jit::load(ss, c10::nullopt, extra);
ASSERT_EQ(extra["metadata.json"], "abc");
ASSERT_EQ(extra["secret.json"], "");
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SerializationTest, ExtraFileHooksWithSecret) {
std::stringstream ss;
{
SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
return {{"secret.json", "topsecret"}};
});
Module m("__torch__.m");
ExtraFilesMap extra;
extra["metadata.json"] = "abc";
m.save(ss, extra);
SetExportModuleExtraFilesHook(nullptr);
}
ss.seekg(0);
{
ExtraFilesMap extra;
extra["metadata.json"] = "";
extra["secret.json"] = "";
jit::load(ss, c10::nullopt, extra);
ASSERT_EQ(extra["metadata.json"], "abc");
ASSERT_EQ(extra["secret.json"], "topsecret");
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SerializationTest, TypeTags) {
auto list = c10::List<c10::List<int64_t>>();
list.push_back(c10::List<int64_t>({1, 2, 3}));
list.push_back(c10::List<int64_t>({4, 5, 6}));
auto dict = c10::Dict<std::string, at::Tensor>();
dict.insert("Hello", torch::ones({2, 2}));
auto dict_list = c10::List<c10::Dict<std::string, at::Tensor>>();
for (size_t i = 0; i < 5; i++) {
auto another_dict = c10::Dict<std::string, at::Tensor>();
another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2}));
dict_list.push_back(another_dict);
}
auto tuple = std::tuple<int, std::string>(2, "hi");
struct TestItem {
IValue value;
TypePtr expected_type;
};
std::vector<TestItem> items = {
{list, ListType::create(ListType::create(IntType::get()))},
{2, IntType::get()},
{dict, DictType::create(StringType::get(), TensorType::get())},
{dict_list,
ListType::create(
DictType::create(StringType::get(), TensorType::get()))},
{tuple, TupleType::create({IntType::get(), StringType::get()})}};
// NOLINTNEXTLINE(performance-for-range-copy)
for (auto item : items) {
auto bytes = torch::pickle_save(item.value);
auto loaded = torch::pickle_load(bytes);
ASSERT_TRUE(loaded.type()->isSubtypeOf(item.expected_type));
ASSERT_TRUE(item.expected_type->isSubtypeOf(loaded.type()));
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SerializationTest, TestJitStream_CUDA) {
torch::jit::Module model;
std::vector<torch::jit::IValue> inputs;
// Deserialize the ScriptModule from a file using torch::jit::load().
// Load the scripted model. This should have been generated by tests_setup.py
// Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py
model = torch::jit::load("saved_stream_model.pt");
auto output = model.forward(inputs);
auto list_of_elements = output.toTuple()->elements();
auto is_stream_s = list_of_elements[0].toBool();
// a,b: These are the two input tensors
// c: This is output tensor generated by the operation torch.cat(a,b)
auto a = list_of_elements[1].toTensor();
auto b = list_of_elements[2].toTensor();
auto c = list_of_elements[3].toTensor();
// op: this is used to verify if the cat operation produced the same results
// as that on the GPU with torch.cat
auto op = at::cat({a, b}, 0);
// Check if the stream is set
ASSERT_TRUE(is_stream_s);
// Check if the sizes of the outputs (op and c) is same on the GPU and CPU
ASSERT_EQ(op.sizes(), c.sizes());
// Check if both the output tensors are equal
ASSERT_TRUE(op.equal(c));
}
} // namespace jit
} // namespace torch