-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathspike.py
232 lines (207 loc) · 7.05 KB
/
spike.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from __future__ import (
print_function,
)
from contextlib import (
contextmanager,
)
from time import (
time,
)
from sys import (
argv,
stderr,
)
import attr
from challenge_bypass_ristretto import (
random_signing_key,
RandomToken,
PublicKey,
BlindedToken,
BatchDLEQProof,
SignedToken,
TokenPreimage,
VerificationSignature,
)
def debug(*a, **kw):
print(*a, file=stderr, **kw)
@attr.s
class Client(object):
signing_public_key = attr.ib()
def request(self, count):
debug("generating tokens")
clients_tokens = list(RandomToken.create() for _ in range(count))
debug("blinding tokens")
clients_blinded_tokens = list(
token.blind()
for token
in clients_tokens
)
debug("marshaling blinded tokens")
marshaled_blinded_tokens = list(
blinded_token.encode_base64()
for blinded_token
in clients_blinded_tokens
)
return (
TokenRequest(self, clients_tokens, clients_blinded_tokens),
marshaled_blinded_tokens,
)
@attr.s
class TokenRequest(object):
client = attr.ib()
tokens = attr.ib()
blinded_tokens = attr.ib()
def redeem(self, message, marshaled_signed_tokens, marshaled_proof):
debug("decoding signed tokens")
clients_signed_tokens = list(
SignedToken.decode_base64(marshaled_signed_token)
for marshaled_signed_token
in marshaled_signed_tokens
)
debug("decoding batch dleq proof")
clients_proof = BatchDLEQProof.decode_base64(marshaled_proof)
debug("validating batch dleq proof and unblinding tokens")
clients_unblinded_tokens = clients_proof.invalid_or_unblind(
self.tokens,
self.blinded_tokens,
clients_signed_tokens,
self.client.signing_public_key,
)
debug("getting token preimages")
clients_preimages = list(
token.preimage()
for token
in clients_unblinded_tokens
)
debug("deriving verification keys")
clients_verification_keys = list(
token.derive_verification_key_sha512()
for token
in clients_unblinded_tokens
)
# "Passes" are tuples of token preimages and verification signatures.
debug("signing message with keys")
clients_passes = zip(
clients_preimages, (
verification_key.sign_sha512(message)
for verification_key
in clients_verification_keys
),
)
debug("encoding passes")
marshaled_passes = list(
(
token_preimage.encode_base64(),
sig.encode_base64()
)
for (token_preimage, sig)
in clients_passes
)
return marshaled_passes
@attr.s
class Server(object):
signing_key = attr.ib()
def issue(self, marshaled_blinded_tokens):
debug("unmarshaling blinded tokens")
servers_blinded_tokens = list(
BlindedToken.decode_base64(marshaled_blinded_token)
for marshaled_blinded_token
in marshaled_blinded_tokens
)
debug("signing blinded tokens")
servers_signed_tokens = list(
self.signing_key.sign(blinded_token)
for blinded_token
in servers_blinded_tokens
)
debug("encoded signed tokens")
marshaled_signed_tokens = list(
signed_token.encode_base64()
for signed_token
in servers_signed_tokens
)
debug("generating batch dleq proof")
servers_proof = BatchDLEQProof.create(
self.signing_key,
servers_blinded_tokens,
servers_signed_tokens,
)
try:
debug("marshaling batch dleq proof")
marshaled_proof = servers_proof.encode_base64()
finally:
debug("releasing batch dleq proof")
servers_proof.destroy()
return marshaled_signed_tokens, marshaled_proof
def verify(self, message, marshaled_passes):
debug("decoding passes")
servers_passes = list(
(
TokenPreimage.decode_base64(token_preimage),
VerificationSignature.decode_base64(sig),
)
for (token_preimage, sig)
in marshaled_passes
)
debug("re-deriving unblinded tokens")
servers_unblinded_tokens = list(
self.signing_key.rederive_unblinded_token(token_preimage)
for (token_preimage, sig)
in servers_passes
)
servers_verification_sigs = list(
sig
for (token_preimage, sig)
in servers_passes
)
debug("deriving verification keys")
servers_verification_keys = list(
unblinded_token.derive_verification_key_sha512()
for unblinded_token
in servers_unblinded_tokens
)
debug("validating verification signatures")
invalid_passes = list(
key.invalid_sha512(
sig,
# NOTE: The client and server must agree on a message somehow.
# One approach is to derive the message from RPC parameters
# trivially visible to both client and server (what method are you
# calling, what arguments did you pass, etc).
message,
)
for (key, sig)
in zip(servers_verification_keys, servers_verification_sigs)
)
if any(invalid_passes):
debug("found invalid signature")
raise Exception("One or more passes was invalid")
return "Issued, redeemed, and verified {} tokens.".format(len(servers_passes))
@contextmanager
def timing(label, count):
before = time()
yield
after = time()
print("{},{},{:0.2f}".format(label, count, (after - before) * 1000))
def main(count=b"100"):
# From the protocol, "R". From the PrivacyPass explanation, "request
# binding data".
message = b"allocate_buckets ABCDEFGH"
debug("generating signing key")
server = Server(random_signing_key())
debug("extracting public key")
# NOTE: Client must obtain the server's public key in some manner it
# considers reliable. If server can give a different public key to each
# client then it can completely defeat PrivacyPass privacy properties.
client = Client(PublicKey.from_signing_key(server.signing_key))
print("label,count,milliseconds")
with timing("request", count):
request, marshaled_blinded_tokens = client.request(int(count))
with timing("issue", count):
marshaled_signed_tokens, marshaled_proof = server.issue(marshaled_blinded_tokens)
with timing("redeem", count):
marshaled_passes = request.redeem(message, marshaled_signed_tokens, marshaled_proof)
with timing("verify", count):
result = server.verify(message, marshaled_passes)
print(result)
main(*argv[1:])