From 63302cc37b5a91b6e14a9e62037fbc97ec064529 Mon Sep 17 00:00:00 2001 From: Zhen Lu Date: Thu, 11 Jul 2024 18:16:54 -0700 Subject: [PATCH] Initial Invoice implementation (#36) * Invoice implementation * Sort the strings before encoding so it's deterministic * Fix the other test --- go.mod | 1 + go.sum | 19 ++++ uma/protocol/counter_party_data.go | 39 ++++++++ uma/protocol/invoice.go | 94 +++++++++++++++++++ uma/protocol/kyc_status.go | 8 ++ uma/test/protocol_test.go | 58 ++++++++++++ uma/test/tlv_utils_test.go | 29 ++++-- uma/utils/tlv_utils.go | 146 +++++++++++++++++------------ 8 files changed, 324 insertions(+), 70 deletions(-) create mode 100644 uma/protocol/invoice.go diff --git a/go.mod b/go.mod index 07d1f84..f9ef677 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( ) require ( + github.com/btcsuite/btcutil v1.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/ethereum/go-ethereum v1.13.8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index c2ea465..60ce072 100644 --- a/go.sum +++ b/go.sum @@ -89,6 +89,7 @@ github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNu github.com/VictoriaMetrics/fastcache v1.12.1/go.mod h1:tX04vaqcNoQeGLD+ra5pU5sWkuxnzWhEzLwhP9w653o= github.com/aclements/go-gg v0.0.0-20170118225347-6dbb4e4fefb0/go.mod h1:55qNq4vcpkIuHowELi5C8e+1yUHtoLoOUR9QU5j7Tes= github.com/aclements/go-moremath v0.0.0-20210112150236-f10218a38794/go.mod h1:7e+I0LQFUI9AXWxOfsQROs9xPhoJtbsyWcjJqDd4KPY= +github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/ajstarks/svgo v0.0.0-20210923152817-c3b6e2f0c527/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= @@ -142,8 +143,18 @@ github.com/bmizerany/pat v0.0.0-20170815010413-6226ea591a40/go.mod h1:8rLXio+Wji github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU= github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= +github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= +github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= +github.com/btcsuite/btcutil v1.0.2 h1:9iZ1Terx9fMIOtq1VrwdqfsATL9MC2l8ZrUY6YZ2uts= +github.com/btcsuite/btcutil v1.0.2/go.mod h1:j9HUFwoQRsZL3V4n+qG+CUnEGHOarIxfC3Le2Yhbcts= +github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg= +github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd/go.mod h1:F+uVaaLLH7j4eDXPRvw78tMflu7Ie2bzYOH4Y8rRKBY= +github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723/go.mod h1:8woku9dyThutzjeg+3xrA5iCpBRH8XEEg3lh6TiUghc= +github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY= +github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs= github.com/c-bata/go-prompt v0.2.2/go.mod h1:VzqtzE2ksDBcdln8G7mk2RX9QyGjH+OVqOCSiVIqS34= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/cp v0.1.0/go.mod h1:SOGHArjBr4JWaSDEVpWpo/hNg6RoKrls6Oh40hiwW+s= @@ -204,6 +215,7 @@ github.com/crate-crypto/go-kzg-4844 v0.7.0/go.mod h1:1kMhvPgI0Ky3yIa+9lFySEBUBXk github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/cyberdelia/templates v0.0.0-20141128023046-ca7fffd4298c/go.mod h1:GyV+0YP4qX0UQ7r2MoYZ+AvYDp12OF5yg4q8rGnyNh4= github.com/dave/jennifer v1.2.0/go.mod h1:fIb+770HOpJ2fmN9EPPKOqm1vMGhB+TwXKMZhrIygKg= +github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -533,11 +545,13 @@ github.com/iris-contrib/schema v0.0.6/go.mod h1:iYszG0IOsuIsfzjymw1kMzTL8YQcCWlm github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jedisct1/go-minisign v0.0.0-20190909160543-45766022959e/go.mod h1:G1CVv03EnqU1wYL2dFwXxW2An0az9JTl/ZsqXQeBlkU= github.com/jedisct1/go-minisign v0.0.0-20230811132847-661be99b8267/go.mod h1:h1nSAbGFqGVzn6Jyl1R/iCcBUHN4g+gW1u9CoBTrb9E= +github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= +github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -581,6 +595,7 @@ github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvW github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= github.com/klauspost/compress v1.4.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.9.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= @@ -712,6 +727,7 @@ github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+ github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= @@ -719,6 +735,7 @@ github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9k github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.0.0/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= @@ -952,6 +969,7 @@ go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9i go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= +golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -963,6 +981,7 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190909091759-094676da4a83/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= diff --git a/uma/protocol/counter_party_data.go b/uma/protocol/counter_party_data.go index 6ec0cfb..7b7fc81 100644 --- a/uma/protocol/counter_party_data.go +++ b/uma/protocol/counter_party_data.go @@ -1,5 +1,11 @@ package protocol +import ( + "fmt" + "sort" + "strings" +) + type CounterPartyDataOption struct { Mandatory bool `json:"mandatory"` } @@ -22,3 +28,36 @@ const ( func (c CounterPartyDataField) String() string { return string(c) } + +func (c *CounterPartyDataOptions) MarshalBytes() ([]byte, error) { + pairs := make([]string, 0, len(*c)) + for k, v := range *c { + str := k + ":" + if v.Mandatory { + str += "1" + } else { + str += "0" + } + pairs = append(pairs, str) + } + + sort.Strings(pairs) + + result := []byte(strings.Join(pairs, ",")) + return result, nil +} + +func (c *CounterPartyDataOptions) UnmarshalBytes(data []byte) error { + *c = make(CounterPartyDataOptions) + pairs := strings.Split(string(data), ",") + for _, pair := range pairs { + parts := strings.Split(pair, ":") + if len(parts) != 2 { + return fmt.Errorf("invalid pair: %s", pair) + } + (*c)[parts[0]] = CounterPartyDataOption{ + Mandatory: parts[1] == "1", + } + } + return nil +} diff --git a/uma/protocol/invoice.go b/uma/protocol/invoice.go new file mode 100644 index 0000000..0c7b560 --- /dev/null +++ b/uma/protocol/invoice.go @@ -0,0 +1,94 @@ +package protocol + +import ( + "fmt" + + "github.com/btcsuite/btcutil/bech32" + "github.com/uma-universal-money-address/uma-go-sdk/uma/utils" +) + +type InvoiceCurrency struct { + // Code is the ISO 4217 (if applicable) currency code (eg. "USD"). For cryptocurrencies, this will be a ticker + // symbol, such as BTC for Bitcoin. + Code string `tlv:"0"` + + // Name is the full display name of the currency (eg. US Dollars). + Name string `tlv:"1"` + + // Symbol is the symbol of the currency (eg. $ for USD). + Symbol string `tlv:"2"` +} + +func (c *InvoiceCurrency) MarshalTLV() ([]byte, error) { + return utils.MarshalTLV(c) +} + +func (c *InvoiceCurrency) UnmarshalTLV(data []byte) error { + return utils.UnmarshalTLV(c, data) +} + +type UmaInvoice struct { + // Receiving UMA address + ReceiverUma string `tlv:"0"` + + // Invoice UUID Served as both the identifier of the UMA invoice, and the validation of proof of payment. + InvoiceUUID string `tlv:"1"` + + // The amount of invoice to be paid in the smalest unit of the ReceivingCurrency. + Amount uint64 `tlv:"2"` + + // The currency of the invoice + ReceivingCurrency InvoiceCurrency `tlv:"3"` + + // The unix timestamp the UMA invoice expires + Expiration uint64 `tlv:"4"` + + // Indicates whether the VASP is a financial institution that requires travel rule information. + IsSubjectToTravelRule bool `tlv:"5"` + + // RequiredPayerData the data about the payer that the sending VASP must provide in order to send a payment. + RequiredPayerData *CounterPartyDataOptions `tlv:"6"` + + // UmaVersion is a list of UMA versions that the VASP supports for this transaction. It should be + // containing the lowest minor version of each major version it supported, separated by commas. + UmaVersion string `tlv:"7"` + + // CommentCharsAllowed is the number of characters that the sender can include in the comment field of the pay request. + CommentCharsAllowed *int `tlv:"8"` + + // The sender's UMA address. If this field presents, the UMA invoice should directly go to the sending VASP instead of showing in other formats. + SenderUma *string `tlv:"9"` + + // The maximum number of the invoice can be paid + InvoiceLimit *uint64 `tlv:"10"` + + // KYC status of the receiver, default is verified. + KycStatus *KycStatus `tlv:"11"` + + // The signature of the UMA invoice + Signature *[]byte `tlv:"100"` +} + +func (i *UmaInvoice) MarshalTLV() ([]byte, error) { + return utils.MarshalTLV(i) +} + +func (i *UmaInvoice) UnmarshalTLV(data []byte) error { + return utils.UnmarshalTLV(i, data) +} + +func (i *UmaInvoice) ToBech32String() (string, error) { + if i.Signature == nil { + return "", fmt.Errorf("signature is required to encode to bech32") + } + tlv, err := i.MarshalTLV() + if err != nil { + return "", err + } + conv, err := bech32.ConvertBits(tlv, 8, 5, true) + if err != nil { + return "", err + } + + return bech32.Encode("uma", conv) +} diff --git a/uma/protocol/kyc_status.go b/uma/protocol/kyc_status.go index 5dfc941..e537c02 100644 --- a/uma/protocol/kyc_status.go +++ b/uma/protocol/kyc_status.go @@ -52,3 +52,11 @@ func (k KycStatus) MarshalJSON() ([]byte, error) { s := k.StringValue() return json.Marshal(s) } + +func (k *KycStatus) MarshalBytes() ([]byte, error) { + return []byte(k.StringValue()), nil +} + +func (k *KycStatus) UnmarshalBytes(b []byte) error { + return k.UnmarshalJSON(b) +} diff --git a/uma/test/protocol_test.go b/uma/test/protocol_test.go index 0ecc324..448d1e7 100644 --- a/uma/test/protocol_test.go +++ b/uma/test/protocol_test.go @@ -251,3 +251,61 @@ usLY8crt6ys3KQ== require.NoError(t, err) require.Equal(t, keysOnlyPubKeyResponse, reserializedPubKeyResponse) } + +func TestBinaryCodableForCounterPartyDataOptions(t *testing.T) { + counterPartyDataOptions := umaprotocol.CounterPartyDataOptions{ + "name": umaprotocol.CounterPartyDataOption{Mandatory: false}, + "email": umaprotocol.CounterPartyDataOption{Mandatory: false}, + "compliance": umaprotocol.CounterPartyDataOption{Mandatory: true}, + } + result, err := counterPartyDataOptions.MarshalBytes() + require.NoError(t, err) + + resultStr := string(result) + require.Equal(t, "compliance:1,email:0,name:0", resultStr) + + counterPartyDataOptions2 := umaprotocol.CounterPartyDataOptions{} + err = counterPartyDataOptions2.UnmarshalBytes([]byte(resultStr)) + require.NoError(t, err) + require.Equal(t, counterPartyDataOptions, counterPartyDataOptions2) +} + +func TestUnsignInvoiceTLVCoding(t *testing.T) { + kyc := umaprotocol.KycStatusVerified + signature := []byte("signature") + invoicenvoice := umaprotocol.UmaInvoice{ + ReceiverUma: "$foo@bar.com", + InvoiceUUID: "c7c07fec-cf00-431c-916f-6c13fc4b69f9", + Amount: 1000, + ReceivingCurrency: umaprotocol.InvoiceCurrency{ + Code: "USD", + Name: "US Dollar", + Symbol: "$", + }, + Expiration: 1000000, + IsSubjectToTravelRule: true, + RequiredPayerData: &umaprotocol.CounterPartyDataOptions{ + "name": umaprotocol.CounterPartyDataOption{Mandatory: false}, + "email": umaprotocol.CounterPartyDataOption{Mandatory: false}, + "compliance": umaprotocol.CounterPartyDataOption{Mandatory: true}, + }, + UmaVersion: "0.3", + CommentCharsAllowed: nil, + SenderUma: nil, + InvoiceLimit: nil, + KycStatus: &kyc, + Signature: &signature, + } + + invoiceTLV, err := invoicenvoice.MarshalTLV() + require.NoError(t, err) + + invoice2 := umaprotocol.UmaInvoice{} + err = invoice2.UnmarshalTLV(invoiceTLV) + require.NoError(t, err) + require.Equal(t, invoicenvoice, invoice2) + + bech32String, err := invoice2.ToBech32String() + require.NoError(t, err) + require.Equal(t, "uma1qqxzgen0daqxyctj9e3k7mgpy33nwcesxanx2cedvdnrqvpdxsenzced8ycnve3dxe3nzvmxvv6xyd3evcusypp3xqcrqqcnqqp4256yqyy425eqg3hkcmrpwgpqzfqyqucnqvpsxqcrqpgpqyrpkcm0d4cxc6tpde3k2w3393jk6ctfdsarqtrwv9kk2w3squpnqt3npvqnxeqfwd5kwmnpw36hyega7x5zz", bech32String) +} diff --git a/uma/test/tlv_utils_test.go b/uma/test/tlv_utils_test.go index 91010f3..229ab8f 100644 --- a/uma/test/tlv_utils_test.go +++ b/uma/test/tlv_utils_test.go @@ -20,10 +20,12 @@ func (b *BinaryCodableStruct) UnmarshalBytes(data []byte) error { } type TLVUtilsTests struct { - StringField string `tlv:"0"` - IntField int `tlv:"1"` - BoolField bool `tlv:"2"` - UInt64Field uint64 `tlv:"3"` + StringField string `tlv:"0"` + IntField int `tlv:"1"` + BoolField bool `tlv:"2"` + UInt64Field uint64 `tlv:"3"` + OptionalStringField *string `tlv:"6"` + OptionalEmptyStringField *string `tlv:"7"` } func (d *TLVUtilsTests) MarshalTLV() ([]byte, error) { @@ -35,11 +37,14 @@ func (d *TLVUtilsTests) UnmarshalTLV(data []byte) error { } func TestSimpleTLVCoder(t *testing.T) { + str := "optional" tlvUtilsTests := TLVUtilsTests{ - StringField: "hello", - IntField: 42, - BoolField: true, - UInt64Field: 123, + StringField: "hello", + IntField: 42, + BoolField: true, + UInt64Field: 123, + OptionalStringField: &str, + OptionalEmptyStringField: nil, } data, err := tlvUtilsTests.MarshalTLV() @@ -68,6 +73,14 @@ func TestSimpleTLVCoder(t *testing.T) { if tlvUtilsTests.UInt64Field != tlvUtilsTests2.UInt64Field { t.Fatalf("expected %d, got %d", tlvUtilsTests.UInt64Field, tlvUtilsTests2.UInt64Field) } + + if *tlvUtilsTests.OptionalStringField != *tlvUtilsTests2.OptionalStringField { + t.Fatalf("expected %s, got %s", *tlvUtilsTests.OptionalStringField, *tlvUtilsTests2.OptionalStringField) + } + + if tlvUtilsTests2.OptionalEmptyStringField != nil { + t.Fatalf("expected optional empty string field to be nil") + } } type NestedTLVUtilsTests struct { diff --git a/uma/utils/tlv_utils.go b/uma/utils/tlv_utils.go index 636f6ee..7f19812 100644 --- a/uma/utils/tlv_utils.go +++ b/uma/utils/tlv_utils.go @@ -20,7 +20,7 @@ type TLVCodable interface { // MarshalTLV marshals a struct to TLV. // It will marshals all the field with tag "tlv". -// The "tlv" tag value will be the type of the field. +// The tagged value will be the type of func MarshalTLV(v interface{}) ([]byte, error) { val := reflect.ValueOf(v) if val.Kind() != reflect.Ptr || val.Elem().Kind() != reflect.Struct { @@ -30,6 +30,40 @@ func MarshalTLV(v interface{}) ([]byte, error) { val = reflect.Indirect(val) typ := val.Type() + var handle func(field reflect.Value) ([]byte, error) + handle = func(field reflect.Value) ([]byte, error) { + switch field.Kind() { + case reflect.String: + return []byte(field.String()), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return []byte(strconv.FormatInt(field.Int(), 10)), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return []byte(strconv.FormatUint(field.Uint(), 10)), nil + case reflect.Bool: + if field.Bool() { + return []byte{1}, nil + } else { + return []byte{0}, nil + } + case reflect.Ptr: + if field.IsNil() { + return nil, nil + } + return handle(reflect.Indirect(field)) + case reflect.Slice: + return field.Bytes(), nil + default: + pointer := field.Addr().Interface() + if coder, ok := pointer.(TLVCodable); ok { + return coder.MarshalTLV() + } else if coder, ok := pointer.(BytesCodable); ok { + return coder.MarshalBytes() + } else { + return nil, fmt.Errorf("unsupported type %s", field.Kind()) + } + } + } + var result []byte for i := 0; i < val.NumField(); i++ { field := val.Field(i) @@ -42,40 +76,13 @@ func MarshalTLV(v interface{}) ([]byte, error) { return nil, err } - var content []byte - - switch field.Kind() { - case reflect.String: - content = []byte(field.String()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - content = []byte(strconv.FormatInt(field.Int(), 10)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - content = []byte(strconv.FormatUint(field.Uint(), 10)) - case reflect.Struct: - pointer := field.Addr().Interface() - if coder, ok := pointer.(TLVCodable); ok { - content, err = coder.MarshalTLV() - if err != nil { - return nil, err - } - } else if coder, ok := pointer.(BytesCodable); ok { - content, err = coder.MarshalBytes() - if err != nil { - return nil, err - } - } else { - return nil, fmt.Errorf("unsupported struct type %s", field.Type().Name()) - } - case reflect.Bool: - if field.Bool() { - content = []byte{1} - } else { - content = []byte{0} - } - default: - return nil, fmt.Errorf("unsupported type %s", field.Kind()) + content, err := handle(field) + if err != nil { + return nil, err + } + if content == nil { + continue } - result = append(result, byte(tlv)) result = append(result, byte(len(content))) result = append(result, content...) @@ -85,7 +92,6 @@ func MarshalTLV(v interface{}) ([]byte, error) { // UnmarshalTLV unmarshals a struct from TLV. // It will unmarshals all the field with tag "tlv". -// The "tlv" tag value will be the type of the field. func UnmarshalTLV(v interface{}, data []byte) error { result := make(map[byte][]byte) for i := 0; i < len(data); { @@ -111,56 +117,72 @@ func UnmarshalTLV(v interface{}, data []byte) error { return fmt.Errorf("unmarshal requires a pointer to a struct") } val = reflect.Indirect(val) - for i := 0; i < val.NumField(); i++ { - field := val.Field(i) - tag := val.Type().Field(i).Tag.Get("tlv") - if tag == "" { - continue - } - tlv, err := strconv.Atoi(tag) - if err != nil { - return err - } - - content, ok := result[byte(tlv)] - if !ok { - continue - } - + var handle func(field reflect.Value, value []byte) error + handle = func(field reflect.Value, value []byte) error { switch field.Kind() { case reflect.String: - field.SetString(string(content)) + field.SetString(string(value)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - i, err := strconv.ParseInt(string(content), 10, 64) + i, err := strconv.ParseInt(string(value), 10, 64) if err != nil { return err } field.SetInt(i) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - i, err := strconv.ParseUint(string(content), 10, 64) + i, err := strconv.ParseUint(string(value), 10, 64) if err != nil { return err } field.SetUint(i) - case reflect.Struct: + case reflect.Bool: + field.SetBool(value[0] != 0) + case reflect.Ptr: + if field.IsNil() { + newValue := reflect.New(field.Type().Elem()) + field.Set(newValue) + } + return handle(field.Elem(), value) + case reflect.Slice: + field.SetBytes(value) + default: pointer := field.Addr().Interface() if coder, ok := pointer.(TLVCodable); ok { - err := coder.UnmarshalTLV(content) + err := coder.UnmarshalTLV(value) if err != nil { return err } } else if coder, ok := pointer.(BytesCodable); ok { - err := coder.UnmarshalBytes(content) + err := coder.UnmarshalBytes(value) if err != nil { return err } } else { - return fmt.Errorf("unsupported struct type %s", field.Type().Name()) + return fmt.Errorf("unsupported type %s", field.Kind()) } - case reflect.Bool: - field.SetBool(content[0] != 0) - default: - return fmt.Errorf("unsupported type %s", field.Kind()) + } + return nil + } + + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + tag := val.Type().Field(i).Tag.Get("tlv") + if tag == "" { + continue + } + tlv, err := strconv.Atoi(tag) + if err != nil { + return err + } + + content, ok := result[byte(tlv)] + if !ok { + continue + } + + err = handle(field, content) + + if err != nil { + return err } }