From 083762c17c3290f948c098fc83e1f4ec826fe8ce Mon Sep 17 00:00:00 2001
From: Tian <tian@dydx.exchange>
Date: Wed, 27 Mar 2024 15:50:35 -0400
Subject: [PATCH] add vault owner shares (#1253)

---
 .../msg_server_deposit_to_vault_test.go       |  67 ++++++----
 protocol/x/vault/keeper/shares.go             |  80 +++++++++++-
 protocol/x/vault/keeper/shares_test.go        | 117 ++++++++++++++++--
 protocol/x/vault/types/keys.go                |   4 +
 protocol/x/vault/types/vault_id.go            |   5 +
 5 files changed, 242 insertions(+), 31 deletions(-)

diff --git a/protocol/x/vault/keeper/msg_server_deposit_to_vault_test.go b/protocol/x/vault/keeper/msg_server_deposit_to_vault_test.go
index 16e120924b..b210f534ee 100644
--- a/protocol/x/vault/keeper/msg_server_deposit_to_vault_test.go
+++ b/protocol/x/vault/keeper/msg_server_deposit_to_vault_test.go
@@ -31,6 +31,8 @@ type DepositInstance struct {
 	checkTxFails bool
 	// Whether DeliverTx fails.
 	deliverTxFails bool
+	// Expected owner shares for depositor above.
+	expectedOwnerShares *big.Rat
 }
 
 // DepositorSetup represents the setup of a depositor.
@@ -67,14 +69,16 @@ func TestMsgDepositToVault(t *testing.T) {
 			},
 			depositInstances: []DepositInstance{
 				{
-					depositor:     constants.Alice_Num0,
-					depositAmount: big.NewInt(123),
-					msgSigner:     constants.Alice_Num0.Owner,
+					depositor:           constants.Alice_Num0,
+					depositAmount:       big.NewInt(123),
+					msgSigner:           constants.Alice_Num0.Owner,
+					expectedOwnerShares: big.NewRat(123, 1),
 				},
 				{
-					depositor:     constants.Alice_Num0,
-					depositAmount: big.NewInt(321),
-					msgSigner:     constants.Alice_Num0.Owner,
+					depositor:           constants.Alice_Num0,
+					depositAmount:       big.NewInt(321),
+					msgSigner:           constants.Alice_Num0.Owner,
+					expectedOwnerShares: big.NewRat(444, 1),
 				},
 			},
 			totalSharesHistory: []*big.Rat{
@@ -100,14 +104,16 @@ func TestMsgDepositToVault(t *testing.T) {
 			},
 			depositInstances: []DepositInstance{
 				{
-					depositor:     constants.Alice_Num0,
-					depositAmount: big.NewInt(1_000),
-					msgSigner:     constants.Alice_Num0.Owner,
+					depositor:           constants.Alice_Num0,
+					depositAmount:       big.NewInt(1_000),
+					msgSigner:           constants.Alice_Num0.Owner,
+					expectedOwnerShares: big.NewRat(1_000, 1),
 				},
 				{
-					depositor:     constants.Bob_Num1,
-					depositAmount: big.NewInt(500),
-					msgSigner:     constants.Bob_Num1.Owner,
+					depositor:           constants.Bob_Num1,
+					depositAmount:       big.NewInt(500),
+					msgSigner:           constants.Bob_Num1.Owner,
+					expectedOwnerShares: big.NewRat(500, 1),
 				},
 			},
 			totalSharesHistory: []*big.Rat{
@@ -133,15 +139,17 @@ func TestMsgDepositToVault(t *testing.T) {
 			},
 			depositInstances: []DepositInstance{
 				{
-					depositor:     constants.Alice_Num0,
-					depositAmount: big.NewInt(1_000),
-					msgSigner:     constants.Alice_Num0.Owner,
+					depositor:           constants.Alice_Num0,
+					depositAmount:       big.NewInt(1_000),
+					msgSigner:           constants.Alice_Num0.Owner,
+					expectedOwnerShares: big.NewRat(1_000, 1),
 				},
 				{
-					depositor:      constants.Bob_Num1,
-					depositAmount:  big.NewInt(501), // Greater than balance.
-					msgSigner:      constants.Bob_Num1.Owner,
-					deliverTxFails: true,
+					depositor:           constants.Bob_Num1,
+					depositAmount:       big.NewInt(501), // Greater than balance.
+					msgSigner:           constants.Bob_Num1.Owner,
+					deliverTxFails:      true,
+					expectedOwnerShares: nil,
 				},
 			},
 			totalSharesHistory: []*big.Rat{
@@ -172,11 +180,13 @@ func TestMsgDepositToVault(t *testing.T) {
 					msgSigner:               constants.Alice_Num0.Owner, // Incorrect signer.
 					checkTxFails:            true,
 					checkTxResponseContains: "does not match signer address",
+					expectedOwnerShares:     nil,
 				},
 				{
-					depositor:     constants.Alice_Num0,
-					depositAmount: big.NewInt(1_000),
-					msgSigner:     constants.Alice_Num0.Owner,
+					depositor:           constants.Alice_Num0,
+					depositAmount:       big.NewInt(1_000),
+					msgSigner:           constants.Alice_Num0.Owner,
+					expectedOwnerShares: big.NewRat(1_000, 1),
 				},
 			},
 			totalSharesHistory: []*big.Rat{
@@ -207,6 +217,7 @@ func TestMsgDepositToVault(t *testing.T) {
 					msgSigner:               constants.Alice_Num0.Owner,
 					checkTxFails:            true,
 					checkTxResponseContains: "Deposit amount is invalid",
+					expectedOwnerShares:     nil,
 				},
 				{
 					depositor:               constants.Bob_Num0,
@@ -214,6 +225,7 @@ func TestMsgDepositToVault(t *testing.T) {
 					msgSigner:               constants.Bob_Num0.Owner,
 					checkTxFails:            true,
 					checkTxResponseContains: "Deposit amount is invalid",
+					expectedOwnerShares:     nil,
 				},
 			},
 			totalSharesHistory: []*big.Rat{
@@ -320,6 +332,17 @@ func TestMsgDepositToVault(t *testing.T) {
 					vaulttypes.BigRatToNumShares(tc.totalSharesHistory[i]),
 					totalShares,
 				)
+				// Check that owner shares of the depositor is as expected.
+				ownerShares, _ := tApp.App.VaultKeeper.GetOwnerShares(
+					ctx,
+					tc.vaultId,
+					depositInstance.depositor.Owner,
+				)
+				require.Equal(
+					t,
+					vaulttypes.BigRatToNumShares(depositInstance.expectedOwnerShares),
+					ownerShares,
+				)
 				// Check that equity of the vault is as expected.
 				vaultEquity, err := tApp.App.VaultKeeper.GetVaultEquity(ctx, tc.vaultId)
 				require.NoError(t, err)
diff --git a/protocol/x/vault/keeper/shares.go b/protocol/x/vault/keeper/shares.go
index 20045fb1c0..5c2988741a 100644
--- a/protocol/x/vault/keeper/shares.go
+++ b/protocol/x/vault/keeper/shares.go
@@ -62,6 +62,54 @@ func (k Keeper) getTotalSharesIterator(ctx sdk.Context) storetypes.Iterator {
 	return storetypes.KVStorePrefixIterator(store, []byte{})
 }
 
+// GetOwnerShares gets owner shares for an owner in a vault.
+func (k Keeper) GetOwnerShares(
+	ctx sdk.Context,
+	vaultId types.VaultId,
+	owner string,
+) (val types.NumShares, exists bool) {
+	store := k.getVaultOwnerSharesStore(ctx, vaultId)
+
+	b := store.Get([]byte(owner))
+	if b == nil {
+		return val, false
+	}
+
+	k.cdc.MustUnmarshal(b, &val)
+	return val, true
+}
+
+// SetOwnerShares sets owner shares for an owner in a vault.
+func (k Keeper) SetOwnerShares(
+	ctx sdk.Context,
+	vaultId types.VaultId,
+	owner string,
+	ownerShares types.NumShares,
+) error {
+	ownerSharesRat, err := ownerShares.ToBigRat()
+	if err != nil {
+		return err
+	}
+	if ownerSharesRat.Sign() < 0 {
+		return types.ErrNegativeShares
+	}
+
+	b := k.cdc.MustMarshal(&ownerShares)
+	store := k.getVaultOwnerSharesStore(ctx, vaultId)
+	store.Set([]byte(owner), b)
+
+	return nil
+}
+
+// getVaultOwnerSharesStore returns the store for owner shares of a given vault.
+func (k Keeper) getVaultOwnerSharesStore(
+	ctx sdk.Context,
+	vaultId types.VaultId,
+) prefix.Store {
+	store := prefix.NewStore(ctx.KVStore(k.storeKey), []byte(types.OwnerSharesKeyPrefix))
+	return prefix.NewStore(store, vaultId.ToStateKeyPrefix())
+}
+
 // MintShares mints shares of a vault for `owner` based on `quantumsToDeposit` by:
 // 1. Increasing total shares of the vault.
 // 2. Increasing owner shares of the vault for given `owner`.
@@ -127,7 +175,37 @@ func (k Keeper) MintShares(
 		return err
 	}
 
-	// TODO (TRA-170): Increase owner shares.
+	// Increase owner shares in the vault.
+	ownerShares, exists := k.GetOwnerShares(ctx, vaultId, owner)
+	if !exists {
+		// Set owner shares to be sharesToMint.
+		err := k.SetOwnerShares(
+			ctx,
+			vaultId,
+			owner,
+			types.BigRatToNumShares(sharesToMint),
+		)
+		if err != nil {
+			return err
+		}
+	} else {
+		// Increase existing owner shares by sharesToMint.
+		existingOwnerShares, err := ownerShares.ToBigRat()
+		if err != nil {
+			return err
+		}
+		err = k.SetOwnerShares(
+			ctx,
+			vaultId,
+			owner,
+			types.BigRatToNumShares(
+				existingOwnerShares.Add(existingOwnerShares, sharesToMint),
+			),
+		)
+		if err != nil {
+			return err
+		}
+	}
 
 	return nil
 }
diff --git a/protocol/x/vault/keeper/shares_test.go b/protocol/x/vault/keeper/shares_test.go
index 0b7accc856..b2da4e66b2 100644
--- a/protocol/x/vault/keeper/shares_test.go
+++ b/protocol/x/vault/keeper/shares_test.go
@@ -80,6 +80,60 @@ func TestGetSetTotalShares(t *testing.T) {
 	)
 }
 
+func TestGetSetOwnerShares(t *testing.T) {
+	tApp := testapp.NewTestAppBuilder(t).Build()
+	ctx := tApp.InitChain()
+	k := tApp.App.VaultKeeper
+
+	alice := constants.AliceAccAddress.String()
+	bob := constants.BobAccAddress.String()
+
+	// Get owners shares for Alice in vault clob 0.
+	_, exists := k.GetOwnerShares(ctx, constants.Vault_Clob_0, alice)
+	require.Equal(t, false, exists)
+
+	// Set owner shares for Alice in vault clob 0 and get.
+	numShares := vaulttypes.BigRatToNumShares(
+		big.NewRat(7, 1),
+	)
+	err := k.SetOwnerShares(ctx, constants.Vault_Clob_0, alice, numShares)
+	require.NoError(t, err)
+	got, exists := k.GetOwnerShares(ctx, constants.Vault_Clob_0, alice)
+	require.Equal(t, true, exists)
+	require.Equal(t, numShares, got)
+
+	// Set owner shares for Alice in vault clob 1 and then get.
+	numShares = vaulttypes.BigRatToNumShares(
+		big.NewRat(456, 3),
+	)
+	err = k.SetOwnerShares(ctx, constants.Vault_Clob_1, alice, numShares)
+	require.NoError(t, err)
+	got, exists = k.GetOwnerShares(ctx, constants.Vault_Clob_1, alice)
+	require.Equal(t, true, exists)
+	require.Equal(t, numShares, got)
+
+	// Set owner shares for Bob in vault clob 1.
+	numShares = vaulttypes.BigRatToNumShares(
+		big.NewRat(0, 1),
+	)
+	err = k.SetOwnerShares(ctx, constants.Vault_Clob_1, bob, numShares)
+	require.NoError(t, err)
+	got, exists = k.GetOwnerShares(ctx, constants.Vault_Clob_1, bob)
+	require.Equal(t, true, exists)
+	require.Equal(t, numShares, got)
+
+	// Set owner shares for Bob in vault clob 1 to a negative value.
+	// Should get error and total shares should remain unchanged.
+	numSharesInvalid := vaulttypes.BigRatToNumShares(
+		big.NewRat(-1, 1),
+	)
+	err = k.SetOwnerShares(ctx, constants.Vault_Clob_1, bob, numSharesInvalid)
+	require.ErrorIs(t, err, vaulttypes.ErrNegativeShares)
+	got, exists = k.GetOwnerShares(ctx, constants.Vault_Clob_1, bob)
+	require.Equal(t, true, exists)
+	require.Equal(t, numShares, got)
+}
+
 func TestMintShares(t *testing.T) {
 	tests := map[string]struct {
 		/* --- Setup --- */
@@ -89,73 +143,98 @@ func TestMintShares(t *testing.T) {
 		equity *big.Int
 		// Existing vault TotalShares.
 		totalShares *big.Rat
+		// Owner that deposits.
+		owner string
+		// Existing owner shares.
+		ownerShares *big.Rat
 		// Quote quantums to deposit.
 		quantumsToDeposit *big.Int
 
 		/* --- Expectations --- */
 		// Expected TotalShares after minting.
 		expectedTotalShares *big.Rat
+		// Expected OwnerShares after minting.
+		expectedOwnerShares *big.Rat
 		// Expected error.
 		expectedErr error
 	}{
-		"Equity 0, TotalShares 0, Deposit 1000": {
+		"Equity 0, TotalShares 0, OwnerShares 0, Deposit 1000": {
 			vaultId:           constants.Vault_Clob_0,
 			equity:            big.NewInt(0),
 			totalShares:       big.NewRat(0, 1),
+			owner:             constants.AliceAccAddress.String(),
+			ownerShares:       big.NewRat(0, 1),
 			quantumsToDeposit: big.NewInt(1_000),
 			// Should mint `1_000 / 1` shares.
 			expectedTotalShares: big.NewRat(1_000, 1),
+			expectedOwnerShares: big.NewRat(1_000, 1),
 		},
-		"Equity 0, TotalShares non-existent, Deposit 12345654321": {
+		"Equity 0, TotalShares non-existent, OwnerShares non-existent, Deposit 12345654321": {
 			vaultId:           constants.Vault_Clob_0,
 			equity:            big.NewInt(0),
+			owner:             constants.AliceAccAddress.String(),
 			quantumsToDeposit: big.NewInt(12_345_654_321),
 			// Should mint `12_345_654_321 / 1` shares.
 			expectedTotalShares: big.NewRat(12_345_654_321, 1),
+			expectedOwnerShares: big.NewRat(12_345_654_321, 1),
 		},
-		"Equity 1000, TotalShares non-existent, Deposit 500": {
+		"Equity 1000, TotalShares non-existent, OwnerShares non-existent, Deposit 500": {
 			vaultId:           constants.Vault_Clob_0,
 			equity:            big.NewInt(1_000),
+			owner:             constants.AliceAccAddress.String(),
 			quantumsToDeposit: big.NewInt(500),
 			// Should mint `500 / 1` shares.
 			expectedTotalShares: big.NewRat(500, 1),
+			expectedOwnerShares: big.NewRat(500, 1),
 		},
-		"Equity 4000, TotalShares 5000, Deposit 1000": {
+		"Equity 4000, TotalShares 5000, OwnerShares 2500, Deposit 1000": {
 			vaultId:           constants.Vault_Clob_1,
 			equity:            big.NewInt(4_000),
 			totalShares:       big.NewRat(5_000, 1),
+			owner:             constants.AliceAccAddress.String(),
+			ownerShares:       big.NewRat(2_500, 1),
 			quantumsToDeposit: big.NewInt(1_000),
 			// Should mint `1_250 / 1` shares.
 			expectedTotalShares: big.NewRat(6_250, 1),
+			expectedOwnerShares: big.NewRat(3_750, 1),
 		},
-		"Equity 1_000_000, TotalShares 1, Deposit 1": {
+		"Equity 1_000_000, TotalShares 1, OwnerShares 1/2, Deposit 1": {
 			vaultId:           constants.Vault_Clob_1,
 			equity:            big.NewInt(1_000_000),
 			totalShares:       big.NewRat(1, 1),
+			owner:             constants.BobAccAddress.String(),
+			ownerShares:       big.NewRat(1, 2),
 			quantumsToDeposit: big.NewInt(1),
 			// Should mint `1 / 1_000_000` shares.
 			expectedTotalShares: big.NewRat(1_000_001, 1_000_000),
+			expectedOwnerShares: big.NewRat(500_001, 1_000_000),
 		},
-		"Equity 8000, TotalShares 4000, Deposit 455": {
+		"Equity 8000, TotalShares 4000, OwnerShares  Deposit 455": {
 			vaultId:           constants.Vault_Clob_1,
 			equity:            big.NewInt(8_000),
 			totalShares:       big.NewRat(4_000, 1),
+			owner:             constants.CarlAccAddress.String(),
+			ownerShares:       big.NewRat(101, 7),
 			quantumsToDeposit: big.NewInt(455),
 			// Should mint `455 / 2` shares.
 			expectedTotalShares: big.NewRat(8_455, 2),
+			expectedOwnerShares: big.NewRat(3_387, 14),
 		},
-		"Equity 123456, TotalShares 654321, Deposit 123456789": {
+		"Equity 123456, TotalShares 654321, OwnerShares 0, Deposit 123456789": {
 			vaultId:           constants.Vault_Clob_1,
 			equity:            big.NewInt(123_456),
 			totalShares:       big.NewRat(654_321, 1),
+			owner:             constants.DaveAccAddress.String(),
 			quantumsToDeposit: big.NewInt(123_456_789),
 			// Should mint `26_926_789_878_423 / 41_152` shares.
 			expectedTotalShares: big.NewRat(26_953_716_496_215, 41_152),
+			expectedOwnerShares: big.NewRat(26_926_789_878_423, 41_152),
 		},
 		"Equity -1, TotalShares 10, Deposit 1": {
 			vaultId:           constants.Vault_Clob_1,
 			equity:            big.NewInt(-1),
 			totalShares:       big.NewRat(10, 1),
+			owner:             constants.AliceAccAddress.String(),
 			quantumsToDeposit: big.NewInt(1),
 			expectedErr:       vaulttypes.ErrNonPositiveEquity,
 		},
@@ -163,12 +242,14 @@ func TestMintShares(t *testing.T) {
 			vaultId:           constants.Vault_Clob_1,
 			equity:            big.NewInt(1),
 			totalShares:       big.NewRat(1, 1),
+			owner:             constants.AliceAccAddress.String(),
 			quantumsToDeposit: big.NewInt(0),
 			expectedErr:       vaulttypes.ErrInvalidDepositAmount,
 		},
 		"Equity 0, TotalShares non-existent, Deposit -1": {
 			vaultId:           constants.Vault_Clob_1,
 			equity:            big.NewInt(0),
+			owner:             constants.AliceAccAddress.String(),
 			quantumsToDeposit: big.NewInt(-1),
 			expectedErr:       vaulttypes.ErrInvalidDepositAmount,
 		},
@@ -209,12 +290,21 @@ func TestMintShares(t *testing.T) {
 				)
 				require.NoError(t, err)
 			}
+			if tc.ownerShares != nil {
+				err := tApp.App.VaultKeeper.SetOwnerShares(
+					ctx,
+					tc.vaultId,
+					tc.owner,
+					vaulttypes.BigRatToNumShares(tc.ownerShares),
+				)
+				require.NoError(t, err)
+			}
 
 			// Mint shares.
 			err := tApp.App.VaultKeeper.MintShares(
 				ctx,
 				tc.vaultId,
-				"", // TODO (TRA-170): Increase owner shares.
+				tc.owner,
 				tc.quantumsToDeposit,
 			)
 			if tc.expectedErr != nil {
@@ -227,6 +317,9 @@ func TestMintShares(t *testing.T) {
 					vaulttypes.BigRatToNumShares(tc.totalShares),
 					totalShares,
 				)
+				// Check that OwnerShares is unchanged.
+				ownerShares, _ := tApp.App.VaultKeeper.GetOwnerShares(ctx, tc.vaultId, tc.owner)
+				require.Equal(t, vaulttypes.BigRatToNumShares(tc.ownerShares), ownerShares)
 			} else {
 				require.NoError(t, err)
 				// Check that TotalShares is as expected.
@@ -237,6 +330,14 @@ func TestMintShares(t *testing.T) {
 					vaulttypes.BigRatToNumShares(tc.expectedTotalShares),
 					totalShares,
 				)
+				// Check that OwnerShares is as expected.
+				ownerShares, exists := tApp.App.VaultKeeper.GetOwnerShares(ctx, tc.vaultId, tc.owner)
+				require.True(t, exists)
+				require.Equal(
+					t,
+					vaulttypes.BigRatToNumShares(tc.expectedOwnerShares),
+					ownerShares,
+				)
 			}
 		})
 	}
diff --git a/protocol/x/vault/types/keys.go b/protocol/x/vault/types/keys.go
index 4e37abb4d9..1130a77fa0 100644
--- a/protocol/x/vault/types/keys.go
+++ b/protocol/x/vault/types/keys.go
@@ -14,6 +14,10 @@ const (
 	// TotalSharesKeyPrefix is the prefix to retrieve all TotalShares.
 	TotalSharesKeyPrefix = "TotalShares:"
 
+	// OwnerSharesKeyPrefix is the prefix to retrieve all OwnerShares.
+	// OwnerShares store: vaultId VaultId -> owner string -> shares NumShares.
+	OwnerSharesKeyPrefix = "OwnerShares:"
+
 	// ParamsKey is the key to retrieve Params.
 	ParamsKey = "Params"
 )
diff --git a/protocol/x/vault/types/vault_id.go b/protocol/x/vault/types/vault_id.go
index 47b0c7a625..af070d0eae 100644
--- a/protocol/x/vault/types/vault_id.go
+++ b/protocol/x/vault/types/vault_id.go
@@ -20,6 +20,11 @@ func (id *VaultId) ToStateKey() []byte {
 	return []byte(id.ToString())
 }
 
+// ToStateKeyPrefix returns the state key prefix for the vault ID.
+func (id *VaultId) ToStateKeyPrefix() []byte {
+	return []byte(fmt.Sprintf("%s:", id.ToString()))
+}
+
 // GetVaultIdFromStateKey returns a vault ID from a given state key.
 func GetVaultIdFromStateKey(stateKey []byte) (*VaultId, error) {
 	stateKeyStr := string(stateKey)