Skip to content

Commit

Permalink
Remove the script tree id field and use the BIP371 serialisation format
Browse files Browse the repository at this point in the history
The id field was used in tests only, was not part of the tree hash and was not serialised.
We also remove our custom serialisation format and keep only the format defined in BIP371.
  • Loading branch information
sstone committed Aug 20, 2024
1 parent 40edfd2 commit ecda613
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 84 deletions.
2 changes: 1 addition & 1 deletion src/commonMain/kotlin/fr/acinq/bitcoin/Script.kt
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,7 @@ public object Script {
require((control.size() - 33) / 32 in 0..128) { "invalid control block size" }
val leafVersion = control[0].toInt() and TAPROOT_LEAF_MASK
val internalKey = XonlyPublicKey(control.slice(1, 33).toByteArray().byteVector32())
val tapleafHash = ScriptTree.Leaf(0, script, leafVersion).hash()
val tapleafHash = ScriptTree.Leaf(script, leafVersion).hash()
this.context.executionData = this.context.executionData.copy(tapleafHash = tapleafHash)

// split input buffer into 32 bytes chunks (input buffer size MUST be a multiple of 32 !!)
Expand Down
95 changes: 28 additions & 67 deletions src/commonMain/kotlin/fr/acinq/bitcoin/ScriptTree.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,47 +23,30 @@ import kotlin.jvm.JvmStatic

/** Simple binary tree structure containing taproot spending scripts. */
public sealed class ScriptTree {
// our own tree-based binary format
public abstract fun write(output: Output): Output
public abstract fun write(output: Output, level: Int): Output

/**
* @return the tree serialised with the format defined in BIP 371
*/
public fun write(): ByteArray {
val output = ByteArrayOutput()
write(output)
return output.toByteArray()
}

// BIP373 binary format
public abstract fun writeForPSbt(output: Output, level: Int): Output

public fun writeForPsbt(): ByteArray {
val output = ByteArrayOutput()
writeForPSbt(output, 0)
write(output, 0)
return output.toByteArray()
}

/**
* Multiple spending scripts can be placed in the leaves of a taproot tree. When using one of those scripts to spend
* funds, we only need to reveal that specific script and a merkle proof that it is a leaf of the tree.
*
* @param id id that isn't used in the hash, but can be used by the caller to reference specific scripts.
* @param script serialized spending script.
* @param leafVersion tapscript version.
*/
public data class Leaf(val id: Int, val script: ByteVector, val leafVersion: Int) : ScriptTree() {
public constructor(id: Int, script: List<ScriptElt>) : this(id, script, Script.TAPROOT_LEAF_TAPSCRIPT)
public constructor(id: Int, script: List<ScriptElt>, leafVersion: Int) : this(id, Script.write(script).byteVector(), leafVersion)
public constructor(id: Int, script: String, leafVersion: Int) : this(id, ByteVector.fromHex(script), leafVersion)

public override fun write(output: Output): Output {
output.write(0)
BtcSerializer.writeVarint(id, output)
BtcSerializer.writeScript(script, output)
output.write(leafVersion)
return output
}
public data class Leaf(val script: ByteVector, val leafVersion: Int) : ScriptTree() {
public constructor(script: List<ScriptElt>) : this(script, Script.TAPROOT_LEAF_TAPSCRIPT)
public constructor(script: List<ScriptElt>, leafVersion: Int) : this(Script.write(script).byteVector(), leafVersion)
public constructor(script: String, leafVersion: Int) : this(ByteVector.fromHex(script), leafVersion)

override fun writeForPSbt(output: Output, level: Int): Output {
// id is not persisted
override fun write(output: Output, level: Int): Output {
output.write(level)
output.write(leafVersion)
BtcSerializer.writeScript(script, output)
Expand All @@ -72,16 +55,9 @@ public sealed class ScriptTree {
}

public data class Branch(val left: ScriptTree, val right: ScriptTree) : ScriptTree() {
public override fun write(output: Output): Output {
output.write(1)
left.write(output)
right.write(output)
return output
}

override fun writeForPSbt(output: Output, level: Int): Output {
left.writeForPSbt(output, level + 1)
right.writeForPSbt(output, level + 1)
override fun write(output: Output, level: Int): Output {
left.write(output, level + 1)
right.write(output, level + 1)
return output
}

Expand All @@ -104,10 +80,10 @@ public sealed class ScriptTree {
}
}

/** Return the first script leaf with the corresponding id, if any. */
public fun findScript(id: Int): Leaf? = when (this) {
is Leaf -> if (this.id == id) this else null
is Branch -> this.left.findScript(id) ?: this.right.findScript(id)
/** Return the first leaf with a matching script, if any. */
public fun findScript(script: ByteVector): Leaf? = when (this) {
is Leaf -> if (this.script == script) this else null
is Branch -> this.left.findScript(script) ?: this.right.findScript(script)
}

/**
Expand All @@ -124,54 +100,39 @@ public sealed class ScriptTree {
}

public companion object {
@JvmStatic
public fun read(input: Input): ScriptTree = when (val tag = input.read()) {
0 -> Leaf(BtcSerializer.varint(input).toInt(), BtcSerializer.script(input).byteVector(), input.read())
1 -> Branch(read(input), read(input))
else -> error("cannot deserialize script tree: invalid tag $tag")
}

@JvmStatic
public fun read(input: ByteArray): ScriptTree = read(ByteArrayInput(input))

internal fun readLeaves(input: Input, setIdToZero: Boolean = true): ArrayList<Pair<Int, ScriptTree>> {
private fun readLeaves(input: Input): ArrayList<Pair<Int, ScriptTree>> {
val leaves = arrayListOf<Pair<Int, ScriptTree>>()
var id = 0
while (input.availableBytes > 0) {
val depth = input.read()
val leafVersion = input.read()
val script = BtcSerializer.script(input).byteVector()
leaves.add(Pair(depth, Leaf(if (setIdToZero) 0 else id++, script, leafVersion)))
leaves.add(Pair(depth, Leaf(script, leafVersion)))
}
return leaves
}

internal fun merge(nodes: ArrayList<Pair<Int, ScriptTree>>): Boolean {
private fun merge(nodes: ArrayList<Pair<Int, ScriptTree>>) {
if (nodes.size > 1) {
var i = 0
while (i < nodes.size - 1) {
if (nodes[i].first == nodes[i + 1].first) {
// merge 2 consecutive nodes that are on the same level
val node = Pair(nodes[i].first - 1, Branch(nodes[i].second, nodes[i + 1].second))
nodes[i] = node
nodes.removeAt(i + 1)
return true
// and start again from the beginning (the node at the bottom-left of the tree)
i = 0
} else i++
}
}
return false
}

@JvmStatic
public fun readFromPsbt(input: Input, setIdToZero: Boolean = true): ScriptTree {
val leaves = readLeaves(input, setIdToZero)
while (merge(leaves)) {
// keep on merging
}
return when (leaves.size) {
1 -> leaves[0].second
2 -> Branch(leaves[0].second, leaves[1].second)
else -> error("cannot merge $leaves")
}
public fun read(input: Input): ScriptTree {
val leaves = readLeaves(input)
merge(leaves)
require(leaves.size == 1) { "invalid serialised script tree" }
return leaves[0].second
}
}
}
30 changes: 19 additions & 11 deletions src/commonTest/kotlin/fr/acinq/bitcoin/TaprootTestsCommon.kt
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class TaprootTestsCommon {
)

// simple script tree with a single element
val scriptTree = ScriptTree.Leaf(0, script)
val scriptTree = ScriptTree.Leaf(script)
// we choose a pubkey that does not have a corresponding private key: our funding tx can only be spent through the script path, not the key path
val internalPubkey = PublicKey.fromHex("0250929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0").xOnly()

Expand Down Expand Up @@ -223,7 +223,7 @@ class TaprootTestsCommon {
PrivateKey(ByteVector32("0101010101010101010101010101010101010101010101010101010101010103"))
)
val scripts = privs.map { listOf(OP_PUSHDATA(it.publicKey().xOnly().value), OP_CHECKSIG) }
val leaves = scripts.mapIndexed { idx, script -> ScriptTree.Leaf(idx, script) }
val leaves = scripts.map { ScriptTree.Leaf(it) }
// root
// / \
// / \ #3
Expand Down Expand Up @@ -423,11 +423,26 @@ class TaprootTestsCommon {
assertContentEquals(buffer, serializedTx)
}

@Test
fun `serialize script tree -- reference test`() {
val tree =
ScriptTree.read(ByteArrayInput(Hex.decode("02c02220736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02ac02c02220631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969ac01c0222044faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c4273ac")))
assertEquals(
ScriptTree.Branch(
ScriptTree.Branch(
ScriptTree.Leaf("20736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02ac", 0xc0),
ScriptTree.Leaf("20631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969ac", 0xc0),
),
ScriptTree.Leaf("2044faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c4273ac", 0xc0)
), tree
)
}

@Test
fun `serialize script trees`() {
val random = kotlin.random.Random.Default

fun randomLeaf(): ScriptTree.Leaf = ScriptTree.Leaf(0, random.nextBytes(random.nextInt(1, 8)).byteVector(), random.nextInt(255))
fun randomLeaf(): ScriptTree.Leaf = ScriptTree.Leaf(random.nextBytes(random.nextInt(1, 8)).byteVector(), random.nextInt(255))

fun randomTree(maxLevel: Int): ScriptTree = when {
maxLevel == 0 -> randomLeaf()
Expand All @@ -437,20 +452,13 @@ class TaprootTestsCommon {

fun serde(input: ScriptTree): ScriptTree {
val output = ByteArrayOutput()
input.write(output)
input.write(output, 0)
return ScriptTree.read(ByteArrayInput(output.toByteArray()))
}

fun serdePsbt(input: ScriptTree): ScriptTree {
val output = ByteArrayOutput()
input.writeForPSbt(output, 0)
return ScriptTree.readFromPsbt(ByteArrayInput(output.toByteArray()))
}

(0 until 1000).forEach { _ ->
val tree = randomTree(10)
assertEquals(tree, serde(tree))
assertEquals(tree, serdePsbt(tree))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class Musig2TestsCommon {
// The redeem script is just the refund script, generated from this policy: and_v(v:pk(user),older(refundDelay))
// It does not depend upon the user's or server's key, just the user's refund key and the refund delay.
val redeemScript = listOf(OP_PUSHDATA(userRefundPrivateKey.xOnlyPublicKey()), OP_CHECKSIGVERIFY, OP_PUSHDATA(Script.encodeNumber(refundDelay)), OP_CHECKSEQUENCEVERIFY)
val scriptTree = ScriptTree.Leaf(0, redeemScript)
val scriptTree = ScriptTree.Leaf(redeemScript)

// The internal pubkey is the musig2 aggregation of the user's and server's public keys: it does not depend upon the user's refund's key.
val internalPubKey = Musig2.aggregateKeys(listOf(userPublicKey, serverPublicKey))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,15 @@ class PsbtTestsCommon {

@Test
fun `update keypaths`() {

println(
Psbt.read(
Hex.decode(
"70736274ff01005e020000000127744ababf3027fe0d6cf23a96eee2efb188ef52301954585883e69b6624b2420000000000ffffffff0148e6052a010000002251200a8cbdc86de1ce1c0f9caeb22d6df7ced3683fe423e05d1e402a879341d6f6f5000000000001012b00f2052a010000002251205a2c2cf5b52cf31f83ad2e8da63ff03183ecd8f609c7510ae8a48e03910a07572116fe349064c98d6e2a853fa3c9b12bd8b304a19c195c60efa7ee2393046d3fa2321900772b2da75600008001000080000000800100000000000000011720fe349064c98d6e2a853fa3c9b12bd8b304a19c195c60efa7ee2393046d3fa2320001052050929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac001066f02c02220736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02ac02c02220631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969ac01c0222044faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c4273ac210744faa49a0338de488c8dfffecdfb6f329f380bd566ef20c8df6d813eab1c42733901f06b798b92a10ed9a9d0bbfd3af173a53b1617da3a4159ca008216cd856b2e0e772b2da75600008001000080010000800000000003000000210750929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac005007c461e5d2107631c5f3b5832b8fbdebfb19704ceeb323c21f40f7a24f43d68ef0cc26b125969390118ace409889785e0ea70ceebb8e1ca892a7a78eaede0f2e296cf435961a8f4ca772b2da756000080010000800200008000000000030000002107736e572900fe1252589a2143c8f3c79f71a0412d2353af755e9701c782694a02390129a5b4915090162d759afd3fe0f93fa3326056d0b4088cb933cae7826cb8d82c772b2da7560000800100008003000080000000000300000000"
)
)
)

val priv1 = PrivateKey.fromHex("0101010101010101010101010101010101010101010101010101010101010101")
val priv2 = PrivateKey.fromHex("0202020202020202020202020202020202020202020202020202020202020202")
val utxo1 = Transaction(2, listOf(), listOf(TxOut(Satoshi(15_000_000L), Script.pay2wpkh(priv1.publicKey()))), 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,21 @@ class BIP341TestsCommon {
// When control blocks are provided, recompute them for each script tree leaf and check that they match.
assertNotNull(scriptTree)
val controlBlocks = expected["scriptPathControlBlocks"]!!.jsonArray.map { ByteVector.fromHex(it.jsonPrimitive.content) }

fun loop(tree: ScriptTree, acc: ArrayList<ScriptTree.Leaf>) {
when (tree) {
is ScriptTree.Leaf -> acc.add(tree)
is ScriptTree.Branch -> {
loop(tree.left, acc)
loop(tree.right, acc)
}
}
}
// traverse the tree from left to right and top to bottom, this is the order that is used in the reference tests
val leaves = ArrayList<ScriptTree.Leaf>()
loop(scriptTree, leaves)
controlBlocks.forEachIndexed { index, expectedControlBlock ->
val scriptLeaf = scriptTree.findScript(index)
val scriptLeaf = leaves[index]
assertNotNull(scriptLeaf)
val computedControlBlock = Script.ControlBlock.build(internalPubkey, scriptTree, scriptLeaf)
assertEquals(expectedControlBlock, computedControlBlock)
Expand All @@ -121,14 +134,13 @@ class BIP341TestsCommon {
else -> ByteVector32(input)
}

private fun scriptLeafFromJson(json: JsonElement): ScriptTree.Leaf = ScriptTree.Leaf(
id = json.jsonObject["id"]!!.jsonPrimitive.int,
private fun scriptLeafFromJson(json: JsonElement): Pair<Int, ScriptTree.Leaf> = json.jsonObject["id"]!!.jsonPrimitive.int to ScriptTree.Leaf(
script = ByteVector.fromHex(json.jsonObject["script"]!!.jsonPrimitive.content),
leafVersion = json.jsonObject["leafVersion"]!!.jsonPrimitive.int
)

fun scriptTreeFromJson(json: JsonElement): ScriptTree = when (json) {
is JsonObject -> scriptLeafFromJson(json)
is JsonObject -> scriptLeafFromJson(json).second
is JsonArray -> {
require(json.size == 2) { "script tree must contain exactly two branches: $json" }
ScriptTree.Branch(scriptTreeFromJson(json[0]), scriptTreeFromJson(json[1]))
Expand Down

0 comments on commit ecda613

Please sign in to comment.