forked from indy256/codelibrary
-
Notifications
You must be signed in to change notification settings - Fork 0
/
wavelet_tree.cpp
86 lines (76 loc) · 2.43 KB
/
wavelet_tree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <bits/stdc++.h>
using namespace std;
struct wavelet_tree {
int lo, hi;
wavelet_tree *l = nullptr;
wavelet_tree *r = nullptr;
vector<int> b;
static wavelet_tree create(int *from, int *to) {
int min = *min_element(from, to);
int max = *max_element(from, to);
return wavelet_tree(from, to, min, max);
}
// nos are in range [x,y]
// array indices are [from, to)
wavelet_tree(int *from, int *to, int x, int y) {
lo = x;
hi = y;
if (lo == hi || from >= to)
return;
int mid = (lo + hi) / 2;
auto f = [mid](int x) { return x <= mid; };
b.reserve(to - from + 1);
b.push_back(0);
for (auto it = from; it != to; it++) {
b.push_back(b.back() + f(*it));
}
// see how lambda function is used here
auto pivot = stable_partition(from, to, f);
l = new wavelet_tree(from, pivot, lo, mid);
r = new wavelet_tree(pivot, to, mid + 1, hi);
}
// kth smallest element in [l, r]
int kth(int l, int r, int k) {
if (l > r)
return 0;
if (lo == hi)
return lo;
int inLeft = b[r] - b[l - 1];
int lb = b[l - 1]; // amt of nos in first (l-1) nos that go in left
int rb = b[r]; // amt of nos in first (r) nos that go in left
return k <= inLeft ? this->l->kth(lb + 1, rb, k) : this->r->kth(l - lb, r - rb, k - inLeft);
}
// count of nos in [l, r] Less than or equal to k
int LTE(int l, int r, int k) {
if (l > r || k < lo)
return 0;
if (hi <= k)
return r - l + 1;
int lb = b[l - 1];
int rb = b[r];
return this->l->LTE(lb + 1, rb, k) + this->r->LTE(l - lb, r - rb, k);
}
// count of nos in [l, r] equal to k
int count(int l, int r, int k) {
if (l > r || k < lo || k > hi)
return 0;
if (lo == hi)
return r - l + 1;
int lb = b[l - 1];
int rb = b[r];
int mid = (lo + hi) / 2;
return k <= mid ? this->l->count(lb + 1, rb, k) : this->r->count(l - lb, r - rb, k);
}
~wavelet_tree() {
delete l;
delete r;
}
};
// usage example
int main() {
int a[] = {3, 1, 4, 2};
wavelet_tree wtree = wavelet_tree::create(a, a + 4);
cout << wtree.kth(1, 3, 2) << endl;
cout << wtree.LTE(1, 3, 2) << endl;
cout << wtree.count(1, 4, 2) << endl;
}