forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathIntegration.cpp
74 lines (64 loc) · 2.23 KB
/
Integration.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/DimVector.h>
#include <c10/util/Exception.h>
namespace at {
namespace native {
namespace {
// The estimated integral of a function y of x,
// sampled at points (y_1, ..., y_n) that are separated by distance (dx_1, ..., dx_{n-1}),
// is given by the trapezoid rule:
//
// \sum_{i=1}^{n-1} dx_i * (y_i + y_{i+1}) / 2
//
// TODO: if we extend TensorIterator to accept 3 inputs,
// we can probably make this a bit more performant.
Tensor do_trapz(const Tensor& y, const Tensor& dx, int64_t dim) {
Tensor left = y.slice(dim, 0, -1);
Tensor right = y.slice(dim, 1);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
return ((left + right) * dx).sum(dim) / 2.;
}
// When dx is constant, the above formula simplifies
// to dx * [(\sum_{i=1}^n y_i) - (y_1 + y_n)/2]
Tensor do_trapz(const Tensor& y, double dx, int64_t dim) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * (0.5)) * dx;
}
Tensor zeros_like_except(const Tensor& y, int64_t dim) {
auto sizes = y.sizes().vec();
dim = maybe_wrap_dim(dim, y.dim());
sizes.erase(sizes.begin() + dim);
return at::zeros(sizes, y.options());
}
}
Tensor trapz(const Tensor& y, const Tensor& x, int64_t dim) {
dim = maybe_wrap_dim(dim, y);
// asking for the integral with zero samples is a bit nonsensical,
// but we'll return "0" to match numpy behavior.
if (y.size(dim) == 0) {
return zeros_like_except(y, dim);
}
Tensor x_viewed;
if (x.dim() == 1) {
TORCH_CHECK(x.size(0) == y.size(dim), "trapz: There must be one `x` value for each sample point");
DimVector sizes(y.dim(), 1);
sizes[dim] = x.size(0);
x_viewed = x.view(sizes);
} else {
x_viewed = x;
}
Tensor x_left = x_viewed.slice(dim, 0, -1);
Tensor x_right = x_viewed.slice(dim, 1);
Tensor dx = x_right - x_left;
return do_trapz(y, dx, dim);
}
Tensor trapz(const Tensor& y, double dx, int64_t dim) {
// see above
if (y.size(dim) == 0) {
return zeros_like_except(y, dim);
}
return do_trapz(y, dx, dim);
}
}} // namespace at::native