Skip to content

Commit

Permalink
Don't return Either when not necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
sstone committed Jan 31, 2024
1 parent eb10a83 commit bc6e398
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 45 deletions.
26 changes: 10 additions & 16 deletions src/commonMain/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import fr.acinq.bitcoin.*
import fr.acinq.bitcoin.utils.Either
import fr.acinq.secp256k1.Hex
import fr.acinq.secp256k1.Secp256k1
import kotlin.jvm.JvmOverloads
import kotlin.jvm.JvmStatic

/**
Expand Down Expand Up @@ -43,12 +44,11 @@ public data class KeyAggCache(val data: ByteVector) {
* @return a new (if cache was null) or updated cache, and the aggregated public key
*/
@JvmStatic
public fun add(pubkeys: List<PublicKey>, cache: KeyAggCache? = null): Either<Throwable, Pair<XonlyPublicKey, KeyAggCache>> = try {
@JvmOverloads
public fun add(pubkeys: List<PublicKey>, cache: KeyAggCache? = null): Pair<XonlyPublicKey, KeyAggCache> {
val localCache = cache?.data?.toByteArray() ?: ByteArray(Secp256k1.MUSIG2_PUBLIC_KEYAGG_CACHE_SIZE)
val aggkey = Secp256k1.musigPubkeyAgg(pubkeys.map { it.value.toByteArray() }.toTypedArray(), localCache)
Either.Right(Pair(XonlyPublicKey(aggkey.byteVector32()), KeyAggCache(localCache.byteVector())))
} catch (t: Throwable) {
Either.Left(t)
return Pair(XonlyPublicKey(aggkey.byteVector32()), KeyAggCache(localCache.byteVector()))
}
}
}
Expand All @@ -69,10 +69,8 @@ public data class Session(val data: ByteVector) {
* @param aggCache key aggregation cache
* @return a Musig2 partial signature
*/
public fun sign(secretNonce: SecretNonce, pk: PrivateKey, aggCache: KeyAggCache): Either<Throwable, ByteVector32> = try {
Either.Right(Secp256k1.musigPartialSign(secretNonce.data.toByteArray(), pk.value.toByteArray(), aggCache.data.toByteArray(), toByteArray()).byteVector32())
} catch (t: Throwable) {
Either.Left(t)
public fun sign(secretNonce: SecretNonce, pk: PrivateKey, aggCache: KeyAggCache): ByteVector32 {
return Secp256k1.musigPartialSign(secretNonce.data.toByteArray(), pk.value.toByteArray(), aggCache.data.toByteArray(), toByteArray()).byteVector32()
}

/**
Expand Down Expand Up @@ -107,11 +105,9 @@ public data class Session(val data: ByteVector) {
* @return a Musig signing session
*/
@JvmStatic
public fun build(aggregatedNonce: AggregatedNonce, msg: ByteVector32, cache: KeyAggCache): Either<Throwable, Session> = try {
public fun build(aggregatedNonce: AggregatedNonce, msg: ByteVector32, cache: KeyAggCache): Session {
val session = Secp256k1.musigNonceProcess(aggregatedNonce.toByteArray(), msg.toByteArray(), cache.data.toByteArray())
Either.Right(Session(session.byteVector()))
} catch (t: Throwable) {
Either.Left(t)
return Session(session.byteVector())
}
}
}
Expand Down Expand Up @@ -139,13 +135,11 @@ public data class SecretNonce(val data: ByteVector) {
* @return a (secret nonce, public nonce) tuple
*/
@JvmStatic
public fun generate(sessionId: ByteVector32, seckey: PrivateKey?, pubkey: PublicKey, msg: ByteVector32?, cache: KeyAggCache?, extraInput: ByteVector32?): Either<Throwable, Pair<SecretNonce, IndividualNonce>> = try {
public fun generate(sessionId: ByteVector32, seckey: PrivateKey?, pubkey: PublicKey, msg: ByteVector32?, cache: KeyAggCache?, extraInput: ByteVector32?): Pair<SecretNonce, IndividualNonce> {
val nonce = Secp256k1.musigNonceGen(sessionId.toByteArray(), seckey?.value?.toByteArray(), pubkey.value.toByteArray(), msg?.toByteArray(), cache?.data?.toByteArray(), extraInput?.toByteArray())
val secretNonce = SecretNonce(nonce.copyOfRange(0, Secp256k1.MUSIG2_SECRET_NONCE_SIZE))
val publicNonce = IndividualNonce(nonce.copyOfRange(Secp256k1.MUSIG2_SECRET_NONCE_SIZE, Secp256k1.MUSIG2_SECRET_NONCE_SIZE + Secp256k1.MUSIG2_PUBLIC_NONCE_SIZE))
Either.Right(Pair(secretNonce, publicNonce))
} catch (t: Throwable) {
Either.Left(t)
return Pair(secretNonce, publicNonce)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ class Musig2TestsCommon {
tests.jsonObject["valid_test_cases"]!!.jsonArray.forEach {
val keyIndices = it.jsonObject["key_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
val expected = XonlyPublicKey(ByteVector32.fromValidHex(it.jsonObject["expected"]!!.jsonPrimitive.content))
val (aggkey, _) = KeyAggCache.add(keyIndices.map { pubkeys[it] }).right!!
val (aggkey, _) = KeyAggCache.add(keyIndices.map { pubkeys[it] })
assertEquals(expected, aggkey)
}
tests.jsonObject["error_test_cases"]!!.jsonArray.forEach {
val keyIndices = it.jsonObject["key_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
val tweakIndices = it.jsonObject["tweak_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
val isXonly = it.jsonObject["is_xonly"]!!.jsonArray.map { it.jsonPrimitive.boolean }
assertFails {
var (_, cache) = KeyAggCache.add(keyIndices.map { pubkeys[it] }).right!!
var (_, cache) = KeyAggCache.add(keyIndices.map { pubkeys[it] })
tweakIndices.zip(isXonly).forEach { cache = cache.tweak(tweaks[it.first], it.second).right!!.first }
}
}
Expand All @@ -48,7 +48,7 @@ class Musig2TestsCommon {
//val expectedSecnonce = SecretNonce(it.jsonObject["expected_secnonce"]!!.jsonPrimitive.content)
val expectedPubnonce = IndividualNonce(it.jsonObject["expected_pubnonce"]!!.jsonPrimitive.content)
if (aggpk == null) {
val (_, pubnonce) = SecretNonce.generate(randprime, sk, pk, msg?.byteVector32(), null, extraInput?.byteVector32()).right!!
val (_, pubnonce) = SecretNonce.generate(randprime, sk, pk, msg?.byteVector32(), null, extraInput?.byteVector32())
// assertEquals(expectedSecnonce, secnonce)
assertEquals(expectedPubnonce, pubnonce)
}
Expand Down Expand Up @@ -92,13 +92,13 @@ class Musig2TestsCommon {
val isXonly = it.jsonObject["is_xonly"]!!.jsonArray.map { it.jsonPrimitive.boolean }
assertEquals(AggregatedNonce(it.jsonObject["aggnonce"]!!.jsonPrimitive.content), aggnonce)
val cache = run {
var (_, c) = KeyAggCache.add(keyIndices.map { pubkeys[it] }).right!!
var (_, c) = KeyAggCache.add(keyIndices.map { pubkeys[it] })
tweakIndices.zip(isXonly).map { tweaks[it.first] to it.second }.forEach { (tweak, isXonly) ->
c = c.tweak(tweak, isXonly).right!!.first
}
c
}
val session = Session.build(aggnonce, msg, cache).right!!
val session = Session.build(aggnonce, msg, cache)
val aggsig = session.add(psigIndices.map { psigs[it] }).right!!
assertEquals(expected, aggsig)
}
Expand All @@ -111,13 +111,13 @@ class Musig2TestsCommon {
val isXonly = it.jsonObject["is_xonly"]!!.jsonArray.map { it.jsonPrimitive.boolean }
assertEquals(AggregatedNonce(it.jsonObject["aggnonce"]!!.jsonPrimitive.content), aggnonce)
val cache = run {
var (_, c) = KeyAggCache.add(keyIndices.map { pubkeys[it] }).right!!
var (_, c) = KeyAggCache.add(keyIndices.map { pubkeys[it] })
tweakIndices.zip(isXonly).map { tweaks[it.first] to it.second }.forEach { (tweak, isXonly) ->
c = c.tweak(tweak, isXonly).right!!.first
}
c
}
val session = Session.build(aggnonce, msg, cache).right!!
val session = Session.build(aggnonce, msg, cache)
assertTrue {
session.add(psigIndices.map { psigs[it] }).isLeft
}
Expand All @@ -141,23 +141,23 @@ class Musig2TestsCommon {

val aggsig = run {
val nonces = privkeys.map {
SecretNonce.generate(random.nextBytes(32).byteVector32(), it, it.publicKey(), null, null, null).right!!
SecretNonce.generate(random.nextBytes(32).byteVector32(), it, it.publicKey(), null, null, null)
}
val secnonces = nonces.map { it.first }
val pubnonces = nonces.map { it.second }

// aggregate public nonces
val aggnonce = IndividualNonce.aggregate(pubnonces).right!!
val cache = run {
val (_, c) = KeyAggCache.add(pubkeys).right!!
val (_, c) = KeyAggCache.add(pubkeys)
val (c1, _) = c.tweak(plainTweak, false).right!!
val (c2, _) = c1.tweak(xonlyTweak, true).right!!
c2
}
val session = Session.build(aggnonce, msg, cache).right!!
val session = Session.build(aggnonce, msg, cache)
// create partial signatures
val psigs = privkeys.indices.map {
session.sign(secnonces[it], privkeys[it], cache).right!!
session.sign(secnonces[it], privkeys[it], cache)
}

// verify partial signatures
Expand All @@ -171,7 +171,7 @@ class Musig2TestsCommon {

// aggregate public keys
val aggpub = run {
val (_, c) = KeyAggCache.add(pubkeys).right!!
val (_, c) = KeyAggCache.add(pubkeys)
val (c1, _) = c.tweak(plainTweak, false).right!!
val (_, p) = c1.tweak(xonlyTweak, true).right!!
p
Expand All @@ -189,7 +189,7 @@ class Musig2TestsCommon {
val bobPubKey = bobPrivKey.publicKey()

// Alice and Bob exchange public keys and agree on a common aggregated key
val (internalPubKey, cache) = KeyAggCache.add(listOf(alicePubKey, bobPubKey)).right!!
val (internalPubKey, cache) = KeyAggCache.add(listOf(alicePubKey, bobPubKey))
// we use the standard BIP86 tweak
val commonPubKey = internalPubKey.outputKey(Crypto.TaprootTweak.NoScriptTweak).first

Expand All @@ -201,17 +201,17 @@ class Musig2TestsCommon {

val commonSig = run {
val random = Random.Default
val aliceNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), alicePrivKey, alicePubKey, null, cache, null).right!!
val bobNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), bobPrivKey, bobPubKey, null, null, null).right!!
val aliceNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), alicePrivKey, alicePubKey, null, cache, null)
val bobNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), bobPrivKey, bobPubKey, null, null, null)

val aggnonce = IndividualNonce.aggregate(listOf(aliceNonce.second, bobNonce.second)).right!!
val msg = Transaction.hashForSigningSchnorr(spendingTx, 0, listOf(tx.txOut[0]), SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPROOT)

// we use the same ctx for Alice and Bob, they both know all the public keys that are used here
val (cache1, _) = cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.NoScriptTweak), true).right!!
val session = Session.build(aggnonce, msg, cache1).right!!
val aliceSig = session.sign(aliceNonce.first, alicePrivKey, cache1).right!!
val bobSig = session.sign(bobNonce.first, bobPrivKey, cache1).right!!
val session = Session.build(aggnonce, msg, cache1)
val aliceSig = session.sign(aliceNonce.first, alicePrivKey, cache1)
val bobSig = session.sign(bobNonce.first, bobPrivKey, cache1)
session.add(listOf(aliceSig, bobSig)).right!!
}

Expand All @@ -236,7 +236,7 @@ class Musig2TestsCommon {
val merkleRoot = scriptTree.hash()

// the internal pubkey is the musig2 aggregation of the user's and server's public keys: it does not depend upon the user's refund's key
val (internalPubKey, cache) = KeyAggCache.add(listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey())).right!!
val (internalPubKey, cache) = KeyAggCache.add(listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey()))

// it is tweaked with the script's merkle root to get the pubkey that will be exposed
val pubkeyScript: List<ScriptElt> = Script.pay2tr(internalPubKey, merkleRoot)
Expand All @@ -258,25 +258,22 @@ class Musig2TestsCommon {
)
// this is the beginning of an interactive musig2 signing session. if user and server are disconnected before they have exchanged partial
// signatures they will have to start again with fresh nonces
val userNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), userPrivateKey, userPrivateKey.publicKey(), null, cache, null).right!!
val serverNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), serverPrivateKey, serverPrivateKey.publicKey(), null, cache, null).right!!
val userNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), userPrivateKey, userPrivateKey.publicKey(), null, cache, null)
val serverNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), serverPrivateKey, serverPrivateKey.publicKey(), null, cache, null)

val txHash = Transaction.hashForSigningSchnorr(tx, 0, swapInTx.txOut, SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPROOT)

val commonSig = IndividualNonce.aggregate(listOf(userNonce.second, serverNonce.second))
.flatMap { commonNonce ->
cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(merkleRoot)), true)
.flatMap { (cache1, _) ->
Session.build(commonNonce, txHash, cache1)
.flatMap { session ->
session.sign(userNonce.first, userPrivateKey, cache1)
.flatMap { userSig ->
session.sign(serverNonce.first, serverPrivateKey, cache1)
.flatMap { serverSig -> session.add(listOf(userSig, serverSig)) }
}
}
val session = Session.build(commonNonce, txHash, cache1)
val userSig = session.sign(userNonce.first, userPrivateKey, cache1)
val serverSig = session.sign(serverNonce.first, serverPrivateKey, cache1)
session.add(listOf(userSig, serverSig))
}
}

val signedTx = tx.updateWitness(0, ScriptWitness(listOf(commonSig.right!!)))
Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
}
Expand Down

0 comments on commit bc6e398

Please sign in to comment.