From d7894426397886a4865dc083e1998aef7e5adc75 Mon Sep 17 00:00:00 2001 From: Salvatore Ingala <6681844+bigspider@users.noreply.github.com> Date: Fri, 1 Mar 2024 15:13:28 +0100 Subject: [PATCH] Compute aggregate xpub for musig() in descriptors in the python client library --- bitcoin_client/ledger_bitcoin/bip0327.py | 177 +++++++++++++++++++++++ bitcoin_client/ledger_bitcoin/client.py | 56 ++++++- 2 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 bitcoin_client/ledger_bitcoin/bip0327.py diff --git a/bitcoin_client/ledger_bitcoin/bip0327.py b/bitcoin_client/ledger_bitcoin/bip0327.py new file mode 100644 index 000000000..8d4680791 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/bip0327.py @@ -0,0 +1,177 @@ +# extracted from the BIP327 reference implementation: https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0327/reference.py + +# Only contains the key aggregation part of the library + +# The code in this source file is distributed under the BSD-3-Clause. + +# autopep8: off + +from typing import List, Optional, Tuple, NewType, NamedTuple +import hashlib + +# +# The following helper functions were copied from the BIP-340 reference implementation: +# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py +# + +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 + +# Points are tuples of X and Y coordinates and the point at infinity is +# represented by the None keyword. +G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) + +Point = Tuple[int, int] + +# This implementation can be sped up by storing the midstate after hashing +# tag_hash instead of rehashing it all the time. +def tagged_hash(tag: str, msg: bytes) -> bytes: + tag_hash = hashlib.sha256(tag.encode()).digest() + return hashlib.sha256(tag_hash + tag_hash + msg).digest() + +def is_infinite(P: Optional[Point]) -> bool: + return P is None + +def x(P: Point) -> int: + assert not is_infinite(P) + return P[0] + +def y(P: Point) -> int: + assert not is_infinite(P) + return P[1] + +def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: + if P1 is None: + return P2 + if P2 is None: + return P1 + if (x(P1) == x(P2)) and (y(P1) != y(P2)): + return None + if P1 == P2: + lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p + else: + lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p + x3 = (lam * lam - x(P1) - x(P2)) % p + return (x3, (lam * (x(P1) - x3) - y(P1)) % p) + +def point_mul(P: Optional[Point], n: int) -> Optional[Point]: + R = None + for i in range(256): + if (n >> i) & 1: + R = point_add(R, P) + P = point_add(P, P) + return R + +def bytes_from_int(x: int) -> bytes: + return x.to_bytes(32, byteorder="big") + +def lift_x(b: bytes) -> Optional[Point]: + x = int_from_bytes(b) + if x >= p: + return None + y_sq = (pow(x, 3, p) + 7) % p + y = pow(y_sq, (p + 1) // 4, p) + if pow(y, 2, p) != y_sq: + return None + return (x, y if y & 1 == 0 else p-y) + +def int_from_bytes(b: bytes) -> int: + return int.from_bytes(b, byteorder="big") + +def has_even_y(P: Point) -> bool: + assert not is_infinite(P) + return y(P) % 2 == 0 + +# +# End of helper functions copied from BIP-340 reference implementation. +# + +PlainPk = NewType('PlainPk', bytes) +XonlyPk = NewType('XonlyPk', bytes) + +# There are two types of exceptions that can be raised by this implementation: +# - ValueError for indicating that an input doesn't conform to some function +# precondition (e.g. an input array is the wrong length, a serialized +# representation doesn't have the correct format). +# - InvalidContributionError for indicating that a signer (or the +# aggregator) is misbehaving in the protocol. +# +# Assertions are used to (1) satisfy the type-checking system, and (2) check for +# inconvenient events that can't happen except with negligible probability (e.g. +# output of a hash function is 0) and can't be manually triggered by any +# signer. + +# This exception is raised if a party (signer or nonce aggregator) sends invalid +# values. Actual implementations should not crash when receiving invalid +# contributions. Instead, they should hold the offending party accountable. +class InvalidContributionError(Exception): + def __init__(self, signer, contrib): + self.signer = signer + # contrib is one of "pubkey", "pubnonce", "aggnonce", or "psig". + self.contrib = contrib + +infinity = None + +def xbytes(P: Point) -> bytes: + return bytes_from_int(x(P)) + +def cbytes(P: Point) -> bytes: + a = b'\x02' if has_even_y(P) else b'\x03' + return a + xbytes(P) + +def point_negate(P: Optional[Point]) -> Optional[Point]: + if P is None: + return P + return (x(P), p - y(P)) + +def cpoint(x: bytes) -> Point: + if len(x) != 33: + raise ValueError('x is not a valid compressed point.') + P = lift_x(x[1:33]) + if P is None: + raise ValueError('x is not a valid compressed point.') + if x[0] == 2: + return P + elif x[0] == 3: + P = point_negate(P) + assert P is not None + return P + else: + raise ValueError('x is not a valid compressed point.') + +KeyAggContext = NamedTuple('KeyAggContext', [('Q', Point), + ('gacc', int), + ('tacc', int)]) + +def key_agg(pubkeys: List[PlainPk]) -> KeyAggContext: + pk2 = get_second_key(pubkeys) + u = len(pubkeys) + Q = infinity + for i in range(u): + try: + P_i = cpoint(pubkeys[i]) + except ValueError: + raise InvalidContributionError(i, "pubkey") + a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2) + Q = point_add(Q, point_mul(P_i, a_i)) + # Q is not the point at infinity except with negligible probability. + assert(Q is not None) + gacc = 1 + tacc = 0 + return KeyAggContext(Q, gacc, tacc) + +def hash_keys(pubkeys: List[PlainPk]) -> bytes: + return tagged_hash('KeyAgg list', b''.join(pubkeys)) + +def get_second_key(pubkeys: List[PlainPk]) -> PlainPk: + u = len(pubkeys) + for j in range(1, u): + if pubkeys[j] != pubkeys[0]: + return pubkeys[j] + return PlainPk(b'\x00'*33) + +def key_agg_coeff_internal(pubkeys: List[PlainPk], pk_: PlainPk, pk2: PlainPk) -> int: + L = hash_keys(pubkeys) + if pk_ == pk2: + return 1 + return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n diff --git a/bitcoin_client/ledger_bitcoin/client.py b/bitcoin_client/ledger_bitcoin/client.py index 351370320..94c22aed8 100644 --- a/bitcoin_client/ledger_bitcoin/client.py +++ b/bitcoin_client/ledger_bitcoin/client.py @@ -3,6 +3,7 @@ import base64 from io import BytesIO, BufferedReader +from .embit import base58 from .embit.base import EmbitError from .embit.descriptor import Descriptor from .embit.networks import NETWORKS @@ -17,9 +18,10 @@ from .merkle import get_merkleized_map_commitment from .wallet import WalletPolicy, WalletType from .psbt import PSBT, normalize_psbt -from . import segwit_addr from ._serialize import deser_string +from .bip0327 import key_agg, cbytes + def parse_stream_to_map(f: BufferedReader) -> Mapping[bytes, bytes]: result = {} @@ -39,6 +41,53 @@ def parse_stream_to_map(f: BufferedReader) -> Mapping[bytes, bytes]: return result +def aggr_xpub(pubkeys: List[bytes], chain: Chain) -> str: + BIP_MUSIG_CHAINCODE = bytes.fromhex( + "868087ca02a6f974c4598924c36b57762d32cb45717167e300622c7167e38965") + ctx = key_agg(pubkeys) + compressed_pubkey = cbytes(ctx.Q) + + # Serialize according to BIP-32 + if chain == Chain.MAIN: + version = 0x0488B21E + else: + version = 0x043587CF + + return base58.encode_check(b''.join([ + version.to_bytes(4, byteorder='big'), + b'\x00', # depth + b'\x00\x00\x00\x00', # parent fingerprint + b'\x00\x00\x00\x00', # child number + BIP_MUSIG_CHAINCODE, + compressed_pubkey + ])) + + +# Given a valid descriptor, replaces each musig() (if any) with the +# corresponding synthetic xpub/tpub. +def replace_musigs(desc: str, chain: Chain) -> str: + while True: + musig_start = desc.find("musig(") + if musig_start == -1: + break + musig_end = desc.find(")", musig_start) + if musig_end == -1: + raise ValueError("Invalid descriptor template") + + key_and_origs = desc[musig_start+6:musig_end].split(",") + pubkeys = [] + for key_orig in key_and_origs: + orig_end = key_orig.find("]") + xpub = key_orig if orig_end == -1 else key_orig[orig_end+1:] + pubkeys.append(base58.decode_check(xpub)[-33:]) + + # replace with the aggregate xpub + desc = desc[:musig_start] + \ + aggr_xpub(pubkeys, chain) + desc[musig_end+1:] + + return desc + + def _make_partial_signature(pubkey_augm: bytes, signature: bytes) -> PartialSignature: if len(pubkey_augm) == 64: # tapscript spend: pubkey_augm is the concatenation of: @@ -273,6 +322,11 @@ def sign_message(self, message: Union[str, bytes], bip32_path: str) -> str: def _derive_address_for_policy(self, wallet: WalletPolicy, change: bool, address_index: int) -> Optional[str]: desc_str = wallet.get_descriptor(change) + + # Since embit does not support musig() in descriptors, we replace each + # occurrence with the corresponding aggregated xpub + desc_str = replace_musigs(desc_str, self.chain) + try: desc = Descriptor.from_string(desc_str)