diff --git a/protocol/x/vault/keeper/msg_server_deposit_to_vault_test.go b/protocol/x/vault/keeper/msg_server_deposit_to_vault_test.go index 16e120924b..b210f534ee 100644 --- a/protocol/x/vault/keeper/msg_server_deposit_to_vault_test.go +++ b/protocol/x/vault/keeper/msg_server_deposit_to_vault_test.go @@ -31,6 +31,8 @@ type DepositInstance struct { checkTxFails bool // Whether DeliverTx fails. deliverTxFails bool + // Expected owner shares for depositor above. + expectedOwnerShares *big.Rat } // DepositorSetup represents the setup of a depositor. @@ -67,14 +69,16 @@ func TestMsgDepositToVault(t *testing.T) { }, depositInstances: []DepositInstance{ { - depositor: constants.Alice_Num0, - depositAmount: big.NewInt(123), - msgSigner: constants.Alice_Num0.Owner, + depositor: constants.Alice_Num0, + depositAmount: big.NewInt(123), + msgSigner: constants.Alice_Num0.Owner, + expectedOwnerShares: big.NewRat(123, 1), }, { - depositor: constants.Alice_Num0, - depositAmount: big.NewInt(321), - msgSigner: constants.Alice_Num0.Owner, + depositor: constants.Alice_Num0, + depositAmount: big.NewInt(321), + msgSigner: constants.Alice_Num0.Owner, + expectedOwnerShares: big.NewRat(444, 1), }, }, totalSharesHistory: []*big.Rat{ @@ -100,14 +104,16 @@ func TestMsgDepositToVault(t *testing.T) { }, depositInstances: []DepositInstance{ { - depositor: constants.Alice_Num0, - depositAmount: big.NewInt(1_000), - msgSigner: constants.Alice_Num0.Owner, + depositor: constants.Alice_Num0, + depositAmount: big.NewInt(1_000), + msgSigner: constants.Alice_Num0.Owner, + expectedOwnerShares: big.NewRat(1_000, 1), }, { - depositor: constants.Bob_Num1, - depositAmount: big.NewInt(500), - msgSigner: constants.Bob_Num1.Owner, + depositor: constants.Bob_Num1, + depositAmount: big.NewInt(500), + msgSigner: constants.Bob_Num1.Owner, + expectedOwnerShares: big.NewRat(500, 1), }, }, totalSharesHistory: []*big.Rat{ @@ -133,15 +139,17 @@ func TestMsgDepositToVault(t *testing.T) { }, depositInstances: []DepositInstance{ { - depositor: constants.Alice_Num0, - depositAmount: big.NewInt(1_000), - msgSigner: constants.Alice_Num0.Owner, + depositor: constants.Alice_Num0, + depositAmount: big.NewInt(1_000), + msgSigner: constants.Alice_Num0.Owner, + expectedOwnerShares: big.NewRat(1_000, 1), }, { - depositor: constants.Bob_Num1, - depositAmount: big.NewInt(501), // Greater than balance. - msgSigner: constants.Bob_Num1.Owner, - deliverTxFails: true, + depositor: constants.Bob_Num1, + depositAmount: big.NewInt(501), // Greater than balance. + msgSigner: constants.Bob_Num1.Owner, + deliverTxFails: true, + expectedOwnerShares: nil, }, }, totalSharesHistory: []*big.Rat{ @@ -172,11 +180,13 @@ func TestMsgDepositToVault(t *testing.T) { msgSigner: constants.Alice_Num0.Owner, // Incorrect signer. checkTxFails: true, checkTxResponseContains: "does not match signer address", + expectedOwnerShares: nil, }, { - depositor: constants.Alice_Num0, - depositAmount: big.NewInt(1_000), - msgSigner: constants.Alice_Num0.Owner, + depositor: constants.Alice_Num0, + depositAmount: big.NewInt(1_000), + msgSigner: constants.Alice_Num0.Owner, + expectedOwnerShares: big.NewRat(1_000, 1), }, }, totalSharesHistory: []*big.Rat{ @@ -207,6 +217,7 @@ func TestMsgDepositToVault(t *testing.T) { msgSigner: constants.Alice_Num0.Owner, checkTxFails: true, checkTxResponseContains: "Deposit amount is invalid", + expectedOwnerShares: nil, }, { depositor: constants.Bob_Num0, @@ -214,6 +225,7 @@ func TestMsgDepositToVault(t *testing.T) { msgSigner: constants.Bob_Num0.Owner, checkTxFails: true, checkTxResponseContains: "Deposit amount is invalid", + expectedOwnerShares: nil, }, }, totalSharesHistory: []*big.Rat{ @@ -320,6 +332,17 @@ func TestMsgDepositToVault(t *testing.T) { vaulttypes.BigRatToNumShares(tc.totalSharesHistory[i]), totalShares, ) + // Check that owner shares of the depositor is as expected. + ownerShares, _ := tApp.App.VaultKeeper.GetOwnerShares( + ctx, + tc.vaultId, + depositInstance.depositor.Owner, + ) + require.Equal( + t, + vaulttypes.BigRatToNumShares(depositInstance.expectedOwnerShares), + ownerShares, + ) // Check that equity of the vault is as expected. vaultEquity, err := tApp.App.VaultKeeper.GetVaultEquity(ctx, tc.vaultId) require.NoError(t, err) diff --git a/protocol/x/vault/keeper/shares.go b/protocol/x/vault/keeper/shares.go index 20045fb1c0..5c2988741a 100644 --- a/protocol/x/vault/keeper/shares.go +++ b/protocol/x/vault/keeper/shares.go @@ -62,6 +62,54 @@ func (k Keeper) getTotalSharesIterator(ctx sdk.Context) storetypes.Iterator { return storetypes.KVStorePrefixIterator(store, []byte{}) } +// GetOwnerShares gets owner shares for an owner in a vault. +func (k Keeper) GetOwnerShares( + ctx sdk.Context, + vaultId types.VaultId, + owner string, +) (val types.NumShares, exists bool) { + store := k.getVaultOwnerSharesStore(ctx, vaultId) + + b := store.Get([]byte(owner)) + if b == nil { + return val, false + } + + k.cdc.MustUnmarshal(b, &val) + return val, true +} + +// SetOwnerShares sets owner shares for an owner in a vault. +func (k Keeper) SetOwnerShares( + ctx sdk.Context, + vaultId types.VaultId, + owner string, + ownerShares types.NumShares, +) error { + ownerSharesRat, err := ownerShares.ToBigRat() + if err != nil { + return err + } + if ownerSharesRat.Sign() < 0 { + return types.ErrNegativeShares + } + + b := k.cdc.MustMarshal(&ownerShares) + store := k.getVaultOwnerSharesStore(ctx, vaultId) + store.Set([]byte(owner), b) + + return nil +} + +// getVaultOwnerSharesStore returns the store for owner shares of a given vault. +func (k Keeper) getVaultOwnerSharesStore( + ctx sdk.Context, + vaultId types.VaultId, +) prefix.Store { + store := prefix.NewStore(ctx.KVStore(k.storeKey), []byte(types.OwnerSharesKeyPrefix)) + return prefix.NewStore(store, vaultId.ToStateKeyPrefix()) +} + // MintShares mints shares of a vault for `owner` based on `quantumsToDeposit` by: // 1. Increasing total shares of the vault. // 2. Increasing owner shares of the vault for given `owner`. @@ -127,7 +175,37 @@ func (k Keeper) MintShares( return err } - // TODO (TRA-170): Increase owner shares. + // Increase owner shares in the vault. + ownerShares, exists := k.GetOwnerShares(ctx, vaultId, owner) + if !exists { + // Set owner shares to be sharesToMint. + err := k.SetOwnerShares( + ctx, + vaultId, + owner, + types.BigRatToNumShares(sharesToMint), + ) + if err != nil { + return err + } + } else { + // Increase existing owner shares by sharesToMint. + existingOwnerShares, err := ownerShares.ToBigRat() + if err != nil { + return err + } + err = k.SetOwnerShares( + ctx, + vaultId, + owner, + types.BigRatToNumShares( + existingOwnerShares.Add(existingOwnerShares, sharesToMint), + ), + ) + if err != nil { + return err + } + } return nil } diff --git a/protocol/x/vault/keeper/shares_test.go b/protocol/x/vault/keeper/shares_test.go index 0b7accc856..b2da4e66b2 100644 --- a/protocol/x/vault/keeper/shares_test.go +++ b/protocol/x/vault/keeper/shares_test.go @@ -80,6 +80,60 @@ func TestGetSetTotalShares(t *testing.T) { ) } +func TestGetSetOwnerShares(t *testing.T) { + tApp := testapp.NewTestAppBuilder(t).Build() + ctx := tApp.InitChain() + k := tApp.App.VaultKeeper + + alice := constants.AliceAccAddress.String() + bob := constants.BobAccAddress.String() + + // Get owners shares for Alice in vault clob 0. + _, exists := k.GetOwnerShares(ctx, constants.Vault_Clob_0, alice) + require.Equal(t, false, exists) + + // Set owner shares for Alice in vault clob 0 and get. + numShares := vaulttypes.BigRatToNumShares( + big.NewRat(7, 1), + ) + err := k.SetOwnerShares(ctx, constants.Vault_Clob_0, alice, numShares) + require.NoError(t, err) + got, exists := k.GetOwnerShares(ctx, constants.Vault_Clob_0, alice) + require.Equal(t, true, exists) + require.Equal(t, numShares, got) + + // Set owner shares for Alice in vault clob 1 and then get. + numShares = vaulttypes.BigRatToNumShares( + big.NewRat(456, 3), + ) + err = k.SetOwnerShares(ctx, constants.Vault_Clob_1, alice, numShares) + require.NoError(t, err) + got, exists = k.GetOwnerShares(ctx, constants.Vault_Clob_1, alice) + require.Equal(t, true, exists) + require.Equal(t, numShares, got) + + // Set owner shares for Bob in vault clob 1. + numShares = vaulttypes.BigRatToNumShares( + big.NewRat(0, 1), + ) + err = k.SetOwnerShares(ctx, constants.Vault_Clob_1, bob, numShares) + require.NoError(t, err) + got, exists = k.GetOwnerShares(ctx, constants.Vault_Clob_1, bob) + require.Equal(t, true, exists) + require.Equal(t, numShares, got) + + // Set owner shares for Bob in vault clob 1 to a negative value. + // Should get error and total shares should remain unchanged. + numSharesInvalid := vaulttypes.BigRatToNumShares( + big.NewRat(-1, 1), + ) + err = k.SetOwnerShares(ctx, constants.Vault_Clob_1, bob, numSharesInvalid) + require.ErrorIs(t, err, vaulttypes.ErrNegativeShares) + got, exists = k.GetOwnerShares(ctx, constants.Vault_Clob_1, bob) + require.Equal(t, true, exists) + require.Equal(t, numShares, got) +} + func TestMintShares(t *testing.T) { tests := map[string]struct { /* --- Setup --- */ @@ -89,73 +143,98 @@ func TestMintShares(t *testing.T) { equity *big.Int // Existing vault TotalShares. totalShares *big.Rat + // Owner that deposits. + owner string + // Existing owner shares. + ownerShares *big.Rat // Quote quantums to deposit. quantumsToDeposit *big.Int /* --- Expectations --- */ // Expected TotalShares after minting. expectedTotalShares *big.Rat + // Expected OwnerShares after minting. + expectedOwnerShares *big.Rat // Expected error. expectedErr error }{ - "Equity 0, TotalShares 0, Deposit 1000": { + "Equity 0, TotalShares 0, OwnerShares 0, Deposit 1000": { vaultId: constants.Vault_Clob_0, equity: big.NewInt(0), totalShares: big.NewRat(0, 1), + owner: constants.AliceAccAddress.String(), + ownerShares: big.NewRat(0, 1), quantumsToDeposit: big.NewInt(1_000), // Should mint `1_000 / 1` shares. expectedTotalShares: big.NewRat(1_000, 1), + expectedOwnerShares: big.NewRat(1_000, 1), }, - "Equity 0, TotalShares non-existent, Deposit 12345654321": { + "Equity 0, TotalShares non-existent, OwnerShares non-existent, Deposit 12345654321": { vaultId: constants.Vault_Clob_0, equity: big.NewInt(0), + owner: constants.AliceAccAddress.String(), quantumsToDeposit: big.NewInt(12_345_654_321), // Should mint `12_345_654_321 / 1` shares. expectedTotalShares: big.NewRat(12_345_654_321, 1), + expectedOwnerShares: big.NewRat(12_345_654_321, 1), }, - "Equity 1000, TotalShares non-existent, Deposit 500": { + "Equity 1000, TotalShares non-existent, OwnerShares non-existent, Deposit 500": { vaultId: constants.Vault_Clob_0, equity: big.NewInt(1_000), + owner: constants.AliceAccAddress.String(), quantumsToDeposit: big.NewInt(500), // Should mint `500 / 1` shares. expectedTotalShares: big.NewRat(500, 1), + expectedOwnerShares: big.NewRat(500, 1), }, - "Equity 4000, TotalShares 5000, Deposit 1000": { + "Equity 4000, TotalShares 5000, OwnerShares 2500, Deposit 1000": { vaultId: constants.Vault_Clob_1, equity: big.NewInt(4_000), totalShares: big.NewRat(5_000, 1), + owner: constants.AliceAccAddress.String(), + ownerShares: big.NewRat(2_500, 1), quantumsToDeposit: big.NewInt(1_000), // Should mint `1_250 / 1` shares. expectedTotalShares: big.NewRat(6_250, 1), + expectedOwnerShares: big.NewRat(3_750, 1), }, - "Equity 1_000_000, TotalShares 1, Deposit 1": { + "Equity 1_000_000, TotalShares 1, OwnerShares 1/2, Deposit 1": { vaultId: constants.Vault_Clob_1, equity: big.NewInt(1_000_000), totalShares: big.NewRat(1, 1), + owner: constants.BobAccAddress.String(), + ownerShares: big.NewRat(1, 2), quantumsToDeposit: big.NewInt(1), // Should mint `1 / 1_000_000` shares. expectedTotalShares: big.NewRat(1_000_001, 1_000_000), + expectedOwnerShares: big.NewRat(500_001, 1_000_000), }, - "Equity 8000, TotalShares 4000, Deposit 455": { + "Equity 8000, TotalShares 4000, OwnerShares Deposit 455": { vaultId: constants.Vault_Clob_1, equity: big.NewInt(8_000), totalShares: big.NewRat(4_000, 1), + owner: constants.CarlAccAddress.String(), + ownerShares: big.NewRat(101, 7), quantumsToDeposit: big.NewInt(455), // Should mint `455 / 2` shares. expectedTotalShares: big.NewRat(8_455, 2), + expectedOwnerShares: big.NewRat(3_387, 14), }, - "Equity 123456, TotalShares 654321, Deposit 123456789": { + "Equity 123456, TotalShares 654321, OwnerShares 0, Deposit 123456789": { vaultId: constants.Vault_Clob_1, equity: big.NewInt(123_456), totalShares: big.NewRat(654_321, 1), + owner: constants.DaveAccAddress.String(), quantumsToDeposit: big.NewInt(123_456_789), // Should mint `26_926_789_878_423 / 41_152` shares. expectedTotalShares: big.NewRat(26_953_716_496_215, 41_152), + expectedOwnerShares: big.NewRat(26_926_789_878_423, 41_152), }, "Equity -1, TotalShares 10, Deposit 1": { vaultId: constants.Vault_Clob_1, equity: big.NewInt(-1), totalShares: big.NewRat(10, 1), + owner: constants.AliceAccAddress.String(), quantumsToDeposit: big.NewInt(1), expectedErr: vaulttypes.ErrNonPositiveEquity, }, @@ -163,12 +242,14 @@ func TestMintShares(t *testing.T) { vaultId: constants.Vault_Clob_1, equity: big.NewInt(1), totalShares: big.NewRat(1, 1), + owner: constants.AliceAccAddress.String(), quantumsToDeposit: big.NewInt(0), expectedErr: vaulttypes.ErrInvalidDepositAmount, }, "Equity 0, TotalShares non-existent, Deposit -1": { vaultId: constants.Vault_Clob_1, equity: big.NewInt(0), + owner: constants.AliceAccAddress.String(), quantumsToDeposit: big.NewInt(-1), expectedErr: vaulttypes.ErrInvalidDepositAmount, }, @@ -209,12 +290,21 @@ func TestMintShares(t *testing.T) { ) require.NoError(t, err) } + if tc.ownerShares != nil { + err := tApp.App.VaultKeeper.SetOwnerShares( + ctx, + tc.vaultId, + tc.owner, + vaulttypes.BigRatToNumShares(tc.ownerShares), + ) + require.NoError(t, err) + } // Mint shares. err := tApp.App.VaultKeeper.MintShares( ctx, tc.vaultId, - "", // TODO (TRA-170): Increase owner shares. + tc.owner, tc.quantumsToDeposit, ) if tc.expectedErr != nil { @@ -227,6 +317,9 @@ func TestMintShares(t *testing.T) { vaulttypes.BigRatToNumShares(tc.totalShares), totalShares, ) + // Check that OwnerShares is unchanged. + ownerShares, _ := tApp.App.VaultKeeper.GetOwnerShares(ctx, tc.vaultId, tc.owner) + require.Equal(t, vaulttypes.BigRatToNumShares(tc.ownerShares), ownerShares) } else { require.NoError(t, err) // Check that TotalShares is as expected. @@ -237,6 +330,14 @@ func TestMintShares(t *testing.T) { vaulttypes.BigRatToNumShares(tc.expectedTotalShares), totalShares, ) + // Check that OwnerShares is as expected. + ownerShares, exists := tApp.App.VaultKeeper.GetOwnerShares(ctx, tc.vaultId, tc.owner) + require.True(t, exists) + require.Equal( + t, + vaulttypes.BigRatToNumShares(tc.expectedOwnerShares), + ownerShares, + ) } }) } diff --git a/protocol/x/vault/types/keys.go b/protocol/x/vault/types/keys.go index 4e37abb4d9..1130a77fa0 100644 --- a/protocol/x/vault/types/keys.go +++ b/protocol/x/vault/types/keys.go @@ -14,6 +14,10 @@ const ( // TotalSharesKeyPrefix is the prefix to retrieve all TotalShares. TotalSharesKeyPrefix = "TotalShares:" + // OwnerSharesKeyPrefix is the prefix to retrieve all OwnerShares. + // OwnerShares store: vaultId VaultId -> owner string -> shares NumShares. + OwnerSharesKeyPrefix = "OwnerShares:" + // ParamsKey is the key to retrieve Params. ParamsKey = "Params" ) diff --git a/protocol/x/vault/types/vault_id.go b/protocol/x/vault/types/vault_id.go index 47b0c7a625..af070d0eae 100644 --- a/protocol/x/vault/types/vault_id.go +++ b/protocol/x/vault/types/vault_id.go @@ -20,6 +20,11 @@ func (id *VaultId) ToStateKey() []byte { return []byte(id.ToString()) } +// ToStateKeyPrefix returns the state key prefix for the vault ID. +func (id *VaultId) ToStateKeyPrefix() []byte { + return []byte(fmt.Sprintf("%s:", id.ToString())) +} + // GetVaultIdFromStateKey returns a vault ID from a given state key. func GetVaultIdFromStateKey(stateKey []byte) (*VaultId, error) { stateKeyStr := string(stateKey)