diff --git a/counterparty-core/counterpartycore/lib/gettxinfo.py b/counterparty-core/counterpartycore/lib/gettxinfo.py index 2b9299b588..7f579baeeb 100644 --- a/counterparty-core/counterpartycore/lib/gettxinfo.py +++ b/counterparty-core/counterpartycore/lib/gettxinfo.py @@ -146,7 +146,8 @@ def get_vin_info(vin): if "value" in vin: return vin["value"], vin["script_pub_key"], vin["is_segwit"] - # Note: We don't know what block the `vin` is in, and the block might have been from a while ago, so this call may not hit the cache. + # Note: We don't know what block the `vin` is in, and the block might + # have been from a while ago, so this call may not hit the cache. vin_ctx = backend.bitcoind.get_decoded_transaction(vin["hash"]) is_segwit = len(vin_ctx["vtxinwit"]) > 0 @@ -155,6 +156,64 @@ def get_vin_info(vin): return vout["value"], vout["script_pub_key"], is_segwit +def is_der_signature_and_not_sighash_all(value): + if isinstance(value, str): + value = binascii.unhexlify(value) + if not isinstance(value, bytes): + return False + if not ( + value.startswith(binascii.unhexlify("3044")) + or value.startswith(binascii.unhexlify("3045")) + and (len(value) == 71 or len(value) == 72) + ): + return False + if not value.endswith(b"\x01"): # 01 is SIGHASH_ALL + return True + return False + + +def is_schnorr_signature_and_not_sighash_all(value): + if isinstance(value, str): + value = binascii.unhexlify(value) + if not isinstance(value, bytes): + return False + # sighash flag is optionnal for schnorr signature + if len(value) not in [64, 65]: + return False + # all flags except 0x01 or no flag are invalid + if len(value) == 65 and not value.endswith(b"\x01"): + return True + return False + + +# We use the following heuristic to check the SIGHASH flag: +# - We look for all items that have the characteristics of a DER encoded signature +# (starting with 3044 or 3045 and of length 70 or 71) in the witnesses and scripts of all inputs. +# - If one of these items ends with something other than '01' (SIGHASH_ALL) the transaction is invalid +# - If the witnesses contain an odd number of elements we assume that one of them +# is a schnorr signature for a P2TR input: if one of the items is 65 in length and +# ends with something other than '01' (SIGHASH_ALL) the transaction is invalid + + +def check_witnesses_sighash(decoded_tx): + if not decoded_tx["segwit"]: + return + + for item in decoded_tx["vtxinwit"]: + if is_der_signature_and_not_sighash_all(item): + raise DecodeError("invalid SIGHASH flag") + # if there is an odd number of items, we assume than one of the item is a schnorr signature + # for a P2TR key path spend + if len(decoded_tx["vtxinwit"]) % 2 == 1 and is_schnorr_signature_and_not_sighash_all(item): + raise DecodeError("invalid SIGHASH flag") + + +def check_script_sighash(asm): + for item in asm: + if is_der_signature_and_not_sighash_all(item): + raise DecodeError("invalid SIGHASH flag") + + def get_transaction_sources(decoded_tx): sources = [] outputs_value = 0 @@ -166,6 +225,8 @@ def get_transaction_sources(decoded_tx): asm = script.script_to_asm(script_pubkey) + check_script_sighash(asm) + if asm[-1] == OP_CHECKSIG: # noqa: F405 new_source, new_data = decode_checksig(asm, decoded_tx) if new_data or not new_source: @@ -391,6 +452,10 @@ def get_tx_info_new(db, decoded_tx, block_index, p2sh_is_segwit=False, composing else: raise BTCOnlyError("no data and not unspendable") + # check for invalid SIGHASH flags in witness data + # each in input is also checked in get_transaction_sources() + check_witnesses_sighash(decoded_tx) + # Collect all (unique) source addresses. # if we haven't found them yet if p2sh_encoding_source is None: diff --git a/counterparty-core/counterpartycore/test/regtest/testscenarios.py b/counterparty-core/counterpartycore/test/regtest/testscenarios.py index ebb1a1a51a..f852718427 100644 --- a/counterparty-core/counterpartycore/test/regtest/testscenarios.py +++ b/counterparty-core/counterpartycore/test/regtest/testscenarios.py @@ -438,7 +438,7 @@ def run_scenarios(serve=False, wsgi_server="gunicorn"): print(regtest_node_thread.node.server_out.getvalue()) raise e finally: - # print(regtest_node_thread.node.server_out.getvalue()) + print(regtest_node_thread.node.server_out.getvalue()) regtest_node_thread.stop()