Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
p-offtermatt committed Dec 11, 2023
1 parent a052b21 commit 60589a1
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 46 deletions.
2 changes: 1 addition & 1 deletion tests/mbt/driver/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
)

const (
P = "provider"
PROVIDER = "provider"
)

// getIndexOfString returns the index of the first occurrence of the given string
Expand Down
23 changes: 13 additions & 10 deletions tests/mbt/driver/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"log"
"math"
"strings"
"testing"
Expand Down Expand Up @@ -125,7 +126,7 @@ func (s *Driver) consumerPower(i int64, chain ChainId) (int64, error) {
// consumerTokens returns the number of tokens that the validator with
// id (ix) i has delegated to it in total on the provider chain
func (s *Driver) providerPower(i int64) (int64, error) {
v, found := s.providerStakingKeeper().GetValidator(s.ctx(P), s.validator(i))
v, found := s.providerStakingKeeper().GetValidator(s.ctx(PROVIDER), s.validator(i))
if !found {
return 0, fmt.Errorf("validator with id %v not found on provider", i)
} else {
Expand All @@ -134,7 +135,7 @@ func (s *Driver) providerPower(i int64) (int64, error) {
}

func (s *Driver) providerValidatorSet() []stakingtypes.Validator {
return s.providerStakingKeeper().GetAllValidators(s.ctx(P))
return s.providerStakingKeeper().GetAllValidators(s.ctx(PROVIDER))
}

func (s *Driver) consumerValidatorSet(chain ChainId) []consumertypes.CrossChainValidator {
Expand All @@ -149,7 +150,7 @@ func (s *Driver) delegate(val, amt int64) {
d := s.delegator()
v := s.validator(val)
msg := stakingtypes.NewMsgDelegate(d, v, coin)
server.Delegate(sdk.WrapSDKContext(s.ctx(P)), msg)
server.Delegate(sdk.WrapSDKContext(s.ctx(PROVIDER)), msg)
}

// undelegate undelegates amt tokens from validator val
Expand All @@ -160,17 +161,19 @@ func (s *Driver) undelegate(val, amt int64) {
d := s.delegator()
v := s.validator(val)
msg := stakingtypes.NewMsgUndelegate(d, v, coin)
server.Undelegate(sdk.WrapSDKContext(s.ctx(P)), msg)
providerStaking.GetAllDelegations(s.ctx(P))
server.Undelegate(sdk.WrapSDKContext(s.ctx(PROVIDER)), msg)
}

// packetQueue returns the queued packets from sender to receiver,
// where either sender or receiver must be the provider.
func (s *Driver) packetQueue(sender, receiver ChainId) []simibc.Packet {
var path *simibc.RelayedPath
if sender == P {
if sender == PROVIDER {
path = s.path(receiver)
} else {
if receiver != PROVIDER {
log.Fatalf("either receiver '%v' or sender '%v' should be provider '%v', but neither is", sender, receiver, PROVIDER)
}
path = s.path(sender)
}
outboxes := path.Outboxes
Expand Down Expand Up @@ -209,7 +212,7 @@ func (s *Driver) getStateString() string {
}

func (s *Driver) isProviderChain(chain ChainId) bool {
return chain == P
return chain == PROVIDER
}

func (s *Driver) getChainStateString(chain ChainId) string {
Expand Down Expand Up @@ -275,7 +278,7 @@ func (s *Driver) getChainStateString(chain ChainId) string {
}

outboxInfo.WriteString("IncomingPackets: \n")
incoming := s.path(chain).Outboxes.OutboxPackets[P]
incoming := s.path(chain).Outboxes.OutboxPackets[PROVIDER]
for _, packet := range incoming {
outboxInfo.WriteString(fmt.Sprintf("%v\n", packet.Packet.String()))
}
Expand All @@ -289,7 +292,7 @@ func (s *Driver) getChainStateString(chain ChainId) string {
}

outboxInfo.WriteString("IncomingAcks: \n")
incomingAcks := s.path(chain).Outboxes.OutboxAcks[P]
incomingAcks := s.path(chain).Outboxes.OutboxAcks[PROVIDER]
for _, packet := range incomingAcks {
outboxInfo.WriteString(fmt.Sprintf("%v\n", packet.Packet.String()))
}
Expand Down Expand Up @@ -375,7 +378,7 @@ func (s *Driver) DeliverPacketToConsumer(recipient ChainId, expectError bool) {
// It updates the client before delivering the packet.
// Since the channel is ordered, the packet that is delivered is the first packet in the outbox.
func (s *Driver) DeliverPacketFromConsumer(sender ChainId, expectError bool) {
s.path(sender).DeliverPackets(P, 1, expectError) // deliver to the provider
s.path(sender).DeliverPackets(PROVIDER, 1, expectError) // deliver to the provider
}

// DeliverAcks delivers, for each path,
Expand Down
12 changes: 6 additions & 6 deletions tests/mbt/driver/mbt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func RunItfTrace(t *testing.T, path string) {
}
consumerChainId := consumer.ChainId

driver.path(ChainId(consumerChainId)).AddClientHeader(Provider, driver.providerHeader())
driver.path(ChainId(consumerChainId)).AddClientHeader(PROVIDER, driver.providerHeader())
err := driver.path(ChainId(consumerChainId)).UpdateClient(consumerChainId, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChainId, err)
}
Expand All @@ -289,7 +289,7 @@ func RunItfTrace(t *testing.T, path string) {
// update the client on the provider
consumerHeader := driver.chain(ChainId(consumerChain)).LastHeader
driver.path(ChainId(consumerChain)).AddClientHeader(consumerChain, consumerHeader)
err := driver.path(ChainId(consumerChain)).UpdateClient(Provider, false)
err := driver.path(ChainId(consumerChain)).UpdateClient(PROVIDER, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChain, err)

case "DeliverVscPacket":
Expand Down Expand Up @@ -439,8 +439,8 @@ func ComparePacketQueues(
timeOffset time.Time,
) {
t.Helper()
ComparePacketQueue(t, driver, currentModelState, Provider, consumer, timeOffset)
ComparePacketQueue(t, driver, currentModelState, consumer, Provider, timeOffset)
ComparePacketQueue(t, driver, currentModelState, PROVIDER, consumer, timeOffset)
ComparePacketQueue(t, driver, currentModelState, consumer, PROVIDER, timeOffset)
}

func ComparePacketQueue(
Expand Down Expand Up @@ -578,8 +578,8 @@ func (s *Stats) EnterStats(driver *Driver) {
// max number of in-flight packets
inFlightPackets := 0
for _, consumer := range driver.runningConsumers() {
inFlightPackets += len(driver.packetQueue(Provider, ChainId(consumer.ChainId)))
inFlightPackets += len(driver.packetQueue(ChainId(consumer.ChainId), Provider))
inFlightPackets += len(driver.packetQueue(PROVIDER, ChainId(consumer.ChainId)))
inFlightPackets += len(driver.packetQueue(ChainId(consumer.ChainId), PROVIDER))
}
if inFlightPackets > s.maxNumInFlightPackets {
s.maxNumInFlightPackets = inFlightPackets
Expand Down
14 changes: 2 additions & 12 deletions tests/mbt/driver/model_viewer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
// This file contains logic to process
// and access parts of the current state of the Quint trace.

const Provider = "provider"

func ProviderState(curStateExpr itf.MapExprType) itf.MapExprType {
return curStateExpr["providerState"].Value.(itf.MapExprType)
}
Expand All @@ -18,7 +16,7 @@ func ConsumerState(curStateExpr itf.MapExprType, consumer string) itf.MapExprTyp
}

func State(curStateExpr itf.MapExprType, chain string) itf.MapExprType {
if chain == Provider {
if chain == PROVIDER {
return ProviderState(curStateExpr)
} else {
return ConsumerState(curStateExpr, chain)
Expand All @@ -38,18 +36,14 @@ func HistoricalValidatorSet(curStateExpr itf.MapExprType, chain string, index in
return history[index].Value.(itf.MapExprType)
}

func LastTime(curStateExpr itf.MapExprType, chain string) int64 {
return ChainState(curStateExpr, chain)["lastTimestamp"].Value.(int64)
}

func RunningTime(curStateExpr itf.MapExprType, chain string) int64 {
return ChainState(curStateExpr, chain)["runningTimestamp"].Value.(int64)
}

// PacketQueue returns the queued packets between sender and receiver.
// Either sender or receiver need to be the provider.
func PacketQueue(curStateExpr itf.MapExprType, sender, receiver string) itf.ListExprType {
if sender == Provider {
if sender == PROVIDER {
packetQueue := ProviderState(curStateExpr)["outstandingPacketsToConsumer"].Value.(itf.MapExprType)[receiver]
if packetQueue.Value == nil {
return itf.ListExprType{}
Expand Down Expand Up @@ -80,10 +74,6 @@ func ConsumerStatus(curStateExpr itf.MapExprType, consumer string) string {
return ProviderState(curStateExpr)["consumerStatus"].Value.(itf.MapExprType)[consumer].Value.(string)
}

func LocalClientExpired(curStateExpr itf.MapExprType, consumer string) bool {
return ConsumerState(curStateExpr, consumer)["localClientExpired"].Value.(bool)
}

func GetTimeoutForPacket(packetExpr itf.MapExprType) int64 {
return packetExpr["timeoutTime"].Value.(int64)
}
Expand Down
24 changes: 7 additions & 17 deletions tests/mbt/driver/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ const (
INITIAL_ACCOUNT_BALANCE = 1000000000

// Parameters used in the staking module
StakingParamsMaxEntries = 10000
StakingParamsMaxValidators = 100
STAKING_PARAMS_MAX_ENTRIES = 10000
STAKING_PARAMS_MAX_VALS = 100
)

// Parameters used by CometBFT
var (
ConsensusParams = cmttypes.DefaultConsensusParams()
CONSENSUS_PARAMS = cmttypes.DefaultConsensusParams()
)

// Given a map from node names to voting powers, create a validator set with the right voting powers.
Expand Down Expand Up @@ -111,8 +111,6 @@ func getAppBytesAndSenders(
acc := authtypes.NewBaseAccount(pk.PubKey().Address().Bytes(), pk.PubKey(), uint64(i), 0)

// Give enough funds for many delegations
// Extra units are to delegate to extra validators created later
// in order to bond them and still have INITIAL_DELEGATOR_TOKENS remaining
bal := banktypes.Balance{
Address: acc.GetAddress().String(),
Coins: sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom,
Expand Down Expand Up @@ -211,8 +209,8 @@ func getAppBytesAndSenders(
}

// Set model parameters
genesisStaking.Params.MaxEntries = StakingParamsMaxEntries
genesisStaking.Params.MaxValidators = StakingParamsMaxValidators
genesisStaking.Params.MaxEntries = STAKING_PARAMS_MAX_ENTRIES
genesisStaking.Params.MaxValidators = STAKING_PARAMS_MAX_VALS
genesisStaking.Params.UnbondingTime = modelParams.UnbondingPeriodPerChain[ChainId(chainID)]
genesisStaking = *stakingtypes.NewGenesisState(genesisStaking.Params, stakingValidators, delegations)
genesis[stakingtypes.ModuleName] = app.AppCodec().MustMarshalJSON(&genesisStaking)
Expand Down Expand Up @@ -259,7 +257,7 @@ func newChain(

stateBytes, senderAccounts := getAppBytesAndSenders(chainID, modelParams, app, genesis, validators, nodes, valNames)

protoConsParams := ConsensusParams.ToProto()
protoConsParams := CONSENSUS_PARAMS.ToProto()
app.InitChain(
abcitypes.RequestInitChain{
ChainId: chainID,
Expand Down Expand Up @@ -336,14 +334,6 @@ func (s *Driver) ConfigureNewPath(consumerChain, providerChain *ibctesting.TestC
// Create the Consumer chain ID mapping in the provider state
s.providerKeeper().SetConsumerClientId(providerChain.GetContext(), consumerChain.ChainID, providerEndPoint.ClientID)

// create consumer key assignment
// for _, val := range s.providerValidatorSet(ChainId(providerChain.ChainID)) {
// pubKey, err := val.TmConsPublicKey()
// require.NoError(s.t, err, "Error getting consensus pubkey for validator %v", val)

// err = s.providerKeeper().AssignConsumerKey(providerChain.GetContext(), consumerChain.ChainID, val, pubKey)
// }

// Configure and create the client on the consumer
tmCfg = consumerEndPoint.ClientConfig.(*ibctesting.TendermintConfig)
tmCfg.UnbondingPeriod = params.UnbondingPeriodPerChain[consumerChainId]
Expand Down Expand Up @@ -371,7 +361,7 @@ func (s *Driver) ConfigureNewPath(consumerChain, providerChain *ibctesting.TestC
string(consumerChainId),
consumerGenesisForProvider)

// Client ID is set in InitGenesis and we treat it as a block box. So
// Client ID is set in InitGenesis and we treat it as a black box. So
// must query it to use it with the endpoint.
clientID, _ := s.consumerKeeper(consumerChainId).GetProviderClientID(s.ctx(consumerChainId))
consumerEndPoint.ClientID = clientID
Expand Down

0 comments on commit 60589a1

Please sign in to comment.