Skip to content

Commit

Permalink
check inputs sighash flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Ouziel committed Dec 5, 2024
1 parent bff6f8a commit daba261
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
67 changes: 66 additions & 1 deletion counterparty-core/counterpartycore/lib/gettxinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def get_vin_info(vin):
if "value" in vin:
return vin["value"], vin["script_pub_key"], vin["is_segwit"]

# Note: We don't know what block the `vin` is in, and the block might have been from a while ago, so this call may not hit the cache.
# Note: We don't know what block the `vin` is in, and the block might
# have been from a while ago, so this call may not hit the cache.
vin_ctx = backend.bitcoind.get_decoded_transaction(vin["hash"])

is_segwit = len(vin_ctx["vtxinwit"]) > 0
Expand All @@ -155,6 +156,64 @@ def get_vin_info(vin):
return vout["value"], vout["script_pub_key"], is_segwit


def is_der_signature_and_not_sighash_all(value):
if isinstance(value, str):
value = binascii.unhexlify(value)
if not isinstance(value, bytes):
return False
if not (
value.startswith(binascii.unhexlify("3044"))
or value.startswith(binascii.unhexlify("3045"))
and (len(value) == 71 or len(value) == 72)
):
return False
if not value.endswith(b"\x01"): # 01 is SIGHASH_ALL
return True
return False


def is_schnorr_signature_and_not_sighash_all(value):
if isinstance(value, str):
value = binascii.unhexlify(value)
if not isinstance(value, bytes):
return False
# sighash flag is optionnal for schnorr signature
if len(value) not in [64, 65]:
return False
# all flags except 0x01 or no flag are invalid
if len(value) == 65 and not value.endswith(b"\x01"):
return True
return False


# We use the following heuristic to check the SIGHASH flag:
# - We look for all items that have the characteristics of a DER encoded signature
# (starting with 3044 or 3045 and of length 70 or 71) in the witnesses and scripts of all inputs.
# - If one of these items ends with something other than '01' (SIGHASH_ALL) the transaction is invalid
# - If the witnesses contain an odd number of elements we assume that one of them
# is a schnorr signature for a P2TR input: if one of the items is 65 in length and
# ends with something other than '01' (SIGHASH_ALL) the transaction is invalid


def check_witnesses_sighash(decoded_tx):
if not decoded_tx["segwit"]:
return

for item in decoded_tx["vtxinwit"]:
if is_der_signature_and_not_sighash_all(item):
raise DecodeError("invalid SIGHASH flag")
# if there is an odd number of items, we assume than one of the item is a schnorr signature
# for a P2TR key path spend
if len(decoded_tx["vtxinwit"]) % 2 == 1 and is_schnorr_signature_and_not_sighash_all(item):
raise DecodeError("invalid SIGHASH flag")


def check_script_sighash(asm):
for item in asm:
if is_der_signature_and_not_sighash_all(item):
raise DecodeError("invalid SIGHASH flag")


def get_transaction_sources(decoded_tx):
sources = []
outputs_value = 0
Expand All @@ -166,6 +225,8 @@ def get_transaction_sources(decoded_tx):

asm = script.script_to_asm(script_pubkey)

check_script_sighash(asm)

if asm[-1] == OP_CHECKSIG: # noqa: F405
new_source, new_data = decode_checksig(asm, decoded_tx)
if new_data or not new_source:
Expand Down Expand Up @@ -391,6 +452,10 @@ def get_tx_info_new(db, decoded_tx, block_index, p2sh_is_segwit=False, composing
else:
raise BTCOnlyError("no data and not unspendable")

# check for invalid SIGHASH flags in witness data
# each in input is also checked in get_transaction_sources()
check_witnesses_sighash(decoded_tx)

# Collect all (unique) source addresses.
# if we haven't found them yet
if p2sh_encoding_source is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def run_scenarios(serve=False, wsgi_server="gunicorn"):
print(regtest_node_thread.node.server_out.getvalue())
raise e
finally:
# print(regtest_node_thread.node.server_out.getvalue())
print(regtest_node_thread.node.server_out.getvalue())
regtest_node_thread.stop()


Expand Down

0 comments on commit daba261

Please sign in to comment.