Skip to content

Commit

Permalink
Simples changes that are required to support taproot LN channels (#129)
Browse files Browse the repository at this point in the history
* Add method to serialize/deserialise taproot script trees

* Export function to create taproot sessions

* Export methods to compare x-only public keys

* Remove the script tree id field and use the BIP371 serialisation format

The id field was used in tests only, was not part of the tree hash and was not serialised.
We also remove our custom serialisation format and keep only the format defined in BIP371.
  • Loading branch information
sstone authored Aug 20, 2024
1 parent e52c575 commit de10997
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,15 @@ public object LexicographicalOrdering {
@JvmStatic
public fun isLessThan(a: PublicKey, b: PublicKey): Boolean = isLessThan(a.value, b.value)

@JvmStatic
public fun isLessThan(a: XonlyPublicKey, b: XonlyPublicKey): Boolean = isLessThan(a.value, b.value)

@JvmStatic
public fun compare(a: PublicKey, b: PublicKey): Int = if (a == b) 0 else if (isLessThan(a, b)) -1 else 1

@JvmStatic
public fun compare(a: XonlyPublicKey, b: XonlyPublicKey): Int = if (a == b) 0 else if (isLessThan(a, b)) -1 else 1

/**
* @param tx input transaction
* @return the input tx with inputs and outputs sorted in lexicographical order
Expand Down
2 changes: 1 addition & 1 deletion src/commonMain/kotlin/fr/acinq/bitcoin/Script.kt
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,7 @@ public object Script {
require((control.size() - 33) / 32 in 0..128) { "invalid control block size" }
val leafVersion = control[0].toInt() and TAPROOT_LEAF_MASK
val internalKey = XonlyPublicKey(control.slice(1, 33).toByteArray().byteVector32())
val tapleafHash = ScriptTree.Leaf(0, script, leafVersion).hash()
val tapleafHash = ScriptTree.Leaf(script, leafVersion).hash()
this.context.executionData = this.context.executionData.copy(tapleafHash = tapleafHash)

// split input buffer into 32 bytes chunks (input buffer size MUST be a multiple of 32 !!)
Expand Down
82 changes: 73 additions & 9 deletions src/commonMain/kotlin/fr/acinq/bitcoin/ScriptTree.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,50 @@
*/
package fr.acinq.bitcoin

import fr.acinq.bitcoin.io.ByteArrayInput
import fr.acinq.bitcoin.io.ByteArrayOutput
import fr.acinq.bitcoin.io.Input
import fr.acinq.bitcoin.io.Output
import kotlin.jvm.JvmStatic

/** Simple binary tree structure containing taproot spending scripts. */
public sealed class ScriptTree {
public abstract fun write(output: Output, level: Int): Unit

/**
* @return the tree serialised with the format defined in BIP 371
*/
public fun write(): ByteArray {
val output = ByteArrayOutput()
write(output, 0)
return output.toByteArray()
}

/**
* Multiple spending scripts can be placed in the leaves of a taproot tree. When using one of those scripts to spend
* funds, we only need to reveal that specific script and a merkle proof that it is a leaf of the tree.
*
* @param id id that isn't used in the hash, but can be used by the caller to reference specific scripts.
* @param script serialized spending script.
* @param leafVersion tapscript version.
*/
public data class Leaf(val id: Int, val script: ByteVector, val leafVersion: Int) : ScriptTree() {
public constructor(id: Int, script: List<ScriptElt>) : this(id, script, Script.TAPROOT_LEAF_TAPSCRIPT)
public constructor(id: Int, script: List<ScriptElt>, leafVersion: Int) : this(id, Script.write(script).byteVector(), leafVersion)
public data class Leaf(val script: ByteVector, val leafVersion: Int) : ScriptTree() {
public constructor(script: List<ScriptElt>) : this(script, Script.TAPROOT_LEAF_TAPSCRIPT)
public constructor(script: List<ScriptElt>, leafVersion: Int) : this(Script.write(script).byteVector(), leafVersion)
public constructor(script: String, leafVersion: Int) : this(ByteVector.fromHex(script), leafVersion)

override fun write(output: Output, level: Int): Unit {
output.write(level)
output.write(leafVersion)
BtcSerializer.writeScript(script, output)
}
}

public data class Branch(val left: ScriptTree, val right: ScriptTree) : ScriptTree()
public data class Branch(val left: ScriptTree, val right: ScriptTree) : ScriptTree() {
override fun write(output: Output, level: Int): Unit {
left.write(output, level + 1)
right.write(output, level + 1)
}
}

/** Compute the merkle root of the script tree. */
public fun hash(): ByteVector32 = when (this) {
Expand All @@ -42,6 +68,7 @@ public sealed class ScriptTree {
BtcSerializer.writeScript(this.script, buffer)
Crypto.taggedHash(buffer.toByteArray(), "TapLeaf")
}

is Branch -> {
val h1 = this.left.hash()
val h2 = this.right.hash()
Expand All @@ -50,10 +77,10 @@ public sealed class ScriptTree {
}
}

/** Return the first script leaf with the corresponding id, if any. */
public fun findScript(id: Int): Leaf? = when (this) {
is Leaf -> if (this.id == id) this else null
is Branch -> this.left.findScript(id) ?: this.right.findScript(id)
/** Return the first leaf with a matching script, if any. */
public fun findScript(script: ByteVector): Leaf? = when (this) {
is Leaf -> if (this.script == script) this else null
is Branch -> this.left.findScript(script) ?: this.right.findScript(script)
}

/**
Expand All @@ -68,4 +95,41 @@ public sealed class ScriptTree {
}
return loop(this, ByteArray(0))
}

public companion object {
private fun readLeaves(input: Input): ArrayList<Pair<Int, ScriptTree>> {
val leaves = arrayListOf<Pair<Int, ScriptTree>>()
while (input.availableBytes > 0) {
val depth = input.read()
val leafVersion = input.read()
val script = BtcSerializer.script(input).byteVector()
leaves.add(Pair(depth, Leaf(script, leafVersion)))
}
return leaves
}

private fun merge(nodes: ArrayList<Pair<Int, ScriptTree>>) {
if (nodes.size > 1) {
var i = 0
while (i < nodes.size - 1) {
if (nodes[i].first == nodes[i + 1].first) {
// merge 2 consecutive nodes that are on the same level
val node = Pair(nodes[i].first - 1, Branch(nodes[i].second, nodes[i + 1].second))
nodes[i] = node
nodes.removeAt(i + 1)
// and start again from the beginning (the node at the bottom-left of the tree)
i = 0
} else i++
}
}
}

@JvmStatic
public fun read(input: Input): ScriptTree {
val leaves = readLeaves(input)
merge(leaves)
require(leaves.size == 1) { "invalid serialised script tree" }
return leaves[0].second
}
}
}
13 changes: 12 additions & 1 deletion src/commonMain/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2.kt
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,18 @@ public object Musig2 {
return SecretNonce.generate(sessionId, privateKey, privateKey.publicKey(), message = null, keyAggCache, extraInput = null)
}

private fun taprootSession(tx: Transaction, inputIndex: Int, inputs: List<TxOut>, publicKeys: List<PublicKey>, publicNonces: List<IndividualNonce>, scriptTree: ScriptTree?): Either<Throwable, Session> {
/**
* Create a musig2 session for a given transaction input.
*
* @param tx transaction
* @param inputIndex transaction input index
* @param inputs outputs spent by this transaction
* @param publicKeys signers' public keys
* @param publicNonces signers' public nonces
* @param scriptTree tapscript tree of the transaction's input, if it has script paths.
*/
@JvmStatic
public fun taprootSession(tx: Transaction, inputIndex: Int, inputs: List<TxOut>, publicKeys: List<PublicKey>, publicNonces: List<IndividualNonce>, scriptTree: ScriptTree?): Either<Throwable, Session> {
return IndividualNonce.aggregate(publicNonces).flatMap { aggregateNonce ->
val (aggregatePublicKey, keyAggCache) = KeyAggCache.create(publicKeys)
val tweak = when (scriptTree) {
Expand Down
43 changes: 41 additions & 2 deletions src/commonTest/kotlin/fr/acinq/bitcoin/TaprootTestsCommon.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package fr.acinq.bitcoin
import fr.acinq.bitcoin.Bech32.hrp
import fr.acinq.bitcoin.Bitcoin.addressToPublicKeyScript
import fr.acinq.bitcoin.Transaction.Companion.hashForSigningSchnorr
import fr.acinq.bitcoin.io.ByteArrayInput
import fr.acinq.bitcoin.io.ByteArrayOutput
import fr.acinq.bitcoin.reference.TransactionTestsCommon.Companion.resourcesDir
import fr.acinq.secp256k1.Hex
import fr.acinq.secp256k1.Secp256k1
Expand Down Expand Up @@ -175,7 +177,7 @@ class TaprootTestsCommon {
)

// simple script tree with a single element
val scriptTree = ScriptTree.Leaf(0, script)
val scriptTree = ScriptTree.Leaf(script)
// we choose a pubkey that does not have a corresponding private key: our funding tx can only be spent through the script path, not the key path
val internalPubkey = PublicKey.fromHex("0250929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0").xOnly()

Expand Down Expand Up @@ -221,7 +223,7 @@ class TaprootTestsCommon {
PrivateKey(ByteVector32("0101010101010101010101010101010101010101010101010101010101010103"))
)
val scripts = privs.map { listOf(OP_PUSHDATA(it.publicKey().xOnly().value), OP_CHECKSIG) }
val leaves = scripts.mapIndexed { idx, script -> ScriptTree.Leaf(idx, script) }
val leaves = scripts.map { ScriptTree.Leaf(it) }
// root
// / \
// / \ #3
Expand Down Expand Up @@ -420,4 +422,41 @@ class TaprootTestsCommon {
val serializedTx = Transaction.write(tx)
assertContentEquals(buffer, serializedTx)
}

@Test
fun `serialize script tree -- reference test`() {
val tree =
ScriptTree.read(ByteArrayInput(Hex.decode("02c02220736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02ac02c02220631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969ac01c0222044faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c4273ac")))
assertEquals(
ScriptTree.Branch(
ScriptTree.Branch(
ScriptTree.Leaf("20736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02ac", 0xc0),
ScriptTree.Leaf("20631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969ac", 0xc0),
),
ScriptTree.Leaf("2044faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c4273ac", 0xc0)
), tree
)
}

@Test
fun `serialize script trees`() {
val random = kotlin.random.Random.Default

fun randomLeaf(): ScriptTree.Leaf = ScriptTree.Leaf(random.nextBytes(random.nextInt(1, 8)).byteVector(), random.nextInt(255))

fun randomTree(maxLevel: Int): ScriptTree = when {
maxLevel == 0 -> randomLeaf()
random.nextBoolean() -> randomLeaf()
else -> ScriptTree.Branch(randomTree(maxLevel - 1), randomTree(maxLevel - 1))
}

fun serde(input: ScriptTree): ScriptTree {
return ScriptTree.read(ByteArrayInput(input.write()))
}

(0 until 1000).forEach { _ ->
val tree = randomTree(10)
assertEquals(tree, serde(tree))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class Musig2TestsCommon {
// The redeem script is just the refund script, generated from this policy: and_v(v:pk(user),older(refundDelay))
// It does not depend upon the user's or server's key, just the user's refund key and the refund delay.
val redeemScript = listOf(OP_PUSHDATA(userRefundPrivateKey.xOnlyPublicKey()), OP_CHECKSIGVERIFY, OP_PUSHDATA(Script.encodeNumber(refundDelay)), OP_CHECKSEQUENCEVERIFY)
val scriptTree = ScriptTree.Leaf(0, redeemScript)
val scriptTree = ScriptTree.Leaf(redeemScript)

// The internal pubkey is the musig2 aggregation of the user's and server's public keys: it does not depend upon the user's refund's key.
val internalPubKey = Musig2.aggregateKeys(listOf(userPublicKey, serverPublicKey))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,21 @@ class BIP341TestsCommon {
// When control blocks are provided, recompute them for each script tree leaf and check that they match.
assertNotNull(scriptTree)
val controlBlocks = expected["scriptPathControlBlocks"]!!.jsonArray.map { ByteVector.fromHex(it.jsonPrimitive.content) }

fun loop(tree: ScriptTree, acc: ArrayList<ScriptTree.Leaf>) {
when (tree) {
is ScriptTree.Leaf -> acc.add(tree)
is ScriptTree.Branch -> {
loop(tree.left, acc)
loop(tree.right, acc)
}
}
}
// traverse the tree from left to right and top to bottom, this is the order that is used in the reference tests
val leaves = ArrayList<ScriptTree.Leaf>()
loop(scriptTree, leaves)
controlBlocks.forEachIndexed { index, expectedControlBlock ->
val scriptLeaf = scriptTree.findScript(index)
val scriptLeaf = leaves[index]
assertNotNull(scriptLeaf)
val computedControlBlock = Script.ControlBlock.build(internalPubkey, scriptTree, scriptLeaf)
assertEquals(expectedControlBlock, computedControlBlock)
Expand All @@ -121,14 +134,13 @@ class BIP341TestsCommon {
else -> ByteVector32(input)
}

private fun scriptLeafFromJson(json: JsonElement): ScriptTree.Leaf = ScriptTree.Leaf(
id = json.jsonObject["id"]!!.jsonPrimitive.int,
private fun scriptLeafFromJson(json: JsonElement): Pair<Int, ScriptTree.Leaf> = json.jsonObject["id"]!!.jsonPrimitive.int to ScriptTree.Leaf(
script = ByteVector.fromHex(json.jsonObject["script"]!!.jsonPrimitive.content),
leafVersion = json.jsonObject["leafVersion"]!!.jsonPrimitive.int
)

fun scriptTreeFromJson(json: JsonElement): ScriptTree = when (json) {
is JsonObject -> scriptLeafFromJson(json)
is JsonObject -> scriptLeafFromJson(json).second
is JsonArray -> {
require(json.size == 2) { "script tree must contain exactly two branches: $json" }
ScriptTree.Branch(scriptTreeFromJson(json[0]), scriptTreeFromJson(json[1]))
Expand Down

0 comments on commit de10997

Please sign in to comment.