From f6ee1cf3d1cebde4fa27a9a3df69209361e48c56 Mon Sep 17 00:00:00 2001 From: Tian Date: Wed, 27 Mar 2024 11:48:53 -0400 Subject: [PATCH] change vault state key to not use serialization (#1248) --- protocol/x/vault/keeper/orders.go | 13 +-- protocol/x/vault/types/vault_id.go | 37 +++++++- protocol/x/vault/types/vault_id_test.go | 107 +++++++++++++++++++++++- 3 files changed, 145 insertions(+), 12 deletions(-) diff --git a/protocol/x/vault/keeper/orders.go b/protocol/x/vault/keeper/orders.go index a1f2610bac..768758340b 100644 --- a/protocol/x/vault/keeper/orders.go +++ b/protocol/x/vault/keeper/orders.go @@ -38,8 +38,11 @@ func (k Keeper) RefreshAllVaultOrders(ctx sdk.Context) { totalSharesIterator := k.getTotalSharesIterator(ctx) defer totalSharesIterator.Close() for ; totalSharesIterator.Valid(); totalSharesIterator.Next() { - var vaultId types.VaultId - k.cdc.MustUnmarshal(totalSharesIterator.Key(), &vaultId) + vaultId, err := types.GetVaultIdFromStateKey(totalSharesIterator.Key()) + if err != nil { + log.ErrorLogWithError(ctx, "Failed to get vault ID from state key", err) + continue + } var totalShares types.NumShares k.cdc.MustUnmarshal(totalSharesIterator.Value(), &totalShares) @@ -53,12 +56,12 @@ func (k Keeper) RefreshAllVaultOrders(ctx sdk.Context) { // Currently only supported vault type is CLOB. switch vaultId.Type { case types.VaultType_VAULT_TYPE_CLOB: - err := k.RefreshVaultClobOrders(ctx, vaultId) + err := k.RefreshVaultClobOrders(ctx, *vaultId) if err != nil { - log.ErrorLogWithError(ctx, "Failed to refresh vault clob orders", err, "vaultId", vaultId) + log.ErrorLogWithError(ctx, "Failed to refresh vault clob orders", err, "vaultId", *vaultId) } default: - log.ErrorLog(ctx, "Failed to refresh vault orders: unknown vault type", "vaultId", vaultId) + log.ErrorLog(ctx, "Failed to refresh vault orders: unknown vault type", "vaultId", *vaultId) } } } diff --git a/protocol/x/vault/types/vault_id.go b/protocol/x/vault/types/vault_id.go index aae07e82be..47b0c7a625 100644 --- a/protocol/x/vault/types/vault_id.go +++ b/protocol/x/vault/types/vault_id.go @@ -2,19 +2,50 @@ package types import ( fmt "fmt" + "strconv" + "strings" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/dydxprotocol/v4-chain/protocol/lib/metrics" satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types" ) +// ToString returns the string representation of a vault ID. +func (id *VaultId) ToString() string { + return fmt.Sprintf("%s-%d", id.Type, id.Number) +} + // ToStateKey returns the state key for the vault ID. func (id *VaultId) ToStateKey() []byte { - b, err := id.Marshal() + return []byte(id.ToString()) +} + +// GetVaultIdFromStateKey returns a vault ID from a given state key. +func GetVaultIdFromStateKey(stateKey []byte) (*VaultId, error) { + stateKeyStr := string(stateKey) + + // Split state key string into type and number. + split := strings.Split(stateKeyStr, "-") + if len(split) != 2 { + return nil, fmt.Errorf("stateKey in string must follow format - but got %s", stateKeyStr) + } + + // Parse vault type. + vaultTypeInt, exists := VaultType_value[split[0]] + if !exists { + return nil, fmt.Errorf("unknown vault type: %s", split[0]) + } + + // Parse vault number. + number, err := strconv.ParseUint(split[1], 10, 32) if err != nil { - panic(err) + return nil, fmt.Errorf("failed to parse number: %s", err.Error()) } - return b + + return &VaultId{ + Type: VaultType(vaultTypeInt), + Number: uint32(number), + }, nil } // ToModuleAccountAddress returns the module account address for the vault ID diff --git a/protocol/x/vault/types/vault_id_test.go b/protocol/x/vault/types/vault_id_test.go index cc4fef9b38..629d727d89 100644 --- a/protocol/x/vault/types/vault_id_test.go +++ b/protocol/x/vault/types/vault_id_test.go @@ -6,15 +6,114 @@ import ( authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/dydxprotocol/v4-chain/protocol/testutil/constants" satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types" + "github.com/dydxprotocol/v4-chain/protocol/x/vault/types" "github.com/stretchr/testify/require" ) +func TestToString(t *testing.T) { + tests := map[string]struct { + // Vault ID. + vaultId types.VaultId + // Expected string. + expectedStr string + }{ + "Vault for Clob Pair 0": { + vaultId: constants.Vault_Clob_0, + expectedStr: "VAULT_TYPE_CLOB-0", + }, + "Vault for Clob Pair 1": { + vaultId: constants.Vault_Clob_1, + expectedStr: "VAULT_TYPE_CLOB-1", + }, + "Vault, missing type and number": { + vaultId: types.VaultId{}, + expectedStr: "VAULT_TYPE_UNSPECIFIED-0", + }, + "Vault, missing type": { + vaultId: types.VaultId{ + Number: 1, + }, + expectedStr: "VAULT_TYPE_UNSPECIFIED-1", + }, + "Vault, missing number": { + vaultId: types.VaultId{ + Type: types.VaultType_VAULT_TYPE_CLOB, + }, + expectedStr: "VAULT_TYPE_CLOB-0", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + require.Equal( + t, + tc.vaultId.ToString(), + tc.expectedStr, + ) + }) + } +} + func TestToStateKey(t *testing.T) { - b, _ := constants.Vault_Clob_0.Marshal() - require.Equal(t, b, constants.Vault_Clob_0.ToStateKey()) + require.Equal( + t, + []byte("VAULT_TYPE_CLOB-0"), + constants.Vault_Clob_0.ToStateKey(), + ) + + require.Equal( + t, + []byte("VAULT_TYPE_CLOB-1"), + constants.Vault_Clob_1.ToStateKey(), + ) +} + +func TestGetVaultIdFromStateKey(t *testing.T) { + tests := map[string]struct { + // State key. + stateKey []byte + // Expected vault ID. + expectedVaultId types.VaultId + // Expected error. + expectedErr string + }{ + "Vault for Clob Pair 0": { + stateKey: []byte("VAULT_TYPE_CLOB-0"), + expectedVaultId: constants.Vault_Clob_0, + }, + "Vault for Clob Pair 1": { + stateKey: []byte("VAULT_TYPE_CLOB-1"), + expectedVaultId: constants.Vault_Clob_1, + }, + "Empty bytes": { + stateKey: []byte{}, + expectedErr: "stateKey in string must follow format -", + }, + "Nil bytes": { + stateKey: nil, + expectedErr: "stateKey in string must follow format -", + }, + "Non-existent vault type": { + stateKey: []byte("VAULT_TYPE_SPOT-1"), + expectedErr: "unknown vault type", + }, + "Malformed vault number": { + stateKey: []byte("VAULT_TYPE_CLOB-abc"), + expectedErr: "failed to parse number", + }, + } - b, _ = constants.Vault_Clob_1.Marshal() - require.Equal(t, b, constants.Vault_Clob_1.ToStateKey()) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + vaultId, err := types.GetVaultIdFromStateKey(tc.stateKey) + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedVaultId, *vaultId) + } + }) + } } func TestToModuleAccountAddress(t *testing.T) {