-
Notifications
You must be signed in to change notification settings - Fork 3
/
snippets.py
376 lines (291 loc) · 11 KB
/
snippets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
# Copyright 2024 Seznam.cz, a.s.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
""" Snippets are created to make it easier to work with JWTs.
Snippets help with creating, verifying and signing JWTs for NATS.
"""
import typing as t
from abc import ABC, abstractmethod
from nkeys import from_seed, KeyPair, nkeys
from nats_jwt.nkeys_ext import create_account_pair, create_operator_pair, create_user_pair, keypair_from_pubkey
from nats_jwt.v2.account_claims import AccountClaims
from nats_jwt.v2.claims import AnyClaims, extract_payload_sig_from_jwt
from nats_jwt.v2.operator_claims import OperatorClaims
from nats_jwt.v2.user_claims import UserClaims
JWT = t.Annotated[str, "Json Web Token"]
class MissingAttribute(Exception):
""" Exception raised when a required attribute is missing. """
pass
def init_new_operator(name: str) -> tuple[KeyPair, str]:
""" Create a new operator by its name.
This process is inspired by the NATS GitHub example.
The Operator can be signed by another keypair, and trust itself.
Args:
name: name of the operator
Returns:
tuple of an operator key pair (instance of nkeys.KeyPair) and (str) JWT
"""
okp: KeyPair = create_operator_pair()
oc: OperatorClaims = OperatorClaims(okp.public_key.decode(), name)
oskp: KeyPair = create_operator_pair()
oc.nats.signing_keys.append(oskp.public_key.decode())
jwt: str = oc.encode(okp)
return okp, jwt
def create_account(operator: KeyPair, name: str) -> tuple[KeyPair, str]:
""" Create a new account by its name.
Args:
operator: operator key pair, that will sign the account
name: name of the account
Returns:
tuple of an account key pair (instance of nkeys.KeyPair) and (str) JWT
"""
akp: KeyPair = create_account_pair()
ac: AccountClaims = AccountClaims(akp.public_key.decode(), name)
askp: KeyPair = create_account_pair()
ac.nats.signing_keys.add(askp.public_key.decode())
jwt: str = ac.encode(operator)
return akp, jwt
class Snippet(ABC, t.Generic[AnyClaims]):
""" Abstract class to help with creating, verifying and signing JWTs for NATS.
Attributes:
key_pair: key-pair of the represented entity
seed_getter: function that will get the nats_seed for a given public key.
claims: claims of the account
Children should implement:
claims_t: type of claims that will be used
new_pair: function that will create a new key pair
"""
key_pair: KeyPair
seed_getter: t.Callable[[str | None], bytes | str] = None
@property
@abstractmethod
def claims_t(self) -> t.Type[AnyClaims]:
"""
Returns:
type of claims that will be used, e.g. AccountClaims
"""
pass
@staticmethod
@abstractmethod
def new_pair() -> KeyPair:
"""
Returns:
function that will create a new key pair for the represented entity
"""
pass
claims: AnyClaims | None
def __init__(
self,
jwt: JWT | None = None,
claims: AnyClaims | None = None,
seed: bytes | None = None,
seed_getter: t.Callable[[str], bytes] = None,
):
""" Create a new instance of a snippet.
Args:
seed: nats seed, starting with `S`
Not raw.
jwt:
seed_getter: function that will get the nats_seed for a given public key.
Note:
does not set claims
"""
self.claims = claims
if seed_getter is not None or seed is not None:
self.seed_getter = lambda _: seed if seed else seed_getter
if jwt is not None:
self.claims = self.claims_t.decode_claims(jwt)
signing_keys = self.claims.nats.signing_keys if isinstance(self, Verifier) else None
if signing_keys:
if self.seed_getter is not None:
nats_seed = self.seed_getter(signing_keys[0])
self.key_pair = from_seed(nats_seed)
# we have everything we need
return
# we can get only public key
self.key_pair = keypair_from_pubkey(signing_keys[0])
if seed is not None:
self.key_pair = from_seed(seed)
return
# nothing was passed in, create a new instance
self.key_pair = self.new_pair()
def set_claims(self, claims: AnyClaims) -> None:
""" Setter of claims for the snippet. """
self.claims = claims
@property
def jwt(self) -> JWT:
""" Return a JWT of the snippet.
Note:
requires claims to be set
Note:
JWT is created in this method.
It encodes on the fly by accessing this property.
Returns:
JWT of the entity
"""
if self.claims is None:
raise MissingAttribute("claims")
return self.claims.encode(self.key_pair)
class Verifier:
""" Mixin for a snippet that can verify JWTs.
Attributes:
claims: claims of the entity (can be None)
key_pair: key-pair of the represented entity
"""
claims: AnyClaims | None
key_pair: KeyPair
def verify(self, jwt: JWT) -> bool:
""" Verify a JWT is signed by this entity.
Args:
jwt: JWT to verify
Returns:
True if the JWT is signed by this entity, False otherwise
"""
try:
if self.claims is None:
# if we can't get claims, we extract from jwt claims and signature to verify
return self.key_pair.verify(*extract_payload_sig_from_jwt(jwt))
# if claims are set, we can verify with them
return self.claims.verify_jwt(jwt)
except nkeys.ErrInvalidSignature as e:
return False
class Operator(Snippet, Verifier):
""" Snippet representing an operator in NATS.
Attributes:
claims_t: type of claims that will be used, e.g. OperatorClaims
claims: claims of the operator
"""
claims_t: t.Final[t.Type[AnyClaims]] = OperatorClaims
claims: OperatorClaims | None
@staticmethod
def new_pair() -> KeyPair:
""" Creates a new operator key pair.
Returns:
new operator key pair (instance of nkeys.KeyPair)
"""
return create_operator_pair()
def create_account(self, name: str) -> "Account":
""" Creates an account for this operator.
Args:
name: name of the account
Returns:
new account snippet, that will be signed by this operator on JWT gen (jwt-gen is lazy)
"""
akp: KeyPair = create_account_pair()
ac = AccountClaims(akp.public_key.decode(), name)
return Account(
claims=ac,
signer_kp=self.key_pair,
seed=akp.seed,
)
class Account(Snippet, Verifier):
""" Snippet representing an account in NATS.
Attributes:
claims_t: type of claims that will be used, e.g. AccountClaims
claims: claims of the account
_skp: (protected) key-pair of the operator that will sign the JWT of this account
"""
claims: AccountClaims
claims_t = AccountClaims
def __init__(
self,
jwt: JWT | None = None,
claims: AnyClaims | None = None,
seed: bytes | None = None,
seed_getter: t.Callable[[str], bytes] = None,
signer_kp: KeyPair | None = None,
):
""" Create a new instance of an account snippet.
Args:
jwt: JWT of the account
claims: claims of the account
seed: nats seed, starting with `S`
Not raw.
seed_getter: function that will get the nats_seed for a given public key.
signer_kp: key-pair of the operator that will sign the JWT of this account
"""
super().__init__(jwt, claims, seed, seed_getter)
self._skp: KeyPair | None = signer_kp
@staticmethod
def new_pair() -> KeyPair:
""" Creates a new account key pair.
Returns:
new account key pair (instance of nkeys.KeyPair)
"""
return create_account_pair()
def create_user(self, name: str) -> "User":
""" Creates a user for this account.
Args:
name: name of the user
Returns:
new user snippet, that will be signed by this account on JWT gen (jwt-gen is lazy)
"""
ukp: KeyPair = create_user_pair()
uc = UserClaims(ukp.public_key.decode(), name)
return User(
claims=uc,
signer_kp=self.key_pair,
seed=ukp.seed,
)
@property
def jwt(self) -> JWT:
""" Return a JWT of the user """
if self.claims is None:
raise MissingAttribute("claims")
if self._skp is None:
raise MissingAttribute("signer key pair (_skp).")
return self.claims.encode(self._skp)
class User(Snippet):
""" Snippet representing a user in NATS.
Attributes:
claims_t: type of claims that will be used, e.g. UserClaims
claims: claims of the user
_skp: (protected) key-pair of the account that will sign the JWT of this user
"""
claims_t = UserClaims
@staticmethod
def new_pair() -> KeyPair:
""" Creates a new user key pair.
Returns:
new user key pair (instance of nkeys.KeyPair)
"""
return create_user_pair()
def __init__(
self,
jwt: JWT | None = None,
claims: AnyClaims | None = None,
seed: bytes | None = None,
seed_getter: t.Callable[[str], bytes | str] = None,
signer_kp: KeyPair | None = None,
):
""" Create a new instance of a user snippet.
Args:
jwt: JWT of the user
claims: claims of the user
seed: nats seed, starting with `S`
Not raw.
seed_getter: function that will get the nats_seed for a given public key.
signer_kp: key-pair of the account that will sign the JWT of this user
"""
super().__init__(jwt, claims, seed, seed_getter)
self._skp: KeyPair = signer_kp
@property
def jwt(self) -> JWT:
""" Return a JWT of the user """
if self.claims is None:
raise MissingAttribute("claims")
if self._skp is None:
raise MissingAttribute("signer key pair(_skp).")
self.claims: UserClaims
return self.claims.encode(self._skp)