Skip to content

Commit

Permalink
change msg type to sdk.msg
Browse files Browse the repository at this point in the history
  • Loading branch information
satawatnack committed May 27, 2024
1 parent 7a4d5c9 commit b4248d1
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 38 deletions.
3 changes: 1 addition & 2 deletions client/client.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package client

import (
oracletypes "github.com/bandprotocol/chain/v2/x/oracle/types"
ctypes "github.com/cometbft/cometbft/rpc/core/types"
"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
Expand All @@ -16,6 +15,6 @@ type Client interface {
GetBlockResult(height int64) (*ctypes.ResultBlockResults, error)
QueryRequestFailureReason(id uint64) (string, error)
GetBalance(account sdk.AccAddress) (uint64, error)
SendRequest(msg *oracletypes.MsgRequestData, gasPrice float64, key keyring.Record) (*sdk.TxResponse, error)
SendRequest(msg sdk.Msg, gasPrice float64, key keyring.Record) (*sdk.TxResponse, error)
GetRequestProofByID(reqID uint64) ([]byte, error)
}
15 changes: 7 additions & 8 deletions client/mock/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion client/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ func (c RPC) GetBalance(_ sdk.AccAddress) (uint64, error) {
return 0, nil
}

func (c RPC) SendRequest(msg *oracletypes.MsgRequestData, gasPrice float64, key keyring.Record) (*sdk.TxResponse, error) {
func (c RPC) SendRequest(msg sdk.Msg, gasPrice float64, key keyring.Record) (*sdk.TxResponse, error) {
// Get account to get nonce of sender first
addr, err := key.GetAddress()
if err != nil {
Expand Down
25 changes: 15 additions & 10 deletions examples/requester/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ import (

type ChainConfig struct {
ChainID string `yaml:"chain_id" mapstructure:"chain_id"`
RPC string `yaml:"rpc" mapstructure:"rpc"`
Fee string `yaml:"fee" mapstructure:"fee"`
Timeout time.Duration `yaml:"timeout" mapstructure:"timeout"`
RPC string `yaml:"rpc" mapstructure:"rpc"`
Fee string `yaml:"fee" mapstructure:"fee"`
Timeout time.Duration `yaml:"timeout" mapstructure:"timeout"`
}

type RequestConfig struct {
OracleScriptID int `yaml:"oracle_script_id" mapstructure:"oracle_script_id"`
Calldata string `yaml:"calldata" mapstructure:"calldata"`
Mnemonic string `yaml:"mnemonic" mapstructure:"mnemonic"`
Calldata string `yaml:"calldata" mapstructure:"calldata"`
Mnemonic string `yaml:"mnemonic" mapstructure:"mnemonic"`
}

type Config struct {
Chain ChainConfig `yaml:"chain" mapstructure:"chain"`
Request RequestConfig `yaml:"request" mapstructure:"request"`
Chain ChainConfig `yaml:"chain" mapstructure:"chain"`
Request RequestConfig `yaml:"request" mapstructure:"request"`
LogLevel string `yaml:"log_level" mapstructure:"log_level"`
SDK *sdk.Config
}
Expand Down Expand Up @@ -122,7 +122,12 @@ func requestOracleData(
adjustExecuteGasHandler := gas.NewInsufficientExecuteGasHandler(1.3, log)

retrySenderMw := middleware.New(
s.FailedRequestsCh(), senderCh, parser.IntoSenderTaskHandler, retryHandler, delayHandler, adjustPrepareGasHandler,
s.FailedRequestsCh(),
senderCh,
parser.IntoSenderTaskHandler,
retryHandler,
delayHandler,
adjustPrepareGasHandler,
)
retryRequestMw := middleware.New(
w.FailedRequestsCh(),
Expand Down Expand Up @@ -150,7 +155,7 @@ func requestOracleData(
go senderToRequestMw.Start()
go resolveMw.Start()

senderCh <- sender.NewTask(1, msg)
senderCh <- sender.NewTask(1, &msg)

select {
case <-time.After(100 * time.Second):
Expand Down Expand Up @@ -192,7 +197,7 @@ func getSigningResult(

func main() {
// Setup
config_file := GetEnv("CONFIG_FILE", "example_band_laozi.yaml")
config_file := GetEnv("CONFIG_FILE", "example_local.yaml")
config, err := GetConfig(config_file)
if err != nil {
panic(err)
Expand Down
20 changes: 16 additions & 4 deletions requester/middleware/handlers/gas/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gas
import (
"errors"

oracletypes "github.com/bandprotocol/chain/v2/x/oracle/types"
"github.com/bandprotocol/go-band-sdk/requester/middleware"
"github.com/bandprotocol/go-band-sdk/requester/sender"
"github.com/bandprotocol/go-band-sdk/requester/types"
Expand Down Expand Up @@ -31,8 +32,14 @@ func (m *InsufficientPrepareGasHandler) Handle(
req, err := next(ctx)

if errors.Is(&ctx.Err, &types.ErrOutOfPrepareGas) {
req.Msg.PrepareGas = uint64(float64(req.Msg.PrepareGas) * m.gasMultiplier)
m.logger.Info("bump prepare gas", "bumping request %d prepare gas to %d", req.ID(), req.Msg.PrepareGas)
msg, ok := req.Msg.(*oracletypes.MsgRequestData)
if !ok {
return req, errors.New("message type is not MsgRequestData")
}

msg.PrepareGas = uint64(float64(msg.PrepareGas) * m.gasMultiplier)
req.Msg = msg
m.logger.Info("bump prepare gas", "bumping request %d prepare gas to %d", req.ID(), msg.PrepareGas)
}
return req, err
}
Expand All @@ -53,8 +60,13 @@ func (m *InsufficientExecuteGasHandler) Handle(
task, err := next(ctx)

if errors.Is(&ctx.Err, &types.ErrOutOfExecuteGas) {
task.Msg.ExecuteGas = uint64(float64(task.Msg.ExecuteGas) * m.gasMultiplier)
m.logger.Info("bump execute gas", "bumping request %d execute gas to %d", task.ID(), task.Msg.ExecuteGas)
msg, ok := task.Msg.(*oracletypes.MsgRequestData)
if !ok {
return task, errors.New("message type is not MsgRequestData")
}
msg.ExecuteGas = uint64(float64(msg.ExecuteGas) * m.gasMultiplier)
task.Msg = msg
m.logger.Info("bump execute gas", "bumping request %d execute gas to %d", task.ID(), msg.ExecuteGas)
}

return task, err
Expand Down
6 changes: 3 additions & 3 deletions requester/middleware/handlers/retry/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestCounterHandler(t *testing.T) {
factory := setupFactory(3)
handler := retry.NewCounterHandler[types.Task, types.Task](factory)

mockTask := sender.NewTask(uint64(1), oracletypes.MsgRequestData{})
mockTask := sender.NewTask(uint64(1), &oracletypes.MsgRequestData{})

_, err := handler.Handle(
mockTask, func(ctx types.Task) (types.Task, error) {
Expand All @@ -37,7 +37,7 @@ func TestCounterHandlerWithMaxRetry(t *testing.T) {
factory := setupFactory(1)
handler := retry.NewCounterHandler[types.Task, types.Task](factory)

mockTask := sender.NewTask(uint64(1), oracletypes.MsgRequestData{})
mockTask := sender.NewTask(uint64(1), &oracletypes.MsgRequestData{})

parser := func(ctx types.Task) (types.Task, error) {
return ctx, nil
Expand All @@ -55,7 +55,7 @@ func TestResolverHandler(t *testing.T) {
counter := retry.NewCounterHandler[types.Task, types.Task](factory)
resolver := retry.NewResolverHandler[types.Task, types.Task](factory)

mockTask := sender.NewTask(uint64(1), oracletypes.MsgRequestData{})
mockTask := sender.NewTask(uint64(1), &oracletypes.MsgRequestData{})

parser := func(ctx types.Task) (types.Task, error) {
return ctx, nil
Expand Down
14 changes: 12 additions & 2 deletions requester/middleware/parser/parser.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package parser

import (
"fmt"

sdk "github.com/cosmos/cosmos-sdk/types"

oracletypes "github.com/bandprotocol/chain/v2/x/oracle/types"
"github.com/bandprotocol/go-band-sdk/client"
"github.com/bandprotocol/go-band-sdk/requester/sender"
"github.com/bandprotocol/go-band-sdk/requester/watcher/request"
Expand All @@ -13,7 +18,12 @@ func IntoRequestWatcherTaskHandler(ctx sender.SuccessResponse) (request.Task, er
return request.Task{}, err
}

return request.NewTask(ctx.ID(), requestID, ctx.Msg), nil
msg, ok := ctx.Msg.(*oracletypes.MsgRequestData)
if !ok {
return request.Task{}, fmt.Errorf("message type is not MsgRequestData")
}

return request.NewTask(ctx.ID(), requestID, *msg), nil
}

func IntoSenderTaskHandler(ctx sender.FailResponse) (sender.Task, error) {
Expand All @@ -25,5 +35,5 @@ func IntoSigningWatcherTaskHandler(ctx signing.FailResponse) (signing.Task, erro
}

func IntoSenderTaskHandlerFromRequest(ctx request.FailResponse) (sender.Task, error) {
return sender.NewTask(ctx.ID(), ctx.Msg), nil
return sender.NewTask(ctx.ID(), sdk.Msg(&ctx.Msg)), nil
}
5 changes: 2 additions & 3 deletions requester/sender/msgs.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sender

import (
oracletypes "github.com/bandprotocol/chain/v2/x/oracle/types"
sdk "github.com/cosmos/cosmos-sdk/types"

"github.com/bandprotocol/go-band-sdk/requester/types"
Expand All @@ -15,10 +14,10 @@ var (

type Task struct {
id uint64
Msg oracletypes.MsgRequestData
Msg sdk.Msg
}

func NewTask(id uint64, msg oracletypes.MsgRequestData) Task {
func NewTask(id uint64, msg sdk.Msg) Task {
return Task{
id: id,
Msg: msg,
Expand Down
10 changes: 8 additions & 2 deletions requester/sender/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/cosmos/cosmos-sdk/crypto/keyring"
sdk "github.com/cosmos/cosmos-sdk/types"

oracletypes "github.com/bandprotocol/chain/v2/x/oracle/types"
"github.com/bandprotocol/go-band-sdk/client"
"github.com/bandprotocol/go-band-sdk/requester/types"
"github.com/bandprotocol/go-band-sdk/utils/logging"
Expand Down Expand Up @@ -93,10 +94,15 @@ func (s *Sender) request(task Task, key keyring.Record) {
s.logger.Error("Sender", "failed to get address from key: %s", err.Error())
return
}
task.Msg.Sender = addr.String()

// Change the sender with the actual sender if the message is Oracle MsgRequestData
if msg, ok := task.Msg.(*oracletypes.MsgRequestData); ok {
msg.Sender = addr.String()
task.Msg = msg
}

// Attempt to send the request
resp, err := s.client.SendRequest(&task.Msg, s.gasPrice, key)
resp, err := s.client.SendRequest(task.Msg, s.gasPrice, key)
// Handle error
if err != nil {
s.logger.Error("Sender", "failed to broadcast task ID(%d) with error: %s", task.ID(), err.Error())
Expand Down
6 changes: 3 additions & 3 deletions requester/sender/sender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestSenderWithSuccess(t *testing.T) {
mockClient.EXPECT().GetTx("abc").Return(&mockResult, nil).Times(1)

mockLogger := mocklogging.NewLogger()
mockTask := sender.NewTask(1, types.MsgRequestData{})
mockTask := sender.NewTask(1, &types.MsgRequestData{})

// Create channels
requestQueueCh := make(chan sender.Task, 1)
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestSenderWithFailure(t *testing.T) {
mockClient.EXPECT().SendRequest(gomock.Any(), 1.0, gomock.Any()).Return(&mockResult, nil).Times(1)

mockLogger := mocklogging.NewLogger()
mockTask := sender.NewTask(1, types.MsgRequestData{})
mockTask := sender.NewTask(1, &types.MsgRequestData{})

// Create channels
requestQueueCh := make(chan sender.Task, 1)
Expand Down Expand Up @@ -160,7 +160,7 @@ func TestSenderWithClientError(t *testing.T) {
mockClient.EXPECT().SendRequest(gomock.Any(), 1.0, gomock.Any()).Return(nil, fmt.Errorf("error")).Times(1)

mockLogger := mocklogging.NewLogger()
mockTask := sender.NewTask(1, types.MsgRequestData{})
mockTask := sender.NewTask(1, &types.MsgRequestData{})

// Create channels
requestQueueCh := make(chan sender.Task, 1)
Expand Down

0 comments on commit b4248d1

Please sign in to comment.