From da0cfc146545c1ec220901094c98e32690f255af Mon Sep 17 00:00:00 2001 From: Giacomo Caironi <30932677+giacomocaironi@users.noreply.github.com> Date: Sun, 3 Dec 2023 16:17:35 +0100 Subject: [PATCH] Check transaction amounts (#121) --- btclib/script/engine/__init__.py | 9 ++++++- tests/script_engine/test_transactions.py | 30 ++++++++++++++++++++---- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/btclib/script/engine/__init__.py b/btclib/script/engine/__init__.py index a15a298d..532470d3 100644 --- a/btclib/script/engine/__init__.py +++ b/btclib/script/engine/__init__.py @@ -180,12 +180,19 @@ def verify_input(prevouts: list[TxOut], tx: Tx, i: int, flags: list[str]) -> Non raise BTClibValueError() +def verify_amounts(prevouts: list[TxOut], tx: Tx) -> None: + if sum(x.value for x in tx.vout) > sum(x.value for x in prevouts): + raise BTClibValueError("Invalid transaction amounts") + + def verify_transaction( - prevouts: list[TxOut], tx: Tx, flags: list | None = None + prevouts: list[TxOut], tx: Tx, flags: list | None = None, check_amounts=True ) -> None: if flags is None: flags = ALL_FLAGS[:] if len(prevouts) != len(tx.vin): raise BTClibValueError() + if check_amounts: + verify_amounts(prevouts, tx) for i in range(len(prevouts)): verify_input(prevouts, tx, i, flags) diff --git a/tests/script_engine/test_transactions.py b/tests/script_engine/test_transactions.py index 90596dad..edb010fc 100644 --- a/tests/script_engine/test_transactions.py +++ b/tests/script_engine/test_transactions.py @@ -16,9 +16,9 @@ import pytest from btclib.exceptions import BTClibValueError -from btclib.script.engine import verify_input, verify_transaction +from btclib.script.engine import verify_amounts, verify_input, verify_transaction from btclib.script.witness import Witness -from btclib.tx.tx import Tx +from btclib.tx import OutPoint, Tx, TxIn from btclib.tx.tx_out import ScriptPubKey, TxOut from tests.script_engine import parse_script @@ -107,13 +107,19 @@ def test_valid_legacy() -> None: if f in flags: flags.remove(f) + check_amounts = True + prevouts = [] for i in x[0]: amount = 0 if len(i) == 3 else i[3] + if not amount: + check_amounts = False script_pub_key = parse_script(i[2]) prevouts.append(TxOut(amount, ScriptPubKey(script_pub_key))) - verify_transaction(prevouts, tx, flags if flags != ["NONE"] else None) + verify_transaction( + prevouts, tx, flags if flags != ["NONE"] else None, check_amounts + ) def test_invalid_legacy() -> None: @@ -136,13 +142,29 @@ def test_invalid_legacy() -> None: flags = x[2].split(",") # different flags handling + check_amounts = True + prevouts = [] for i in x[0]: amount = 0 if len(i) == 3 else i[3] + if not amount: + check_amounts = False with warnings.catch_warnings(): warnings.simplefilter("ignore") script_pub_key = parse_script(i[2]) prevouts.append(TxOut(amount, ScriptPubKey(script_pub_key))) with pytest.raises((BTClibValueError, IndexError, KeyError)): - verify_transaction(prevouts, tx, flags if flags != ["NONE"] else None) + verify_transaction( + prevouts, tx, flags if flags != ["NONE"] else None, check_amounts + ) + + +def test_invalid_amount() -> None: + prevout = TxOut(0, ScriptPubKey("")) + + tx = Tx(vin=[TxIn(OutPoint(b"1" * 32, 1))], vout=[TxOut(10, ScriptPubKey(""))]) + + # Output amount greater than sum of inputs + with pytest.raises(BTClibValueError): + verify_amounts([prevout], tx)