diff --git a/cashu/wallet/utils.py b/cashu/wallet/utils.py new file mode 100644 index 00000000..b5b1115b --- /dev/null +++ b/cashu/wallet/utils.py @@ -0,0 +1,10 @@ +def sanitize_url(url: str) -> str: + # extract host from url and lower case it, remove trailing slash from url + protocol = url.split("://")[0] + host = url.split("://")[1].split("/")[0].lower() + path = ( + url.split("://")[1].split("/", 1)[1].rstrip("/") + if "/" in url.split("://")[1] + else "" + ) + return f"{protocol}://{host}{'/' + path if path else ''}" diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 30bb878f..8a255d5c 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -8,9 +8,6 @@ from bip32 import BIP32 from loguru import logger -from cashu.core.crypto.keys import derive_keyset_id -from cashu.core.json_rpc.base import JSONRPCSubscriptionKinds - from ..core.base import ( BlindedMessage, BlindedSignature, @@ -22,10 +19,16 @@ WalletKeyset, ) from ..core.crypto import b_dhke +from ..core.crypto.keys import derive_keyset_id from ..core.crypto.secp import PrivateKey, PublicKey from ..core.db import Database from ..core.errors import KeysetNotFoundError -from ..core.helpers import amount_summary, calculate_number_of_blank_outputs, sum_proofs +from ..core.helpers import ( + amount_summary, + calculate_number_of_blank_outputs, + sum_proofs, +) +from ..core.json_rpc.base import JSONRPCSubscriptionKinds from ..core.migrations import migrate_databases from ..core.models import ( PostCheckStateResponse, @@ -34,7 +37,8 @@ from ..core.p2pk import Secret from ..core.settings import settings from ..core.split import amount_split -from ..wallet.crud import ( +from . import migrations +from .crud import ( bump_secret_derivation, get_keysets, get_proofs, @@ -48,7 +52,6 @@ update_lightning_invoice, update_proof, ) -from . import migrations from .htlc import WalletHTLC from .mint_info import MintInfo from .p2pk import WalletP2PK @@ -56,6 +59,7 @@ from .secrets import WalletSecrets from .subscriptions import SubscriptionManager from .transactions import WalletTransactions +from .utils import sanitize_url from .v1_api import LedgerAPI @@ -107,7 +111,8 @@ def __init__(self, url: str, db: str, name: str = "no_name", unit: str = "sat"): self.proofs: List[Proof] = [] self.name = name self.unit = Unit[unit] - url = url.rstrip("/") + url = sanitize_url(url) + super().__init__(url=url, db=self.db) logger.debug("Wallet initialized") logger.debug(f"Mint URL: {url}") diff --git a/tests/test_wallet_utils.py b/tests/test_wallet_utils.py new file mode 100644 index 00000000..534db8df --- /dev/null +++ b/tests/test_wallet_utils.py @@ -0,0 +1,63 @@ +from typing import List, Union + +from cashu.core.errors import CashuError +from cashu.wallet.utils import sanitize_url + + +async def assert_err(f, msg: Union[str, CashuError]): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + error_message: str = str(exc.args[0]) + if isinstance(msg, CashuError): + if msg.detail not in error_message: + raise Exception( + f"CashuError. Expected error: {msg.detail}, got: {error_message}" + ) + return + if msg not in error_message: + raise Exception(f"Expected error: {msg}, got: {error_message}") + return + raise Exception(f"Expected error: {msg}, got no error") + + +async def assert_err_multiple(f, msgs: List[str]): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + for msg in msgs: + if msg in str(exc.args[0]): + return + raise Exception(f"Expected error: {msgs}, got: {exc.args[0]}") + raise Exception(f"Expected error: {msgs}, got no error") + + +def test_sanitize_url(): + url = "https://localhost:3338" + assert sanitize_url(url) == "https://localhost:3338" + + url = "https://mint.com:3338" + assert sanitize_url(url) == "https://mint.com:3338" + + url = "https://Mint.com:3338" + assert sanitize_url(url) == "https://mint.com:3338" + + url = "https://mint.com:3338/" + assert sanitize_url(url) == "https://mint.com:3338" + + url = "https://mint.com:3338/abc" + assert sanitize_url(url) == "https://mint.com:3338/abc" + + url = "https://mint.com:3338/Abc" + assert sanitize_url(url) == "https://mint.com:3338/Abc" + + url = "https://mint.com:3338/abc/" + assert sanitize_url(url) == "https://mint.com:3338/abc" + + url = "https://mint.com:3338/Abc/" + assert sanitize_url(url) == "https://mint.com:3338/Abc" + + url = "https://Mint.com:3338/Abc/def" + assert sanitize_url(url) == "https://mint.com:3338/Abc/def"