forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
EmbeddingBag.h
49 lines (43 loc) · 1.29 KB
/
EmbeddingBag.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <ATen/ATen.h>
namespace at {
namespace native {
void check_arguments(
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const int64_t mode,
const c10::optional<Tensor>& per_sample_weights,
bool include_last_offset);
void make_bag_size_out(
Tensor& bag_size_out,
const Tensor& offsets,
const Tensor& indices,
const int64_t mode,
const bool include_last_offset,
const bool requires_grad);
void make_max_indices_out(
Tensor& max_indices_out,
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const Tensor& bag_size,
const int64_t mode,
bool include_last_offset);
void make_offset2bag_out(
Tensor& offset2bag,
Tensor& output,
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const int64_t mode,
const c10::optional<Tensor>& per_sample_weights,
const int64_t padding_idx = -1);
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
Tensor& bag_size, Tensor& max_indices,
const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const int64_t mode = 0,
const c10::optional<Tensor>& per_sample_weights = c10::nullopt,
bool include_last_offset = false,
int64_t padding_idx = -1);
} // native
} // at