From 710f33923cbdc6fd3fa320e6b817d3c06a6efa11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Nowosielski?= Date: Mon, 21 Oct 2024 10:17:52 +0200 Subject: [PATCH] Add tests --- rpc/events.go | 114 +++++++++++++++++++++++++++++++++++++++++++-- rpc/events_test.go | 89 +++++++++++++++++++++++++++++++++++ rpc/handlers.go | 7 +++ rpc/transaction.go | 10 ++-- 4 files changed, 211 insertions(+), 9 deletions(-) diff --git a/rpc/events.go b/rpc/events.go index 87963551c5..7fc5dabd61 100644 --- a/rpc/events.go +++ b/rpc/events.go @@ -3,7 +3,6 @@ package rpc import ( "context" "encoding/json" - "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" @@ -23,6 +22,10 @@ type EventsArg struct { ResultPageRequest } +type SubscriptionID struct { + ID uint64 `json:"subscription_id"` +} + type EventFilter struct { FromBlock *BlockID `json:"from_block"` ToBlock *BlockID `json:"to_block"` @@ -53,10 +56,6 @@ type EventsChunk struct { ContinuationToken string `json:"continuation_token,omitempty"` } -type SubscriptionID struct { - ID uint64 `json:"subscription_id"` -} - /**************************************************** Events Handlers *****************************************************/ @@ -372,6 +371,84 @@ func (h *Handler) sendReorg(w jsonrpc.Conn, reorg *sync.ReorgData, id uint64) er return err } +func (h *Handler) SubscribeTxnStatus(ctx context.Context, txHash felt.Felt, _ *BlockID) (*SubscriptionID, *jsonrpc.Error) { + var ( + lastKnownStatus, lastSendStatus *TransactionStatus + wrapResult = func(s *TransactionStatus) *NewTransactionStatus { + return &NewTransactionStatus{ + TransactionHash: &txHash, + Status: s, + } + } + ) + + w, ok := jsonrpc.ConnFromContext(ctx) + if !ok { + return nil, jsonrpc.Err(jsonrpc.MethodNotFound, nil) + } + + id := h.idgen() + subscriptionCtx, subscriptionCtxCancel := context.WithCancel(ctx) + sub := &subscription{ + cancel: subscriptionCtxCancel, + conn: w, + } + + lastKnownStatus, rpcErr := h.TransactionStatus(subscriptionCtx, txHash) + if rpcErr != nil { + h.log.Warnw("Failed to get Tx status", "txHash", &txHash, "rpcErr", rpcErr) + return nil, rpcErr + } + + h.mu.Lock() + h.subscriptions[id] = sub + h.mu.Unlock() + + statusSub := h.txnStatus.Subscribe() + headerSub := h.newHeads.Subscribe() + sub.wg.Go(func() { + defer func() { + h.unsubscribe(sub, id) + statusSub.Unsubscribe() + headerSub.Unsubscribe() + }() + + if err := h.sendTxnStatus(sub.conn, wrapResult(lastKnownStatus), id); err != nil { + h.log.Warnw("Error while sending Txn status", "txHash", txHash) + return + } + lastSendStatus = lastKnownStatus + + for { + select { + case <-subscriptionCtx.Done(): + return + case <-headerSub.Recv(): + lastKnownStatus, rpcErr = h.TransactionStatus(subscriptionCtx, txHash) + if rpcErr != nil { + h.log.Warnw("Failed to get Tx status", "txHash", txHash, "rpcErr", rpcErr) + return + } + + if *lastKnownStatus != *lastSendStatus { + if err := h.sendTxnStatus(sub.conn, wrapResult(lastKnownStatus), id); err != nil { + h.log.Warnw("Error while sending Txn status", "txHash", txHash) + return + } + lastSendStatus = lastKnownStatus + } + + // Stop when final status reached and notified + if isFinal(lastSendStatus) { + return + } + } + } + }) + + return &SubscriptionID{ID: id}, nil +} + func (h *Handler) Unsubscribe(ctx context.Context, id uint64) (bool, *jsonrpc.Error) { w, ok := jsonrpc.ConnFromContext(ctx) if !ok { @@ -489,3 +566,30 @@ func setEventFilterRange(filter *blockchain.EventFilter, fromID, toID *BlockID, } return set(blockchain.EventFilterTo, toID) } + +type NewTransactionStatus struct { + TransactionHash *felt.Felt `json:"transaction_hash"` + Status *TransactionStatus `json:"status"` +} + +// sendHeader creates a request and sends it to the client +func (h *Handler) sendTxnStatus(w jsonrpc.Conn, status *NewTransactionStatus, id uint64) error { + resp, err := json.Marshal(jsonrpc.Request{ + Version: "2.0", + Method: "starknet_subscriptionTransactionsStatus", + Params: map[string]any{ + "subscription_id": id, + "result": status, + }, + }) + if err != nil { + return err + } + h.log.Infow("Sending Txn status", "status", string(resp)) + _, err = w.Write(resp) + return err +} + +func isFinal(status *TransactionStatus) bool { + return status.Finality == TxnStatusRejected || status.Finality == TxnStatusAcceptedOnL1 +} diff --git a/rpc/events_test.go b/rpc/events_test.go index 4f942d0cd5..e4eab0571d 100644 --- a/rpc/events_test.go +++ b/rpc/events_test.go @@ -3,10 +3,15 @@ package rpc_test import ( "context" "fmt" + "io" + "net" "net/http/httptest" "testing" "time" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/mocks" + "go.uber.org/mock/gomock" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" @@ -648,3 +653,87 @@ func sendAndReceiveMessage(t *testing.T, ctx context.Context, conn *websocket.Co require.NoError(t, err) return string(response) } + +func TestSubscribeTxStatusAndUnsubscribe(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockReader := mocks.NewMockReader(mockCtrl) + + syncer := newFakeSyncer() + log, _ := utils.NewZapLogger(utils.INFO, false) + handler := rpc.New(mockReader, syncer, nil, "", log) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + go func() { + require.NoError(t, handler.Run(ctx)) + }() + // Technically, there's a race between goroutine above and the SubscribeNewHeads call down below. + // Sleep for a moment just in case. + time.Sleep(50 * time.Millisecond) + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + txnHash := utils.HexToFelt(t, "0x4c5772d1914fe6ce891b64eb35bf3522aeae1315647314aac58b01137607f3f") + txn := &core.DeployTransaction{TransactionHash: txnHash, Version: (*core.TransactionVersion)(&felt.Zero)} + receipt := &core.TransactionReceipt{ + TransactionHash: txnHash, + Reverted: false, + } + + mockReader.EXPECT().TransactionByHash(txnHash).Return(txn, nil).Times(1) + mockReader.EXPECT().Receipt(txnHash).Return(receipt, nil, uint64(1), nil).Times(1) + mockReader.EXPECT().TransactionByHash(gomock.Any()).Return(nil, db.ErrKeyNotFound).AnyTimes() + + // Subscribe without setting the connection on the context. + id, rpcErr := handler.SubscribeTxnStatus(ctx, felt.Zero, nil) + require.Nil(t, id) + require.Equal(t, jsonrpc.MethodNotFound, rpcErr.Code) + + // Subscribe correctly but for the unknown transaction + subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + id, rpcErr = handler.SubscribeTxnStatus(subCtx, felt.Zero, nil) + require.Equal(t, rpc.ErrTxnHashNotFound, rpcErr) + require.Nil(t, id) + + // Subscribe correctly for the known transaction + subCtx = context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + id, rpcErr = handler.SubscribeTxnStatus(subCtx, *txnHash, nil) + require.Nil(t, rpcErr) + + // Receive a block header. + time.Sleep(100 * time.Millisecond) + got := make([]byte, 0, 300) + _, err := clientConn.Read(got) + require.NoError(t, err) + require.Equal(t, "", string(got)) + + // Unsubscribe without setting the connection on the context. + ok, rpcErr := handler.Unsubscribe(ctx, id.ID) + require.Equal(t, jsonrpc.MethodNotFound, rpcErr.Code) + require.False(t, ok) + + // Unsubscribe on correct connection with the incorrect id. + ok, rpcErr = handler.Unsubscribe(subCtx, id.ID+1) + require.Equal(t, rpc.ErrSubscriptionNotFound, rpcErr) + require.False(t, ok) + + // Unsubscribe on incorrect connection with the correct id. + subCtx = context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{}) + ok, rpcErr = handler.Unsubscribe(subCtx, id.ID) + require.Equal(t, rpc.ErrSubscriptionNotFound, rpcErr) + require.False(t, ok) + + // Unsubscribe on correct connection with the correct id. + subCtx = context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + ok, rpcErr = handler.Unsubscribe(subCtx, id.ID) + require.Nil(t, rpcErr) + require.True(t, ok) +} diff --git a/rpc/handlers.go b/rpc/handlers.go index e89540bae0..9bcff14b76 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -87,6 +87,7 @@ type Handler struct { newHeads *feed.Feed[*core.Header] reorgs *feed.Feed[*sync.ReorgData] pendingTxs *feed.Feed[[]core.Transaction] + txnStatus *feed.Feed[*NewTransactionStatus] idgen func() uint64 mu stdsync.Mutex // protects subscriptions. @@ -122,6 +123,7 @@ func New(bcReader blockchain.Reader, syncReader sync.Reader, virtualMachine vm.V newHeads: feed.New[*core.Header](), reorgs: feed.New[*sync.ReorgData](), pendingTxs: feed.New[[]core.Transaction](), + txnStatus: feed.New[*NewTransactionStatus](), subscriptions: make(map[uint64]*subscription), blockTraceCache: lru.NewCache[traceCacheKey, []TracedBlockTransaction](traceCacheSize), @@ -336,6 +338,11 @@ func (h *Handler) Methods() ([]jsonrpc.Method, string) { //nolint: funlen Params: []jsonrpc.Parameter{{Name: "transaction_details", Optional: true}, {Name: "sender_address", Optional: true}}, Handler: h.SubscribePendingTxs, }, + { + Name: "starknet_subscribeTransactionStatus", + Params: []jsonrpc.Parameter{{Name: "transaction_hash"}, {Name: "block"}}, + Handler: h.SubscribeTxnStatus, + }, { Name: "juno_unsubscribe", Params: []jsonrpc.Parameter{{Name: "id"}}, diff --git a/rpc/transaction.go b/rpc/transaction.go index b8269b5342..a376493d5b 100644 --- a/rpc/transaction.go +++ b/rpc/transaction.go @@ -224,8 +224,9 @@ type Transaction struct { } type TransactionStatus struct { - Finality TxnStatus `json:"finality_status"` - Execution TxnExecutionStatus `json:"execution_status,omitempty"` + Finality TxnStatus `json:"finality_status"` + Execution TxnExecutionStatus `json:"execution_status,omitempty"` + FailureReason string `json:"failure_reason,omitempty"` } type MsgFromL1 struct { @@ -600,8 +601,9 @@ func (h *Handler) TransactionStatus(ctx context.Context, hash felt.Felt) (*Trans switch txErr { case nil: return &TransactionStatus{ - Finality: TxnStatus(receipt.FinalityStatus), - Execution: receipt.ExecutionStatus, + Finality: TxnStatus(receipt.FinalityStatus), + Execution: receipt.ExecutionStatus, + FailureReason: receipt.RevertReason, }, nil case ErrTxnHashNotFound: if h.feederClient == nil {