forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathNamedTensor.h
140 lines (112 loc) · 4.93 KB
/
NamedTensor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#pragma once
#include <ATen/core/Dimname.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/C++17.h>
namespace at {
class TensorBase;
// XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
// Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
// so we have a couple of workarounds.
//
// In the long term, we'll move Dimname to c10 and everything in this file
// can be refactored out. The main blocker for that is that "c10::Symbol"
// actually exists outside of c10 and needs to be moved in.
// TensorImpl has a unique_ptr<NamedTensorMetaInterface> field.
// XXX: Ideally we would just put optional<vector<Dimname>> into TensorImpl.
//
// This class has an important invariant: there must be at least ONE
// non-wildcard
struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
// This enum is to remind people that the invariant on constructors is that
// the list of dimnames must have at least one non-wildcard
enum HAS_NON_WILDCARD {
HasNonWildcard
};
explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names)
: names_(names.vec()) {
check_invariants();
}
explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector<Dimname>&& names)
: names_(std::move(names)) {
check_invariants();
}
std::unique_ptr<c10::NamedTensorMetaInterface> clone() const override {
return std::make_unique<NamedTensorMeta>(HasNonWildcard, names_);
}
DimnameList names() const { return names_; }
// Used for an assertion in TensorImpl.h
int64_t slow_dim() const override {
return names_.size();
}
void check_invariants() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); }));
}
void set_names(HAS_NON_WILDCARD, DimnameList new_names) {
TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
std::copy(new_names.begin(), new_names.end(), names_.begin());
check_invariants();
}
void set_names(HAS_NON_WILDCARD, std::vector<Dimname>&& new_names) {
TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
names_ = std::move(new_names);
check_invariants();
}
// INVARIANT: at least one Dimname is non-WILDCARD
std::vector<Dimname> names_;
};
// When NamesMode is disabled, then all operations ignore tensors' names fields.
// Concretely speaking, all tensors are treated as having nullopt names.
struct TORCH_API NamesMode {
static bool is_enabled();
static void set_enabled(bool enabled);
};
// A RAII, thread local (!) guard that enables or disables names upon
// construction, and sets it back to the original value upon destruction.
struct TORCH_API NoNamesGuard {
NoNamesGuard() : prev_mode(NamesMode::is_enabled()), initialized(true) {
NamesMode::set_enabled(false);
}
~NoNamesGuard() {
if (initialized) {
reset();
}
}
void reset() {
TORCH_INTERNAL_ASSERT(initialized);
NamesMode::set_enabled(prev_mode);
}
private:
bool prev_mode;
bool initialized;
};
void check_names_valid_for(const TensorBase& tensor, DimnameList names);
void check_names_valid_for(size_t tensor_dim, DimnameList names);
// Sets the names of `tensor` to be `names`.
TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, c10::optional<DimnameList> names);
TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names);
constexpr size_t kMaxNamedTensorDim = 64;
DimnameList default_names(size_t len);
namespace impl {
// Some helper functions on TensorImpl. Useful for working with names in TH.
// XXX: Ideally these would exist as methods on TensorImpl
TORCH_API void internal_set_names_inplace(TensorImpl* impl, c10::optional<DimnameList> names, bool validate_names);
TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names);
void check_names_valid_for(TensorImpl* impl, DimnameList names);
// Returns true if the tensor's names exist and are not all 'None'.
// Returns false if the tensor's names don't exist (were not allocated),
// or if all names are 'None'.
// We treat not-allocated-names the same as allocated names that are all 'None'.
TORCH_API bool has_names(const TensorImpl* impl);
// Returns the names of the tensor's dimensions.
// Unnamed tensors are treated as having 'None' in all dimension; this method
// would return a DimnameList of all 'None's for an unnamed tensor.
TORCH_API DimnameList get_names(const TensorImpl* impl);
// This is more of an implementation detail; one should use impl::get_names /
// Tensor::names() whenever possible because it provides a cleaner API.
// Returns the names of the tensor if they have been allocated; returns nullopt
// instead if the haven't been. The names of a tensor are not allocated if a
// tensor is constructed with names=None.
TORCH_API c10::optional<DimnameList> get_opt_names(const TensorImpl* impl);
} // namespace impl
} // namespace at