Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pnowosie committed Oct 22, 2024
1 parent 66c0c1f commit 710f339
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 9 deletions.
114 changes: 109 additions & 5 deletions rpc/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package rpc
import (
"context"
"encoding/json"

"github.com/NethermindEth/juno/blockchain"
"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/core/felt"
Expand All @@ -23,6 +22,10 @@ type EventsArg struct {
ResultPageRequest
}

type SubscriptionID struct {
ID uint64 `json:"subscription_id"`
}

type EventFilter struct {
FromBlock *BlockID `json:"from_block"`
ToBlock *BlockID `json:"to_block"`
Expand Down Expand Up @@ -53,10 +56,6 @@ type EventsChunk struct {
ContinuationToken string `json:"continuation_token,omitempty"`
}

type SubscriptionID struct {
ID uint64 `json:"subscription_id"`
}

/****************************************************
Events Handlers
*****************************************************/
Expand Down Expand Up @@ -372,6 +371,84 @@ func (h *Handler) sendReorg(w jsonrpc.Conn, reorg *sync.ReorgData, id uint64) er
return err
}

func (h *Handler) SubscribeTxnStatus(ctx context.Context, txHash felt.Felt, _ *BlockID) (*SubscriptionID, *jsonrpc.Error) {
var (
lastKnownStatus, lastSendStatus *TransactionStatus
wrapResult = func(s *TransactionStatus) *NewTransactionStatus {
return &NewTransactionStatus{
TransactionHash: &txHash,
Status: s,
}
}
)

w, ok := jsonrpc.ConnFromContext(ctx)
if !ok {
return nil, jsonrpc.Err(jsonrpc.MethodNotFound, nil)
}

id := h.idgen()
subscriptionCtx, subscriptionCtxCancel := context.WithCancel(ctx)
sub := &subscription{
cancel: subscriptionCtxCancel,
conn: w,
}

lastKnownStatus, rpcErr := h.TransactionStatus(subscriptionCtx, txHash)
if rpcErr != nil {
h.log.Warnw("Failed to get Tx status", "txHash", &txHash, "rpcErr", rpcErr)
return nil, rpcErr
}

h.mu.Lock()
h.subscriptions[id] = sub
h.mu.Unlock()

statusSub := h.txnStatus.Subscribe()
headerSub := h.newHeads.Subscribe()
sub.wg.Go(func() {
defer func() {
h.unsubscribe(sub, id)
statusSub.Unsubscribe()
headerSub.Unsubscribe()
}()

if err := h.sendTxnStatus(sub.conn, wrapResult(lastKnownStatus), id); err != nil {
h.log.Warnw("Error while sending Txn status", "txHash", txHash)
return
}
lastSendStatus = lastKnownStatus

for {
select {
case <-subscriptionCtx.Done():
return
case <-headerSub.Recv():
lastKnownStatus, rpcErr = h.TransactionStatus(subscriptionCtx, txHash)
if rpcErr != nil {
h.log.Warnw("Failed to get Tx status", "txHash", txHash, "rpcErr", rpcErr)
return
}

if *lastKnownStatus != *lastSendStatus {
if err := h.sendTxnStatus(sub.conn, wrapResult(lastKnownStatus), id); err != nil {
h.log.Warnw("Error while sending Txn status", "txHash", txHash)
return
}
lastSendStatus = lastKnownStatus
}

// Stop when final status reached and notified
if isFinal(lastSendStatus) {
return
}
}
}
})

return &SubscriptionID{ID: id}, nil
}

func (h *Handler) Unsubscribe(ctx context.Context, id uint64) (bool, *jsonrpc.Error) {
w, ok := jsonrpc.ConnFromContext(ctx)
if !ok {
Expand Down Expand Up @@ -489,3 +566,30 @@ func setEventFilterRange(filter *blockchain.EventFilter, fromID, toID *BlockID,
}
return set(blockchain.EventFilterTo, toID)
}

type NewTransactionStatus struct {
TransactionHash *felt.Felt `json:"transaction_hash"`
Status *TransactionStatus `json:"status"`
}

// sendHeader creates a request and sends it to the client
func (h *Handler) sendTxnStatus(w jsonrpc.Conn, status *NewTransactionStatus, id uint64) error {
resp, err := json.Marshal(jsonrpc.Request{
Version: "2.0",
Method: "starknet_subscriptionTransactionsStatus",
Params: map[string]any{
"subscription_id": id,
"result": status,
},
})
if err != nil {
return err
}
h.log.Infow("Sending Txn status", "status", string(resp))
_, err = w.Write(resp)
return err
}

func isFinal(status *TransactionStatus) bool {
return status.Finality == TxnStatusRejected || status.Finality == TxnStatusAcceptedOnL1
}
89 changes: 89 additions & 0 deletions rpc/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ package rpc_test
import (
"context"
"fmt"
"io"
"net"
"net/http/httptest"
"testing"
"time"

"github.com/NethermindEth/juno/db"
"github.com/NethermindEth/juno/mocks"
"go.uber.org/mock/gomock"
"github.com/NethermindEth/juno/blockchain"
"github.com/NethermindEth/juno/clients/feeder"
"github.com/NethermindEth/juno/core"
Expand Down Expand Up @@ -648,3 +653,87 @@ func sendAndReceiveMessage(t *testing.T, ctx context.Context, conn *websocket.Co
require.NoError(t, err)
return string(response)
}

func TestSubscribeTxStatusAndUnsubscribe(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)

mockReader := mocks.NewMockReader(mockCtrl)

syncer := newFakeSyncer()
log, _ := utils.NewZapLogger(utils.INFO, false)
handler := rpc.New(mockReader, syncer, nil, "", log)

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

go func() {
require.NoError(t, handler.Run(ctx))
}()
// Technically, there's a race between goroutine above and the SubscribeNewHeads call down below.
// Sleep for a moment just in case.
time.Sleep(50 * time.Millisecond)

serverConn, clientConn := net.Pipe()
t.Cleanup(func() {
require.NoError(t, serverConn.Close())
require.NoError(t, clientConn.Close())
})

txnHash := utils.HexToFelt(t, "0x4c5772d1914fe6ce891b64eb35bf3522aeae1315647314aac58b01137607f3f")
txn := &core.DeployTransaction{TransactionHash: txnHash, Version: (*core.TransactionVersion)(&felt.Zero)}
receipt := &core.TransactionReceipt{
TransactionHash: txnHash,
Reverted: false,
}

mockReader.EXPECT().TransactionByHash(txnHash).Return(txn, nil).Times(1)
mockReader.EXPECT().Receipt(txnHash).Return(receipt, nil, uint64(1), nil).Times(1)
mockReader.EXPECT().TransactionByHash(gomock.Any()).Return(nil, db.ErrKeyNotFound).AnyTimes()

// Subscribe without setting the connection on the context.
id, rpcErr := handler.SubscribeTxnStatus(ctx, felt.Zero, nil)
require.Nil(t, id)
require.Equal(t, jsonrpc.MethodNotFound, rpcErr.Code)

// Subscribe correctly but for the unknown transaction
subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn})
id, rpcErr = handler.SubscribeTxnStatus(subCtx, felt.Zero, nil)
require.Equal(t, rpc.ErrTxnHashNotFound, rpcErr)
require.Nil(t, id)

// Subscribe correctly for the known transaction
subCtx = context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn})
id, rpcErr = handler.SubscribeTxnStatus(subCtx, *txnHash, nil)
require.Nil(t, rpcErr)

// Receive a block header.
time.Sleep(100 * time.Millisecond)
got := make([]byte, 0, 300)
_, err := clientConn.Read(got)
require.NoError(t, err)
require.Equal(t, "", string(got))

// Unsubscribe without setting the connection on the context.
ok, rpcErr := handler.Unsubscribe(ctx, id.ID)
require.Equal(t, jsonrpc.MethodNotFound, rpcErr.Code)
require.False(t, ok)

// Unsubscribe on correct connection with the incorrect id.
ok, rpcErr = handler.Unsubscribe(subCtx, id.ID+1)
require.Equal(t, rpc.ErrSubscriptionNotFound, rpcErr)
require.False(t, ok)

// Unsubscribe on incorrect connection with the correct id.
subCtx = context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{})
ok, rpcErr = handler.Unsubscribe(subCtx, id.ID)
require.Equal(t, rpc.ErrSubscriptionNotFound, rpcErr)
require.False(t, ok)

// Unsubscribe on correct connection with the correct id.
subCtx = context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn})
ok, rpcErr = handler.Unsubscribe(subCtx, id.ID)
require.Nil(t, rpcErr)
require.True(t, ok)
}
7 changes: 7 additions & 0 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ type Handler struct {
newHeads *feed.Feed[*core.Header]
reorgs *feed.Feed[*sync.ReorgData]
pendingTxs *feed.Feed[[]core.Transaction]
txnStatus *feed.Feed[*NewTransactionStatus]

idgen func() uint64
mu stdsync.Mutex // protects subscriptions.
Expand Down Expand Up @@ -122,6 +123,7 @@ func New(bcReader blockchain.Reader, syncReader sync.Reader, virtualMachine vm.V
newHeads: feed.New[*core.Header](),
reorgs: feed.New[*sync.ReorgData](),
pendingTxs: feed.New[[]core.Transaction](),
txnStatus: feed.New[*NewTransactionStatus](),
subscriptions: make(map[uint64]*subscription),

blockTraceCache: lru.NewCache[traceCacheKey, []TracedBlockTransaction](traceCacheSize),
Expand Down Expand Up @@ -336,6 +338,11 @@ func (h *Handler) Methods() ([]jsonrpc.Method, string) { //nolint: funlen
Params: []jsonrpc.Parameter{{Name: "transaction_details", Optional: true}, {Name: "sender_address", Optional: true}},
Handler: h.SubscribePendingTxs,
},
{
Name: "starknet_subscribeTransactionStatus",
Params: []jsonrpc.Parameter{{Name: "transaction_hash"}, {Name: "block"}},
Handler: h.SubscribeTxnStatus,
},
{
Name: "juno_unsubscribe",
Params: []jsonrpc.Parameter{{Name: "id"}},
Expand Down
10 changes: 6 additions & 4 deletions rpc/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,9 @@ type Transaction struct {
}

type TransactionStatus struct {
Finality TxnStatus `json:"finality_status"`
Execution TxnExecutionStatus `json:"execution_status,omitempty"`
Finality TxnStatus `json:"finality_status"`
Execution TxnExecutionStatus `json:"execution_status,omitempty"`
FailureReason string `json:"failure_reason,omitempty"`
}

type MsgFromL1 struct {
Expand Down Expand Up @@ -600,8 +601,9 @@ func (h *Handler) TransactionStatus(ctx context.Context, hash felt.Felt) (*Trans
switch txErr {
case nil:
return &TransactionStatus{
Finality: TxnStatus(receipt.FinalityStatus),
Execution: receipt.ExecutionStatus,
Finality: TxnStatus(receipt.FinalityStatus),
Execution: receipt.ExecutionStatus,
FailureReason: receipt.RevertReason,
}, nil
case ErrTxnHashNotFound:
if h.feederClient == nil {
Expand Down

0 comments on commit 710f339

Please sign in to comment.