Skip to content

Commit

Permalink
Merge #1562: JWT follow-up
Browse files Browse the repository at this point in the history
c88429d JWT authority fixes (roshii)

Pull request description:

  - Fix JWT unit tests (`async` was somehow making tests always successful therefore hiding some issues)
  - Fix `WWW-Authenticate` header construction (#1480 (comment))
  - Encode wallet names with base64 in scopes to allow for space delimited names (#1480 (comment), joinmarket-webui/jam#663 (comment))
  - Fix syntax errors in OpenAPI RPC documentation (#1559)

Top commit has no ACKs.

Tree-SHA512: 6625c4c457c4caf3b4979505334c955bec50fcc0b01707e313dc772571c5c8c8b3ca359a18b5e67f1b0d0eb9b2b7c234ae9716d785234e8de0f3bfb76d53d29a
  • Loading branch information
kristapsk committed Oct 6, 2023
2 parents b27c86e + c88429d commit 28c8413
Showing 6 changed files with 81 additions and 65 deletions.
9 changes: 2 additions & 7 deletions docs/api/wallet-rpc.yaml
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ paths:
On initially creating, unlocking or recovering a wallet, store both the refresh and access tokens, the latter is valid for only 30 minutes (must be used for any authenticated call) while the former is for 4 hours (can only be used in the refresh request parameters). Use /token endpoint on a regular basis to get new access and refresh tokens, ideally before access token expiration to avoid authentication errors and in any case, before refresh token expiration. The newly issued tokens must be used in subsequent calls since operation invalidates previously issued tokens.
responses:
'200':
$ref: '#/components/responses/RefreshToken-200-OK'
$ref: '#/components/responses/Token-200-OK'
'400':
$ref: '#/components/responses/400-BadRequest'
requestBody:
@@ -579,11 +579,6 @@ paths:
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/GetSeedResponse'
responses:
'200':
$ref: '#/components/responses/GetSeed-200-OK'
@@ -684,7 +679,7 @@ components:
token_type:
type: string
expires_in:
type: int
type: integer
scope:
type: string
refresh_token:
12 changes: 8 additions & 4 deletions src/jmclient/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import os
from base64 import b64encode

import jwt

@@ -19,6 +20,9 @@ def get_random_key(size: int = 16) -> str:
return bintohex(os.urandom(size))


def b64str(s: str) -> str:
return b64encode(s.encode()).decode()

class JMTokenAuthority:
"""Manage authorization tokens."""

@@ -57,13 +61,13 @@ def verify(self, token: str, *, token_type: str = "access"):
if not self._scope <= token_claims:
raise InvalidScopeError

def add_to_scope(self, *args: str):
def add_to_scope(self, *args: str, encoded: bool = True):
for arg in args:
self._scope.add(arg)
self._scope.add(b64str(arg) if encoded else arg)

def discard_from_scope(self, *args: str):
def discard_from_scope(self, *args: str, encoded: bool = True):
for arg in args:
self._scope.discard(arg)
self._scope.discard(b64str(arg) if encoded else arg)

@property
def scope(self):
10 changes: 5 additions & 5 deletions src/jmclient/wallet_rpc.py
Original file line number Diff line number Diff line change
@@ -280,10 +280,10 @@ def stopSubServices(self):
self.taker_finished(False)

def auth_err(self, request, error, description=None):
request.setHeader("WWW-Authenticate", "Bearer")
request.setHeader("WWW-Authenticate", f'error="{error}"')
value = f'Bearer, error="{error}"'
if description is not None:
request.setHeader("WWW-Authenticate", f'error_description="{description}"')
value += f', error_description="{description}"'
request.setHeader("WWW-Authenticate", value)
return

def err(self, request, message):
@@ -305,7 +305,7 @@ def invalid_credentials(self, request, failure):
@app.handle_errors(InvalidToken)
def invalid_token(self, request, failure):
request.setResponseCode(401)
return self.auth_err(request, "invalid_token", str(failure))
return self.auth_err(request, "invalid_token", failure.getErrorMessage())

@app.handle_errors(InsufficientScope)
def insufficient_scope(self, request, failure):
@@ -643,7 +643,7 @@ def _mkerr(err, description=""):
"The requested scope is invalid, unknown, malformed, "
"or exceeds the scope granted by the resource owner.",
)
except auth.ExpiredSignatureError:
except Exception:
return _mkerr(
"invalid_grant",
f"The provided {grant_type} is invalid, revoked, "
18 changes: 12 additions & 6 deletions test/jmclient/test_auth.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,12 @@
import jwt
import pytest

from jmclient.auth import ExpiredSignatureError, InvalidScopeError, JMTokenAuthority
from jmclient.auth import (
ExpiredSignatureError,
InvalidScopeError,
JMTokenAuthority,
b64str,
)


class TestJMTokenAuthority:
@@ -17,7 +22,7 @@ class TestJMTokenAuthority:
refresh_sig = copy.copy(token_auth.signature_key["refresh"])

validity = datetime.timedelta(hours=1)
scope = f"walletrpc {wallet_name}"
scope = f"walletrpc {b64str(wallet_name)}"

@pytest.mark.parametrize(
"sig, token_type", [(access_sig, "access"), (refresh_sig, "refresh")]
@@ -83,15 +88,16 @@ def scope_equals(scope):

def test_scope_operation(self):
assert "walletrpc" in self.token_auth._scope
assert self.wallet_name in self.token_auth._scope
assert b64str(self.wallet_name) in self.token_auth._scope

scope = copy.copy(self.token_auth._scope)
s = "new_wallet"

self.token_auth.add_to_scope(s)
assert scope < self.token_auth._scope
assert s in self.token_auth._scope
assert b64str(s) in self.token_auth._scope

self.token_auth.discard_from_scope(s, "walletrpc")
self.token_auth.discard_from_scope(s)
self.token_auth.discard_from_scope("walletrpc", encoded=False)
assert scope > self.token_auth._scope
assert s not in self.token_auth._scope
assert b64str(s) not in self.token_auth._scope
84 changes: 47 additions & 37 deletions test/jmclient/test_wallet_rpc.py
Original file line number Diff line number Diff line change
@@ -29,8 +29,7 @@
from commontest import make_wallets
from test_coinjoin import make_wallets_to_list, sync_wallets

from test_websocket import (ClientTProtocol, test_tx_hex_1,
test_tx_hex_txid, test_token_authority)
from test_websocket import ClientTProtocol, test_tx_hex_1, test_tx_hex_txid

pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind")

@@ -41,10 +40,6 @@
jlog = get_log()

class JMWalletDaemonT(JMWalletDaemon):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token = test_token_authority

def check_cookie(self, request, *args, **kwargs):
if self.auth_disabled:
return True
@@ -220,6 +215,7 @@ def test_notif(self):
"ws://127.0.0.1:"+str(self.wss_port),
delay=0.1, callbackfn=self.fire_tx_notif)
self.client_factory.protocol = ClientNotifTestProto
self.client_factory.protocol.ACCESS_TOKEN = self.daemon.token.issue()["token"].encode("utf8")
self.client_connector = connectWS(self.client_factory)
self.attempt_receipt_counter = 0
return task.deferLater(reactor, 0.0, self.wait_to_receive)
@@ -754,22 +750,28 @@ def process_get_seed_response(self, response, code):


class TrialTestWRPC_JWT(WalletRPCTestBase, unittest.TestCase):
@defer.inlineCallbacks
def do_request(self, agent, method, addr, body, handler, token):
headers = Headers({"Authorization": ["Bearer " + token]})
response = yield agent.request(method, addr, headers, bodyProducer=body)
handler(response)

def get_token(self, grant_type: str, status: str = "valid"):
now, delta = datetime.datetime.utcnow(), datetime.timedelta(hours=1)
exp = now - delta if status == "expired" else now + delta

scope = f"walletrpc {self.daemon.wallet_name}"
if status == "invalid_scope":
scope = "walletrpc another_wallet"
scope = status

alg = test_token_authority.SIGNATURE_ALGORITHM
alg = self.daemon.token.SIGNATURE_ALGORITHM
if status == "invalid_alg":
alg = ({"HS256", "HS384", "HS512"} - {alg}).pop()

t = jwt.encode(
{"exp": exp, "scope": scope},
test_token_authority.signature_key[grant_type],
algorithm=test_token_authority.SIGNATURE_ALGORITHM,
self.daemon.token.signature_key[grant_type],
algorithm=alg,
)

if status == "invalid_sig":
@@ -792,22 +794,23 @@ def get_token(self, grant_type: str, status: str = "valid"):

return t

def authorized_response_handler(self, response, code):
assert code == 200
def authorized_response_handler(self, response):
assert response.code == 200

def forbidden_response_handler(self, response, code):
assert code == 403
assert "insufficient_scope" in response.headers.get("WWW-Authenticate")
def forbidden_response_handler(self, response):
assert response.code == 403
assert "insufficient_scope" in response.headers.getRawHeaders("WWW-Authenticate").pop()

def unauthorized_response_handler(self, response, code):
assert code == 401
assert "Bearer" in response.headers.get("WWW-Authenticate")
def unauthorized_response_handler(self, response):
assert response.code == 401
assert "Bearer" in response.headers.getRawHeaders("WWW-Authenticate").pop()

def expired_access_token_response_handler(self, response, code):
self.unauthorized_response_handler(response, code)
assert "expired" in response.headers.get("WWW-Authenticate")
def expired_access_token_response_handler(self, response):
self.unauthorized_response_handler(response)
assert "expired" in response.headers.getRawHeaders("WWW-Authenticate").pop()

async def test_jwt_authentication(self):
@defer.inlineCallbacks
def test_jwt_authentication(self):
"""Test JWT authentication and authorization"""

agent = get_nontor_agent()
@@ -828,31 +831,37 @@ async def test_jwt_authentication(self):
}[responde_handler]
token = self.get_token("access", access_token_status)

await self.do_request(agent, b"GET", addr, None, handler, token)
yield self.do_request(agent, b"GET", addr, None, handler, token)

def successful_refresh_response_handler(self, response, code):
self.authorized_response_handler(response, code)
json_body = json.loads(response.decode("utf-8"))
@defer.inlineCallbacks
def successful_refresh_response_handler(self, response):
self.authorized_response_handler(response)
body = yield readBody(response)
json_body = json.loads(body.decode("utf-8"))
assert {"token", "refresh_token", "expires_in", "token_type", "scope"} <= set(
json_body.keys()
)

@defer.inlineCallbacks
def failed_refresh_response_handler(
self, response, code, *, message=None, error_description=None
self, response, *, message=None, error_description=None
):
assert code == 400
json_body = json.loads(response.decode("utf-8"))
assert response.code == 400
body = yield readBody(response)
json_body = json.loads(body.decode("utf-8"))
if message is not None:
assert json_body.get("message") == message
if error_description is not None:
assert error_description in json_body.get("error_description")

async def do_refresh_request(self, body, handler, token):
@defer.inlineCallbacks
def do_refresh_request(self, body, handler, token):
agent = get_nontor_agent()
addr = (self.get_route_root() + "/token").encode()
body = BytesProducer(json.dumps(body).encode())
await self.do_request(agent, b"POST", addr, body, handler, token)
yield self.do_request(agent, b"POST", addr, body, handler, token)

@defer.inlineCallbacks
def test_refresh_token_request(self):
"""Test token endpoint with valid refresh token"""
for access_token_status, request_status, error in [
@@ -864,7 +873,7 @@ def test_refresh_token_request(self):
if error is None:
handler = self.successful_refresh_response_handler
else:
handler = functools.partialmethod(
handler = functools.partial(
self.failed_refresh_response_handler, message=error
)

@@ -877,23 +886,24 @@ def test_refresh_token_request(self):
if request_status == "unsupported_grant_type":
body["grant_type"] = "joinmarket"

self.do_refresh_request(
yield self.do_refresh_request(
body, handler, self.get_token("access", access_token_status)
)

async def test_refresh_token(self):
@defer.inlineCallbacks
def test_refresh_token(self):
"""Test refresh token endpoint"""
for refresh_token_status, error in [
("expired", "expired"),
("invalid_scope", "invalid_scope"),
("invalid_sig", "invalid_grant"),
]:
if error == "expired":
handler = functools.partialmethod(
handler = functools.partial(
self.failed_refresh_response_handler, error_description=error
)
else:
handler = functools.partialmethod(
handler = functools.partial(
self.failed_refresh_response_handler, message=error
)

@@ -902,7 +912,7 @@ async def test_refresh_token(self):
"refresh_token": self.get_token("refresh", refresh_token_status),
}

self.do_refresh_request(body, handler, self.get_token("access"))
yield self.do_refresh_request(body, handler, self.get_token("access"))


"""
13 changes: 7 additions & 6 deletions test/jmclient/test_websocket.py
Original file line number Diff line number Diff line change
@@ -21,19 +21,20 @@
test_tx_hex_txid = "ca606efc5ba8f6669ba15e9262e5d38e745345ea96106d5a919688d1ff0da0cc"

# Shared JWT token authority for test:
test_token_authority = JMTokenAuthority("dummywallet")
token_authority = JMTokenAuthority()


class ClientTProtocol(WebSocketClientProtocol):
"""
Simple client that connects to a WebSocket server, send a HELLO
message every 2 seconds and print everything it receives.
"""

ACCESS_TOKEN = token_authority.issue()["token"].encode("utf8")

def sendAuth(self):
""" Our server will not broadcast
to us unless we authenticate.
"""
self.sendMessage(test_token_authority.issue()["token"].encode('utf8'))
"""Our server will not broadcast to us unless we authenticate."""
self.sendMessage(self.ACCESS_TOKEN)

def onOpen(self):
# auth on startup
@@ -65,7 +66,7 @@ def setUp(self):
free_ports = get_free_tcp_ports(1)
self.wss_port = free_ports[0]
self.wss_url = "ws://127.0.0.1:" + str(self.wss_port)
self.wss_factory = JmwalletdWebSocketServerFactory(self.wss_url, test_token_authority)
self.wss_factory = JmwalletdWebSocketServerFactory(self.wss_url, token_authority)
self.wss_factory.protocol = JmwalletdWebSocketServerProtocol
self.listeningport = listenWS(self.wss_factory, contextFactory=None)
self.test_tx = CTransaction.deserialize(hextobin(test_tx_hex_1))

0 comments on commit 28c8413

Please sign in to comment.