-
-
Notifications
You must be signed in to change notification settings - Fork 50
/
api_utils.go
132 lines (123 loc) · 2.53 KB
/
api_utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package tensor
import (
"math"
"math/rand"
"reflect"
"sort"
"github.com/chewxy/math32"
)
// SortIndex: Similar to numpy's argsort.
// Returns indices for sorting a slice in increasing order.
// Input slice remains unchanged.
// SortIndex may not be stable; for stability, use SortIndexStable.
func SortIndex(in interface{}) (out []int) {
return sortIndex(in, sort.Slice)
}
// SortIndexStable: Similar to SortIndex, but stable.
// Returns indices for sorting a slice in increasing order.
// Input slice remains unchanged.
func SortIndexStable(in interface{}) (out []int) {
return sortIndex(in, sort.SliceStable)
}
func sortIndex(in interface{}, sortFunc func(x interface{}, less func(i int, j int) bool)) (out []int) {
switch list := in.(type) {
case []int:
out = make([]int, len(list))
for i := 0; i < len(list); i++ {
out[i] = i
}
sortFunc(out, func(i, j int) bool {
return list[out[i]] < list[out[j]]
})
case []float64:
out = make([]int, len(list))
for i := 0; i < len(list); i++ {
out[i] = i
}
sortFunc(out, func(i, j int) bool {
return list[out[i]] < list[out[j]]
})
case sort.Interface:
out = make([]int, list.Len())
for i := 0; i < list.Len(); i++ {
out[i] = i
}
sortFunc(out, func(i, j int) bool {
return list.Less(out[i], out[j])
})
default:
panic("The slice type is not currently supported.")
}
return
}
// SampleIndex samples a slice or a Tensor.
// TODO: tidy this up.
func SampleIndex(in interface{}) int {
// var l int
switch list := in.(type) {
case []int:
var sum, i int
// l = len(list)
r := rand.Int()
for {
sum += list[i]
if sum > r && i > 0 {
return i
}
i++
}
case []float64:
var sum float64
var i int
// l = len(list)
r := rand.Float64()
for {
sum += list[i]
if sum > r && i > 0 {
return i
}
i++
}
case *Dense:
var i int
switch list.t.Kind() {
case reflect.Float64:
var sum float64
r := rand.Float64()
data := list.Float64s()
// l = len(data)
for {
datum := data[i]
if math.IsNaN(datum) || math.IsInf(datum, 0) {
return i
}
sum += datum
if sum > r && i > 0 {
return i
}
i++
}
case reflect.Float32:
var sum float32
r := rand.Float32()
data := list.Float32s()
// l = len(data)
for {
datum := data[i]
if math32.IsNaN(datum) || math32.IsInf(datum, 0) {
return i
}
sum += datum
if sum > r && i > 0 {
return i
}
i++
}
default:
panic("not yet implemented")
}
default:
panic("Not yet implemented")
}
return -1
}