From aeead08df9b5752f677a656ba65ed718ca196128 Mon Sep 17 00:00:00 2001 From: Zhen Lu Date: Fri, 19 Jul 2024 16:10:46 -0700 Subject: [PATCH] Fix a bug that enum is encode to int. It should check for TLV/Binary first. --- uma/protocol/kyc_status.go | 15 +++++- uma/test/protocol_test.go | 2 +- uma/utils/tlv_utils.go | 106 ++++++++++++++++++------------------- 3 files changed, 68 insertions(+), 55 deletions(-) diff --git a/uma/protocol/kyc_status.go b/uma/protocol/kyc_status.go index e537c02..88b8a5a 100644 --- a/uma/protocol/kyc_status.go +++ b/uma/protocol/kyc_status.go @@ -58,5 +58,18 @@ func (k *KycStatus) MarshalBytes() ([]byte, error) { } func (k *KycStatus) UnmarshalBytes(b []byte) error { - return k.UnmarshalJSON(b) + s := string(b) + switch s { + default: + *k = KycStatusUnknown + case "UNKNOWN": + *k = KycStatusUnknown + case "NOT_VERIFIED": + *k = KycStatusNotVerified + case "PENDING": + *k = KycStatusPending + case "VERIFIED": + *k = KycStatusVerified + } + return nil } diff --git a/uma/test/protocol_test.go b/uma/test/protocol_test.go index e95dc42..8e09f3d 100644 --- a/uma/test/protocol_test.go +++ b/uma/test/protocol_test.go @@ -308,7 +308,7 @@ func TestUMAInvoiceTLVAndBech32(t *testing.T) { bech32String, err := invoice2.ToBech32String() require.NoError(t, err) - require.Equal(t, "uma1qqxzgen0daqxyctj9e3k7mgpy33nwcesxanx2cedvdnrqvpdxsenzced8ycnve3dxe3nzvmxvv6xyd3evcusypp3xqcrqqcnqqp4256yqyy425eqg3hkcmrpwgpqzfqyqucnqvpsxqcrqpgpqyrpkcm0d4cxc6tpde3k2w3393jk6ctfdsarqtrwv9kk2w3squpnqt3npvqnxrqudp68gurn8ghj7etcv9khqmr99e3k7mf0vdskcmrzv93kkeqfwd5kwmnpw36hyeg0e4m4j", bech32String) + require.Equal(t, "uma1qqxzgen0daqxyctj9e3k7mgpy33nwcesxanx2cedvdnrqvpdxsenzced8ycnve3dxe3nzvmxvv6xyd3evcusypp3xqcrqqcnqqp4256yqyy425eqg3hkcmrpwgpqzfqyqucnqvpsxqcrqpgpqyrpkcm0d4cxc6tpde3k2w3393jk6ctfdsarqtrwv9kk2w3squpnqt3npvy9v32jf9ryj32ypswxsar5wpen5te0v4uxzmtsd3jjucm0d5hkxctvd33xzcmtvsyhx6t8deshgatjv5sjy5ff", bech32String) invoice3, err := umaprotocol.FromBech32String(bech32String) require.NoError(t, err) diff --git a/uma/utils/tlv_utils.go b/uma/utils/tlv_utils.go index 2a7303c..5643287 100644 --- a/uma/utils/tlv_utils.go +++ b/uma/utils/tlv_utils.go @@ -32,33 +32,33 @@ func MarshalTLV(v interface{}) ([]byte, error) { var handle func(field reflect.Value) ([]byte, error) handle = func(field reflect.Value) ([]byte, error) { - switch field.Kind() { - case reflect.String: - return []byte(field.String()), nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return []byte(strconv.FormatInt(field.Int(), 10)), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return []byte(strconv.FormatUint(field.Uint(), 10)), nil - case reflect.Bool: - if field.Bool() { - return []byte{1}, nil - } else { - return []byte{0}, nil - } - case reflect.Ptr: - if field.IsNil() { - return nil, nil - } - return handle(reflect.Indirect(field)) - case reflect.Slice: - return field.Bytes(), nil - default: - pointer := field.Addr().Interface() - if coder, ok := pointer.(TLVCodable); ok { - return coder.MarshalTLV() - } else if coder, ok := pointer.(BytesCodable); ok { - return coder.MarshalBytes() - } else { + pointer := field.Addr().Interface() + if coder, ok := pointer.(TLVCodable); ok { + return coder.MarshalTLV() + } else if coder, ok := pointer.(BytesCodable); ok { + return coder.MarshalBytes() + } else { + switch field.Kind() { + case reflect.String: + return []byte(field.String()), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return []byte(strconv.FormatInt(field.Int(), 10)), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return []byte(strconv.FormatUint(field.Uint(), 10)), nil + case reflect.Bool: + if field.Bool() { + return []byte{1}, nil + } else { + return []byte{0}, nil + } + case reflect.Ptr: + if field.IsNil() { + return nil, nil + } + return handle(reflect.Indirect(field)) + case reflect.Slice: + return field.Bytes(), nil + default: return nil, fmt.Errorf("unsupported type %s", field.Kind()) } } @@ -119,44 +119,44 @@ func UnmarshalTLV(v interface{}, data []byte) error { val = reflect.Indirect(val) var handle func(field reflect.Value, value []byte) error handle = func(field reflect.Value, value []byte) error { - switch field.Kind() { - case reflect.String: - field.SetString(string(value)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - i, err := strconv.ParseInt(string(value), 10, 64) + pointer := field.Addr().Interface() + if coder, ok := pointer.(TLVCodable); ok { + err := coder.UnmarshalTLV(value) if err != nil { return err } - field.SetInt(i) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - i, err := strconv.ParseUint(string(value), 10, 64) + } else if coder, ok := pointer.(BytesCodable); ok { + err := coder.UnmarshalBytes(value) if err != nil { return err } - field.SetUint(i) - case reflect.Bool: - field.SetBool(value[0] != 0) - case reflect.Ptr: - if field.IsNil() { - newValue := reflect.New(field.Type().Elem()) - field.Set(newValue) - } - return handle(field.Elem(), value) - case reflect.Slice: - field.SetBytes(value) - default: - pointer := field.Addr().Interface() - if coder, ok := pointer.(TLVCodable); ok { - err := coder.UnmarshalTLV(value) + } else { + switch field.Kind() { + case reflect.String: + field.SetString(string(value)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := strconv.ParseInt(string(value), 10, 64) if err != nil { return err } - } else if coder, ok := pointer.(BytesCodable); ok { - err := coder.UnmarshalBytes(value) + field.SetInt(i) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i, err := strconv.ParseUint(string(value), 10, 64) if err != nil { return err } - } else { + field.SetUint(i) + case reflect.Bool: + field.SetBool(value[0] != 0) + case reflect.Ptr: + if field.IsNil() { + newValue := reflect.New(field.Type().Elem()) + field.Set(newValue) + } + return handle(field.Elem(), value) + case reflect.Slice: + field.SetBytes(value) + default: return fmt.Errorf("unsupported type %s", field.Kind()) } }