diff --git a/.env.example b/.env.example index d3b11cc9..7cf46c0c 100644 --- a/.env.example +++ b/.env.example @@ -24,6 +24,9 @@ NOSTR_RELAYS=["wss://nostr-pub.wellorder.net"] # Wallet API port API_PORT=4448 +# Wallet default unit +WALLET_UNIT="sat" + # --------- MINT --------- # Network @@ -38,8 +41,17 @@ MINT_INFO_CONTACT=[["email","contact@me.com"], ["twitter","@me"], ["nostr", "np MINT_INFO_MOTD="Message to users" MINT_PRIVATE_KEY=supersecretprivatekey -# increment derivation path to rotate to a new keyset -MINT_DERIVATION_PATH="0/0/0/0" + +# Increment derivation path to rotate to a new keyset +# Example: m/0'/0'/0' -> m/0'/0'/1' +MINT_DERIVATION_PATH="m/0'/0'/0'" + +# Multiple derivation paths and units. Unit is parsed from the derivation path. +# m/0'/0'/0' is "sat" (default) +# m/0'/1'/0' is "msat" +# m/0'/2'/0' is "usd" +# In this example, we have 2 keysets for sat, 1 for msat and 1 for usd +# MINT_DERIVATION_PATH_LIST=["m/0'/0'/0'", "m/0'/0'/1'", "m/0'/1'/0'", "m/0'/2'/0'"] MINT_DATABASE=data/mint diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 873fce80..82296e13 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,8 +17,7 @@ jobs: poetry-version: ["1.7.1"] mint-cache-secrets: ["false", "true"] mint-only-deprecated: ["false", "true"] - # db-url: ["", "postgres://cashu:cashu@localhost:5432/test"] # TODO: Postgres test not working - db-url: [""] + mint-database: ["./test_data/test_mint", "postgres://cashu:cashu@localhost:5432/cashu"] backend-wallet-class: ["FakeWallet"] uses: ./.github/workflows/tests.yml with: @@ -27,6 +26,7 @@ jobs: poetry-version: ${{ matrix.poetry-version }} mint-cache-secrets: ${{ matrix.mint-cache-secrets }} mint-only-deprecated: ${{ matrix.mint-only-deprecated }} + mint-database: ${{ matrix.mint-database }} regtest: uses: ./.github/workflows/regtest.yml strategy: @@ -38,3 +38,4 @@ jobs: with: python-version: ${{ matrix.python-version }} backend-wallet-class: ${{ matrix.backend-wallet-class }} + mint-database: "./test_data/test_mint" diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml new file mode 100644 index 00000000..ded34be9 --- /dev/null +++ b/.github/workflows/docker.yaml @@ -0,0 +1,38 @@ +name: Docker Build + +on: + push: + release: + types: [published] + +jobs: + build-and-push: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Determine Tag + id: get_tag + run: | + if [[ "${{ github.event_name }}" == "push" ]]; then + echo "::set-output name=tag::latest" + elif [[ "${{ github.event_name }}" == "release" ]]; then + echo "::set-output name=tag::${{ github.event.release.tag_name }}" + fi + + - name: Build and push on release + uses: docker/build-push-action@v5 + with: + context: . + push: ${{ github.event_name == 'release' }} + tags: ${{ secrets.DOCKER_USERNAME }}/${{ github.event.repository.name }}:${{ steps.get_tag.outputs.tag }} diff --git a/.github/workflows/regtest.yml b/.github/workflows/regtest.yml index 48653736..46e2d8ed 100644 --- a/.github/workflows/regtest.yml +++ b/.github/workflows/regtest.yml @@ -12,7 +12,7 @@ on: os-version: default: "ubuntu-latest" type: string - db-url: + mint-database: default: "" type: string backend-wallet-class: @@ -23,6 +23,20 @@ jobs: regtest: runs-on: ${{ inputs.os-version }} timeout-minutes: 10 + services: + postgres: + image: postgres:latest + env: + POSTGRES_USER: cashu + POSTGRES_PASSWORD: cashu + POSTGRES_DB: cashu + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 steps: - uses: actions/checkout@v3 @@ -47,7 +61,7 @@ jobs: WALLET_NAME: test_wallet MINT_HOST: localhost MINT_PORT: 3337 - MINT_DATABASE: ${{ inputs.db-url }} + MINT_TEST_DATABASE: ${{ inputs.mint-database }} TOR: false MINT_LIGHTNING_BACKEND: ${{ inputs.backend-wallet-class }} MINT_LNBITS_ENDPOINT: http://localhost:5001 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6659b06a..cf1a176b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,7 @@ on: poetry-version: default: "1.7.1" type: string - db-url: + mint-database: default: "" type: string os: @@ -24,7 +24,7 @@ on: jobs: poetry: - name: Run (mint-cache-secrets ${{ inputs.mint-cache-secrets }}, mint-only-deprecated ${{ inputs.mint-only-deprecated }}) + name: Run (mint-cache-secrets ${{ inputs.mint-cache-secrets }}, mint-only-deprecated ${{ inputs.mint-only-deprecated }}, mint-database ${{ inputs.mint-database }}) runs-on: ${{ inputs.os }} services: postgres: @@ -53,7 +53,7 @@ jobs: WALLET_NAME: test_wallet MINT_HOST: localhost MINT_PORT: 3337 - MINT_DATABASE: ${{ inputs.db-url }} + MINT_TEST_DATABASE: ${{ inputs.mint-database }} MINT_CACHE_SECRETS: ${{ inputs.mint-cache-secrets }} DEBUG_MINT_ONLY_DEPRECATED: ${{ inputs.mint-only-deprecated }} TOR: false diff --git a/Dockerfile b/Dockerfile index 6d65364a..da3eb565 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.10-slim RUN apt-get update RUN apt-get install -y curl python3-dev autoconf g++ RUN apt-get install -y libpq-dev diff --git a/Makefile b/Makefile index 7bc6d6e9..b9e79292 100644 --- a/Makefile +++ b/Makefile @@ -59,3 +59,11 @@ install-pre-commit-hook: pre-commit: poetry run pre-commit run --all-files + +docker-build: + rm -rf docker-build || true + mkdir -p docker-build + git clone . docker-build + cd docker-build + docker buildx build -f Dockerfile -t cashubtc/nutshell:0.15.0 --platform linux/amd64 . + # docker push cashubtc/nutshell:0.15.0 diff --git a/cashu/core/base.py b/cashu/core/base.py index 8abc69f5..b32933b2 100644 --- a/cashu/core/base.py +++ b/cashu/core/base.py @@ -9,6 +9,8 @@ from loguru import logger from pydantic import BaseModel, Field +from .crypto.aes import AESCipher +from .crypto.b_dhke import hash_to_curve from .crypto.keys import ( derive_keys, derive_keys_sha256, @@ -87,8 +89,9 @@ class Proof(BaseModel): id: Union[None, str] = "" amount: int = 0 secret: str = "" # secret or message to be blinded and signed + Y: str = "" # hash_to_curve(secret) C: str = "" # signature on secret, unblinded by wallet - dleq: Union[DLEQWallet, None] = None # DLEQ proof + dleq: Optional[DLEQWallet] = None # DLEQ proof witness: Union[None, str] = "" # witness for spending condition # whether this proof is reserved for sending, used for coin management in the wallet @@ -105,6 +108,11 @@ class Proof(BaseModel): None # holds the id of the melt operation that destroyed this proof ) + def __init__(self, **data): + super().__init__(**data) + if not self.Y: + self.Y = hash_to_curve(self.secret.encode("utf-8")).serialize().hex() + @classmethod def from_dict(cls, proof_dict: dict): if proof_dict.get("dleq") and isinstance(proof_dict["dleq"], str): @@ -218,10 +226,37 @@ class MeltQuote(BaseModel): amount: int fee_reserve: int paid: bool - created_time: int = 0 - paid_time: int = 0 + created_time: Union[int, None] = None + paid_time: Union[int, None] = None fee_paid: int = 0 proof: str = "" + expiry: Optional[int] = None + + @classmethod + def from_row(cls, row: Row): + try: + created_time = int(row["created_time"]) if row["created_time"] else None + paid_time = int(row["paid_time"]) if row["paid_time"] else None + except Exception: + created_time = ( + int(row["created_time"].timestamp()) if row["created_time"] else None + ) + paid_time = int(row["paid_time"].timestamp()) if row["paid_time"] else None + + return cls( + quote=row["quote"], + method=row["method"], + request=row["request"], + checking_id=row["checking_id"], + unit=row["unit"], + amount=row["amount"], + fee_reserve=row["fee_reserve"], + paid=row["paid"], + created_time=created_time, + paid_time=paid_time, + fee_paid=row["fee_paid"], + proof=row["proof"], + ) class MintQuote(BaseModel): @@ -233,9 +268,35 @@ class MintQuote(BaseModel): amount: int paid: bool issued: bool - created_time: int = 0 - paid_time: int = 0 - expiry: int = 0 + created_time: Union[int, None] = None + paid_time: Union[int, None] = None + expiry: Optional[int] = None + + @classmethod + def from_row(cls, row: Row): + + try: + # SQLITE: row is timestamp (string) + created_time = int(row["created_time"]) if row["created_time"] else None + paid_time = int(row["paid_time"]) if row["paid_time"] else None + except Exception: + # POSTGRES: row is datetime.datetime + created_time = ( + int(row["created_time"].timestamp()) if row["created_time"] else None + ) + paid_time = int(row["paid_time"].timestamp()) if row["paid_time"] else None + return cls( + quote=row["quote"], + method=row["method"], + request=row["request"], + checking_id=row["checking_id"], + unit=row["unit"], + amount=row["amount"], + paid=row["paid"], + issued=row["issued"], + created_time=created_time, + paid_time=paid_time, + ) # ------- API ------- @@ -309,7 +370,7 @@ class PostMintQuoteResponse(BaseModel): quote: str # quote id request: str # input payment request paid: bool # whether the request has been paid - expiry: int # expiry of the quote + expiry: Optional[int] # expiry of the quote # ------- API: MINT ------- @@ -356,6 +417,7 @@ class PostMeltQuoteResponse(BaseModel): amount: int # input amount fee_reserve: int # input fee reserve paid: bool # whether the request has been paid + expiry: Optional[int] # expiry of the quote # ------- API: MELT ------- @@ -365,7 +427,7 @@ class PostMeltRequest(BaseModel): quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) outputs: Union[List[BlindedMessage], None] = Field( - ..., max_items=settings.mint_max_request_length + None, max_items=settings.mint_max_request_length ) @@ -379,7 +441,7 @@ class PostMeltRequest_deprecated(BaseModel): proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) pr: str = Field(..., max_length=settings.mint_max_request_length) outputs: Union[List[BlindedMessage], None] = Field( - ..., max_items=settings.mint_max_request_length + None, max_items=settings.mint_max_request_length ) @@ -641,35 +703,50 @@ class MintKeyset: unit: Unit derivation_path: str seed: Optional[str] = None - public_keys: Union[Dict[int, PublicKey], None] = None - valid_from: Union[str, None] = None - valid_to: Union[str, None] = None - first_seen: Union[str, None] = None - version: Union[str, None] = None + encrypted_seed: Optional[str] = None + seed_encryption_method: Optional[str] = None + public_keys: Optional[Dict[int, PublicKey]] = None + valid_from: Optional[str] = None + valid_to: Optional[str] = None + first_seen: Optional[str] = None + version: Optional[str] = None duplicate_keyset_id: Optional[str] = None # BACKWARDS COMPATIBILITY < 0.15.0 def __init__( self, *, - id="", - valid_from=None, - valid_to=None, - first_seen=None, - active=None, + derivation_path: str, seed: Optional[str] = None, - derivation_path: Optional[str] = None, + encrypted_seed: Optional[str] = None, + seed_encryption_method: Optional[str] = None, + valid_from: Optional[str] = None, + valid_to: Optional[str] = None, + first_seen: Optional[str] = None, + active: Optional[bool] = None, unit: Optional[str] = None, - version: str = "0", + version: Optional[str] = None, + id: str = "", ): - self.derivation_path = derivation_path or "" - self.seed = seed + self.derivation_path = derivation_path + + if encrypted_seed and not settings.mint_seed_decryption_key: + raise Exception("MINT_SEED_DECRYPTION_KEY not set, but seed is encrypted.") + if settings.mint_seed_decryption_key and encrypted_seed: + self.seed = AESCipher(settings.mint_seed_decryption_key).decrypt( + encrypted_seed + ) + else: + self.seed = seed + + assert self.seed, "seed not set" + self.id = id self.valid_from = valid_from self.valid_to = valid_to self.first_seen = first_seen self.active = bool(active) if active is not None else False - self.version = version + self.version = version or settings.version self.version_tuple = tuple( [int(i) for i in self.version.split(".")] if self.version else [] @@ -677,7 +754,7 @@ def __init__( # infer unit from derivation path if not unit: - logger.warning( + logger.trace( f"Unit for keyset {self.derivation_path} not set – attempting to parse" " from derivation path" ) @@ -685,9 +762,9 @@ def __init__( self.unit = Unit( int(self.derivation_path.split("/")[2].replace("'", "")) ) - logger.warning(f"Inferred unit: {self.unit.name}") + logger.trace(f"Inferred unit: {self.unit.name}") except Exception: - logger.warning( + logger.trace( "Could not infer unit from derivation path" f" {self.derivation_path} – assuming 'sat'" ) @@ -696,10 +773,12 @@ def __init__( self.unit = Unit[unit] # generate keys from seed - if self.seed and self.derivation_path: - self.generate_keys() + assert self.seed, "seed not set" + assert self.derivation_path, "derivation path not set" + + self.generate_keys() - logger.debug(f"Keyset id: {self.id} ({self.unit.name})") + logger.trace(f"Loaded keyset id: {self.id} ({self.unit.name})") @property def public_keys_hex(self) -> Dict[int, str]: @@ -720,14 +799,14 @@ def generate_keys(self): self.seed, self.derivation_path ) self.public_keys = derive_pubkeys(self.private_keys) # type: ignore - logger.warning( + logger.trace( f"WARNING: Using weak key derivation for keyset {self.id} (backwards" " compatibility < 0.12)" ) self.id = derive_keyset_id_deprecated(self.public_keys) # type: ignore elif self.version_tuple < (0, 15): self.private_keys = derive_keys_sha256(self.seed, self.derivation_path) - logger.warning( + logger.trace( f"WARNING: Using non-bip32 derivation for keyset {self.id} (backwards" " compatibility < 0.15)" ) diff --git a/cashu/core/crypto/aes.py b/cashu/core/crypto/aes.py new file mode 100644 index 00000000..7f856bf7 --- /dev/null +++ b/cashu/core/crypto/aes.py @@ -0,0 +1,65 @@ +import base64 +from hashlib import sha256 + +from Cryptodome import Random +from Cryptodome.Cipher import AES + +BLOCK_SIZE = 16 + + +class AESCipher: + """This class is compatible with crypto-js/aes.js + + Encrypt and decrypt in Javascript using: + import AES from "crypto-js/aes.js"; + import Utf8 from "crypto-js/enc-utf8.js"; + AES.encrypt(decrypted, password).toString() + AES.decrypt(encrypted, password).toString(Utf8); + + """ + + def __init__(self, key: str, description=""): + self.key: str = key + self.description = description + " " + + def pad(self, data): + length = BLOCK_SIZE - (len(data) % BLOCK_SIZE) + return data + (chr(length) * length).encode() + + def unpad(self, data): + return data[: -(data[-1] if isinstance(data[-1], int) else ord(data[-1]))] + + def bytes_to_key(self, data, salt, output=48): + # extended from https://gist.github.com/gsakkis/4546068 + assert len(salt) == 8, len(salt) + data += salt + key = sha256(data).digest() + final_key = key + while len(final_key) < output: + key = sha256(key + data).digest() + final_key += key + return final_key[:output] + + def decrypt(self, encrypted: str) -> str: # type: ignore + """Decrypts a string using AES-256-CBC.""" + encrypted = base64.urlsafe_b64decode(encrypted) # type: ignore + assert encrypted[0:8] == b"Salted__" + salt = encrypted[8:16] + key_iv = self.bytes_to_key(self.key.encode(), salt, 32 + 16) + key = key_iv[:32] + iv = key_iv[32:] + aes = AES.new(key, AES.MODE_CBC, iv) + try: + return self.unpad(aes.decrypt(encrypted[16:])).decode() # type: ignore + except UnicodeDecodeError: + raise ValueError("Wrong passphrase") + + def encrypt(self, message: bytes) -> str: + salt = Random.new().read(8) + key_iv = self.bytes_to_key(self.key.encode(), salt, 32 + 16) + key = key_iv[:32] + iv = key_iv[32:] + aes = AES.new(key, AES.MODE_CBC, iv) + return base64.urlsafe_b64encode( + b"Salted__" + salt + aes.encrypt(self.pad(message)) + ).decode() diff --git a/cashu/core/crypto/b_dhke.py b/cashu/core/crypto/b_dhke.py index e8706239..4abf1b77 100644 --- a/cashu/core/crypto/b_dhke.py +++ b/cashu/core/crypto/b_dhke.py @@ -71,6 +71,36 @@ def hash_to_curve(message: bytes) -> PublicKey: return point +DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_Cashu_" + + +def hash_to_curve_domain_separated(message: bytes) -> PublicKey: + """Generates a secp256k1 point from a message. + + The point is generated by hashing the message with a domain separator and then + iteratively trying to compute a point from the hash. An increasing uint32 counter + (byte order little endian) is appended to the hash until a point is found that lies on the curve. + + The chance of finding a valid point is 50% for every iteration. The maximum number of iterations + is 2**16. If no valid point is found after 2**16 iterations, a ValueError is raised (this should + never happen in practice). + + The domain separator is b"Secp256k1_HashToCurve_Cashu_" or + bytes.fromhex("536563703235366b315f48617368546f43757276655f43617368755f"). + """ + msg_to_hash = hashlib.sha256(DOMAIN_SEPARATOR + message).digest() + counter = 0 + while counter < 2**16: + _hash = hashlib.sha256(msg_to_hash + counter.to_bytes(4, "little")).digest() + try: + # will error if point does not lie on curve + return PublicKey(b"\x02" + _hash, raw=True) + except Exception: + counter += 1 + # it should never reach this point + raise ValueError("No valid point found") + + def step1_alice( secret_msg: str, blinding_factor: Optional[PrivateKey] = None ) -> tuple[PublicKey, PrivateKey]: @@ -80,6 +110,15 @@ def step1_alice( return B_, r +def step1_alice_domain_separated( + secret_msg: str, blinding_factor: Optional[PrivateKey] = None +) -> tuple[PublicKey, PrivateKey]: + Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) + r = blinding_factor or PrivateKey() + B_: PublicKey = Y + r.pubkey # type: ignore + return B_, r + + def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]: C_: PublicKey = B_.mult(a) # type: ignore # produce dleq proof @@ -94,7 +133,13 @@ def step3_alice(C_: PublicKey, r: PrivateKey, A: PublicKey) -> PublicKey: def verify(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool: Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8")) - return C == Y.mult(a) # type: ignore + valid = C == Y.mult(a) # type: ignore + # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 + if not valid: + Y1: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) + return C == Y1.mult(a) # type: ignore + # END: BACKWARDS COMPATIBILITY < 0.15.1 + return valid def hash_e(*publickeys: PublicKey) -> bytes: @@ -149,7 +194,14 @@ def carol_verify_dleq( Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8")) C_: PublicKey = C + A.mult(r) # type: ignore B_: PublicKey = Y + r.pubkey # type: ignore - return alice_verify_dleq(B_, C_, e, s, A) + valid = alice_verify_dleq(B_, C_, e, s, A) + # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 + if not valid: + Y1: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) + B_1: PublicKey = Y1 + r.pubkey # type: ignore + return alice_verify_dleq(B_1, C_, e, s, A) + # END: BACKWARDS COMPATIBILITY < 0.15.1 + return valid # Below is a test of a simple positive and negative case diff --git a/cashu/core/db.py b/cashu/core/db.py index 1f96910a..1863c605 100644 --- a/cashu/core/db.py +++ b/cashu/core/db.py @@ -31,7 +31,8 @@ def timestamp_now(self) -> str: if self.type in {POSTGRES, COCKROACH}: return "now()" elif self.type == SQLITE: - return "(strftime('%s', 'now'))" + # return "(strftime('%s', 'now'))" + return str(int(time.time())) return "" @property @@ -204,6 +205,26 @@ def lock_table(db: Database, table: str) -> str: return "" +def timestamp_from_seconds( + db: Database, seconds: Union[int, float, None] +) -> Union[str, None]: + if seconds is None: + return None + seconds = int(seconds) + if db.type in {POSTGRES, COCKROACH}: + return datetime.datetime.fromtimestamp(seconds).strftime("%Y-%m-%d %H:%M:%S") + elif db.type == SQLITE: + return str(seconds) + return None + + +def timestamp_now(db: Database) -> str: + timestamp = timestamp_from_seconds(db, time.time()) + if timestamp is None: + raise Exception("Timestamp is None") + return timestamp + + @asynccontextmanager async def get_db_connection(db: Database, conn: Optional[Connection] = None): """Either yield the existing database connection or create a new one. diff --git a/cashu/core/errors.py b/cashu/core/errors.py index fa2ca4af..d36614a4 100644 --- a/cashu/core/errors.py +++ b/cashu/core/errors.py @@ -12,7 +12,7 @@ def __init__(self, detail, code=0): class NotAllowedError(CashuError): - detail = "Not allowed." + detail = "not allowed" code = 10000 def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): @@ -20,7 +20,7 @@ def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): class TransactionError(CashuError): - detail = "Transaction error." + detail = "transaction error" code = 11000 def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): @@ -36,7 +36,7 @@ def __init__(self): class SecretTooLongError(TransactionError): - detail = "Secret too long." + detail = "secret too long" code = 11003 def __init__(self): @@ -44,7 +44,7 @@ def __init__(self): class NoSecretInProofsError(TransactionError): - detail = "No secret in proofs." + detail = "no secret in proofs" code = 11004 def __init__(self): @@ -52,7 +52,7 @@ def __init__(self): class KeysetError(CashuError): - detail = "Keyset error." + detail = "keyset error" code = 12000 def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): @@ -60,7 +60,7 @@ def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): class KeysetNotFoundError(KeysetError): - detail = "Keyset not found." + detail = "keyset not found" code = 12001 def __init__(self): @@ -68,15 +68,15 @@ def __init__(self): class LightningError(CashuError): - detail = "Lightning error." + detail = "Lightning error" code = 20000 def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): super().__init__(detail or self.detail, code=code or self.code) -class InvoiceNotPaidError(CashuError): - detail = "Lightning invoice not paid yet." +class QuoteNotPaidError(CashuError): + detail = "quote not paid" code = 20001 def __init__(self): diff --git a/cashu/core/migrations.py b/cashu/core/migrations.py index ce3e2bfe..7bab5eeb 100644 --- a/cashu/core/migrations.py +++ b/cashu/core/migrations.py @@ -1,8 +1,44 @@ +import os import re +import time from loguru import logger from ..core.db import COCKROACH, POSTGRES, SQLITE, Database, table_with_schema +from ..core.settings import settings + + +async def backup_database(db: Database, version: int = 0) -> str: + # for postgres: use pg_dump + # for sqlite: use sqlite3 + + # skip backups if db_backup_path is None + # and if version is 0 (fresh database) + if not settings.db_backup_path or not version: + return "" + + filename = f"backup_{db.name}_{int(time.time())}_v{version}" + try: + # create backup directory if it doesn't exist + os.makedirs(os.path.join(settings.db_backup_path), exist_ok=True) + except Exception as e: + logger.error( + f"Error creating backup directory: {e}. Run with BACKUP_DB_MIGRATION=False" + " to disable backups before database migrations." + ) + raise e + filepath = os.path.join(settings.db_backup_path, filename) + + if db.type == SQLITE: + filepath = f"{filepath}.sqlite3" + logger.info(f"Creating {db.type} backup of {db.name} db to {filepath}") + os.system(f"cp {db.path} {filepath}") + elif db.type in {POSTGRES, COCKROACH}: + filepath = f"{filepath}.dump" + logger.info(f"Creating {db.type} backup of {db.name} db to {filepath}") + os.system(f"pg_dump --dbname={db.db_location} --file={filepath}") + + return filepath async def migrate_databases(db: Database, migrations_module): @@ -19,6 +55,21 @@ async def set_migration_version(conn, db_name, version): async def run_migration(db, migrations_module): db_name = migrations_module.__name__.split(".")[-2] + # we first check whether any migration is needed and create a backup if so + migration_needed = False + for key, migrate in migrations_module.__dict__.items(): + match = matcher.match(key) + if match: + version = int(match.group(1)) + if version > current_versions.get(db_name, 0): + migration_needed = True + break + if migration_needed and settings.db_backup_path: + logger.debug(f"Creating backup of {db_name} db") + current_version = current_versions.get(db_name, 0) + await backup_database(db, current_version) + + # then we run the migrations for key, migrate in migrations_module.__dict__.items(): match = matcher.match(key) if match: diff --git a/cashu/core/p2pk.py b/cashu/core/p2pk.py index fab14967..f42a3a95 100644 --- a/cashu/core/p2pk.py +++ b/cashu/core/p2pk.py @@ -68,16 +68,16 @@ def n_sigs(self) -> Union[None, int]: return int(n_sigs) if n_sigs else None -def sign_p2pk_sign(message: bytes, private_key: PrivateKey): +def sign_p2pk_sign(message: bytes, private_key: PrivateKey) -> bytes: # ecdsa version # signature = private_key.ecdsa_serialize(private_key.ecdsa_sign(message)) signature = private_key.schnorr_sign( hashlib.sha256(message).digest(), None, raw=True ) - return signature.hex() + return signature -def verify_p2pk_signature(message: bytes, pubkey: PublicKey, signature: bytes): +def verify_p2pk_signature(message: bytes, pubkey: PublicKey, signature: bytes) -> bool: # ecdsa version # return pubkey.ecdsa_verify(message, pubkey.ecdsa_deserialize(signature)) return pubkey.schnorr_verify( diff --git a/cashu/core/settings.py b/cashu/core/settings.py index 689287ea..a7d126e3 100644 --- a/cashu/core/settings.py +++ b/cashu/core/settings.py @@ -17,7 +17,7 @@ def find_env_file(): if not os.path.isfile(env_file): env_file = os.path.join(str(Path.home()), ".cashu", ".env") if os.path.isfile(env_file): - env.read_env(env_file) + env.read_env(env_file, recurse=False, override=True) else: env_file = "" return env_file @@ -45,21 +45,50 @@ class EnvSettings(CashuSettings): cashu_dir: str = Field(default=os.path.join(str(Path.home()), ".cashu")) debug_profiling: bool = Field(default=False) debug_mint_only_deprecated: bool = Field(default=False) + db_backup_path: Optional[str] = Field(default=None) class MintSettings(CashuSettings): mint_private_key: str = Field(default=None) + mint_seed_decryption_key: str = Field(default=None) mint_derivation_path: str = Field(default="m/0'/0'/0'") mint_derivation_path_list: List[str] = Field(default=[]) mint_listen_host: str = Field(default="127.0.0.1") mint_listen_port: int = Field(default=3338) mint_lightning_backend: str = Field(default="LNbitsWallet") mint_database: str = Field(default="data/mint") - mint_peg_out_only: bool = Field(default=False) - mint_max_peg_in: int = Field(default=None) - mint_max_peg_out: int = Field(default=None) - mint_max_request_length: int = Field(default=1000) - mint_max_balance: int = Field(default=None) + mint_test_database: str = Field(default="test_data/test_mint") + mint_peg_out_only: bool = Field( + default=False, + title="Peg-out only", + description="Mint allows no mint operations.", + ) + mint_max_peg_in: int = Field( + default=None, + title="Maximum peg-in", + description="Maximum amount for a mint operation.", + ) + mint_max_peg_out: int = Field( + default=None, + title="Maximum peg-out", + description="Maximum amount for a melt operation.", + ) + mint_max_request_length: int = Field( + default=1000, + title="Maximum request length", + description="Maximum length of REST API request arrays.", + ) + mint_max_balance: int = Field( + default=None, title="Maximum mint balance", description="Maximum mint balance." + ) + mint_duplicate_keysets: bool = Field( + default=True, + title="Duplicate keysets", + description=( + "Whether to duplicate keysets for backwards compatibility before v1 API" + " (Nutshell 0.15.0)." + ), + ) mint_lnbits_endpoint: str = Field(default=None) mint_lnbits_key: str = Field(default=None) @@ -84,7 +113,7 @@ class MintInformation(CashuSettings): class WalletSettings(CashuSettings): - tor: bool = Field(default=True) + tor: bool = Field(default=False) socks_host: str = Field(default=None) # deprecated socks_port: int = Field(default=9050) # deprecated socks_proxy: str = Field(default=None) @@ -94,7 +123,7 @@ class WalletSettings(CashuSettings): mint_port: int = Field(default=3338) wallet_name: str = Field(default="wallet") wallet_unit: str = Field(default="sat") - + wallet_domain_separation: bool = Field(default=False) api_port: int = Field(default=4448) api_host: str = Field(default="127.0.0.1") @@ -110,6 +139,7 @@ class WalletSettings(CashuSettings): ) locktime_delta_seconds: int = Field(default=86400) # 1 day + proofs_batch_size: int = Field(default=1000) class LndRestFundingSource(MintSettings): diff --git a/cashu/lightning/fake.py b/cashu/lightning/fake.py index dc6a6d91..8da840ee 100644 --- a/cashu/lightning/fake.py +++ b/cashu/lightning/fake.py @@ -66,8 +66,7 @@ async def create_invoice( else: tags.add(TagChar.description, memo or "") - if expiry: - tags.add(TagChar.expire_time, expiry) + tags.add(TagChar.expire_time, expiry or 3600) if payment_secret: secret = payment_secret.hex() diff --git a/cashu/mint/conditions.py b/cashu/mint/conditions.py index d48c06ef..051947c3 100644 --- a/cashu/mint/conditions.py +++ b/cashu/mint/conditions.py @@ -83,7 +83,7 @@ def _verify_p2pk_spending_conditions(self, proof: Proof, secret: Secret) -> bool logger.trace(f"verifying signature {input_sig} by pubkey {pubkey}.") logger.trace(f"Message: {p2pk_secret.serialize().encode('utf-8')}") if verify_p2pk_signature( - message=p2pk_secret.serialize().encode("utf-8"), + message=proof.secret.encode("utf-8"), pubkey=PublicKey(bytes.fromhex(pubkey), raw=True), signature=bytes.fromhex(input_sig), ): @@ -154,7 +154,7 @@ def _verify_htlc_spending_conditions(self, proof: Proof, secret: Secret) -> bool assert signature, TransactionError("no HTLC refund signature provided") for pubkey in refund_pubkeys: if verify_p2pk_signature( - message=htlc_secret.serialize().encode("utf-8"), + message=proof.secret.encode("utf-8"), pubkey=PublicKey(bytes.fromhex(pubkey), raw=True), signature=bytes.fromhex(signature), ): @@ -181,7 +181,7 @@ def _verify_htlc_spending_conditions(self, proof: Proof, secret: Secret) -> bool assert signature, TransactionError("HTLC no hash lock signatures provided.") for pubkey in hashlock_pubkeys: if verify_p2pk_signature( - message=htlc_secret.serialize().encode("utf-8"), + message=proof.secret.encode("utf-8"), pubkey=PublicKey(bytes.fromhex(pubkey), raw=True), signature=bytes.fromhex(signature), ): @@ -305,7 +305,7 @@ def _verify_output_p2pk_spending_conditions( for sig in p2pksigs: for pubkey in pubkeys: if verify_p2pk_signature( - message=output.B_.encode("utf-8"), + message=bytes.fromhex(output.B_), pubkey=PublicKey(bytes.fromhex(pubkey), raw=True), signature=bytes.fromhex(sig), ): diff --git a/cashu/mint/crud.py b/cashu/mint/crud.py index 2d9660a3..853ff094 100644 --- a/cashu/mint/crud.py +++ b/cashu/mint/crud.py @@ -1,4 +1,3 @@ -import time from abc import ABC, abstractmethod from typing import Any, List, Optional @@ -9,7 +8,13 @@ MintQuote, Proof, ) -from ..core.db import Connection, Database, table_with_schema +from ..core.db import ( + Connection, + Database, + table_with_schema, + timestamp_from_seconds, + timestamp_now, +) class LedgerCrud(ABC): @@ -27,6 +32,7 @@ async def get_keyset( db: Database, id: str = "", derivation_path: str = "", + seed: str = "", conn: Optional[Connection] = None, ) -> List[MintKeyset]: ... @@ -41,8 +47,8 @@ async def get_spent_proofs( async def get_proof_used( self, *, + Y: str, db: Database, - secret: str, conn: Optional[Connection] = None, ) -> Optional[Proof]: ... @@ -59,6 +65,7 @@ async def invalidate_proof( async def get_proofs_pending( self, *, + proofs: List[Proof], db: Database, conn: Optional[Connection] = None, ) -> List[Proof]: ... @@ -223,7 +230,7 @@ async def store_promise( e, s, id, - int(time.time()), + timestamp_now(db), ), ) @@ -265,28 +272,34 @@ async def invalidate_proof( await (conn or db).execute( f""" INSERT INTO {table_with_schema(db, 'proofs_used')} - (amount, C, secret, id, witness, created) - VALUES (?, ?, ?, ?, ?, ?) + (amount, C, secret, Y, id, witness, created) + VALUES (?, ?, ?, ?, ?, ?, ?) """, ( proof.amount, proof.C, proof.secret, + proof.Y, proof.id, proof.witness, - int(time.time()), + timestamp_now(db), ), ) async def get_proofs_pending( self, *, + proofs: List[Proof], db: Database, conn: Optional[Connection] = None, ) -> List[Proof]: - rows = await (conn or db).fetchall(f""" + rows = await (conn or db).fetchall( + f""" SELECT * from {table_with_schema(db, 'proofs_pending')} - """) + WHERE Y IN ({','.join(['?']*len(proofs))}) + """, + tuple(proof.Y for proof in proofs), + ) return [Proof(**r) for r in rows] async def set_proof_pending( @@ -300,14 +313,15 @@ async def set_proof_pending( await (conn or db).execute( f""" INSERT INTO {table_with_schema(db, 'proofs_pending')} - (amount, C, secret, created) - VALUES (?, ?, ?, ?) + (amount, C, secret, Y, created) + VALUES (?, ?, ?, ?, ?) """, ( proof.amount, - str(proof.C), - str(proof.secret), - int(time.time()), + proof.C, + proof.secret, + proof.Y, + timestamp_now(db), ), ) @@ -348,8 +362,8 @@ async def store_mint_quote( quote.amount, quote.issued, quote.paid, - quote.created_time, - quote.paid_time, + timestamp_from_seconds(db, quote.created_time), + timestamp_from_seconds(db, quote.paid_time), ), ) @@ -367,7 +381,7 @@ async def get_mint_quote( """, (quote_id,), ) - return MintQuote(**dict(row)) if row else None + return MintQuote.from_row(row) if row else None async def get_mint_quote_by_checking_id( self, @@ -383,7 +397,7 @@ async def get_mint_quote_by_checking_id( """, (checking_id,), ) - return MintQuote(**dict(row)) if row else None + return MintQuote.from_row(row) if row else None async def update_mint_quote( self, @@ -398,7 +412,7 @@ async def update_mint_quote( ( quote.issued, quote.paid, - quote.paid_time, + timestamp_from_seconds(db, quote.paid_time), quote.quote, ), ) @@ -442,8 +456,8 @@ async def store_melt_quote( quote.amount, quote.fee_reserve or 0, quote.paid, - quote.created_time, - quote.paid_time, + timestamp_from_seconds(db, quote.created_time), + timestamp_from_seconds(db, quote.paid_time), quote.fee_paid, quote.proof, ), @@ -481,7 +495,7 @@ async def get_melt_quote( ) if row is None: return None - return MeltQuote(**dict(row)) if row else None + return MeltQuote.from_row(row) if row else None async def update_melt_quote( self, @@ -496,7 +510,7 @@ async def update_melt_quote( ( quote.paid, quote.fee_paid, - quote.paid_time, + timestamp_from_seconds(db, quote.paid_time), quote.proof, quote.quote, ), @@ -512,16 +526,18 @@ async def store_keyset( await (conn or db).execute( # type: ignore f""" INSERT INTO {table_with_schema(db, 'keysets')} - (id, seed, derivation_path, valid_from, valid_to, first_seen, active, version, unit) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + (id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( keyset.id, keyset.seed, + keyset.encrypted_seed, + keyset.seed_encryption_method, keyset.derivation_path, - keyset.valid_from or int(time.time()), - keyset.valid_to or int(time.time()), - keyset.first_seen or int(time.time()), + keyset.valid_from or timestamp_now(db), + keyset.valid_to or timestamp_now(db), + keyset.first_seen or timestamp_now(db), True, keyset.version, keyset.unit.name, @@ -545,6 +561,7 @@ async def get_keyset( db: Database, id: Optional[str] = None, derivation_path: Optional[str] = None, + seed: Optional[str] = None, unit: Optional[str] = None, active: Optional[bool] = None, conn: Optional[Connection] = None, @@ -560,6 +577,9 @@ async def get_keyset( if derivation_path is not None: clauses.append("derivation_path = ?") values.append(derivation_path) + if seed is not None: + clauses.append("seed = ?") + values.append(seed) if unit is not None: clauses.append("unit = ?") values.append(unit) @@ -578,15 +598,16 @@ async def get_keyset( async def get_proof_used( self, + *, + Y: str, db: Database, - secret: str, conn: Optional[Connection] = None, ) -> Optional[Proof]: row = await (conn or db).fetchone( f""" SELECT * from {table_with_schema(db, 'proofs_used')} - WHERE secret = ? + WHERE Y = ? """, - (secret,), + (Y,), ) return Proof(**row) if row else None diff --git a/cashu/mint/decrypt.py b/cashu/mint/decrypt.py new file mode 100644 index 00000000..71247bf6 --- /dev/null +++ b/cashu/mint/decrypt.py @@ -0,0 +1,151 @@ +import click + +try: + from ..core.crypto.aes import AESCipher +except ImportError: + # for the CLI to work + from cashu.core.crypto.aes import AESCipher +import asyncio +from functools import wraps + +from cashu.core.db import Database, table_with_schema +from cashu.core.migrations import migrate_databases +from cashu.core.settings import settings +from cashu.mint import migrations +from cashu.mint.crud import LedgerCrudSqlite +from cashu.mint.ledger import Ledger + + +# https://github.com/pallets/click/issues/85#issuecomment-503464628 +def coro(f): + @wraps(f) + def wrapper(*args, **kwargs): + return asyncio.run(f(*args, **kwargs)) + + return wrapper + + +@click.group() +def cli(): + """Ledger Decrypt CLI""" + pass + + +@cli.command() +@click.option("--message", prompt=True, help="The message to encrypt.") +@click.option( + "--key", + prompt=True, + hide_input=True, + confirmation_prompt=True, + help="The encryption key.", +) +def encrypt(message, key): + """Encrypt a message.""" + aes = AESCipher(key) + encrypted_message = aes.encrypt(message.encode()) + click.echo(f"Encrypted message: {encrypted_message}") + + +@cli.command() +@click.option("--encrypted", prompt=True, help="The encrypted message to decrypt.") +@click.option( + "--key", + prompt=True, + hide_input=True, + help="The decryption key.", +) +def decrypt(encrypted, key): + """Decrypt a message.""" + aes = AESCipher(key) + decrypted_message = aes.decrypt(encrypted) + click.echo(f"Decrypted message: {decrypted_message}") + + +# command to migrate the database to encrypted seeds +@cli.command() +@coro +@click.option("--no-dry-run", is_flag=True, help="Dry run.", default=False) +async def migrate(no_dry_run): + """Migrate the database to encrypted seeds.""" + ledger = Ledger( + db=Database("mint", settings.mint_database), + seed=settings.mint_private_key, + seed_decryption_key=settings.mint_seed_decryption_key, + derivation_path=settings.mint_derivation_path, + backends={}, + crud=LedgerCrudSqlite(), + ) + assert settings.mint_seed_decryption_key, "MINT_SEED_DECRYPTION_KEY not set." + assert ( + len(settings.mint_seed_decryption_key) > 12 + ), "MINT_SEED_DECRYPTION_KEY is too short, must be at least 12 characters." + click.echo( + "Decryption key:" + f" {settings.mint_seed_decryption_key[0]}{'*'*10}{settings.mint_seed_decryption_key[-1]}" + ) + + aes = AESCipher(settings.mint_seed_decryption_key) + + click.echo("Making sure that db is migrated to latest version first.") + await migrate_databases(ledger.db, migrations) + + # get all keysets + async with ledger.db.connect() as conn: + rows = await conn.fetchall( + f"SELECT * FROM {table_with_schema(ledger.db, 'keysets')} WHERE seed IS NOT" + " NULL" + ) + click.echo(f"Found {len(rows)} keysets in database.") + keysets_all = [dict(**row) for row in rows] + keysets_migrate = [] + # encrypt the seeds + for keyset_dict in keysets_all: + if keyset_dict["seed"] and not keyset_dict["encrypted_seed"]: + keyset_dict["encrypted_seed"] = aes.encrypt(keyset_dict["seed"].encode()) + keyset_dict["seed_encryption_method"] = "aes" + keysets_migrate.append(keyset_dict) + else: + click.echo(f"Skipping keyset {keyset_dict['id']}: already migrated.") + + click.echo(f"There are {len(keysets_migrate)} keysets to migrate.") + + for keyset_dict in keysets_migrate: + click.echo(f"Keyset {keyset_dict['id']}") + click.echo(f" Encrypted seed: {keyset_dict['encrypted_seed']}") + click.echo(f" Encryption method: {keyset_dict['seed_encryption_method']}") + decryption_success_str = ( + "✅" + if aes.decrypt(keyset_dict["encrypted_seed"]) == keyset_dict["seed"] + else "❌" + ) + click.echo(f" Seed decryption test: {decryption_success_str}") + + if not no_dry_run: + click.echo( + "This was a dry run. Use --no-dry-run to apply the changes to the database." + ) + if no_dry_run and keysets_migrate: + click.confirm( + "Are you sure you want to continue? Before you continue, make sure to have" + " a backup of your keysets database table.", + abort=True, + ) + click.echo("Updating keysets in the database.") + async with ledger.db.connect() as conn: + for keyset_dict in keysets_migrate: + click.echo(f"Updating keyset {keyset_dict['id']}") + await conn.execute( + f"UPDATE {table_with_schema(ledger.db, 'keysets')} SET seed=''," + " encrypted_seed = ?, seed_encryption_method = ? WHERE id = ?", + ( + keyset_dict["encrypted_seed"], + keyset_dict["seed_encryption_method"], + keyset_dict["id"], + ), + ) + click.echo("✅ Migration complete.") + + +if __name__ == "__main__": + cli() diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index 78360984..23502830 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -25,6 +25,7 @@ Unit, ) from ..core.crypto import b_dhke +from ..core.crypto.aes import AESCipher from ..core.crypto.keys import ( derive_keyset_id, derive_keyset_id_deprecated, @@ -38,6 +39,7 @@ KeysetNotFoundError, LightningError, NotAllowedError, + QuoteNotPaidError, TransactionError, ) from ..core.helpers import sum_proofs @@ -67,10 +69,18 @@ def __init__( db: Database, seed: str, backends: Mapping[Method, Mapping[Unit, LightningBackend]], + seed_decryption_key: Optional[str] = None, derivation_path="", crud=LedgerCrudSqlite(), ): - self.master_key = seed + assert seed, "seed not set" + + # decrypt seed if seed_decryption_key is set + self.master_key = ( + AESCipher(seed_decryption_key).decrypt(seed) + if seed_decryption_key + else seed + ) self.derivation_path = derivation_path self.db = db @@ -81,7 +91,14 @@ def __init__( # ------- KEYS ------- - async def activate_keyset(self, derivation_path, autosave=True) -> MintKeyset: + async def activate_keyset( + self, + *, + derivation_path: str, + seed: Optional[str] = None, + version: Optional[str] = None, + autosave=True, + ) -> MintKeyset: """Load the keyset for a derivation path if it already exists. If not generate new one and store in the db. Args: @@ -91,29 +108,33 @@ async def activate_keyset(self, derivation_path, autosave=True) -> MintKeyset: Returns: MintKeyset: Keyset """ - logger.debug(f"Activating keyset for derivation path {derivation_path}") + assert derivation_path, "derivation path not set" + seed = seed or self.master_key + tmp_keyset_local = MintKeyset( + seed=seed, + derivation_path=derivation_path, + version=version or settings.version, + ) + logger.debug( + f"Activating keyset for derivation path {derivation_path} with id" + f" {tmp_keyset_local.id}." + ) # load the keyset from db logger.trace(f"crud: loading keyset for {derivation_path}") - tmp_keyset_local: List[MintKeyset] = await self.crud.get_keyset( - derivation_path=derivation_path, db=self.db + tmp_keysets_local: List[MintKeyset] = await self.crud.get_keyset( + id=tmp_keyset_local.id, db=self.db ) - logger.trace(f"crud: loaded {len(tmp_keyset_local)} keysets") - if tmp_keyset_local: + logger.trace(f"crud: loaded {len(tmp_keysets_local)} keysets") + if tmp_keysets_local: # we have a keyset with this derivation path in the database - keyset = tmp_keyset_local[0] - # we keys are not stored in the database but only their derivation path - # so we might need to generate the keys for keysets loaded from the database - if not len(keyset.private_keys): - keyset.generate_keys() - + keyset = tmp_keysets_local[0] else: - logger.trace(f"crud: no keyset for {derivation_path}") # no keyset for this derivation path yet # we create a new keyset (keys will be generated at instantiation) keyset = MintKeyset( - seed=self.master_key, + seed=seed or self.master_key, derivation_path=derivation_path, - version=settings.version, + version=version or settings.version, ) logger.debug(f"Generated new keyset {keyset.id}.") if autosave: @@ -135,42 +156,35 @@ async def activate_keyset(self, derivation_path, autosave=True) -> MintKeyset: logger.debug(f"Loaded keyset {keyset.id}") return keyset - async def init_keysets(self, autosave=True) -> None: + async def init_keysets( + self, autosave: bool = True, duplicate_keysets: Optional[bool] = None + ) -> None: """Initializes all keysets of the mint from the db. Loads all past keysets from db - and generate their keys. Then load the current keyset. + and generate their keys. Then activate the current keyset set by self.derivation_path. Args: autosave (bool, optional): Whether the current keyset should be saved if it is - not in the database yet. Will be passed to `self.activate_keyset` where it is - generated from `self.derivation_path`. Defaults to True. + not in the database yet. Will be passed to `self.activate_keyset` where it is + generated from `self.derivation_path`. Defaults to True. + duplicate_keysets (bool, optional): Whether to duplicate new keysets and compute + their old keyset id, and duplicate old keysets and compute their new keyset id. + Defaults to False. """ - # load all past keysets from db + # load all past keysets from db, the keys will be generated at instantiation tmp_keysets: List[MintKeyset] = await self.crud.get_keyset(db=self.db) - logger.debug( - f"Loaded {len(tmp_keysets)} keysets from database. Generating keys..." - ) - # add keysets from db to current keysets + + # add keysets from db to memory for k in tmp_keysets: self.keysets[k.id] = k - # generate keys for all keysets in the database - for _, v in self.keysets.items(): - # if we already generated the keys for this keyset, skip - if v.id and v.public_keys and len(v.public_keys): - continue - logger.trace(f"Generating keys for keyset {v.id}") - v.seed = self.master_key - v.generate_keys() - - logger.info(f"Initialized {len(self.keysets)} keysets from the database.") + logger.info(f"Loaded {len(self.keysets)} keysets from database.") # activate the current keyset set by self.derivation_path - self.keyset = await self.activate_keyset(self.derivation_path, autosave) - logger.info( - "Activated keysets from database:" - f" {[f'{k} ({v.unit.name})' for k, v in self.keysets.items()]}" - ) - logger.info(f"Current keyset: {self.keyset.id}") + if self.derivation_path: + self.keyset = await self.activate_keyset( + derivation_path=self.derivation_path, autosave=autosave + ) + logger.info(f"Current keyset: {self.keyset.id}") # check that we have a least one active keyset assert any([k.active for k in self.keysets.values()]), "No active keyset found." @@ -178,18 +192,22 @@ async def init_keysets(self, autosave=True) -> None: # BEGIN BACKWARDS COMPATIBILITY < 0.15.0 # we duplicate new keysets and compute their old keyset id, and # we duplicate old keysets and compute their new keyset id - for _, keyset in copy.copy(self.keysets).items(): - keyset_copy = copy.copy(keyset) - assert keyset_copy.public_keys - if keyset.version_tuple >= (0, 15): - keyset_copy.id = derive_keyset_id_deprecated(keyset_copy.public_keys) - else: - keyset_copy.id = derive_keyset_id(keyset_copy.public_keys) - keyset_copy.duplicate_keyset_id = keyset.id - self.keysets[keyset_copy.id] = keyset_copy - # remember which keyset this keyset was duplicated from - logger.debug(f"Duplicated keyset id {keyset.id} -> {keyset_copy.id}") - + if ( + duplicate_keysets is None and settings.mint_duplicate_keysets + ) or duplicate_keysets: + for _, keyset in copy.copy(self.keysets).items(): + keyset_copy = copy.copy(keyset) + assert keyset_copy.public_keys + if keyset.version_tuple >= (0, 15): + keyset_copy.id = derive_keyset_id_deprecated( + keyset_copy.public_keys + ) + else: + keyset_copy.id = derive_keyset_id(keyset_copy.public_keys) + keyset_copy.duplicate_keyset_id = keyset.id + self.keysets[keyset_copy.id] = keyset_copy + # remember which keyset this keyset was duplicated from + logger.debug(f"Duplicated keyset id {keyset.id} -> {keyset_copy.id}") # END BACKWARDS COMPATIBILITY < 0.15.0 def get_keyset(self, keyset_id: Optional[str] = None) -> Dict[int, str]: @@ -215,7 +233,7 @@ async def _invalidate_proofs( proofs (List[Proof]): Proofs to add to known secret table. conn: (Optional[Connection], optional): Database connection to reuse. Will create a new one if not given. Defaults to None. """ - self.spent_proofs.update({p.secret: p for p in proofs}) + self.spent_proofs.update({p.Y: p for p in proofs}) async with get_db_connection(self.db, conn) as conn: # store in db for p in proofs: @@ -295,6 +313,7 @@ async def mint_quote(self, quote_request: PostMintQuoteRequest) -> MintQuote: MintQuote: Mint quote object. """ logger.trace("called request_mint") + assert quote_request.amount > 0, "amount must be positive" if settings.mint_max_peg_in and quote_request.amount > settings.mint_max_peg_in: raise NotAllowedError( f"Maximum mint amount is {settings.mint_max_peg_in} sat." @@ -334,7 +353,7 @@ async def mint_quote(self, quote_request: PostMintQuoteRequest) -> MintQuote: issued=False, paid=False, created_time=int(time.time()), - expiry=invoice_obj.expiry or 0, + expiry=invoice_obj.expiry, ) await self.crud.store_mint_quote( quote=quote, @@ -368,6 +387,7 @@ async def get_mint_quote(self, quote_id: str) -> MintQuote: if status.paid: logger.trace(f"Setting quote {quote_id} as paid") quote.paid = True + quote.paid_time = int(time.time()) await self.crud.update_mint_quote(quote=quote, db=self.db) return quote @@ -404,7 +424,7 @@ async def mint( ) # create a new lock if it doesn't exist async with self.locks[quote_id]: quote = await self.get_mint_quote(quote_id=quote_id) - assert quote.paid, "quote not paid" + assert quote.paid, QuoteNotPaidError() assert not quote.issued, "quote already issued" assert ( quote.amount == sum_amount_outputs @@ -487,6 +507,7 @@ async def melt_quote( paid=False, fee_reserve=payment_quote.fee.to(unit).amount, created_time=int(time.time()), + expiry=invoice_obj.expiry, ) await self.crud.store_melt_quote(quote=quote, db=self.db) return PostMeltQuoteResponse( @@ -494,6 +515,7 @@ async def melt_quote( amount=quote.amount, fee_reserve=quote.fee_reserve, paid=quote.paid, + expiry=quote.expiry, ) async def get_melt_quote(self, quote_id: str) -> MeltQuote: @@ -593,6 +615,7 @@ async def melt_mint_settle_internally(self, melt_quote: MeltQuote) -> MeltQuote: await self.crud.update_melt_quote(quote=melt_quote, db=self.db) mint_quote.paid = True + mint_quote.paid_time = melt_quote.paid_time await self.crud.update_mint_quote(quote=mint_quote, db=self.db) return melt_quote @@ -625,7 +648,7 @@ async def melt( # make sure that the outputs (for fee return) are in the same unit as the quote if outputs: - await self._verify_outputs(outputs) + await self._verify_outputs(outputs, skip_amount_check=True) assert outputs[0].id, "output id not set" outputs_unit = self.keysets[outputs[0].id].unit assert melt_quote.unit == outputs_unit.name, ( @@ -852,7 +875,7 @@ async def load_used_proofs(self) -> None: logger.debug("Loading used proofs into memory") spent_proofs_list = await self.crud.get_spent_proofs(db=self.db) or [] logger.debug(f"Loaded {len(spent_proofs_list)} used proofs") - self.spent_proofs = {p.secret: p for p in spent_proofs_list} + self.spent_proofs = {p.Y: p for p in spent_proofs_list} async def check_proofs_state(self, secrets: List[str]) -> List[ProofState]: """Checks if provided proofs are spend or are pending. @@ -870,19 +893,25 @@ async def check_proofs_state(self, secrets: List[str]) -> List[ProofState]: List[bool]: List of which proof are pending (True if pending, else False) """ states: List[ProofState] = [] - proofs_spent = await self._get_proofs_spent(secrets) - proofs_pending = await self._get_proofs_pending(secrets) + proofs_spent_idx_secret = await self._get_proofs_spent_idx_secret(secrets) + proofs_pending_idx_secret = await self._get_proofs_pending_idx_secret(secrets) for secret in secrets: - if secret not in proofs_spent and secret not in proofs_pending: + if ( + secret not in proofs_spent_idx_secret + and secret not in proofs_pending_idx_secret + ): states.append(ProofState(secret=secret, state=SpentState.unspent)) - elif secret not in proofs_spent and secret in proofs_pending: + elif ( + secret not in proofs_spent_idx_secret + and secret in proofs_pending_idx_secret + ): states.append(ProofState(secret=secret, state=SpentState.pending)) else: states.append( ProofState( secret=secret, state=SpentState.spent, - witness=proofs_spent[secret].witness, + witness=proofs_spent_idx_secret[secret].witness, ) ) return states @@ -901,13 +930,13 @@ async def _set_proofs_pending(self, proofs: List[Proof]) -> None: async with self.proofs_pending_lock: async with self.db.connect() as conn: await self._validate_proofs_pending(proofs, conn) - for p in proofs: - try: + try: + for p in proofs: await self.crud.set_proof_pending( proof=p, db=self.db, conn=conn ) - except Exception: - raise TransactionError("proofs already pending.") + except Exception: + raise TransactionError("Failed to set proofs pending.") async def _unset_proofs_pending(self, proofs: List[Proof]) -> None: """Deletes proofs from pending table. @@ -931,8 +960,9 @@ async def _validate_proofs_pending( Raises: Exception: At least one of the proofs is in the pending table. """ - proofs_pending = await self.crud.get_proofs_pending(db=self.db, conn=conn) - for p in proofs: - for pp in proofs_pending: - if p.secret == pp.secret: - raise TransactionError("proofs are pending.") + assert ( + len( + await self.crud.get_proofs_pending(proofs=proofs, db=self.db, conn=conn) + ) + == 0 + ), TransactionError("proofs are pending.") diff --git a/cashu/mint/migrations.py b/cashu/mint/migrations.py index 72160405..90323507 100644 --- a/cashu/mint/migrations.py +++ b/cashu/mint/migrations.py @@ -1,6 +1,5 @@ -import time - -from ..core.db import Connection, Database, table_with_schema +from ..core.base import Proof +from ..core.db import Connection, Database, table_with_schema, timestamp_now from ..core.settings import settings @@ -50,34 +49,46 @@ async def m001_initial(db: Database): """) -async def m002_add_balance_views(db: Database): - async with db.connect() as conn: - await conn.execute(f""" - CREATE VIEW {table_with_schema(db, 'balance_issued')} AS - SELECT COALESCE(SUM(s), 0) AS balance FROM ( - SELECT SUM(amount) AS s - FROM {table_with_schema(db, 'promises')} - WHERE amount > 0 - ) AS balance_issued; - """) +async def drop_balance_views(db: Database, conn: Connection): + await conn.execute(f"DROP VIEW IF EXISTS {table_with_schema(db, 'balance')}") + await conn.execute(f"DROP VIEW IF EXISTS {table_with_schema(db, 'balance_issued')}") + await conn.execute( + f"DROP VIEW IF EXISTS {table_with_schema(db, 'balance_redeemed')}" + ) - await conn.execute(f""" - CREATE VIEW {table_with_schema(db, 'balance_redeemed')} AS - SELECT COALESCE(SUM(s), 0) AS balance FROM ( - SELECT SUM(amount) AS s - FROM {table_with_schema(db, 'proofs_used')} - WHERE amount > 0 - ) AS balance_redeemed; - """) - await conn.execute(f""" - CREATE VIEW {table_with_schema(db, 'balance')} AS - SELECT s_issued - s_used FROM ( - SELECT bi.balance AS s_issued, bu.balance AS s_used - FROM {table_with_schema(db, 'balance_issued')} bi - CROSS JOIN {table_with_schema(db, 'balance_redeemed')} bu - ) AS balance; - """) +async def create_balance_views(db: Database, conn: Connection): + await conn.execute(f""" + CREATE VIEW {table_with_schema(db, 'balance_issued')} AS + SELECT COALESCE(SUM(s), 0) AS balance FROM ( + SELECT SUM(amount) AS s + FROM {table_with_schema(db, 'promises')} + WHERE amount > 0 + ) AS balance_issued; + """) + + await conn.execute(f""" + CREATE VIEW {table_with_schema(db, 'balance_redeemed')} AS + SELECT COALESCE(SUM(s), 0) AS balance FROM ( + SELECT SUM(amount) AS s + FROM {table_with_schema(db, 'proofs_used')} + WHERE amount > 0 + ) AS balance_redeemed; + """) + + await conn.execute(f""" + CREATE VIEW {table_with_schema(db, 'balance')} AS + SELECT s_issued - s_used FROM ( + SELECT bi.balance AS s_issued, bu.balance AS s_used + FROM {table_with_schema(db, 'balance_issued')} bi + CROSS JOIN {table_with_schema(db, 'balance_redeemed')} bu + ) AS balance; + """) + + +async def m002_add_balance_views(db: Database): + async with db.connect() as conn: + await create_balance_views(db, conn) async def m003_mint_keysets(db: Database): @@ -185,15 +196,6 @@ async def m008_promises_dleq(db: Database): async def m009_add_out_to_invoices(db: Database): # column in invoices for marking whether the invoice is incoming (out=False) or outgoing (out=True) async with db.connect() as conn: - # we have to drop the balance views first and recreate them later - await conn.execute(f"DROP VIEW IF EXISTS {table_with_schema(db, 'balance')}") - await conn.execute( - f"DROP VIEW IF EXISTS {table_with_schema(db, 'balance_issued')}" - ) - await conn.execute( - f"DROP VIEW IF EXISTS {table_with_schema(db, 'balance_redeemed')}" - ) - # rename column pr to bolt11 await conn.execute( f"ALTER TABLE {table_with_schema(db, 'invoices')} RENAME COLUMN pr TO" @@ -204,10 +206,6 @@ async def m009_add_out_to_invoices(db: Database): f"ALTER TABLE {table_with_schema(db, 'invoices')} RENAME COLUMN hash TO id" ) - # recreate balance views - await m002_add_balance_views(db) - - async with db.connect() as conn: await conn.execute( f"ALTER TABLE {table_with_schema(db, 'invoices')} ADD COLUMN out BOOL" ) @@ -234,7 +232,7 @@ async def m011_add_quote_tables(db: Database): ) await conn.execute( f"UPDATE {table_with_schema(db, table)} SET created =" - f" '{int(time.time())}'" + f" '{timestamp_now(db)}'" ) # add column "witness" to table proofs_used @@ -298,9 +296,214 @@ async def m011_add_quote_tables(db: Database): await conn.execute( f"INSERT INTO {table_with_schema(db, 'mint_quotes')} (quote, method," " request, checking_id, unit, amount, paid, issued, created_time," - " paid_time) SELECT id, 'bolt11', bolt11, payment_hash, 'sat', amount," - f" False, issued, created, 0 FROM {table_with_schema(db, 'invoices')} " + " paid_time) SELECT id, 'bolt11', bolt11, COALESCE(payment_hash, 'None')," + f" 'sat', amount, False, issued, COALESCE(created, '{timestamp_now(db)}')," + f" NULL FROM {table_with_schema(db, 'invoices')} " ) # drop table invoices await conn.execute(f"DROP TABLE {table_with_schema(db, 'invoices')}") + + +async def m012_keysets_uniqueness_with_seed(db: Database): + # copy table keysets to keysets_old, create a new table keysets + # with the same columns but with a unique constraint on (seed, derivation_path) + # and copy the data from keysets_old to keysets, then drop keysets_old + async with db.connect() as conn: + await conn.execute( + f"DROP TABLE IF EXISTS {table_with_schema(db, 'keysets_old')}" + ) + await conn.execute( + f"CREATE TABLE {table_with_schema(db, 'keysets_old')} AS" + f" SELECT * FROM {table_with_schema(db, 'keysets')}" + ) + await conn.execute(f"DROP TABLE {table_with_schema(db, 'keysets')}") + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'keysets')} ( + id TEXT NOT NULL, + derivation_path TEXT, + seed TEXT, + valid_from TIMESTAMP, + valid_to TIMESTAMP, + first_seen TIMESTAMP, + active BOOL DEFAULT TRUE, + version TEXT, + unit TEXT, + + UNIQUE (seed, derivation_path) + + ); + """) + await conn.execute( + f"INSERT INTO {table_with_schema(db, 'keysets')} (id," + " derivation_path, valid_from, valid_to, first_seen," + " active, version, seed, unit) SELECT id, derivation_path," + " valid_from, valid_to, first_seen, active, version, seed," + f" unit FROM {table_with_schema(db, 'keysets_old')}" + ) + await conn.execute(f"DROP TABLE {table_with_schema(db, 'keysets_old')}") + + +async def m013_keysets_add_encrypted_seed(db: Database): + async with db.connect() as conn: + # set keysets table unique constraint to id + # copy table keysets to keysets_old, create a new table keysets + # with the same columns but with a unique constraint on id + # and copy the data from keysets_old to keysets, then drop keysets_old + await conn.execute( + f"DROP TABLE IF EXISTS {table_with_schema(db, 'keysets_old')}" + ) + await conn.execute( + f"CREATE TABLE {table_with_schema(db, 'keysets_old')} AS" + f" SELECT * FROM {table_with_schema(db, 'keysets')}" + ) + await conn.execute(f"DROP TABLE {table_with_schema(db, 'keysets')}") + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'keysets')} ( + id TEXT NOT NULL, + derivation_path TEXT, + seed TEXT, + valid_from TIMESTAMP, + valid_to TIMESTAMP, + first_seen TIMESTAMP, + active BOOL DEFAULT TRUE, + version TEXT, + unit TEXT, + + UNIQUE (id) + + ); + """) + await conn.execute( + f"INSERT INTO {table_with_schema(db, 'keysets')} (id," + " derivation_path, valid_from, valid_to, first_seen," + " active, version, seed, unit) SELECT id, derivation_path," + " valid_from, valid_to, first_seen, active, version, seed," + f" unit FROM {table_with_schema(db, 'keysets_old')}" + ) + await conn.execute(f"DROP TABLE {table_with_schema(db, 'keysets_old')}") + + # add columns encrypted_seed and seed_encryption_method to keysets + await conn.execute( + f"ALTER TABLE {table_with_schema(db, 'keysets')} ADD COLUMN encrypted_seed" + " TEXT" + ) + await conn.execute( + f"ALTER TABLE {table_with_schema(db, 'keysets')} ADD COLUMN" + " seed_encryption_method TEXT" + ) + + +async def m014_proofs_add_Y_column(db: Database): + # get all proofs_used and proofs_pending from the database and compute Y for each of them + async with db.connect() as conn: + rows = await conn.fetchall( + f"SELECT * FROM {table_with_schema(db, 'proofs_used')}" + ) + # Proof() will compute Y from secret upon initialization + proofs_used = [Proof(**r) for r in rows] + + rows = await conn.fetchall( + f"SELECT * FROM {table_with_schema(db, 'proofs_pending')}" + ) + proofs_pending = [Proof(**r) for r in rows] + async with db.connect() as conn: + # we have to drop the balance views first and recreate them later + await drop_balance_views(db, conn) + + await conn.execute( + f"ALTER TABLE {table_with_schema(db, 'proofs_used')} ADD COLUMN Y TEXT" + ) + for proof in proofs_used: + await conn.execute( + f"UPDATE {table_with_schema(db, 'proofs_used')} SET Y = '{proof.Y}'" + f" WHERE secret = '{proof.secret}'" + ) + # Copy proofs_used to proofs_used_old and create a new table proofs_used + # with the same columns but with a unique constraint on (Y) + # and copy the data from proofs_used_old to proofs_used, then drop proofs_used_old + await conn.execute( + f"DROP TABLE IF EXISTS {table_with_schema(db, 'proofs_used_old')}" + ) + await conn.execute( + f"CREATE TABLE {table_with_schema(db, 'proofs_used_old')} AS" + f" SELECT * FROM {table_with_schema(db, 'proofs_used')}" + ) + await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_used')}") + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_used')} ( + amount INTEGER NOT NULL, + C TEXT NOT NULL, + secret TEXT NOT NULL, + id TEXT, + Y TEXT, + created TIMESTAMP, + witness TEXT, + + UNIQUE (Y) + + ); + """) + await conn.execute( + f"INSERT INTO {table_with_schema(db, 'proofs_used')} (amount, C, " + "secret, id, Y, created, witness) SELECT amount, C, secret, id, Y," + f" created, witness FROM {table_with_schema(db, 'proofs_used_old')}" + ) + await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_used_old')}") + + # add column Y to proofs_pending + await conn.execute( + f"ALTER TABLE {table_with_schema(db, 'proofs_pending')} ADD COLUMN Y TEXT" + ) + for proof in proofs_pending: + await conn.execute( + f"UPDATE {table_with_schema(db, 'proofs_pending')} SET Y = '{proof.Y}'" + f" WHERE secret = '{proof.secret}'" + ) + + # Copy proofs_pending to proofs_pending_old and create a new table proofs_pending + # with the same columns but with a unique constraint on (Y) + # and copy the data from proofs_pending_old to proofs_pending, then drop proofs_pending_old + await conn.execute( + f"DROP TABLE IF EXISTS {table_with_schema(db, 'proofs_pending_old')}" + ) + + await conn.execute( + f"CREATE TABLE {table_with_schema(db, 'proofs_pending_old')} AS" + f" SELECT * FROM {table_with_schema(db, 'proofs_pending')}" + ) + + await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_pending')}") + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {table_with_schema(db, 'proofs_pending')} ( + amount INTEGER NOT NULL, + C TEXT NOT NULL, + secret TEXT NOT NULL, + Y TEXT, + id TEXT, + created TIMESTAMP, + + UNIQUE (Y) + + ); + """) + await conn.execute( + f"INSERT INTO {table_with_schema(db, 'proofs_pending')} (amount, C, " + "secret, Y, id, created) SELECT amount, C, secret, Y, id, created" + f" FROM {table_with_schema(db, 'proofs_pending_old')}" + ) + + await conn.execute(f"DROP TABLE {table_with_schema(db, 'proofs_pending_old')}") + + # recreate the balance views + await create_balance_views(db, conn) + + +async def m015_add_index_Y_to_proofs_used(db: Database): + # create index on proofs_used table for Y + async with db.connect() as conn: + await conn.execute( + "CREATE INDEX IF NOT EXISTS" + " proofs_used_Y_idx ON" + f" {table_with_schema(db, 'proofs_used')} (Y)" + ) diff --git a/cashu/mint/router.py b/cashu/mint/router.py index 5cf303b6..d0cac90a 100644 --- a/cashu/mint/router.py +++ b/cashu/mint/router.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from fastapi import APIRouter, Request +from fastapi import APIRouter from loguru import logger from ..core.base import ( @@ -50,6 +50,7 @@ async def info() -> GetInfoResponse: mint_features: Dict[int, Dict[str, Any]] = { 4: dict( methods=method_unit_pairs, + disabled=False, ), 5: dict( methods=method_unit_pairs, @@ -80,8 +81,7 @@ async def info() -> GetInfoResponse: name="Mint public keys", summary="Get the public keys of the newest mint keyset", response_description=( - "A dictionary of all supported token values of the mint and their associated" - " public key of the current keyset." + "All supported token values their associated public keys for all active keysets" ), response_model=KeysResponse, ) @@ -107,12 +107,12 @@ async def keys(): name="Keyset public keys", summary="Public keys of a specific keyset", response_description=( - "A dictionary of all supported token values of the mint and their associated" + "All supported token values of the mint and their associated" " public key for a specific keyset." ), response_model=KeysResponse, ) -async def keyset_keys(keyset_id: str, request: Request) -> KeysResponse: +async def keyset_keys(keyset_id: str) -> KeysResponse: """ Get the public keys of the mint from a specific keyset id. """ @@ -127,7 +127,7 @@ async def keyset_keys(keyset_id: str, request: Request) -> KeysResponse: keyset = ledger.keysets.get(keyset_id) if keyset is None: - raise CashuError(code=0, detail="Keyset not found.") + raise CashuError(code=0, detail="keyset not found") keyset_for_response = KeysResponseKeyset( id=keyset.id, @@ -172,12 +172,6 @@ async def mint_quote(payload: PostMintQuoteRequest) -> PostMintQuoteResponse: Call `POST /v1/mint/bolt11` after paying the invoice. """ logger.trace(f"> POST /v1/mint/quote/bolt11: payload={payload}") - amount = payload.amount - if amount > 21_000_000 * 100_000_000 or amount <= 0: - raise CashuError(code=0, detail="Amount must be a valid amount of sat.") - if settings.mint_peg_out_only: - raise CashuError(code=0, detail="Mint does not allow minting new tokens.") - quote = await ledger.mint_quote(payload) resp = PostMintQuoteResponse( request=quote.request, @@ -190,7 +184,7 @@ async def mint_quote(payload: PostMintQuoteRequest) -> PostMintQuoteResponse: @router.get( - "/v1/mint/quote/{quote}", + "/v1/mint/quote/bolt11/{quote}", summary="Get mint quote", response_model=PostMintQuoteResponse, response_description="Get an existing mint quote to check its status.", @@ -199,7 +193,7 @@ async def get_mint_quote(quote: str) -> PostMintQuoteResponse: """ Get mint quote state. """ - logger.trace(f"> POST /v1/mint/quote/{quote}") + logger.trace(f"> GET /v1/mint/quote/bolt11/{quote}") mint_quote = await ledger.get_mint_quote(quote) resp = PostMintQuoteResponse( quote=mint_quote.quote, @@ -207,14 +201,14 @@ async def get_mint_quote(quote: str) -> PostMintQuoteResponse: paid=mint_quote.paid, expiry=mint_quote.expiry, ) - logger.trace(f"< POST /v1/mint/quote/{quote}") + logger.trace(f"< GET /v1/mint/quote/bolt11/{quote}") return resp @router.post( "/v1/mint/bolt11", - name="Mint tokens", - summary="Mint tokens in exchange for a Bitcoin payment that the user has made", + name="Mint tokens with a Lightning payment", + summary="Mint tokens by paying a bolt11 Lightning invoice.", response_model=PostMintResponse, response_description=( "A list of blinded signatures that can be used to create proofs." @@ -253,7 +247,7 @@ async def get_melt_quote(payload: PostMeltQuoteRequest) -> PostMeltQuoteResponse @router.get( - "/v1/melt/quote/{quote}", + "/v1/melt/quote/bolt11/{quote}", summary="Get melt quote", response_model=PostMeltQuoteResponse, response_description="Get an existing melt quote to check its status.", @@ -262,15 +256,16 @@ async def melt_quote(quote: str) -> PostMeltQuoteResponse: """ Get melt quote state. """ - logger.trace(f"> POST /v1/melt/quote/{quote}") + logger.trace(f"> GET /v1/melt/quote/bolt11/{quote}") melt_quote = await ledger.get_melt_quote(quote) resp = PostMeltQuoteResponse( quote=melt_quote.quote, amount=melt_quote.amount, fee_reserve=melt_quote.fee_reserve, paid=melt_quote.paid, + expiry=melt_quote.expiry, ) - logger.trace(f"< POST /v1/melt/quote/{quote}") + logger.trace(f"< GET /v1/melt/quote/bolt11/{quote}") return resp @@ -311,7 +306,7 @@ async def melt(payload: PostMeltRequest) -> PostMeltResponse: "An array of blinded signatures that can be used to create proofs." ), ) -async def split( +async def swap( payload: PostSplitRequest, ) -> PostSplitResponse: """ diff --git a/cashu/mint/router_deprecated.py b/cashu/mint/router_deprecated.py index b1d4be64..a2ac71b7 100644 --- a/cashu/mint/router_deprecated.py +++ b/cashu/mint/router_deprecated.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Dict, List, Optional from fastapi import APIRouter from loguru import logger @@ -70,7 +70,7 @@ async def info() -> GetInfoResponse_deprecated: response_model=KeysResponse_deprecated, deprecated=True, ) -async def keys_deprecated(): +async def keys_deprecated() -> Dict[str, str]: """This endpoint returns a dictionary of all supported token values of the mint and their associated public key.""" logger.trace("> GET /keys") keyset = ledger.get_keyset() @@ -86,10 +86,10 @@ async def keys_deprecated(): "A dictionary of all supported token values of the mint and their associated" " public key for a specific keyset." ), - response_model=KeysResponse_deprecated, + response_model=Dict[str, str], deprecated=True, ) -async def keyset_deprecated(idBase64Urlsafe: str): +async def keyset_deprecated(idBase64Urlsafe: str) -> Dict[str, str]: """ Get the public keys of the mint from a specific keyset id. The id is encoded in idBase64Urlsafe (by a wallet) and is converted back to @@ -323,7 +323,7 @@ async def split_deprecated( ), deprecated=True, ) -async def check_spendable( +async def check_spendable_deprecated( payload: CheckSpendableRequest_deprecated, ) -> CheckSpendableResponse_deprecated: """Check whether a secret has been spent already or not.""" diff --git a/cashu/mint/startup.py b/cashu/mint/startup.py index acc78559..f094f97b 100644 --- a/cashu/mint/startup.py +++ b/cashu/mint/startup.py @@ -16,6 +16,18 @@ logger.debug("Enviroment Settings:") for key, value in settings.dict().items(): + if key in [ + "mint_private_key", + "mint_seed_decryption_key", + "nostr_private_key", + "mint_lnbits_key", + "mint_strike_key", + "mint_lnd_rest_macaroon", + "mint_lnd_rest_admin_macaroon", + "mint_lnd_rest_invoice_macaroon", + "mint_corelightning_rest_macaroon", + ]: + value = "********" if value is not None else None logger.debug(f"{key}: {value}") wallets_module = importlib.import_module("cashu.lightning") @@ -39,6 +51,7 @@ ledger = Ledger( db=Database("mint", settings.mint_database), seed=settings.mint_private_key, + seed_decryption_key=settings.mint_seed_decryption_key, derivation_path=settings.mint_derivation_path, backends=backends, crud=LedgerCrudSqlite(), @@ -56,7 +69,7 @@ async def rotate_keys(n_seconds=60): incremented_derivation_path = ( "/".join(ledger.derivation_path.split("/")[:-1]) + f"/{i}" ) - await ledger.activate_keyset(incremented_derivation_path) + await ledger.activate_keyset(derivation_path=incremented_derivation_path) logger.info(f"Current keyset: {ledger.keyset.id}") await asyncio.sleep(n_seconds) @@ -68,7 +81,7 @@ async def start_mint_init(): await ledger.init_keysets() for derivation_path in settings.mint_derivation_path_list: - await ledger.activate_keyset(derivation_path) + await ledger.activate_keyset(derivation_path=derivation_path) for method in ledger.backends: for unit in ledger.backends[method]: diff --git a/cashu/mint/verification.py b/cashu/mint/verification.py index 06eb8b72..e3da1d54 100644 --- a/cashu/mint/verification.py +++ b/cashu/mint/verification.py @@ -51,8 +51,10 @@ async def verify_inputs_and_outputs( """ # Verify inputs # Verify proofs are spendable - spent_proofs = await self._get_proofs_spent([p.secret for p in proofs]) - if not len(spent_proofs) == 0: + if ( + not len(await self._get_proofs_spent_idx_secret([p.secret for p in proofs])) + == 0 + ): raise TokenAlreadySpentError() # Verify amounts of inputs if not all([self._verify_amount(p.amount) for p in proofs]): @@ -96,7 +98,9 @@ async def verify_inputs_and_outputs( if outputs and not self._verify_output_spending_conditions(proofs, outputs): raise TransactionError("validation of output spending conditions failed.") - async def _verify_outputs(self, outputs: List[BlindedMessage]): + async def _verify_outputs( + self, outputs: List[BlindedMessage], skip_amount_check=False + ): """Verify that the outputs are valid.""" logger.trace(f"Verifying {len(outputs)} outputs.") # Verify all outputs have the same keyset id @@ -108,8 +112,10 @@ async def _verify_outputs(self, outputs: List[BlindedMessage]): if not self.keysets[outputs[0].id].active: raise TransactionError("keyset id inactive.") # Verify amounts of outputs - if not all([self._verify_amount(o.amount) for o in outputs]): - raise TransactionError("invalid amount.") + # we skip the amount check for NUT-8 change outputs (which can have amount 0) + if not skip_amount_check: + if not all([self._verify_amount(o.amount) for o in outputs]): + raise TransactionError("invalid amount.") # verify that only unique outputs were used if not self._verify_no_duplicate_outputs(outputs): raise TransactionError("duplicate outputs.") @@ -137,27 +143,35 @@ async def _check_outputs_issued_before(self, outputs: List[BlindedMessage]): result.append(False if promise is None else True) return result - async def _get_proofs_pending(self, secrets: List[str]) -> Dict[str, Proof]: + async def _get_proofs_pending_idx_secret( + self, secrets: List[str] + ) -> Dict[str, Proof]: """Returns only those proofs that are pending.""" - all_proofs_pending = await self.crud.get_proofs_pending(db=self.db) + all_proofs_pending = await self.crud.get_proofs_pending( + proofs=[Proof(secret=s) for s in secrets], db=self.db + ) proofs_pending = list(filter(lambda p: p.secret in secrets, all_proofs_pending)) proofs_pending_dict = {p.secret: p for p in proofs_pending} return proofs_pending_dict - async def _get_proofs_spent(self, secrets: List[str]) -> Dict[str, Proof]: + async def _get_proofs_spent_idx_secret( + self, secrets: List[str] + ) -> Dict[str, Proof]: """Returns all proofs that are spent.""" + proofs = [Proof(secret=s) for s in secrets] proofs_spent: List[Proof] = [] if settings.mint_cache_secrets: # check used secrets in memory - for secret in secrets: - if secret in self.spent_proofs: - proofs_spent.append(self.spent_proofs[secret]) + for proof in proofs: + spent_proof = self.spent_proofs.get(proof.Y) + if spent_proof: + proofs_spent.append(spent_proof) else: # check used secrets in database async with self.db.connect() as conn: - for secret in secrets: + for proof in proofs: spent_proof = await self.crud.get_proof_used( - db=self.db, secret=secret, conn=conn + db=self.db, Y=proof.Y, conn=conn ) if spent_proof: proofs_spent.append(spent_proof) diff --git a/cashu/wallet/cli/cli.py b/cashu/wallet/cli/cli.py index 2d1d177c..b2010b86 100644 --- a/cashu/wallet/cli/cli.py +++ b/cashu/wallet/cli/cli.py @@ -14,10 +14,9 @@ from click import Context from loguru import logger -from cashu.core.logging import configure_logger - from ...core.base import TokenV3, Unit from ...core.helpers import sum_proofs +from ...core.logging import configure_logger from ...core.settings import settings from ...nostr.client.client import NostrClient from ...tor.tor import TorProxy @@ -577,7 +576,12 @@ async def burn(ctx: Context, token: str, all: bool, force: bool, delete: str): if delete: await wallet.invalidate(proofs) else: - await wallet.invalidate(proofs, check_spendable=True) + # invalidate proofs in batches + for _proofs in [ + proofs[i : i + settings.proofs_batch_size] + for i in range(0, len(proofs), settings.proofs_batch_size) + ]: + await wallet.invalidate(_proofs, check_spendable=True) print_balance(ctx) diff --git a/cashu/wallet/nostr.py b/cashu/wallet/nostr.py index b6099cc9..56dfe73f 100644 --- a/cashu/wallet/nostr.py +++ b/cashu/wallet/nostr.py @@ -6,8 +6,7 @@ from httpx import ConnectError from loguru import logger -from cashu.core.base import TokenV3 - +from ..core.base import TokenV3 from ..core.settings import settings from ..nostr.client.client import NostrClient from ..nostr.event import Event diff --git a/cashu/wallet/p2pk.py b/cashu/wallet/p2pk.py index a2d824c7..aff057c0 100644 --- a/cashu/wallet/p2pk.py +++ b/cashu/wallet/p2pk.py @@ -79,7 +79,7 @@ async def sign_p2pk_proofs(self, proofs: List[Proof]) -> List[str]: sign_p2pk_sign( message=proof.secret.encode("utf-8"), private_key=private_key, - ) + ).hex() for proof in proofs ] logger.debug(f"Signatures: {signatures}") @@ -93,9 +93,9 @@ async def sign_p2pk_outputs(self, outputs: List[BlindedMessage]) -> List[str]: assert private_key.pubkey return [ sign_p2pk_sign( - message=output.B_.encode("utf-8"), + message=bytes.fromhex(output.B_), private_key=private_key, - ) + ).hex() for output in outputs ] diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 8ace2478..aef3dc13 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -157,7 +157,12 @@ def raise_on_error_request( Raises: Exception: if the response contains an error """ - resp_dict = resp.json() + try: + resp_dict = resp.json() + except json.JSONDecodeError: + # if we can't decode the response, raise for status + resp.raise_for_status() + return if "detail" in resp_dict: logger.trace(f"Error from mint: {resp_dict}") error_message = f"Mint Error: {resp_dict['detail']}" @@ -543,6 +548,7 @@ async def melt_quote(self, payment_request: str) -> PostMeltQuoteResponse: amount=invoice_obj.amount_msat // 1000, fee_reserve=ret.fee or 0, paid=False, + expiry=invoice_obj.expiry, ) # END backwards compatibility < 0.15.0 self.raise_on_error_request(resp) @@ -1120,7 +1126,14 @@ async def _construct_proofs( C = b_dhke.step3_alice( C_, r, self.keysets[promise.id].public_keys[promise.amount] ) - B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs + # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 + if not settings.wallet_domain_separation: + B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs + # END: BACKWARDS COMPATIBILITY < 0.15.1 + else: + B_, r = b_dhke.step1_alice_domain_separated( + secret, r + ) # recompute B_ for dleq proofs proof = Proof( id=promise.id, @@ -1183,7 +1196,12 @@ def _construct_outputs( rs_ = [None] * len(amounts) if not rs else rs rs_return: List[PrivateKey] = [] for secret, amount, r in zip(secrets, amounts, rs_): - B_, r = b_dhke.step1_alice(secret, r or None) + # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 + if not settings.wallet_domain_separation: + B_, r = b_dhke.step1_alice(secret, r or None) + # END: BACKWARDS COMPATIBILITY < 0.15.1 + else: + B_, r = b_dhke.step1_alice_domain_separated(secret, r or None) rs_return.append(r) output = BlindedMessage( amount=amount, B_=B_.serialize().hex(), id=self.keyset_id diff --git a/setup.py b/setup.py index d9d8e3d8..52574dc7 100644 --- a/setup.py +++ b/setup.py @@ -14,12 +14,12 @@ setuptools.setup( name="cashu", version="0.15.0", - description="Ecash wallet and mint for Bitcoin Lightning", + description="Ecash wallet and mint", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/cashubtc/cashu", author="Calle", - author_email="calle@protonmail.com", + author_email="callebtc@protonmail.com", license="MIT", packages=setuptools.find_namespace_packages(), classifiers=[ diff --git a/tests/conftest.py b/tests/conftest.py index 82aa4ab0..f6887cb2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import asyncio +import importlib import multiprocessing import os import shutil @@ -34,17 +35,21 @@ settings.fakewallet_brr = True settings.fakewallet_delay_payment = False settings.fakewallet_stochastic_invoice = False -settings.mint_database = "./test_data/test_mint" +assert ( + settings.mint_test_database != settings.mint_database +), "Test database is the same as the main database" +settings.mint_database = settings.mint_test_database settings.mint_derivation_path = "m/0'/0'/0'" settings.mint_derivation_path_list = [] settings.mint_private_key = "TEST_PRIVATE_KEY" +settings.mint_seed_decryption_key = "" settings.mint_max_balance = 0 assert "test" in settings.cashu_dir shutil.rmtree(settings.cashu_dir, ignore_errors=True) Path(settings.cashu_dir).mkdir(parents=True, exist_ok=True) -from cashu.mint.startup import lightning_backend # noqa +# from cashu.mint.startup import lightning_backend # noqa @pytest.fixture(scope="session") @@ -98,7 +103,15 @@ async def start_mint_init(ledger: Ledger): db_file = os.path.join(settings.mint_database, "mint.sqlite3") if os.path.exists(db_file): os.remove(db_file) - + else: + # clear postgres database + db = Database("mint", settings.mint_database) + async with db.connect() as conn: + await conn.execute("DROP SCHEMA public CASCADE;") + await conn.execute("CREATE SCHEMA public;") + + wallets_module = importlib.import_module("cashu.lightning") + lightning_backend = getattr(wallets_module, settings.mint_lightning_backend)() backends = { Method.bolt11: {Unit.sat: lightning_backend}, } diff --git a/tests/helpers.py b/tests/helpers.py index 0cd36a9f..94fe729b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -30,6 +30,8 @@ async def get_random_invoice_data(): is_fake: bool = WALLET.__class__.__name__ == "FakeWallet" is_regtest: bool = not is_fake is_deprecated_api_only = settings.debug_mint_only_deprecated +is_github_actions = os.getenv("GITHUB_ACTIONS") == "true" +is_postgres = settings.mint_database.startswith("postgres") docker_lightning_cli = [ "docker", diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 00000000..3b2af374 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,69 @@ +import datetime +import os +import time + +import pytest + +from cashu.core import db +from cashu.core.db import Connection, timestamp_now +from cashu.core.migrations import backup_database +from cashu.core.settings import settings +from cashu.mint.ledger import Ledger +from tests.helpers import is_github_actions, is_postgres + + +@pytest.mark.asyncio +@pytest.mark.skipif( + is_github_actions and is_postgres, + reason=( + "Fails on GitHub Actions because pg_dump is not the same version as postgres" + ), +) +async def test_backup_db_migration(ledger: Ledger): + settings.db_backup_path = "./test_data/backups/" + filepath = await backup_database(ledger.db, 999) + assert os.path.exists(filepath) + + +@pytest.mark.asyncio +async def test_timestamp_now(ledger: Ledger): + ts = timestamp_now(ledger.db) + if ledger.db.type == db.SQLITE: + assert isinstance(ts, str) + assert int(ts) <= time.time() + elif ledger.db.type in {db.POSTGRES, db.COCKROACH}: + assert isinstance(ts, str) + datetime.datetime.strptime(ts, "%Y-%m-%d %H:%M:%S") + + +@pytest.mark.asyncio +async def test_get_connection(ledger: Ledger): + async with ledger.db.connect() as conn: + assert isinstance(conn, Connection) + + +@pytest.mark.asyncio +async def test_db_tables(ledger: Ledger): + async with ledger.db.connect() as conn: + if ledger.db.type == db.SQLITE: + tables_res = await conn.execute( + "SELECT name FROM sqlite_master WHERE type='table';" + ) + elif ledger.db.type in {db.POSTGRES, db.COCKROACH}: + tables_res = await conn.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema =" + " 'public';" + ) + tables = [t[0] for t in await tables_res.fetchall()] + tables_expected = [ + "dbversions", + "keysets", + "proofs_used", + "proofs_pending", + "melt_quotes", + "mint_quotes", + "mint_pubkeys", + "promises", + ] + for table in tables_expected: + assert table in tables diff --git a/tests/test_mint_api.py b/tests/test_mint_api.py index 14de13b2..22d8e30c 100644 --- a/tests/test_mint_api.py +++ b/tests/test_mint_api.py @@ -182,6 +182,15 @@ async def test_mint_quote(ledger: Ledger): assert result["request"] invoice = bolt11.decode(result["request"]) assert invoice.amount_msat == 100 * 1000 + assert result["expiry"] == invoice.expiry + + # get mint quote again from api + response = httpx.get( + f"{BASE_URL}/v1/mint/quote/bolt11/{result['quote']}", + ) + assert response.status_code == 200, f"{response.url} {response.status_code}" + result2 = response.json() + assert result2["quote"] == result["quote"] @pytest.mark.asyncio @@ -235,6 +244,16 @@ async def test_melt_quote_internal(ledger: Ledger, wallet: Wallet): assert result["amount"] == 64 # TODO: internal invoice, fee should be 0 assert result["fee_reserve"] == 0 + invoice_obj = bolt11.decode(request) + assert result["expiry"] == invoice_obj.expiry + + # get melt quote again from api + response = httpx.get( + f"{BASE_URL}/v1/melt/quote/bolt11/{result['quote']}", + ) + assert response.status_code == 200, f"{response.url} {response.status_code}" + result2 = response.json() + assert result2["quote"] == result["quote"] @pytest.mark.asyncio diff --git a/tests/test_mint_api_deprecated.py b/tests/test_mint_api_deprecated.py index 12676ebc..fcdcf8cd 100644 --- a/tests/test_mint_api_deprecated.py +++ b/tests/test_mint_api_deprecated.py @@ -186,6 +186,44 @@ async def test_melt_internal(ledger: Ledger, wallet: Wallet): assert result["paid"] is True +@pytest.mark.asyncio +async def test_melt_internal_no_change_outputs(ledger: Ledger, wallet: Wallet): + # Clients without NUT-08 will not send change outputs + # internal invoice + invoice = await wallet.request_mint(64) + pay_if_regtest(invoice.bolt11) + await wallet.mint(64, id=invoice.id) + assert wallet.balance == 64 + + # create invoice to melt to + invoice = await wallet.request_mint(64) + + invoice_payment_request = invoice.bolt11 + + quote = await wallet.melt_quote(invoice_payment_request) + assert quote.amount == 64 + assert quote.fee_reserve == 0 + + inputs_payload = [p.to_dict() for p in wallet.proofs] + + # outputs for change + secrets, rs, derivation_paths = await wallet.generate_n_secrets(1) + outputs, rs = wallet._construct_outputs([2], secrets, rs) + + response = httpx.post( + f"{BASE_URL}/melt", + json={ + "pr": invoice_payment_request, + "proofs": inputs_payload, + }, + timeout=None, + ) + assert response.status_code == 200, f"{response.url} {response.status_code}" + result = response.json() + assert result.get("preimage") is not None + assert result["paid"] is True + + @pytest.mark.asyncio @pytest.mark.skipif( is_fake, diff --git a/tests/test_mint_db.py b/tests/test_mint_db.py new file mode 100644 index 00000000..1d006a2f --- /dev/null +++ b/tests/test_mint_db.py @@ -0,0 +1,71 @@ +import pytest +import pytest_asyncio + +from cashu.core.base import PostMeltQuoteRequest +from cashu.mint.ledger import Ledger +from cashu.wallet.wallet import Wallet +from cashu.wallet.wallet import Wallet as Wallet1 +from tests.conftest import SERVER_ENDPOINT +from tests.helpers import is_postgres + + +async def assert_err(f, msg): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + if msg not in str(exc.args[0]): + raise Exception(f"Expected error: {msg}, got: {exc.args[0]}") + return + raise Exception(f"Expected error: {msg}, got no error") + + +@pytest_asyncio.fixture(scope="function") +async def wallet1(ledger: Ledger): + wallet1 = await Wallet1.with_db( + url=SERVER_ENDPOINT, + db="test_data/wallet1", + name="wallet1", + ) + await wallet1.load_mint() + yield wallet1 + + +@pytest.mark.asyncio +async def test_mint_quote(wallet1: Wallet, ledger: Ledger): + invoice = await wallet1.request_mint(128) + assert invoice is not None + quote = await ledger.crud.get_mint_quote(quote_id=invoice.id, db=ledger.db) + assert quote is not None + assert quote.quote == invoice.id + assert quote.amount == 128 + assert quote.unit == "sat" + assert not quote.paid + assert quote.checking_id == invoice.payment_hash + assert quote.paid_time is None + assert quote.created_time + + +@pytest.mark.asyncio +async def test_melt_quote(wallet1: Wallet, ledger: Ledger): + invoice = await wallet1.request_mint(128) + assert invoice is not None + melt_quote = await ledger.melt_quote( + PostMeltQuoteRequest(request=invoice.bolt11, unit="sat") + ) + quote = await ledger.crud.get_melt_quote(quote_id=melt_quote.quote, db=ledger.db) + assert quote is not None + assert quote.quote == melt_quote.quote + assert quote.amount == 128 + assert quote.unit == "sat" + assert not quote.paid + assert quote.checking_id == invoice.payment_hash + assert quote.paid_time is None + assert quote.created_time + + +@pytest.mark.asyncio +@pytest.mark.skipif(not is_postgres, reason="only works with Postgres") +async def test_postgres_working(): + assert is_postgres + assert True diff --git a/tests/test_mint_init.py b/tests/test_mint_init.py new file mode 100644 index 00000000..77f111b8 --- /dev/null +++ b/tests/test_mint_init.py @@ -0,0 +1,128 @@ +from typing import List + +import pytest + +from cashu.core.base import Proof +from cashu.core.crypto.aes import AESCipher +from cashu.core.db import Database +from cashu.core.settings import settings +from cashu.mint.crud import LedgerCrudSqlite +from cashu.mint.ledger import Ledger + +SEED = "TEST_PRIVATE_KEY" +DERIVATION_PATH = "m/0'/0'/0'" +DECRYPTON_KEY = "testdecryptionkey" +ENCRYPTED_SEED = "U2FsdGVkX1_7UU_-nVBMBWDy_9yDu4KeYb7MH8cJTYQGD4RWl82PALH8j-HKzTrI" + + +async def assert_err(f, msg): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + assert exc.args[0] == msg, Exception( + f"Expected error: {msg}, got: {exc.args[0]}" + ) + + +def assert_amt(proofs: List[Proof], expected: int): + """Assert amounts the proofs contain.""" + assert [p.amount for p in proofs] == expected + + +@pytest.mark.asyncio +async def test_init_keysets_with_duplicates(ledger: Ledger): + ledger.keysets = {} + await ledger.init_keysets(duplicate_keysets=True) + assert len(ledger.keysets) == 2 + + +@pytest.mark.asyncio +async def test_init_keysets_with_duplicates_via_settings(ledger: Ledger): + ledger.keysets = {} + settings.mint_duplicate_keysets = True + await ledger.init_keysets() + assert len(ledger.keysets) == 2 + + +@pytest.mark.asyncio +async def test_init_keysets_without_duplicates(ledger: Ledger): + ledger.keysets = {} + await ledger.init_keysets(duplicate_keysets=False) + assert len(ledger.keysets) == 1 + + +@pytest.mark.asyncio +async def test_init_keysets_without_duplicates_via_settings(ledger: Ledger): + ledger.keysets = {} + settings.mint_duplicate_keysets = False + await ledger.init_keysets() + assert len(ledger.keysets) == 1 + + +@pytest.mark.asyncio +async def test_ledger_encrypt(): + aes = AESCipher(DECRYPTON_KEY) + encrypted = aes.encrypt(SEED.encode()) + assert aes.decrypt(encrypted) == SEED + + +@pytest.mark.asyncio +async def test_ledger_decrypt(): + aes = AESCipher(DECRYPTON_KEY) + assert aes.decrypt(ENCRYPTED_SEED) == SEED + + +@pytest.mark.asyncio +async def test_decrypt_seed(): + ledger = Ledger( + db=Database("mint", settings.mint_database), + seed=SEED, + seed_decryption_key=None, + derivation_path=DERIVATION_PATH, + backends={}, + crud=LedgerCrudSqlite(), + ) + await ledger.init_keysets() + assert ledger.keyset.seed == SEED + private_key_1 = ( + ledger.keysets[list(ledger.keysets.keys())[0]].private_keys[1].serialize() + ) + assert ( + private_key_1 + == "8300050453f08e6ead1296bb864e905bd46761beed22b81110fae0751d84604d" + ) + pubkeys = ledger.keysets[list(ledger.keysets.keys())[0]].public_keys + assert pubkeys + assert ( + pubkeys[1].serialize().hex() + == "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + ) + + ledger_encrypted = Ledger( + db=Database("mint", settings.mint_database), + seed=ENCRYPTED_SEED, + seed_decryption_key=DECRYPTON_KEY, + derivation_path=DERIVATION_PATH, + backends={}, + crud=LedgerCrudSqlite(), + ) + await ledger_encrypted.init_keysets() + assert ledger_encrypted.keyset.seed == SEED + private_key_1 = ( + ledger_encrypted.keysets[list(ledger_encrypted.keysets.keys())[0]] + .private_keys[1] + .serialize() + ) + assert ( + private_key_1 + == "8300050453f08e6ead1296bb864e905bd46761beed22b81110fae0751d84604d" + ) + pubkeys_encrypted = ledger_encrypted.keysets[ + list(ledger_encrypted.keysets.keys())[0] + ].public_keys + assert pubkeys_encrypted + assert ( + pubkeys_encrypted[1].serialize().hex() + == "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + ) diff --git a/tests/test_mint_keysets.py b/tests/test_mint_keysets.py new file mode 100644 index 00000000..6085ea76 --- /dev/null +++ b/tests/test_mint_keysets.py @@ -0,0 +1,73 @@ +import pytest + +from cashu.core.base import MintKeyset +from cashu.core.settings import settings +from tests.test_mint_init import DECRYPTON_KEY, DERIVATION_PATH, ENCRYPTED_SEED, SEED + + +async def assert_err(f, msg): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + if msg not in str(exc.args[0]): + raise Exception(f"Expected error: {msg}, got: {exc.args[0]}") + return + raise Exception(f"Expected error: {msg}, got no error") + + +@pytest.mark.asyncio +async def test_keyset_0_15_0(): + keyset = MintKeyset(seed=SEED, derivation_path=DERIVATION_PATH, version="0.15.0") + assert len(keyset.public_keys_hex) == settings.max_order + assert keyset.seed == "TEST_PRIVATE_KEY" + assert keyset.derivation_path == "m/0'/0'/0'" + assert ( + keyset.public_keys_hex[1] + == "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + ) + assert keyset.id == "009a1f293253e41e" + + +@pytest.mark.asyncio +async def test_keyset_0_14_0(): + keyset = MintKeyset(seed=SEED, derivation_path=DERIVATION_PATH, version="0.14.0") + assert len(keyset.public_keys_hex) == settings.max_order + assert keyset.seed == "TEST_PRIVATE_KEY" + assert keyset.derivation_path == "m/0'/0'/0'" + assert ( + keyset.public_keys_hex[1] + == "036d6f3adf897e88e16ece3bffb2ce57a0b635fa76f2e46dbe7c636a937cd3c2f2" + ) + assert keyset.id == "xnI+Y0j7cT1/" + + +@pytest.mark.asyncio +async def test_keyset_0_11_0(): + keyset = MintKeyset(seed=SEED, derivation_path=DERIVATION_PATH, version="0.11.0") + assert len(keyset.public_keys_hex) == settings.max_order + assert keyset.seed == "TEST_PRIVATE_KEY" + assert keyset.derivation_path == "m/0'/0'/0'" + assert ( + keyset.public_keys_hex[1] + == "026b714529f157d4c3de5a93e3a67618475711889b6434a497ae6ad8ace6682120" + ) + assert keyset.id == "Zkdws9zWxNc4" + + +@pytest.mark.asyncio +async def test_keyset_0_15_0_encrypted(): + settings.mint_seed_decryption_key = DECRYPTON_KEY + keyset = MintKeyset( + encrypted_seed=ENCRYPTED_SEED, + derivation_path=DERIVATION_PATH, + version="0.15.0", + ) + assert len(keyset.public_keys_hex) == settings.max_order + assert keyset.seed == "TEST_PRIVATE_KEY" + assert keyset.derivation_path == "m/0'/0'/0'" + assert ( + keyset.public_keys_hex[1] + == "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" + ) + assert keyset.id == "009a1f293253e41e" diff --git a/tests/test_wallet.py b/tests/test_wallet.py index 15c253dc..78195f6d 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -13,7 +13,13 @@ from cashu.wallet.wallet import Wallet as Wallet1 from cashu.wallet.wallet import Wallet as Wallet2 from tests.conftest import SERVER_ENDPOINT -from tests.helpers import get_real_invoice, is_fake, is_regtest, pay_if_regtest +from tests.helpers import ( + get_real_invoice, + is_fake, + is_github_actions, + is_regtest, + pay_if_regtest, +) async def assert_err(f, msg: Union[str, CashuError]): @@ -349,12 +355,30 @@ async def test_duplicate_proofs_double_spent(wallet1: Wallet): doublespend = await wallet1.mint(64, id=invoice.id) await assert_err( wallet1.split(wallet1.proofs + doublespend, 20), - "Mint Error: proofs already pending.", + "Mint Error: Failed to set proofs pending.", ) assert wallet1.balance == 64 assert wallet1.available_balance == 64 +@pytest.mark.asyncio +@pytest.mark.skipif(is_github_actions, reason="GITHUB_ACTIONS") +async def test_split_race_condition(wallet1: Wallet): + invoice = await wallet1.request_mint(64) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(64, id=invoice.id) + # run two splits in parallel + import asyncio + + await assert_err( + asyncio.gather( + wallet1.split(wallet1.proofs, 20), + wallet1.split(wallet1.proofs, 20), + ), + "proofs are pending.", + ) + + @pytest.mark.asyncio async def test_send_and_redeem(wallet1: Wallet, wallet2: Wallet): invoice = await wallet1.request_mint(64) diff --git a/tests/test_cli.py b/tests/test_wallet_cli.py similarity index 100% rename from tests/test_cli.py rename to tests/test_wallet_cli.py