-
Notifications
You must be signed in to change notification settings - Fork 9
/
histogram.cu
142 lines (120 loc) · 5.67 KB
/
histogram.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
#include <math.h>
#define THREAD_COUNT 1024
__global__ void computeHistogram(float *tensor, float *histogram, float *minv, float *maxv, unsigned int channels, unsigned int tensorSize, unsigned int nBins)
{
unsigned int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < channels * tensorSize)
{
// Compute which channel we're in
unsigned int channel = index / tensorSize;
// Normalize the value in range [0, numBins]
float value = (tensor[index] - minv[channel]) / (maxv[channel] - minv[channel]) * float(nBins);
// Compute bin index
int bin = min((unsigned int)(value), nBins - 1);
// Increment relevant bin
atomicAdd(histogram + (channel * nBins) + bin, 1);
}
}
// return cummulative histogram shifed to the right by 1
// ==> histogram[c][0] alweays == 0
__global__ void accumulateHistogram(float *histogram, unsigned int nBins)
{
float t = 0;
for (unsigned int i=0 ; i < nBins ; ++i)
{
float swap = histogram[i + blockIdx.x * nBins];
histogram[i + blockIdx.x * nBins ] = t;
t += swap;
}
}
__global__ void buildSortedLinkmap(float *tensor, unsigned int *linkMap, float *cumulativeHistogram, unsigned int *localIndexes, long *indirection, float *minv, float *maxv, unsigned int channels, unsigned int tensorSize, unsigned int nBins)
{
unsigned int index = threadIdx.x + blockIdx.x* blockDim.x;
if (index < channels * tensorSize)
{
// Shuffle image -- Avoid the blurry top bug
index = indirection[index];
// Compute which channel we're in
unsigned int channel = index / tensorSize;
// Normalize the value in range [0, numBins]
float value = (tensor[index] - minv[channel]) / (maxv[channel] - minv[channel]) * float(nBins);
// Compute bin index
int binIndex = min((unsigned int)(value), nBins - 1);
// Increment and retrieve the number of pixel in said bin
int localIndex = atomicAdd(&localIndexes[(channel * 256) + binIndex], 1);
// Retrieve the number of pixel in all bin lower (in cummulative histogram)
unsigned int lowerPixelCount = cumulativeHistogram[(channel * 256) + binIndex];
// Set the linkmap for indes to it's position as "pseudo-sorted"
linkMap[index] = lowerPixelCount + localIndex;
}
}
__global__ void rebuild(float *tensor, unsigned int *linkMap, float *targetHistogram, float scale, unsigned int channels, unsigned int tensorSize)
{
unsigned int index = threadIdx.x + blockIdx.x* blockDim.x;
if (index < channels * tensorSize)
{
unsigned int channel = index / tensorSize;
unsigned int value = 0;
for (int i=0 ; i < 256 ; ++i)
if (linkMap[index] >= targetHistogram[(channel * 256) + i] * scale) value = i;
tensor[index] = (float)value;
}
}
at::Tensor computeHistogram(at::Tensor const &t, unsigned int numBins)
{
at::Tensor unsqueezed(t);
unsqueezed = unsqueezed.cuda();
if (unsqueezed.ndimension() == 1)
unsqueezed.unsqueeze_(0);
if (unsqueezed.ndimension() > 2)
unsqueezed = unsqueezed.view({unsqueezed.size(0), -1});
unsigned int c = unsqueezed.size(0); // Number od channels
unsigned int n = unsqueezed.numel() / c; // Number of element per channel
at::Tensor min = torch::min_values(unsqueezed, 1, true).cuda();
at::Tensor max = torch::max_values(unsqueezed, 1, true).cuda();
at::Tensor h = at::zeros({int(c), int(numBins)}, unsqueezed.type()).cuda();
computeHistogram<<<(c*n) / THREAD_COUNT + 1, THREAD_COUNT>>>(unsqueezed.data<float>(),
h.data<float>(),
min.data<float>(),
max.data<float>(),
c, n, numBins);
return h;
}
void matchHistogram(at::Tensor &featureMaps, at::Tensor &targetHistogram)
{
static std::map<unsigned int, at::Tensor> randomIndices;
if (randomIndices[featureMaps.numel()].numel() != featureMaps.numel())
randomIndices[featureMaps.numel()] = torch::randperm(featureMaps.numel(), torch::TensorOptions().dtype(at::kLong)).cuda();
at::Tensor unsqueezed(featureMaps);
if (unsqueezed.ndimension() == 1)
unsqueezed.unsqueeze_(0);
if (unsqueezed.ndimension() > 2)
unsqueezed = unsqueezed.view({unsqueezed.size(0), -1});
unsigned int nBins = targetHistogram.size(1);
unsigned int c = unsqueezed.size(0); // Number of channels
unsigned int n = unsqueezed.numel() / c; // Number of element per channel
// Scale = numberOf Element in features / number of element in target
float scale = float(featureMaps.numel()) / targetHistogram.sum().item<float>();
at::Tensor featuresHistogram = computeHistogram(unsqueezed, nBins);
accumulateHistogram<<<c, 1>>>(featuresHistogram.data<float>(), nBins);
accumulateHistogram<<<c, 1>>>(targetHistogram.data<float>(), nBins);
unsigned int *linkMap = NULL;
cudaMalloc(&linkMap, c * n * sizeof(unsigned int));
unsigned int *localIndexes = NULL;
cudaMalloc(&localIndexes, c * nBins * sizeof(unsigned int));
cudaMemset(localIndexes, 0, c * nBins * sizeof(unsigned int));
at::Tensor min = torch::min_values(unsqueezed, 1, true).cuda();
at::Tensor max = torch::max_values(unsqueezed, 1, true).cuda();
buildSortedLinkmap<<<(c*n) / THREAD_COUNT + 1, THREAD_COUNT>>>(featureMaps.data<float>(), linkMap, featuresHistogram.data<float>(), localIndexes, randomIndices[featureMaps.numel()].data<long>(), min.data<float>(), max.data<float>(), c, n, nBins);
rebuild<<<(c*n) / THREAD_COUNT + 1, THREAD_COUNT>>>(featureMaps.data<float>(), linkMap, targetHistogram.data<float>(), scale, c, n);
featureMaps.div_(float(nBins));
cudaFree(linkMap);
cudaFree(localIndexes);
}