forked from coinbase/kryptology
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelpers.go
181 lines (157 loc) · 5.48 KB
/
helpers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
//
// Copyright Coinbase, Inc. All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//
package bulletproof
import (
"github.com/pkg/errors"
"github.com/coinbase/kryptology/pkg/core/curves"
)
// innerProduct takes two lists of scalars (a, b) and performs the dot product returning a single scalar.
func innerProduct(a, b []curves.Scalar) (curves.Scalar, error) {
if len(a) != len(b) {
return nil, errors.New("length of scalar vectors must be the same")
}
if len(a) < 1 {
return nil, errors.New("length of vectors must be at least one")
}
// Get a new scalar of value zero of the same curve as input arguments
innerProduct := a[0].Zero()
for i, aElem := range a {
bElem := b[i]
// innerProduct = aElem*bElem + innerProduct
innerProduct = aElem.MulAdd(bElem, innerProduct)
}
return innerProduct, nil
}
// splitPointVector takes a vector of points, splits it in half returning each half.
func splitPointVector(points []curves.Point) ([]curves.Point, []curves.Point, error) {
if len(points) < 1 {
return nil, nil, errors.New("length of points must be at least one")
}
if len(points)&0x01 != 0 {
return nil, nil, errors.New("length of points must be even")
}
nPrime := len(points) >> 1
firstHalf := points[:nPrime]
secondHalf := points[nPrime:]
return firstHalf, secondHalf, nil
}
// splitScalarVector takes a vector of scalars, splits it in half returning each half.
func splitScalarVector(scalars []curves.Scalar) ([]curves.Scalar, []curves.Scalar, error) {
if len(scalars) < 1 {
return nil, nil, errors.New("length of scalars must be at least one")
}
if len(scalars)&0x01 != 0 {
return nil, nil, errors.New("length of scalars must be even")
}
nPrime := len(scalars) >> 1
firstHalf := scalars[:nPrime]
secondHalf := scalars[nPrime:]
return firstHalf, secondHalf, nil
}
// multiplyScalarToPointVector takes a single scalar and a list of points, multiplies each point by scalar.
func multiplyScalarToPointVector(x curves.Scalar, g []curves.Point) []curves.Point {
products := make([]curves.Point, len(g))
for i, gElem := range g {
product := gElem.Mul(x)
products[i] = product
}
return products
}
// multiplyScalarToScalarVector takes a single scalar (x) and a list of scalars (a), multiplies each scalar in the vector by the scalar.
func multiplyScalarToScalarVector(x curves.Scalar, a []curves.Scalar) []curves.Scalar {
products := make([]curves.Scalar, len(a))
for i, aElem := range a {
product := aElem.Mul(x)
products[i] = product
}
return products
}
// multiplyPairwisePointVectors takes two lists of points (g, h) and performs a pairwise multiplication returning a list of points.
func multiplyPairwisePointVectors(g, h []curves.Point) ([]curves.Point, error) {
if len(g) != len(h) {
return nil, errors.New("length of point vectors must be the same")
}
product := make([]curves.Point, len(g))
for i, gElem := range g {
product[i] = gElem.Add(h[i])
}
return product, nil
}
// multiplyPairwiseScalarVectors takes two lists of points (a, b) and performs a pairwise multiplication returning a list of scalars.
func multiplyPairwiseScalarVectors(a, b []curves.Scalar) ([]curves.Scalar, error) {
if len(a) != len(b) {
return nil, errors.New("length of point vectors must be the same")
}
product := make([]curves.Scalar, len(a))
for i, aElem := range a {
product[i] = aElem.Mul(b[i])
}
return product, nil
}
// addPairwiseScalarVectors takes two lists of scalars (a, b) and performs a pairwise addition returning a list of scalars.
func addPairwiseScalarVectors(a, b []curves.Scalar) ([]curves.Scalar, error) {
if len(a) != len(b) {
return nil, errors.New("length of scalar vectors must be the same")
}
sum := make([]curves.Scalar, len(a))
for i, aElem := range a {
sum[i] = aElem.Add(b[i])
}
return sum, nil
}
// subtractPairwiseScalarVectors takes two lists of scalars (a, b) and performs a pairwise subtraction returning a list of scalars.
func subtractPairwiseScalarVectors(a, b []curves.Scalar) ([]curves.Scalar, error) {
if len(a) != len(b) {
return nil, errors.New("length of scalar vectors must be the same")
}
diff := make([]curves.Scalar, len(a))
for i, aElem := range a {
diff[i] = aElem.Sub(b[i])
}
return diff, nil
}
// invertScalars takes a list of scalars then returns a list with each element inverted.
func invertScalars(xs []curves.Scalar) ([]curves.Scalar, error) {
xinvs := make([]curves.Scalar, len(xs))
for i, x := range xs {
xinv, err := x.Invert()
if err != nil {
return nil, errors.Wrap(err, "bulletproof helpers invertx")
}
xinvs[i] = xinv
}
return xinvs, nil
}
// isPowerOfTwo returns whether a number i is a power of two or not.
func isPowerOfTwo(i int) bool {
return i&(i-1) == 0
}
// get2nVector returns a scalar vector 2^n such that [1, 2, 4, ... 2^(n-1)]
// See k^n and 2^n definitions on pg 12 of https://eprint.iacr.org/2017/1066.pdf
func get2nVector(length int, curve curves.Curve) []curves.Scalar {
vector2n := make([]curves.Scalar, length)
vector2n[0] = curve.Scalar.One()
for i := 1; i < length; i++ {
vector2n[i] = vector2n[i-1].Double()
}
return vector2n
}
func get1nVector(length int, curve curves.Curve) []curves.Scalar {
vector1n := make([]curves.Scalar, length)
for i := 0; i < length; i++ {
vector1n[i] = curve.Scalar.One()
}
return vector1n
}
func getknVector(k curves.Scalar, length int, curve curves.Curve) []curves.Scalar {
vectorkn := make([]curves.Scalar, length)
vectorkn[0] = curve.Scalar.One()
vectorkn[1] = k
for i := 2; i < length; i++ {
vectorkn[i] = vectorkn[i-1].Mul(k)
}
return vectorkn
}