diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index dd6f088..98dbbeb 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -6,12 +6,14 @@ kotlin = "1.9.0" kotlinCoroutines = "1.6.4" kotlinxDateTime = "0.4.0" kotlinSerializationJson = "1.4.1" +kotlinReflect = "2.0.0" ktlint = "11.3.1" ktor = "2.2.3" mavenPublish = "0.25.2" mockitoCore = "5.5.0" taskTree = "2.1.1" junit = "4.13.2" +bitcoinj-core = "0.16.3" [libraries] gradleClasspath-dokka = { module = "org.jetbrains.dokka:dokka-gradle-plugin", version.ref = "dokka" } @@ -19,6 +21,8 @@ gradleClasspath-ktlint = { module = "org.jlleitschuh.gradle:ktlint-gradle", vers gradleClasspath-kotlin = { module = "org.jetbrains.kotlin:kotlin-gradle-plugin", version.ref = "kotlin" } gradleClasspath-mavenPublish = { module = "com.vanniktech:gradle-maven-publish-plugin", version.ref = "mavenPublish" } +bitcoin-core = { module = "org.bitcoinj:bitcoinj-core", version.ref = "bitcoinj-core" } + task-tree = { module = "com.dorongold.plugins:task-tree", version.ref = "taskTree" } kotlin-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinSerializationJson" } diff --git a/uma-sdk/build.gradle.kts b/uma-sdk/build.gradle.kts index 08f372e..efd4677 100644 --- a/uma-sdk/build.gradle.kts +++ b/uma-sdk/build.gradle.kts @@ -27,6 +27,7 @@ kotlin { implementation(libs.kotlinx.coroutines.core) implementation(libs.ktor.client.core) implementation(libs.jna) + implementation(libs.bitcoin.core) } } val commonTest by getting { diff --git a/uma-sdk/src/commonMain/kotlin/me/uma/protocol/CounterPartyData.kt b/uma-sdk/src/commonMain/kotlin/me/uma/protocol/CounterPartyData.kt index f088d65..d9b6c93 100644 --- a/uma-sdk/src/commonMain/kotlin/me/uma/protocol/CounterPartyData.kt +++ b/uma-sdk/src/commonMain/kotlin/me/uma/protocol/CounterPartyData.kt @@ -13,31 +13,6 @@ data class CounterPartyDataOption( typealias CounterPartyDataOptions = Map -data class InvoiceCounterPartyDataOptions( - val options: CounterPartyDataOptions -) : ByteCodeable { - override fun toBytes(): ByteArray { - val optionsString = options.map { (key, option) -> - "${key}:${ if (option.mandatory) 1 else 0}" - }.joinToString(",") - return optionsString.toByteArray(Charsets.UTF_8) - } - - companion object { - fun fromBytes(bytes: ByteArray): InvoiceCounterPartyDataOptions { - val optionsString = String(bytes) - return InvoiceCounterPartyDataOptions( - optionsString.split(",").mapNotNull { - val options = it.split(':') - if (options.size == 2) { - options[0] to CounterPartyDataOption(options[1] == "1") - } else null - }.toMap() - ) - } - } -} - fun createCounterPartyDataOptions(map: Map): CounterPartyDataOptions { return map.mapValues { CounterPartyDataOption(it.value) } } diff --git a/uma-sdk/src/commonMain/kotlin/me/uma/protocol/Invoice.kt b/uma-sdk/src/commonMain/kotlin/me/uma/protocol/Invoice.kt index 3edf182..fa65432 100644 --- a/uma-sdk/src/commonMain/kotlin/me/uma/protocol/Invoice.kt +++ b/uma-sdk/src/commonMain/kotlin/me/uma/protocol/Invoice.kt @@ -1,6 +1,9 @@ package me.uma.protocol +import io.ktor.utils.io.core.toByteArray +import me.uma.utils.ByteCodeable import me.uma.utils.TLVCodeable +import me.uma.utils.array import me.uma.utils.getBoolean import me.uma.utils.getByteCodeable import me.uma.utils.getNumber @@ -14,7 +17,6 @@ import me.uma.utils.putTLVCodeable import me.uma.utils.putNumber import me.uma.utils.putString import me.uma.utils.valueOffset -import java.nio.ByteBuffer import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable @@ -22,6 +24,9 @@ import kotlinx.serialization.builtins.ByteArraySerializer import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder +import org.bitcoinj.core.Bech32 + +private const val UMA_BECH32_PREFIX = "uma" @Serializable(with = InvoiceCurrencyTLVSerializer::class) data class InvoiceCurrency( @@ -54,20 +59,12 @@ data class InvoiceCurrency( } } - override fun toTLV(): ByteArray { - val bytes = ByteBuffer.allocate( - 2 + name.length + - 2 + code.length + - 2 + symbol.length + - 3 // for int - ) - .putString(0, code) - .putString(1, name) - .putString(2, symbol) - .putNumber(3, decimals) - .array() - return bytes - } + override fun toTLV() = mutableListOf() + .putString(0, code) + .putString(1, name) + .putString(2, symbol) + .putNumber(3, decimals) + .array() } @OptIn(ExperimentalSerializationApi::class) @@ -87,23 +84,6 @@ class InvoiceCurrencyTLVSerializer: KSerializer { ) } -@OptIn(ExperimentalSerializationApi::class) -class InvoiceTLVSerializer: KSerializer { - private val delegateSerializer = ByteArraySerializer() - override val descriptor = SerialDescriptor("Invoice", delegateSerializer.descriptor) - - override fun serialize(encoder: Encoder, value: Invoice) { - encoder.encodeSerializableValue( - delegateSerializer, - value.toTLV() - ) - } - - override fun deserialize(decoder: Decoder) = Invoice.fromTLV( - decoder.decodeSerializableValue(delegateSerializer) - ) -} - @Serializable(with = InvoiceTLVSerializer::class) class Invoice( val receiverUma: String, @@ -148,10 +128,8 @@ class Invoice( // The signature of the UMA invoice val signature: ByteArray, ) : TLVCodeable { - override fun toTLV(): ByteArray { - val bytes = ByteBuffer.allocate( - 256 - ) + + override fun toTLV() = mutableListOf() .putString(0, receiverUma) .putString(1, invoiceUUID) .putNumber(2, amount) @@ -163,16 +141,12 @@ class Invoice( .putNumber(8, commentCharsAllowed) .putString(9, senderUma) .putNumber(10, invoiceLimit) - .putByteCodeable(11, KycStatusWrapper(kycStatus)) + .putByteCodeable(11, InvoiceKycStatus(kycStatus)) .putString(12, callback) .putByteArray(100, signature) .array() - return bytes - } - fun justForFun() { - - } + fun toBech32() = Bech32.encode(Bech32.Encoding.BECH32, UMA_BECH32_PREFIX, this.toTLV()) companion object { fun fromTLV(bytes: ByteArray): Invoice { @@ -194,9 +168,7 @@ class Invoice( while(offset < bytes.size) { val length = bytes[offset.lengthOffset()].toInt() when(bytes[offset].toInt()) { - 0 -> { - receiverUma = bytes.getString(offset.valueOffset(), length) - } + 0 -> receiverUma = bytes.getString(offset.valueOffset(), length) 1 -> invoiceUUID = bytes.getString(offset.valueOffset(), length) 2 -> amount = bytes.getNumber(offset.valueOffset(), length) 3 -> receivingCurrency = bytes.getTLV(offset.valueOffset(), length, InvoiceCurrency::fromTLV) as InvoiceCurrency @@ -211,7 +183,7 @@ class Invoice( 8 -> commentCharsAllowed = bytes.getNumber(offset.valueOffset(), length) 9 -> senderUma = bytes.getString(offset.valueOffset(), length) 10 -> invoiceLimit = bytes.getNumber(offset.valueOffset(), length) - 11 -> kycStatus = (bytes.getByteCodeable(offset.valueOffset(), length, KycStatusWrapper::fromBytes) as KycStatusWrapper).status + 11 -> kycStatus = (bytes.getByteCodeable(offset.valueOffset(), length, InvoiceKycStatus::fromBytes) as InvoiceKycStatus).status 12 -> callback = bytes.getString(offset.valueOffset(), length) 100 -> signature = bytes.sliceArray(offset.valueOffset()..< offset.valueOffset()+length ) @@ -235,5 +207,66 @@ class Invoice( signature = signature, ) } + + fun fromBech32(bech32String: String): Invoice { + val b32data = Bech32.decode(bech32String) + return fromTLV(b32data.data) + } + } +} + +@OptIn(ExperimentalSerializationApi::class) +class InvoiceTLVSerializer: KSerializer { + private val delegateSerializer = ByteArraySerializer() + override val descriptor = SerialDescriptor("Invoice", delegateSerializer.descriptor) + + override fun serialize(encoder: Encoder, value: Invoice) { + encoder.encodeSerializableValue( + delegateSerializer, + value.toTLV() + ) + } + + override fun deserialize(decoder: Decoder) = Invoice.fromTLV( + decoder.decodeSerializableValue(delegateSerializer) + ) +} + +data class InvoiceCounterPartyDataOptions( + val options: CounterPartyDataOptions +) : ByteCodeable { + override fun toBytes(): ByteArray { + val optionsString = options.map { (key, option) -> + "${key}:${ if (option.mandatory) 1 else 0}" + }.joinToString(",") + return optionsString.toByteArray(Charsets.UTF_8) + } + + companion object { + fun fromBytes(bytes: ByteArray): InvoiceCounterPartyDataOptions { + val optionsString = String(bytes) + return InvoiceCounterPartyDataOptions( + optionsString.split(",").mapNotNull { + val options = it.split(':') + if (options.size == 2) { + options[0] to CounterPartyDataOption(options[1] == "1") + } else null + }.toMap() + ) + } + } +} + +data class InvoiceKycStatus(val status: KycStatus): ByteCodeable { + override fun toBytes(): ByteArray { + return status.rawValue.toByteArray() + } + + companion object { + fun fromBytes(bytes: ByteArray): InvoiceKycStatus { + return InvoiceKycStatus( + KycStatus.fromRawValue(bytes.toString(Charsets.UTF_8)) + ) + } } } diff --git a/uma-sdk/src/commonMain/kotlin/me/uma/protocol/KycStatus.kt b/uma-sdk/src/commonMain/kotlin/me/uma/protocol/KycStatus.kt index 32f9728..48c5b4c 100644 --- a/uma-sdk/src/commonMain/kotlin/me/uma/protocol/KycStatus.kt +++ b/uma-sdk/src/commonMain/kotlin/me/uma/protocol/KycStatus.kt @@ -5,7 +5,6 @@ import kotlinx.serialization.Serializable import kotlinx.serialization.encodeToString import me.uma.utils.EnumSerializer import me.uma.utils.serialFormat -import okio.Utf8 @Serializable(with = KycStatusSerializer::class) enum class KycStatus(val rawValue: String) { @@ -31,20 +30,6 @@ enum class KycStatus(val rawValue: String) { } } -data class KycStatusWrapper(val status: KycStatus): ByteCodeable { - override fun toBytes(): ByteArray { - return status.rawValue.toByteArray() - } - - companion object { - fun fromBytes(bytes: ByteArray): KycStatusWrapper { - return KycStatusWrapper( - KycStatus.fromRawValue(bytes.toString(Charsets.UTF_8)) - ) - } - } -} - object KycStatusSerializer : EnumSerializer( KycStatus::class, diff --git a/uma-sdk/src/commonMain/kotlin/me/uma/utils/TLVUtils.kt b/uma-sdk/src/commonMain/kotlin/me/uma/utils/TLVUtils.kt index 2943515..944b0b6 100644 --- a/uma-sdk/src/commonMain/kotlin/me/uma/utils/TLVUtils.kt +++ b/uma-sdk/src/commonMain/kotlin/me/uma/utils/TLVUtils.kt @@ -17,67 +17,109 @@ fun Int.lengthOffset() = this + 1 fun Int.valueOffset() = this + 2 -fun ByteBuffer.putString(tag: Int, value: String): ByteBuffer { - return put(tag.toByte()) - .put(value.length.toByte()) - .put(value.toByteArray()) +fun MutableList.putString(tag: Int, value: String): MutableList { + val byteStr = value.toByteArray(Charsets.UTF_8) + add( + ByteBuffer.allocate(2 + byteStr.size) + .put(tag.toByte()) + .put(byteStr.size.toByte()) + .put(byteStr) + .array() + ) + return this } -fun ByteBuffer.putNumber(tag: Int, value: Number): ByteBuffer { - put(tag.toByte()) // insert tag - return when (value) { - is Int -> { - when (value) { - in Byte.MIN_VALUE.toInt()..Byte.MAX_VALUE.toInt() -> { - put(Byte.SIZE_BYTES.toByte()).put(value.toByte()) - } - in Short.MIN_VALUE.toInt()..Short.MAX_VALUE.toInt() -> { - put(Short.SIZE_BYTES.toByte()).putShort(value.toShort()) +fun MutableList.putNumber(tag: Int, value: Number): MutableList { + val tlvBuffer = { numberSize: Int -> + ByteBuffer + .allocate(2 + numberSize) + .put(tag.toByte()) + .put(numberSize.toByte()) + } + add( + when (value) { + is Int -> { + when (value) { + in Byte.MIN_VALUE.toInt()..Byte.MAX_VALUE.toInt() -> { + tlvBuffer(Byte.SIZE_BYTES).put(value.toByte()) + } + in Short.MIN_VALUE.toInt()..Short.MAX_VALUE.toInt() -> { + tlvBuffer(Short.SIZE_BYTES).putShort(value.toShort()) + } + else -> { + tlvBuffer(Int.SIZE_BYTES).putInt(value) + } } - else -> put(Int.SIZE_BYTES.toByte()).putInt(value) } - } - is Short -> { - when (value) { - in Byte.MIN_VALUE..Byte.MAX_VALUE -> { - put(Byte.SIZE_BYTES.toByte()).put(value.toByte()) + is Short -> { + when (value) { + in Byte.MIN_VALUE..Byte.MAX_VALUE -> { + tlvBuffer(Byte.SIZE_BYTES).put(value.toByte()) + } + else -> tlvBuffer(Short.SIZE_BYTES).putShort(value.toShort()) } - else -> put(Short.SIZE_BYTES.toByte()).putShort(value) } - } - is Byte -> put(Byte.SIZE_BYTES.toByte()).put(value.toByte()) - is Float -> put(Float.SIZE_BYTES.toByte()).putFloat(value) - is Double -> put(Double.SIZE_BYTES.toByte()).putDouble(value) - is Long -> put(Long.SIZE_BYTES.toByte()).putLong(value) - else -> throw IllegalArgumentException("Unsupported type: ${value::class.simpleName}") - } + is Byte -> tlvBuffer(Byte.SIZE_BYTES).put(value.toByte()) + is Float -> tlvBuffer(Float.SIZE_BYTES).putFloat(value) + is Double -> tlvBuffer(Double.SIZE_BYTES).putDouble(value) + is Long -> tlvBuffer(Long.SIZE_BYTES).putLong(value) + else -> throw IllegalArgumentException("Unsupported type: ${value::class.simpleName}") + }.array() + ) + return this } -fun ByteBuffer.putBoolean(tag: Int, value: Boolean): ByteBuffer { - return put(tag.toByte()) - .put(1) - .put(if (value) 1 else 0) +fun MutableList.putBoolean(tag: Int, value: Boolean): MutableList { + add( + ByteBuffer.allocate(2 + 1) + .put(tag.toByte()) + .put(1) + .put(if(value) 1 else 0) + .array() + ) + return this } -fun ByteBuffer.putByteArray(tag: Int, value: ByteArray): ByteBuffer = - put(tag.toByte()) - .put(value.size.toByte()) - .put(value) +fun MutableList.putByteArray(tag: Int, value: ByteArray): MutableList { + add( + ByteBuffer.allocate(2 + value.size) + .put(tag.toByte()) + .put(value.size.toByte()) + .put(value) + .array() + ) + return this +} -fun ByteBuffer.putByteCodeable(tag: Int, value: ByteCodeable): ByteBuffer { +fun MutableList.putByteCodeable(tag: Int, value: ByteCodeable): MutableList { val encodedBytes = value.toBytes() - return put(tag.toByte()) - .put(encodedBytes.size.toByte()) - .put(encodedBytes) + add( + ByteBuffer.allocate(2 + encodedBytes.size) + .put(tag.toByte()) + .put(encodedBytes.size.toByte()) + .put(encodedBytes) + .array() + ) + return this } -fun ByteBuffer.putTLVCodeable(tag: Int, value: TLVCodeable): ByteBuffer { +fun MutableList.putTLVCodeable(tag: Int, value: TLVCodeable): MutableList { val encodedBytes = value.toTLV() - return put(tag.toByte()) - .put(encodedBytes.size.toByte()) - .put(encodedBytes) + add( + ByteBuffer.allocate(2 + encodedBytes.size) + .put(tag.toByte()) + .put(encodedBytes.size.toByte()) + .put(encodedBytes) + .array() + ) + return this } +fun MutableList.array(): ByteArray { + val buffer = ByteBuffer.allocate(sumOf { it.size }) + forEach(buffer::put) + return buffer.array() +} fun ByteArray.getNumber(offset: Int, length: Int): Int { val buffer = ByteBuffer.wrap(slice(offset..(encoded) assertEquals("usd", result.code) @@ -38,48 +37,25 @@ class UmaTests { @Test fun `test create invoice`() = runTest { - val requiredPayerData = mapOf( - "name" to CounterPartyDataOption(false), - "email" to CounterPartyDataOption(false), - "compliance" to CounterPartyDataOption(true), - ) - val invoiceCurrency = InvoiceCurrency( - code = "USD", - name = "US Dollar", - symbol = "$", - decimals = 2, - ) - val invoice = Invoice( - receiverUma = "\$foo@bar.com", - invoiceUUID = "c7c07fec-cf00-431c-916f-6c13fc4b69f9", - amount = 1000, - receivingCurrency = invoiceCurrency, - expiration = 1000000, - isSubjectToTravelRule = true, - requiredPayerData = requiredPayerData, - commentCharsAllowed = 30, - senderUma = "\$other@uma.com", - invoiceLimit = 100, - umaVersion = "0.3", - kycStatus = KycStatus.VERIFIED, - callback = "https://example.com/callback", - signature = "signature".toByteArray(), - ) + val invoice = createInvoice() val serializedInvoice = serialFormat.encodeToString(invoice) val result = serialFormat.decodeFromString(serializedInvoice) - assertEquals("\$foo@bar.com", result.receiverUma) - assertEquals("c7c07fec-cf00-431c-916f-6c13fc4b69f9", result.invoiceUUID) - assertEquals(1000, result.amount) - assertEquals(1000000, result.expiration) - assertEquals(true, result.isSubjectToTravelRule) - assertEquals(30, result.commentCharsAllowed) - assertEquals("\$other@uma.com", result.senderUma) - assertEquals(100, result.invoiceLimit) - assertEquals("0.3", result.umaVersion) - assertEquals(KycStatus.VERIFIED, result.kycStatus) - assertEquals("https://example.com/callback", result.callback) - assertEquals(requiredPayerData, result.requiredPayerData) - assertEquals(invoiceCurrency, result.receivingCurrency) + validateInvoice(invoice, result) + } + + @Test + fun `test encode invoice as bech32`() = runTest { + val invoice = createInvoice() + val bech32str = try { + invoice.toBech32() + } catch (e: IndexOutOfBoundsException) { + "" + } + assertEquals("uma", bech32str.slice(0..2)) + + val decodedInvoice = Invoice.fromBech32(bech32str) + validateInvoice(invoice, decodedInvoice) + } @Test @@ -206,4 +182,51 @@ class UmaTests { compliancePayerData, ) } + + private fun createInvoice( + ): Invoice { + val requiredPayerData = mapOf( + "name" to CounterPartyDataOption(false), + "email" to CounterPartyDataOption(false), + "compliance" to CounterPartyDataOption(true), + ) + val invoiceCurrency = InvoiceCurrency( + code = "USD", + name = "US Dollar", + symbol = "$", + decimals = 2, + ) + return Invoice( + receiverUma = "\$foo@bar.com", + invoiceUUID = "c7c07fec-cf00-431c-916f-6c13fc4b69f9", + amount = 1000, + receivingCurrency = invoiceCurrency, + expiration = 1000000, + isSubjectToTravelRule = true, + requiredPayerData = requiredPayerData, + commentCharsAllowed = 30, + senderUma = "\$other@uma.com", + invoiceLimit = 100, + umaVersion = "0.3", + kycStatus = KycStatus.VERIFIED, + callback = "https://example.com/callback", + signature = "signature".toByteArray(), + ) + } + + private fun validateInvoice(preEncodedInvoice: Invoice, decodedInvoice: Invoice) { + assertEquals(preEncodedInvoice.receiverUma, decodedInvoice.receiverUma) + assertEquals(preEncodedInvoice.invoiceUUID, decodedInvoice.invoiceUUID) + assertEquals(preEncodedInvoice.amount, decodedInvoice.amount) + assertEquals(preEncodedInvoice.expiration, decodedInvoice.expiration) + assertEquals(preEncodedInvoice.isSubjectToTravelRule, decodedInvoice.isSubjectToTravelRule) + assertEquals(preEncodedInvoice.commentCharsAllowed, decodedInvoice.commentCharsAllowed) + assertEquals(preEncodedInvoice.senderUma, decodedInvoice.senderUma) + assertEquals(preEncodedInvoice.invoiceLimit, decodedInvoice.invoiceLimit) + assertEquals(preEncodedInvoice.umaVersion, decodedInvoice.umaVersion) + assertEquals(preEncodedInvoice.kycStatus, decodedInvoice.kycStatus) + assertEquals(preEncodedInvoice.callback, decodedInvoice.callback) + assertEquals(preEncodedInvoice.requiredPayerData, decodedInvoice.requiredPayerData) + assertEquals(preEncodedInvoice.receivingCurrency, decodedInvoice.receivingCurrency) + } }