forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_utils.h
52 lines (39 loc) · 1.54 KB
/
test_utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#pragma once
#include <torch/csrc/jit/testing/file_check.h>
#include "test/cpp/jit/test_base.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/runtime/autodiff.h"
#include "torch/csrc/jit/runtime/interpreter.h"
namespace torch {
namespace jit {
using tensor_list = std::vector<at::Tensor>;
using namespace torch::autograd;
// work around the fact that variable_tensor_list doesn't duplicate all
// of std::vector's constructors.
// most constructors are never used in the implementation, just in our tests.
Stack createStack(std::vector<at::Tensor>&& list);
void assertAllClose(const tensor_list& a, const tensor_list& b);
std::vector<at::Tensor> run(
InterpreterState& interp,
const std::vector<at::Tensor>& inputs);
std::pair<tensor_list, tensor_list> runGradient(
Gradient& grad_spec,
tensor_list& tensors_in,
tensor_list& tensor_grads_in);
std::shared_ptr<Graph> build_lstm();
at::Tensor t_use(at::Tensor x);
at::Tensor t_def(at::Tensor x);
// given the difference of output vs expected tensor, check whether the
// difference is within a relative tolerance range. This is a standard way of
// matching tensor values up to certain precision
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
bool almostEqual(const at::Tensor& a, const at::Tensor& b);
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
std::pair<at::Tensor, at::Tensor> lstm(
at::Tensor input,
at::Tensor hx,
at::Tensor cx,
at::Tensor w_ih,
at::Tensor w_hh);
} // namespace jit
} // namespace torch