Skip to content

Commit

Permalink
Fixing convert number functions for Longs (#56)
Browse files Browse the repository at this point in the history
# Long Serialization
There was an issue serializing longs in UMA Invoices - 
most often, the expiration field is a long, since it represents a unix timestamp 

This change converts the field to a long and handles serialization/deserialization of said long
  • Loading branch information
matthappens authored Sep 5, 2024
2 parents 9559b6b + bcaa590 commit 826b5a8
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 32 deletions.
12 changes: 4 additions & 8 deletions uma-sdk/src/commonMain/kotlin/me/uma/UmaProtocolHelper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -893,10 +893,7 @@ class UmaProtocolHelper @JvmOverloads constructor(
return identifier.substring(atIndex + 1)
}

fun verifyUmaInvoice(
invoice: Invoice,
pubKeyResponse: PubKeyResponse,
): Boolean {
fun verifyUmaInvoice(invoice: Invoice, pubKeyResponse: PubKeyResponse): Boolean {
return invoice.signature?.let { signature ->
verifySignature(
invoice.toSignablePayload(),
Expand All @@ -909,11 +906,10 @@ class UmaProtocolHelper @JvmOverloads constructor(
fun getInvoice(
receiverUma: String,
invoiceUUID: String,
amount: Int,
amount: Long,
receivingCurrency: InvoiceCurrency,
expiration: Int,
expiration: Long,
isSubjectToTravelRule: Boolean,
umaVersion: String,
commentCharsAllowed: Int? = null,
senderUma: String? = null,
invoiceLimit: Int? = null,
Expand All @@ -929,7 +925,7 @@ class UmaProtocolHelper @JvmOverloads constructor(
receivingCurrency = receivingCurrency,
expiration = expiration,
isSubjectToTravelRule = isSubjectToTravelRule,
umaVersion = umaVersion,
umaVersion = UMA_VERSION_STRING,
commentCharsAllowed = commentCharsAllowed,
senderUma = senderUma,
invoiceLimit = invoiceLimit,
Expand Down
18 changes: 9 additions & 9 deletions uma-sdk/src/commonMain/kotlin/me/uma/protocol/Invoice.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ data class InvoiceCurrency(
0 -> code = bytes.getString(offset.valueOffset(), length)
1 -> name = bytes.getString(offset.valueOffset(), length)
2 -> symbol = bytes.getString(offset.valueOffset(), length)
3 -> decimals = bytes.getNumber(offset.valueOffset(), length)
3 -> decimals = bytes.getInt(offset.valueOffset(), length)
}
offset = offset.valueOffset() + length
}
Expand Down Expand Up @@ -73,11 +73,11 @@ data class Invoice(
/** Invoice UUID Served as both the identifier of the UMA invoice, and the validation of proof of payment.*/
val invoiceUUID: String,
/** The amount of invoice to be paid in the smallest unit of the ReceivingCurrency. */
val amount: Int,
val amount: Long,
/** The currency of the invoice */
val receivingCurrency: InvoiceCurrency,
/** The unix timestamp the UMA invoice expires */
val expiration: Int,
val expiration: Long,
/** Indicates whether the VASP is a financial institution that requires travel rule information. */
val isSubjectToTravelRule: Boolean,
/** RequiredPayerData the data about the payer that the sending VASP must provide in order to send a payment. */
Expand Down Expand Up @@ -141,12 +141,12 @@ data class Invoice(
when (bytes[offset].toInt()) {
0 -> ib.receiverUma = bytes.getString(offset.valueOffset(), length)
1 -> ib.invoiceUUID = bytes.getString(offset.valueOffset(), length)
2 -> ib.amount = bytes.getNumber(offset.valueOffset(), length)
2 -> ib.amount = bytes.getLong(offset.valueOffset(), length)
3 ->
ib.receivingCurrency =
bytes.getTLV(offset.valueOffset(), length, InvoiceCurrency::fromTLV) as InvoiceCurrency

4 -> ib.expiration = bytes.getNumber(offset.valueOffset(), length)
4 -> ib.expiration = bytes.getLong(offset.valueOffset(), length)
5 -> ib.isSubjectToTravelRule = bytes.getBoolean(offset.valueOffset())
6 ->
ib.requiredPayerData =
Expand All @@ -159,9 +159,9 @@ data class Invoice(
).options

7 -> ib.umaVersion = bytes.getString(offset.valueOffset(), length)
8 -> ib.commentCharsAllowed = bytes.getNumber(offset.valueOffset(), length)
8 -> ib.commentCharsAllowed = bytes.getInt(offset.valueOffset(), length)
9 -> ib.senderUma = bytes.getString(offset.valueOffset(), length)
10 -> ib.invoiceLimit = bytes.getNumber(offset.valueOffset(), length)
10 -> ib.invoiceLimit = bytes.getInt(offset.valueOffset(), length)
11 ->
ib.kycStatus = (
bytes.getByteCodeable(
Expand Down Expand Up @@ -193,9 +193,9 @@ data class Invoice(
class InvoiceBuilder {
var receiverUma: String? = null
var invoiceUUID: String? = null
var amount: Int? = null
var amount: Long? = null
var receivingCurrency: InvoiceCurrency? = null
var expiration: Int? = null
var expiration: Long? = null
var isSubjectToTravelRule: Boolean? = null
var requiredPayerData: CounterPartyDataOptions? = null
var umaVersion: String? = null
Expand Down
40 changes: 37 additions & 3 deletions uma-sdk/src/commonMain/kotlin/me/uma/utils/TLVUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ fun MutableList<ByteArray>.putNumber(tag: Int, value: Number?): MutableList<Byte
}
add(
when (value) {
is Long -> {
when (value) {
in Byte.MIN_VALUE.toInt()..Byte.MAX_VALUE.toInt() -> {
tlvBuffer(Byte.SIZE_BYTES).put(value.toByte())
}
in Short.MIN_VALUE.toInt()..Short.MAX_VALUE.toInt() -> {
tlvBuffer(Short.SIZE_BYTES).putShort(value.toShort())
}
in Int.MIN_VALUE..Int.MAX_VALUE -> {
tlvBuffer(Int.SIZE_BYTES).putInt(value.toInt())
}
else -> {
tlvBuffer(Long.SIZE_BYTES).putLong(value)
}
}
}

is Int -> {
when (value) {
in Byte.MIN_VALUE.toInt()..Byte.MAX_VALUE.toInt() -> {
Expand All @@ -54,6 +71,7 @@ fun MutableList<ByteArray>.putNumber(tag: Int, value: Number?): MutableList<Byte
}
}
}

is Short -> {
when (value) {
in Byte.MIN_VALUE..Byte.MAX_VALUE -> {
Expand All @@ -62,12 +80,12 @@ fun MutableList<ByteArray>.putNumber(tag: Int, value: Number?): MutableList<Byte
else -> tlvBuffer(Short.SIZE_BYTES).putShort(value.toShort())
}
}

is Byte -> tlvBuffer(Byte.SIZE_BYTES).put(value.toByte())
is Float -> tlvBuffer(Float.SIZE_BYTES).putFloat(value)
is Double -> tlvBuffer(Double.SIZE_BYTES).putDouble(value)
is Long -> tlvBuffer(Long.SIZE_BYTES).putLong(value)
else -> throw IllegalArgumentException("Unsupported type: ${value::class.simpleName}")
}.array()
}.array(),
)
return this
}
Expand Down Expand Up @@ -128,12 +146,28 @@ fun MutableList<ByteArray>.array(): ByteArray {
return buffer.array()
}

fun ByteArray.getNumber(offset: Int, length: Int): Int {
fun ByteArray.getInt(offset: Int, length: Int): Int {
return getNumber(offset, length).toInt()
}

fun ByteArray.getLong(offset: Int, length: Int): Long {
return getNumber(offset, length).toLong()
}

/**
* in Invoice's TLV, the numeric fields are stored in their smallest possible representation (ie, 9L would
* be stored as a single Byte)
* what this means is that when deserializing, we can't simply call buffer.getLong() for Long fields,
* as the encoded field may be as little as 1 byte, triggering a Buffer Underflow Exception.
* Instead, we read the value based on its byte length, and then case it to a Long or Int in a wrapper function
*/
private fun ByteArray.getNumber(offset: Int, length: Int): Number {
val buffer = ByteBuffer.wrap(slice(offset..<offset + length).toByteArray())
return when (length) {
1 -> this[offset].toInt()
2 -> buffer.getShort().toInt()
4 -> buffer.getInt()
8 -> buffer.getLong()
else -> this[offset].toInt()
}
}
Expand Down
31 changes: 19 additions & 12 deletions uma-sdk/src/commonTest/kotlin/me/uma/UmaTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ class UmaTests {
validateInvoice(invoice, result)
}

@Test
fun `correctly serialized timestamps in invoices`() = runTest {
val timestamp = System.currentTimeMillis()
val invoice = createInvoice(timestamp)
val serializedInvoice = serialFormat.encodeToString(invoice)
val result = serialFormat.decodeFromString<Invoice>(serializedInvoice)
assertEquals(result.expiration, timestamp)
}

@Test
fun `deserializing an Invoice with missing required fields triggers error`() = runTest {
val exception =
Expand Down Expand Up @@ -88,7 +97,7 @@ class UmaTests {
assertEquals("\$[email protected]", decodedInvoice.receiverUma)
assertEquals("c7c07fec-cf00-431c-916f-6c13fc4b69f9", decodedInvoice.invoiceUUID)
assertEquals(1000, decodedInvoice.amount)
assertEquals(1000000, decodedInvoice.expiration)
assertEquals(1000000L, decodedInvoice.expiration)
assertEquals(true, decodedInvoice.isSubjectToTravelRule)
assertEquals("0.3", decodedInvoice.umaVersion)
assertEquals(KycStatus.VERIFIED, decodedInvoice.kycStatus)
Expand All @@ -113,10 +122,9 @@ class UmaTests {
commentCharsAllowed = null,
senderUma = null,
invoiceLimit = null,
umaVersion = "0.3",
kycStatus = KycStatus.VERIFIED,
callback = "https://example.com/callback",
privateSigningKey = keys.privateKey
privateSigningKey = keys.privateKey,
)
assertTrue(UmaProtocolHelper().verifyUmaInvoice(invoice, PubKeyResponse(keys.publicKey, keys.publicKey)))
}
Expand All @@ -136,12 +144,11 @@ class UmaTests {
utxoCallback = "https://example.com/utxo",
travelRuleInfo = "travel rule info",
travelRuleFormat = TravelRuleFormat("someFormat", "1.0"),
requestedPayeeData =
createCounterPartyDataOptions(
"email" to true,
"name" to false,
"compliance" to true,
),
requestedPayeeData = createCounterPartyDataOptions(
"email" to true,
"name" to false,
"compliance" to true,
),
receiverUmaVersion = "1.0",
)
assertTrue(payreq is PayRequestV1)
Expand Down Expand Up @@ -251,7 +258,7 @@ class UmaTests {
)
}

private fun createInvoice(): Invoice {
private fun createInvoice(timestamp: Long? = null): Invoice {
val requiredPayerData =
mapOf(
"name" to CounterPartyDataOption(false),
Expand All @@ -269,9 +276,9 @@ class UmaTests {
return Invoice(
receiverUma = "\$[email protected]",
invoiceUUID = "c7c07fec-cf00-431c-916f-6c13fc4b69f9",
amount = 1000,
amount = 1000L,
receivingCurrency = invoiceCurrency,
expiration = 1000000,
expiration = timestamp ?: 1000000L,
isSubjectToTravelRule = true,
requiredPayerData = requiredPayerData,
commentCharsAllowed = null,
Expand Down

0 comments on commit 826b5a8

Please sign in to comment.