Skip to content

Commit

Permalink
snap server implements handler methods directly
Browse files Browse the repository at this point in the history
  • Loading branch information
pnowosie committed Sep 5, 2024
1 parent 18e92f7 commit 7b04b44
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 260 deletions.
199 changes: 114 additions & 85 deletions sync/snap_server.go → p2p/snap_server.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package sync
package p2p

import (
"context"
"google.golang.org/protobuf/proto"
"math/big"

"github.com/NethermindEth/juno/adapters/core2p2p"
Expand Down Expand Up @@ -43,18 +43,21 @@ type ClassRangeStreamingResult struct {
RangeProof *spec.PatriciaRangeProof
}

// TODO: delete, duplicate of SnapProvider
type SnapServer interface {
GetContractRange(ctx context.Context, request *spec.ContractRangeRequest) (iter.Seq[*ContractRangeStreamingResult], error)
GetStorageRange(ctx context.Context, request *StorageRangeRequest) (iter.Seq[*StorageRangeStreamingResult], error)
GetClassRange(ctx context.Context, request *spec.ClassRangeRequest) (iter.Seq[*ClassRangeStreamingResult], error)
GetClasses(ctx context.Context, classHashes []*felt.Felt) ([]*spec.Class, error)
GetClassRange(request *spec.ClassRangeRequest) (iter.Seq[proto.Message], error)
GetContractRange(request *spec.ContractRangeRequest) (iter.Seq[proto.Message], error)
GetStorageRange(request *spec.ContractStorageRequest) (iter.Seq[proto.Message], error)
GetClasses(request *spec.ClassHashesRequest) (iter.Seq[proto.Message], error)
}

type SnapServerBlockchain interface {
GetStateForStateRoot(stateRoot *felt.Felt) (*core.State, error)
GetClasses(felts []*felt.Felt) ([]core.Class, error)
}

type yieldFunc = func(proto.Message) bool

var _ SnapServerBlockchain = (*blockchain.Blockchain)(nil)

func NewSnapServer(blockchain SnapServerBlockchain) SnapServer {
Expand All @@ -67,7 +70,7 @@ type snapServer struct {
blockchain SnapServerBlockchain
}

func determineMaxNodes(specifiedMaxNodes uint64) uint64 {
func determineMaxNodes(specifiedMaxNodes uint32) uint32 {
const (
defaultMaxNodes = 1024 * 16
maxNodePerRequest = 1024 * 1024 // I just want it to process faster
Expand All @@ -84,8 +87,12 @@ func determineMaxNodes(specifiedMaxNodes uint64) uint64 {
return maxNodePerRequest
}

func (b *snapServer) GetClassRange(ctx context.Context, request *spec.ClassRangeRequest) (iter.Seq[*ClassRangeStreamingResult], error) {
return func(yield func(*ClassRangeStreamingResult) bool) {
func (b *snapServer) GetClassRange(request *spec.ClassRangeRequest) (iter.Seq[proto.Message], error) {
var finMsg proto.Message = &spec.ClassRangeResponse{
Responses: &spec.ClassRangeResponse_Fin{},
}

return func(yield yieldFunc) {
stateRoot := p2p2core.AdaptHash(request.Root)

s, err := b.blockchain.GetStateForStateRoot(stateRoot)
Expand Down Expand Up @@ -119,7 +126,7 @@ func (b *snapServer) GetClassRange(ctx context.Context, request *spec.ClassRange
}

classkeys := []*felt.Felt{}
proofs, finished, err := iterateWithLimit(ctrie, startAddr, limitAddr, determineMaxNodes(uint64(request.ChunksPerProof)),
proofs, finished, err := iterateWithLimit(ctrie, startAddr, limitAddr, determineMaxNodes(request.ChunksPerProof),
func(key, value *felt.Felt) error {
classkeys = append(classkeys, key)
return nil
Expand All @@ -143,26 +150,32 @@ func (b *snapServer) GetClassRange(ctx context.Context, request *spec.ClassRange
response.Classes = append(response.Classes, core2p2p.AdaptClass(coreclass))
}

shouldContinue := yield(&ClassRangeStreamingResult{
ContractsRoot: contractRoot,
ClassesRoot: classRoot,
Range: response,
RangeProof: Core2P2pProof(proofs),
})
clsMsg := &spec.ClassRangeResponse{
ContractsRoot: core2p2p.AdaptHash(contractRoot),
ClassesRoot: core2p2p.AdaptHash(classRoot),
Responses: &spec.ClassRangeResponse_Classes{
Classes: response,
},
RangeProof: Core2P2pProof(proofs),
}

shouldContinue := yield(clsMsg)
if finished || !shouldContinue {
break
}
startAddr = classkeys[len(classkeys)-1]
}

yield(finMsg)
}, nil
}

func (b *snapServer) GetContractRange(
ctx context.Context,
request *spec.ContractRangeRequest,
) (iter.Seq[*ContractRangeStreamingResult], error) {
return func(yield func(*ContractRangeStreamingResult) bool) {
func (b *snapServer) GetContractRange(request *spec.ContractRangeRequest) (iter.Seq[proto.Message], error) {
var finMsg proto.Message = &spec.ContractRangeResponse{
Responses: &spec.ContractRangeResponse_Fin{},
}

return func(yield yieldFunc) {
stateRoot := p2p2core.AdaptHash(request.StateRoot)

s, err := b.blockchain.GetStateForStateRoot(stateRoot)
Expand All @@ -189,7 +202,7 @@ func (b *snapServer) GetContractRange(
states := []*spec.ContractState{}

for {
proofs, finished, err := iterateWithLimit(strie, startAddr, limitAddr, determineMaxNodes(uint64(request.ChunksPerProof)),
proofs, finished, err := iterateWithLimit(strie, startAddr, limitAddr, determineMaxNodes(request.ChunksPerProof),
func(key, value *felt.Felt) error {
classHash, err := s.ContractClassHash(key)
if err != nil {
Expand Down Expand Up @@ -225,61 +238,71 @@ func (b *snapServer) GetContractRange(
return
}

shouldContinue := yield(&ContractRangeStreamingResult{
ContractsRoot: contractRoot,
ClassesRoot: classRoot,
Range: states,
cntrMsg := &spec.ContractRangeResponse{
Root: request.StateRoot,
ContractsRoot: core2p2p.AdaptHash(contractRoot),
ClassesRoot: core2p2p.AdaptHash(classRoot),
RangeProof: Core2P2pProof(proofs),
})
Responses: &spec.ContractRangeResponse_Range{
Range: &spec.ContractRange{
State: states,
},
},
}

shouldContinue := yield(cntrMsg)
if finished || !shouldContinue {
break
}
}

yield(finMsg)
}, nil
}

func (b *snapServer) GetStorageRange(ctx context.Context, request *StorageRangeRequest) (iter.Seq[*StorageRangeStreamingResult], error) {
return func(yield func(*StorageRangeStreamingResult) bool) {
stateRoot := request.StateRoot
func (b *snapServer) GetStorageRange(request *spec.ContractStorageRequest) (iter.Seq[proto.Message], error) {
var finMsg proto.Message = &spec.ContractStorageResponse{
Responses: &spec.ContractStorageResponse_Fin{},
}

return func(yield yieldFunc) {
stateRoot := p2p2core.AdaptHash(request.StateRoot)

s, err := b.blockchain.GetStateForStateRoot(stateRoot)
if err != nil {
log.Error("error getting state for state root", "err", err)
return
}

contractRoot, classRoot, err := s.StateAndClassRoot()
if err != nil {
log.Error("error getting state and class root", "err", err)
return
}
var curNodeLimit uint32 = 1000000

var curNodeLimit int64 = 1000000

for _, query := range request.Queries {
if ctxerr := ctx.Err(); ctxerr != nil {
break
}

contractLimit := uint64(curNodeLimit)
for _, query := range request.Query {
contractLimit := curNodeLimit

strie, err := s.StorageTrieForAddr(p2p2core.AdaptAddress(query.Address))
if err != nil {
log.Error("error getting storage trie for address", "addr", query.Address.String(), "err", err)
return
}

handled, err := b.handleStorageRangeRequest(ctx, strie, query, request.ChunkPerProof, contractLimit,
func(values []*spec.ContractStoredValue, proofs []trie.ProofNode) {
yield(&StorageRangeStreamingResult{
ContractsRoot: contractRoot,
ClassesRoot: classRoot,
StorageAddr: p2p2core.AdaptAddress(query.Address),
Range: values,
RangeProof: Core2P2pProof(proofs),
})
handled, err := b.handleStorageRangeRequest(strie, query, request.ChunksPerProof, contractLimit,
func(values []*spec.ContractStoredValue, proofs []trie.ProofNode) bool {
stoMsg := &spec.ContractStorageResponse{
StateRoot: request.StateRoot,
ContractAddress: query.Address,
RangeProof: Core2P2pProof(proofs),
Responses: &spec.ContractStorageResponse_Storage{
Storage: &spec.ContractStorage{
KeyValue: values,
},
},
}
if !yield(stoMsg) {
return false
}
return true
})

if err != nil {
log.Error("error handling storage range request", "err", err)
return
Expand All @@ -291,42 +314,50 @@ func (b *snapServer) GetStorageRange(ctx context.Context, request *StorageRangeR
break
}
}

yield(finMsg)
}, nil
}

// GetStorageRangeStd TODO: move/change 👆 - just to check it can work on spec structs
func (b *snapServer) GetStorageRangeStd(ctx context.Context, request *spec.ContractStorageRequest) (iter.Seq[*StorageRangeStreamingResult], error) {
req := &StorageRangeRequest{
StateRoot: p2p2core.AdaptHash(request.StateRoot),
Queries: request.Query,
ChunkPerProof: uint64(request.ChunksPerProof),
func (b *snapServer) GetClasses(request *spec.ClassHashesRequest) (iter.Seq[proto.Message], error) {
var finMsg proto.Message = &spec.ClassesResponse{
ClassMessage: &spec.ClassesResponse_Fin{},
}
return b.GetStorageRange(ctx, req)
}

func (b *snapServer) GetClasses(ctx context.Context, felts []*felt.Felt) ([]*spec.Class, error) {
classes := make([]*spec.Class, len(felts))
coreClasses, err := b.blockchain.GetClasses(felts)
if err != nil {
return nil, err
}
return func(yield yieldFunc) {
felts := make([]*felt.Felt, len(request.ClassHashes))
for _, hash := range request.ClassHashes {
felts = append(felts, p2p2core.AdaptHash(hash))
}

for i, class := range coreClasses {
classes[i] = core2p2p.AdaptClass(class)
}
coreClasses, err := b.blockchain.GetClasses(felts)
if err != nil {
log.Error("error getting classes", "err", err)
return
}

for _, cls := range coreClasses {
clsMsg := &spec.ClassesResponse{
ClassMessage: &spec.ClassesResponse_Class{
Class: core2p2p.AdaptClass(cls),
},
}
if !yield(clsMsg) {
break
}
}

return classes, nil
yield(finMsg)
}, nil
}

func (b *snapServer) handleStorageRangeRequest(
ctx context.Context,
stTrie *trie.Trie,
request *spec.StorageRangeQuery,
maxChunkPerProof uint64,
nodeLimit uint64,
yield func([]*spec.ContractStoredValue, []trie.ProofNode),
) (int64, error) {
totalSent := int64(0)
maxChunkPerProof, nodeLimit uint32,
yield func([]*spec.ContractStoredValue, []trie.ProofNode) bool,
) (uint32, error) {
totalSent := 0
finished := false
startAddr := p2p2core.AdaptFelt(request.Start.Key)
var endAddr *felt.Felt = nil
Expand All @@ -335,10 +366,6 @@ func (b *snapServer) handleStorageRangeRequest(
}

for !finished {
if ctxErr := ctx.Err(); ctxErr != nil {
return totalSent, ctxErr
}

response := []*spec.ContractStoredValue{}

limit := maxChunkPerProof
Expand All @@ -365,31 +392,33 @@ func (b *snapServer) handleStorageRangeRequest(
finished = true
}

yield(response, proofs)
if !yield(response, proofs) {
finished = true
}

totalSent += int64(len(response))
totalSent += len(response)
nodeLimit -= limit

asBint := startAddr.BigInt(big.NewInt(0))
asBint = asBint.Add(asBint, big.NewInt(1))
startAddr = startAddr.SetBigInt(asBint)
}

return totalSent, nil
return uint32(totalSent), nil
}

func iterateWithLimit(
srcTrie *trie.Trie,
startAddr *felt.Felt,
limitAddr *felt.Felt,
maxNodes uint64,
maxNodes uint32,
consumer func(key, value *felt.Felt) error,
) ([]trie.ProofNode, bool, error) {
pathes := make([]*felt.Felt, 0)
hashes := make([]*felt.Felt, 0)

// TODO: Verify class trie
count := uint64(0)
count := uint32(0)
proof, finished, err := srcTrie.IterateAndGenerateProof(startAddr, func(key *felt.Felt, value *felt.Felt) (bool, error) {
// Need at least one.
if limitAddr != nil && key.Cmp(limitAddr) > 1 && count > 0 {
Expand Down
Loading

0 comments on commit 7b04b44

Please sign in to comment.