-
Notifications
You must be signed in to change notification settings - Fork 1
/
hash.go
114 lines (99 loc) · 2.93 KB
/
hash.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
// Copyright 2022, NLP Odyssey Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gomaddness
import "math"
// Hash is the data structure for MADDNESS hash function.
// It holds the learned balanced binary regression tree and the prototype
// vectors.
type Hash[F Float] struct {
TreeLevels []*HashingTreeLevel[F]
Prototypes Vectors[F]
}
// HashingTreeLevel is one level of the binary tree from a Hash.
type HashingTreeLevel[F Float] struct {
SplitIndex int
SplitThresholds Vector[F]
}
// TrainHash runs the learning process for MADDNESS hash function parameters,
// and return a new trained Hash.
func TrainHash[F Float](examples Vectors[F]) *Hash[F] {
buckets := Buckets[F]{
&Bucket[F]{
Level: -1,
NodeIndex: 0,
Vectors: examples,
},
}
levels := make([]*HashingTreeLevel[F], 4)
for i := range levels {
buckets, levels[i] = nextHashingTreeLevel(buckets)
}
return &Hash[F]{
TreeLevels: levels,
Prototypes: buckets.Prototypes(),
}
}
// Hash maps the given vector to an index, applying MADDNESS hash function.
func (h *Hash[F]) Hash(v Vector[F]) uint8 {
var i uint8 = 1
for _, level := range h.TreeLevels {
threshold := level.SplitThresholds[i-1]
i = 2 * i
if v[level.SplitIndex] < threshold {
i--
continue
}
}
return i - 1
}
func nextHashingTreeLevel[F Float](buckets Buckets[F]) (Buckets[F], *HashingTreeLevel[F]) {
indices := buckets.HeuristicSelectIndices()
bestLoss := F(math.Inf(+1))
bestSplitIndex := -1
var bestSplitThresholds Vector[F]
for _, splitIndex := range indices {
var loss F
splitThresholds := make(Vector[F], len(buckets))
for j, bucket := range buckets {
t, l := bucket.Vectors.OptimalSplitThreshold(splitIndex)
splitThresholds[j] = t
loss += l
}
if loss < bestLoss {
bestLoss = loss
bestSplitIndex = splitIndex
bestSplitThresholds = splitThresholds
}
}
newBuckets := make(Buckets[F], 0, len(buckets)*2)
for j, bucket := range buckets {
lt, gte := bucket.Vectors.SplitByThreshold(bestSplitIndex, bestSplitThresholds[j])
// TODO: check corner cases when lt or gte are empty
if len(lt) == 0 {
v := gte.Copy().SortByColumn(bestSplitIndex)[0].Copy()
v[bestSplitIndex] = F(math.Nextafter32(float32(v[bestSplitIndex]), float32(math.Inf(-1))))
lt = Vectors[F]{v}
}
if len(gte) == 0 {
v := lt.Copy().SortByColumn(bestSplitIndex)[len(lt)-1].Copy()
v[bestSplitIndex] = F(math.Nextafter32(float32(v[bestSplitIndex]), float32(math.Inf(+1))))
gte = Vectors[F]{v}
}
newBuckets = append(newBuckets, &Bucket[F]{
Level: bucket.Level + 1,
NodeIndex: j * 2,
Vectors: lt,
})
newBuckets = append(newBuckets, &Bucket[F]{
Level: bucket.Level + 1,
NodeIndex: j*2 + 1,
Vectors: gte,
})
}
nextLevel := &HashingTreeLevel[F]{
SplitIndex: bestSplitIndex,
SplitThresholds: bestSplitThresholds,
}
return newBuckets, nextLevel
}