Skip to content

Commit

Permalink
Fix a bug that enum is encode to int. It should check for TLV/Binary …
Browse files Browse the repository at this point in the history
…first.
  • Loading branch information
zhenlu committed Jul 19, 2024
1 parent 45bc629 commit aeead08
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 55 deletions.
15 changes: 14 additions & 1 deletion uma/protocol/kyc_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,18 @@ func (k *KycStatus) MarshalBytes() ([]byte, error) {
}

func (k *KycStatus) UnmarshalBytes(b []byte) error {
return k.UnmarshalJSON(b)
s := string(b)
switch s {
default:
*k = KycStatusUnknown
case "UNKNOWN":
*k = KycStatusUnknown
case "NOT_VERIFIED":
*k = KycStatusNotVerified
case "PENDING":
*k = KycStatusPending
case "VERIFIED":
*k = KycStatusVerified
}
return nil
}
2 changes: 1 addition & 1 deletion uma/test/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ func TestUMAInvoiceTLVAndBech32(t *testing.T) {

bech32String, err := invoice2.ToBech32String()
require.NoError(t, err)
require.Equal(t, "uma1qqxzgen0daqxyctj9e3k7mgpy33nwcesxanx2cedvdnrqvpdxsenzced8ycnve3dxe3nzvmxvv6xyd3evcusypp3xqcrqqcnqqp4256yqyy425eqg3hkcmrpwgpqzfqyqucnqvpsxqcrqpgpqyrpkcm0d4cxc6tpde3k2w3393jk6ctfdsarqtrwv9kk2w3squpnqt3npvqnxrqudp68gurn8ghj7etcv9khqmr99e3k7mf0vdskcmrzv93kkeqfwd5kwmnpw36hyeg0e4m4j", bech32String)
require.Equal(t, "uma1qqxzgen0daqxyctj9e3k7mgpy33nwcesxanx2cedvdnrqvpdxsenzced8ycnve3dxe3nzvmxvv6xyd3evcusypp3xqcrqqcnqqp4256yqyy425eqg3hkcmrpwgpqzfqyqucnqvpsxqcrqpgpqyrpkcm0d4cxc6tpde3k2w3393jk6ctfdsarqtrwv9kk2w3squpnqt3npvy9v32jf9ryj32ypswxsar5wpen5te0v4uxzmtsd3jjucm0d5hkxctvd33xzcmtvsyhx6t8deshgatjv5sjy5ff", bech32String)

invoice3, err := umaprotocol.FromBech32String(bech32String)
require.NoError(t, err)
Expand Down
106 changes: 53 additions & 53 deletions uma/utils/tlv_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,33 @@ func MarshalTLV(v interface{}) ([]byte, error) {

var handle func(field reflect.Value) ([]byte, error)
handle = func(field reflect.Value) ([]byte, error) {
switch field.Kind() {
case reflect.String:
return []byte(field.String()), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return []byte(strconv.FormatInt(field.Int(), 10)), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return []byte(strconv.FormatUint(field.Uint(), 10)), nil
case reflect.Bool:
if field.Bool() {
return []byte{1}, nil
} else {
return []byte{0}, nil
}
case reflect.Ptr:
if field.IsNil() {
return nil, nil
}
return handle(reflect.Indirect(field))
case reflect.Slice:
return field.Bytes(), nil
default:
pointer := field.Addr().Interface()
if coder, ok := pointer.(TLVCodable); ok {
return coder.MarshalTLV()
} else if coder, ok := pointer.(BytesCodable); ok {
return coder.MarshalBytes()
} else {
pointer := field.Addr().Interface()
if coder, ok := pointer.(TLVCodable); ok {
return coder.MarshalTLV()
} else if coder, ok := pointer.(BytesCodable); ok {
return coder.MarshalBytes()
} else {
switch field.Kind() {
case reflect.String:
return []byte(field.String()), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return []byte(strconv.FormatInt(field.Int(), 10)), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return []byte(strconv.FormatUint(field.Uint(), 10)), nil
case reflect.Bool:
if field.Bool() {
return []byte{1}, nil
} else {
return []byte{0}, nil
}
case reflect.Ptr:
if field.IsNil() {
return nil, nil
}
return handle(reflect.Indirect(field))
case reflect.Slice:
return field.Bytes(), nil
default:
return nil, fmt.Errorf("unsupported type %s", field.Kind())
}
}
Expand Down Expand Up @@ -119,44 +119,44 @@ func UnmarshalTLV(v interface{}, data []byte) error {
val = reflect.Indirect(val)
var handle func(field reflect.Value, value []byte) error
handle = func(field reflect.Value, value []byte) error {
switch field.Kind() {
case reflect.String:
field.SetString(string(value))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i, err := strconv.ParseInt(string(value), 10, 64)
pointer := field.Addr().Interface()
if coder, ok := pointer.(TLVCodable); ok {
err := coder.UnmarshalTLV(value)
if err != nil {
return err
}
field.SetInt(i)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
i, err := strconv.ParseUint(string(value), 10, 64)
} else if coder, ok := pointer.(BytesCodable); ok {
err := coder.UnmarshalBytes(value)
if err != nil {
return err
}
field.SetUint(i)
case reflect.Bool:
field.SetBool(value[0] != 0)
case reflect.Ptr:
if field.IsNil() {
newValue := reflect.New(field.Type().Elem())
field.Set(newValue)
}
return handle(field.Elem(), value)
case reflect.Slice:
field.SetBytes(value)
default:
pointer := field.Addr().Interface()
if coder, ok := pointer.(TLVCodable); ok {
err := coder.UnmarshalTLV(value)
} else {
switch field.Kind() {
case reflect.String:
field.SetString(string(value))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i, err := strconv.ParseInt(string(value), 10, 64)
if err != nil {
return err
}
} else if coder, ok := pointer.(BytesCodable); ok {
err := coder.UnmarshalBytes(value)
field.SetInt(i)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
i, err := strconv.ParseUint(string(value), 10, 64)
if err != nil {
return err
}
} else {
field.SetUint(i)
case reflect.Bool:
field.SetBool(value[0] != 0)
case reflect.Ptr:
if field.IsNil() {
newValue := reflect.New(field.Type().Elem())
field.Set(newValue)
}
return handle(field.Elem(), value)
case reflect.Slice:
field.SetBytes(value)
default:
return fmt.Errorf("unsupported type %s", field.Kind())
}
}
Expand Down

0 comments on commit aeead08

Please sign in to comment.