Skip to content

Commit

Permalink
Merge pull request #2807 from CounterpartyXCP/sighash
Browse files Browse the repository at this point in the history
Check Sighash Flag
  • Loading branch information
ouziel-slama authored Dec 9, 2024
2 parents 1411de7 + fb5a8a5 commit ece96b5
Show file tree
Hide file tree
Showing 24 changed files with 5,062 additions and 4,384 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ruff-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ jobs:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
with:
version: 0.7.4
version: 0.8.2
2 changes: 1 addition & 1 deletion .github/workflows/ruff-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ jobs:
- uses: chartboost/ruff-action@v1
with:
args: "format --check"
version: 0.7.4
version: 0.8.2
4,634 changes: 2,324 additions & 2,310 deletions apiary.apib

Large diffs are not rendered by default.

45 changes: 23 additions & 22 deletions counterparty-core/counterpartycore/lib/api/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def check_database_version():
database.update_version(state_db)


def run_api_server(args, server_ready_value, stop_event):
def run_api_server(args, server_ready_value, stop_event, parent_pid):
logger.info("Starting API Server process...")

def handle_interrupt_signal(signum, frame):
Expand Down Expand Up @@ -513,7 +513,7 @@ def handle_interrupt_signal(signum, frame):
wsgi_server = wsgi.WSGIApplication(app, args=args)

logger.info("Starting Parent Process Checker thread...")
parent_checker = ParentProcessChecker(wsgi_server)
parent_checker = ParentProcessChecker(wsgi_server, stop_event, parent_pid)
parent_checker.start()

app.app_context().push()
Expand All @@ -539,52 +539,53 @@ def handle_interrupt_signal(signum, frame):
watcher.stop()
watcher.join()

if parent_checker is not None:
logger.trace("Stopping Parent Process Checker thread...")
parent_checker.stop()
parent_checker.join()

logger.info("API Server stopped.")


def is_process_alive(pid):
"""Check For the existence of a unix pid."""
try:
os.kill(pid, 0)
except OSError:
return False
else:
return True


# This thread is used for the following two reasons:
# 1. `docker-compose stop` does not send a SIGTERM to the child processes (in this case the API v2 process)
# 2. `process.terminate()` does not trigger a `KeyboardInterrupt` or execute the `finally` block.
class ParentProcessChecker(threading.Thread):
def __init__(self, wsgi_server):
def __init__(self, wsgi_server, stop_event, parent_pid):
super().__init__(name="ParentProcessChecker")
self.daemon = True
self.wsgi_server = wsgi_server
self.stop_event = threading.Event()
self.stop_event = stop_event
self.parent_pid = parent_pid

def run(self):
parent_pid = os.getppid()
try:
while not self.stop_event.is_set():
if os.getppid() != parent_pid:
logger.debug("Parent process is dead. Exiting...")
if self.wsgi_server is not None:
self.wsgi_server.stop()
break
while not self.stop_event.is_set() and is_process_alive(self.parent_pid):
time.sleep(1)
logger.debug("Parent process stopped. Exiting...")
if self.wsgi_server is not None:
self.wsgi_server.stop()
except KeyboardInterrupt:
pass

def stop(self):
self.stop_event.set()


class APIServer(object):
def __init__(self):
def __init__(self, stop_event):
self.process = None
self.server_ready_value = Value("I", 0)
self.stop_event = multiprocessing.Event()
self.stop_event = stop_event

def start(self, args):
if self.process is not None:
raise Exception("API Server is already running")
self.process = Process(
target=run_api_server, args=(vars(args), self.server_ready_value, self.stop_event)
target=run_api_server,
args=(vars(args), self.server_ready_value, self.stop_event, os.getpid()),
)
self.process.start()
return self.process
Expand Down
1 change: 1 addition & 0 deletions counterparty-core/counterpartycore/lib/api/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ def get_tx_info(tx_hex, block_index=None):
db,
deserialize.deserialize_tx(tx_hex, use_txid=use_txid),
block_index=block_index,
composing=True,
)
)
return (
Expand Down
14 changes: 10 additions & 4 deletions counterparty-core/counterpartycore/lib/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,21 @@ def read_transaction(vds, use_txid=True):
offset_before_tx_witnesses = vds.read_cursor - start_pos
for vin in transaction["vin"]: # noqa: B007
witnesses_count = vds.read_compact_size()
for i in range(witnesses_count): # noqa: B007
witness_length = vds.read_compact_size()
witness = vds.read_bytes(witness_length)
transaction["vtxinwit"].append(witness)
if witnesses_count == 0:
transaction["vtxinwit"].append([])
else:
vin_witnesses = []
for i in range(witnesses_count): # noqa: B007
witness_length = vds.read_compact_size()
witness = vds.read_bytes(witness_length)
vin_witnesses.append(witness)
transaction["vtxinwit"].append(vin_witnesses)

transaction["lock_time"] = vds.read_uint32()
data = vds.input[start_pos : vds.read_cursor]

transaction["tx_hash"] = ib2h(double_hash(data))
transaction["tx_id"] = transaction["tx_hash"]
if transaction["segwit"]:
hash_data = data[:4] + data[6:offset_before_tx_witnesses] + data[-4:]
transaction["tx_id"] = ib2h(double_hash(hash_data))
Expand Down
157 changes: 152 additions & 5 deletions counterparty-core/counterpartycore/lib/gettxinfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import binascii
import logging
import struct
from io import BytesIO

from counterpartycore.lib import arc4, backend, config, ledger, message_type, script, util
from counterpartycore.lib.exceptions import BTCOnlyError, DecodeError
Expand Down Expand Up @@ -146,7 +147,8 @@ def get_vin_info(vin):
if "value" in vin:
return vin["value"], vin["script_pub_key"], vin["is_segwit"]

# Note: We don't know what block the `vin` is in, and the block might have been from a while ago, so this call may not hit the cache.
# Note: We don't know what block the `vin` is in, and the block might
# have been from a while ago, so this call may not hit the cache.
vin_ctx = backend.bitcoind.get_decoded_transaction(vin["hash"])

is_segwit = len(vin_ctx["vtxinwit"]) > 0
Expand All @@ -155,11 +157,152 @@ def get_vin_info(vin):
return vout["value"], vout["script_pub_key"], is_segwit


def is_valid_der(der):
if not isinstance(der, bytes):
return False
try:
s = BytesIO(der)
compound = s.read(1)[0]
if compound != 0x30:
return False
length = s.read(1)[0]
if length + 2 != len(der):
return False
marker = s.read(1)[0]
if marker != 0x02:
return False
rlength = s.read(1)[0]
_r = int(s.read(rlength).hex(), 16)
marker = s.read(1)[0]
if marker != 0x02:
return False
slength = s.read(1)[0]
s = int(s.read(slength).hex(), 16)
if len(der) != 6 + rlength + slength:
return False
return True
except Exception:
return False


def is_valid_schnorr(schnorr):
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141

if not isinstance(schnorr, bytes):
return False
if len(schnorr) not in [64, 65]:
return False
if len(schnorr) == 65:
schnorr = schnorr[:-1]
try:
r = int.from_bytes(schnorr[0:32], byteorder="big")
s = int.from_bytes(schnorr[32:64], byteorder="big")
except Exception:
return False
if (r >= p) or (s >= n):
return False
return True


def get_der_signature_sighash_flag(value):
if is_valid_der(value[:-1]):
return value[-1:]
return None


def get_schnorr_signature_sighash_flag(value):
if is_valid_schnorr(value):
if len(value) == 65:
return value[-1:]
return b"\x01" # SIGHASH_ALL


def collect_sighash_flags(script_sig, witnesses):
flags = []

# P2PK, P2PKH, P2MS
if script_sig != b"":
asm = script.script_to_asm(script_sig)
for item in asm:
flag = get_der_signature_sighash_flag(item)
if flag is not None:
flags.append(flag)

if len(witnesses) == 0:
return flags

witnesses = [
binascii.unhexlify(witness) if isinstance(witness, str) else witness
for witness in witnesses
]

# P2WPKH
if len(witnesses) == 2:
flag = get_der_signature_sighash_flag(witnesses[0])
if flag is not None:
flags.append(flag)
return flags

# P2TR key path spend
if len(witnesses) == 1:
flag = get_schnorr_signature_sighash_flag(witnesses[0])
if flag is not None:
flags.append(flag)
return flags

# Other cases
if len(witnesses) >= 3:
for item in witnesses:
flag = get_schnorr_signature_sighash_flag(item) or get_der_signature_sighash_flag(item)
if flag is not None:
flags.append(flag)
return flags

return flags


# class SighashFlagError(DecodeError):
class SighashFlagError(Exception):
pass


# known transactions with invalid SIGHASH flag
SIGHASH_FLAG_TRANSACTION_WHITELIST = [
"c8091f1ef768a2f00d48e6d0f7a2c2d272a5d5c8063db78bf39977adcb12e103"
]


def check_signatures_sighash_flag(decoded_tx):
if decoded_tx["tx_id"] in SIGHASH_FLAG_TRANSACTION_WHITELIST:
return

script_sig = decoded_tx["vin"][0]["script_sig"]
witnesses = []
if decoded_tx["segwit"]:
witnesses = decoded_tx["vtxinwit"][0]

flags = collect_sighash_flags(script_sig, witnesses)

if len(flags) == 0:
error = f"impossible to determine SIGHASH flag for transaction {decoded_tx['tx_id']}"
logger.debug(error)
raise SighashFlagError(error)

# first input must be signed with SIGHASH_ALL or SIGHASH_ALL|SIGHASH_ANYONECANPAY
authorized_flags = [b"\x01", b"\x81"]
for flag in flags:
if flag not in authorized_flags:
error = f"invalid SIGHASH flag for transaction {decoded_tx['tx_id']}"
logger.debug(error)
raise SighashFlagError(error)


def get_transaction_sources(decoded_tx):
sources = []
outputs_value = 0

for vin in decoded_tx["vin"][:]: # Loop through inputs.
for vin in decoded_tx["vin"]: # Loop through inputs.
vout_value, script_pubkey, _is_segwit = get_vin_info(vin)

outputs_value += vout_value
Expand Down Expand Up @@ -394,6 +537,8 @@ def get_tx_info_new(db, decoded_tx, block_index, p2sh_is_segwit=False, composing
# Collect all (unique) source addresses.
# if we haven't found them yet
if p2sh_encoding_source is None:
if not composing:
check_signatures_sighash_flag(decoded_tx)
sources, outputs_value = get_transaction_sources(decoded_tx)
if not fee_added:
fee += outputs_value
Expand Down Expand Up @@ -524,7 +669,7 @@ def get_tx_info_legacy(decoded_tx, block_index):
return source, destination, btc_amount, fee, data, []


def _get_tx_info(db, decoded_tx, block_index, p2sh_is_segwit=False):
def _get_tx_info(db, decoded_tx, block_index, p2sh_is_segwit=False, composing=False):
"""Get the transaction info. Calls one of two subfunctions depending on signature type."""
if not block_index:
block_index = util.CURRENT_BLOCK_INDEX
Expand All @@ -535,12 +680,14 @@ def _get_tx_info(db, decoded_tx, block_index, p2sh_is_segwit=False):
decoded_tx,
block_index,
p2sh_is_segwit=p2sh_is_segwit,
composing=composing,
)
elif util.enabled("multisig_addresses", block_index=block_index): # Protocol change.
return get_tx_info_new(
db,
decoded_tx,
block_index,
composing=composing,
)
else:
return get_tx_info_legacy(decoded_tx, block_index)
Expand Down Expand Up @@ -604,7 +751,7 @@ def get_utxos_info(db, decoded_tx):
]


def get_tx_info(db, decoded_tx, block_index):
def get_tx_info(db, decoded_tx, block_index, composing=False):
"""Get the transaction info. Returns normalized None data for DecodeError and BTCOnlyError."""
if util.enabled("utxo_support", block_index=block_index):
# utxos_info is a space-separated list of UTXOs, last element is the destination,
Expand All @@ -618,7 +765,7 @@ def get_tx_info(db, decoded_tx, block_index):
utxos_info = []
try:
source, destination, btc_amount, fee, data, dispensers_outs = _get_tx_info(
db, decoded_tx, block_index
db, decoded_tx, block_index, composing=composing
)
return source, destination, btc_amount, fee, data, dispensers_outs, utxos_info
except DecodeError as e: # noqa: F841
Expand Down
5 changes: 3 additions & 2 deletions counterparty-core/counterpartycore/lib/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2163,7 +2163,7 @@ def _get_holders(
return holders


def holders(db, asset, exclude_empty_holders=False):
def holders(db, asset, exclude_empty_holders=False, block_index=None):
"""Return holders of the asset."""
holders = []
cursor = db.cursor()
Expand All @@ -2189,8 +2189,9 @@ def holders(db, asset, exclude_empty_holders=False):
SELECT *, rowid
FROM balances
WHERE asset = ? AND utxo IS NOT NULL
ORDER BY rowid DESC
ORDER BY utxo
"""

bindings = (asset,)
cursor.execute(query, bindings)
holders += _get_holders(
Expand Down
12 changes: 8 additions & 4 deletions counterparty-core/counterpartycore/lib/message_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ def get_transaction_type(data: bytes, destination: str, block_index: int):
return "unknown"

if message_type_id == messages.utxo.ID:
message_data = messages.utxo.unpack(message, return_dict=True)
if util.is_utxo_format(message_data["source"]):
return "detach"
return "attach"
try:
message_data = messages.utxo.unpack(message, return_dict=True)
if util.is_utxo_format(message_data["source"]):
return "detach"
return "attach"
except Exception:
return "unknown"

return TRANSACTION_TYPE_BY_ID.get(message_type_id, "unknown")
Loading

0 comments on commit ece96b5

Please sign in to comment.