diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bea6c343..472c0aad2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The following emojis are used to highlight certain changes: ### Added +- `routing/http`: added support for address and protocol filtering to the delegated routing server ([IPIP-484](https://github.com/ipfs/specs/pull/484)) [#671](https://github.com/ipfs/boxo/pull/671) + ### Changed ### Removed diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index 590deed11..732861797 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -228,9 +228,7 @@ func TestClient_FindProviders(t *testing.T) { } bitswapRecord := makeBitswapRecord() - bitswapProviders := []iter.Result[types.Record]{ - {Val: &bitswapRecord}, - } + peerRecordFromBitswapRecord := types.FromBitswapRecord(&bitswapRecord) cases := []struct { name string @@ -254,8 +252,8 @@ func TestClient_FindProviders(t *testing.T) { }, { name: "happy case (with deprecated bitswap schema)", - routerResult: bitswapProviders, - expResult: bitswapProviders, + routerResult: []iter.Result[types.Record]{{Val: &bitswapRecord}}, + expResult: []iter.Result[types.Record]{{Val: peerRecordFromBitswapRecord}}, expStreamingResponse: true, }, { diff --git a/routing/http/server/filters.go b/routing/http/server/filters.go new file mode 100644 index 000000000..bb5dfa0d5 --- /dev/null +++ b/routing/http/server/filters.go @@ -0,0 +1,197 @@ +package server + +import ( + "reflect" + "slices" + "strings" + + "github.com/ipfs/boxo/routing/http/types" + "github.com/ipfs/boxo/routing/http/types/iter" + "github.com/multiformats/go-multiaddr" +) + +// filters implements IPIP-0484 + +func parseFilter(param string) []string { + if param == "" { + return nil + } + return strings.Split(strings.ToLower(param), ",") +} + +// applyFiltersToIter applies the filters to the given iterator and returns a new iterator. +// +// The function iterates over the input iterator, applying the specified filters to each record. +// It supports both positive and negative filters for both addresses and protocols. +// +// Parameters: +// - recordsIter: An iterator of types.Record to be filtered. +// - filterAddrs: A slice of strings representing the address filter criteria. +// - filterProtocols: A slice of strings representing the protocol filter criteria. +func applyFiltersToIter(recordsIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) iter.ResultIter[types.Record] { + mappedIter := iter.Map(recordsIter, func(v iter.Result[types.Record]) iter.Result[types.Record] { + if v.Err != nil || v.Val == nil { + return v + } + + switch v.Val.GetSchema() { + case types.SchemaPeer: + record, ok := v.Val.(*types.PeerRecord) + if !ok { + logger.Errorw("problem casting find providers record", "Schema", v.Val.GetSchema(), "Type", reflect.TypeOf(v).String()) + // drop failed type assertion + return iter.Result[types.Record]{} + } + + record = applyFilters(record, filterAddrs, filterProtocols) + if record == nil { + return iter.Result[types.Record]{} + } + v.Val = record + + //lint:ignore SA1019 // ignore staticcheck + case types.SchemaBitswap: + //lint:ignore SA1019 // ignore staticcheck + record, ok := v.Val.(*types.BitswapRecord) + if !ok { + logger.Errorw("problem casting find providers record", "Schema", v.Val.GetSchema(), "Type", reflect.TypeOf(v).String()) + // drop failed type assertion + return iter.Result[types.Record]{} + } + peerRecord := types.FromBitswapRecord(record) + peerRecord = applyFilters(peerRecord, filterAddrs, filterProtocols) + if peerRecord == nil { + return iter.Result[types.Record]{} + } + v.Val = peerRecord + } + return v + }) + + // filter out nil results and errors + filteredIter := iter.Filter(mappedIter, func(v iter.Result[types.Record]) bool { + return v.Err == nil && v.Val != nil + }) + + return filteredIter +} + +// Applies the filters. Returns nil if the provider does not pass the protocols filter +// The address filter is more complicated because it potentially modifies the Addrs slice. +func applyFilters(provider *types.PeerRecord, filterAddrs, filterProtocols []string) *types.PeerRecord { + if len(filterAddrs) == 0 && len(filterProtocols) == 0 { + return provider + } + + if !protocolsAllowed(provider.Protocols, filterProtocols) { + // If the provider doesn't match any of the passed protocols, the provider is omitted from the response. + return nil + } + + // return untouched if there's no filter or filterAddrsQuery contains "unknown" and provider has no addrs + if len(filterAddrs) == 0 || (len(provider.Addrs) == 0 && slices.Contains(filterAddrs, "unknown")) { + return provider + } + + filteredAddrs := applyAddrFilter(provider.Addrs, filterAddrs) + + // If filtering resulted in no addrs, omit the provider + if len(filteredAddrs) == 0 { + return nil + } + + provider.Addrs = filteredAddrs + return provider +} + +// applyAddrFilter filters a list of multiaddresses based on the provided filter query. +// +// Parameters: +// - addrs: A slice of types.Multiaddr to be filtered. +// - filterAddrsQuery: A slice of strings representing the filter criteria. +// +// The function supports both positive and negative filters: +// - Positive filters (e.g., "tcp", "udp") include addresses that match the specified protocols. +// - Negative filters (e.g., "!tcp", "!udp") exclude addresses that match the specified protocols. +// +// If no filters are provided, the original list of addresses is returned unchanged. +// If only negative filters are provided, addresses not matching any negative filter are included. +// If positive filters are provided, only addresses matching at least one positive filter (and no negative filters) are included. +// If both positive and negative filters are provided, the address must match at least one positive filter and no negative filters to be included. +// +// Returns: +// A new slice of types.Multiaddr containing only the addresses that pass the filter criteria. +func applyAddrFilter(addrs []types.Multiaddr, filterAddrsQuery []string) []types.Multiaddr { + if len(filterAddrsQuery) == 0 { + return addrs + } + + var filteredAddrs []types.Multiaddr + var positiveFilters, negativeFilters []multiaddr.Protocol + + // Separate positive and negative filters + for _, filter := range filterAddrsQuery { + if strings.HasPrefix(filter, "!") { + negativeFilters = append(negativeFilters, multiaddr.ProtocolWithName(filter[1:])) + } else { + positiveFilters = append(positiveFilters, multiaddr.ProtocolWithName(filter)) + } + } + + for _, addr := range addrs { + protocols := addr.Protocols() + + // Check negative filters + if containsAny(protocols, negativeFilters) { + continue + } + + // If no positive filters or matches a positive filter, include the address + if len(positiveFilters) == 0 || containsAny(protocols, positiveFilters) { + filteredAddrs = append(filteredAddrs, addr) + } + } + + return filteredAddrs +} + +// Helper function to check if protocols contain any of the filters +func containsAny(protocols []multiaddr.Protocol, filters []multiaddr.Protocol) bool { + for _, filter := range filters { + if containsProtocol(protocols, filter) { + return true + } + } + return false +} + +func containsProtocol(protos []multiaddr.Protocol, proto multiaddr.Protocol) bool { + for _, p := range protos { + if p.Code == proto.Code { + return true + } + } + return false +} + +// protocolsAllowed returns true if the peerProtocols are allowed by the filter protocols. +func protocolsAllowed(peerProtocols []string, filterProtocols []string) bool { + if len(filterProtocols) == 0 { + // If no filter is passed, do not filter + return true + } + + for _, filterProtocol := range filterProtocols { + if filterProtocol == "unknown" && len(peerProtocols) == 0 { + return true + } + + for _, peerProtocol := range peerProtocols { + if strings.EqualFold(peerProtocol, filterProtocol) { + return true + } + + } + } + return false +} diff --git a/routing/http/server/filters_test.go b/routing/http/server/filters_test.go new file mode 100644 index 000000000..078e4aa96 --- /dev/null +++ b/routing/http/server/filters_test.go @@ -0,0 +1,326 @@ +package server + +import ( + "testing" + + "github.com/ipfs/boxo/routing/http/types" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyAddrFilter(t *testing.T) { + // Create some test multiaddrs + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr2, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/udp/4001/quic/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr3, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001/ws/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr4, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr5, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr6, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/certhash/uEiD9f05PrY82lovP4gOFonmY7sO0E7_jyovt9p2LEcAS-Q/certhash/uEiBtGJsNz-PcywwXOVzEYeQQloQiHMqDqdj18t2Fe4GTLQ/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr7, _ := multiaddr.NewMultiaddr("/dns4/ny5.bootstrap.libp2p.io/tcp/443/wss/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr8, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/udp/4001/quic-v1/webtransport/certhash/uEiAMrMcVWFNiqtSeRXZTwHTac4p9WcGh5hg8kVBzTC1JTA/certhash/uEiA4dfvbbbnBIYalhp1OpW1Bk-nuWIKSy21ol6vPea67Cw/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + + addrs := []types.Multiaddr{ + {Multiaddr: addr1}, + {Multiaddr: addr2}, + {Multiaddr: addr3}, + {Multiaddr: addr4}, + {Multiaddr: addr5}, + {Multiaddr: addr6}, + {Multiaddr: addr7}, + {Multiaddr: addr8}, + } + + testCases := []struct { + name string + filterAddrs []string + expectedAddrs []types.Multiaddr + }{ + { + name: "No filter", + filterAddrs: []string{}, + expectedAddrs: addrs, + }, + { + name: "Filter TCP", + filterAddrs: []string{"tcp"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr1}, {Multiaddr: addr3}, {Multiaddr: addr4}, {Multiaddr: addr7}}, + }, + { + name: "Filter UDP", + filterAddrs: []string{"udp"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr2}, {Multiaddr: addr5}, {Multiaddr: addr6}, {Multiaddr: addr8}}, + }, + { + name: "Filter WebSocket", + filterAddrs: []string{"ws"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr3}}, + }, + { + name: "Exclude TCP", + filterAddrs: []string{"!tcp"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr2}, {Multiaddr: addr5}, {Multiaddr: addr6}, {Multiaddr: addr8}}, + }, + { + name: "Filter TCP addresses that don't have WebSocket and p2p-circuit", + filterAddrs: []string{"tcp", "!ws", "!wss", "!p2p-circuit"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr1}}, + }, + { + name: "Include WebTransport and exclude p2p-circuit", + filterAddrs: []string{"webtransport", "!p2p-circuit"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr8}}, + }, + { + name: "empty for unknown protocol nae", + filterAddrs: []string{"fakeproto"}, + expectedAddrs: []types.Multiaddr{}, + }, + { + name: "Include WebTransport but ignore unknown protocol name", + filterAddrs: []string{"webtransport", "fakeproto"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr6}, {Multiaddr: addr8}}, + }, + { + name: "Multiple filters", + filterAddrs: []string{"tcp", "ws"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr1}, {Multiaddr: addr3}, {Multiaddr: addr4}, {Multiaddr: addr7}}, + }, + { + name: "Multiple negative filters", + filterAddrs: []string{"!tcp", "!ws"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr2}, {Multiaddr: addr5}, {Multiaddr: addr6}, {Multiaddr: addr8}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := applyAddrFilter(addrs, tc.filterAddrs) + assert.Equal(t, len(tc.expectedAddrs), len(result), "Unexpected number of addresses after filtering") + + // Check that each expected address is in the result + for _, expectedAddr := range tc.expectedAddrs { + found := false + for _, resultAddr := range result { + if expectedAddr.Multiaddr.Equal(resultAddr.Multiaddr) { + found = true + break + } + } + assert.True(t, found, "Expected address not found in test %s result: %s", tc.name, expectedAddr.Multiaddr) + } + + // Check that each result address is in the expected list + for _, resultAddr := range result { + found := false + for _, expectedAddr := range tc.expectedAddrs { + if resultAddr.Multiaddr.Equal(expectedAddr.Multiaddr) { + found = true + break + } + } + assert.True(t, found, "Unexpected address found in test %s result: %s", tc.name, resultAddr.Multiaddr) + } + }) + } +} + +func TestProtocolsAllowed(t *testing.T) { + testCases := []struct { + name string + peerProtocols []string + filterProtocols []string + expected bool + }{ + { + name: "No filter", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{}, + expected: true, + }, + { + name: "Single matching protocol", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{"transport-bitswap"}, + expected: true, + }, + { + name: "Single non-matching protocol", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{"transport-graphsync-filecoinv1"}, + expected: false, + }, + { + name: "Multiple protocols, one match", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{"transport-graphsync-filecoinv1", "transport-ipfs-gateway-http"}, + expected: true, + }, + { + name: "Unknown protocol for empty peer protocols", + peerProtocols: []string{}, + filterProtocols: []string{"unknown"}, + expected: true, + }, + { + name: "Unknown protocol for non-empty peer protocols", + peerProtocols: []string{"transport-bitswap"}, + filterProtocols: []string{"unknown"}, + expected: false, + }, + { + name: "Unknown or specific protocol for matching non-empty peer protocols", + peerProtocols: []string{"transport-bitswap"}, + filterProtocols: []string{"unknown", "transport-bitswap", "transport-ipfs-gateway-http"}, + expected: true, + }, + { + name: "Unknown or specific protocol for matching empty peer protocols", + peerProtocols: []string{}, + filterProtocols: []string{"unknown", "transport-bitswap", "transport-ipfs-gateway-http"}, + expected: true, + }, + { + name: "Unknown or specific protocol for not matching non-empty peer protocols", + peerProtocols: []string{"transport-graphsync-filecoinv1"}, + filterProtocols: []string{"unknown", "transport-bitswap", "transport-ipfs-gateway-http"}, + expected: false, + }, + { + name: "Case insensitive match", + peerProtocols: []string{"TRANSPORT-BITSWAP", "Transport-IPFS-Gateway-HTTP"}, + filterProtocols: []string{"transport-bitswap"}, + expected: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := protocolsAllowed(tc.peerProtocols, tc.filterProtocols) + assert.Equal(t, tc.expected, result, "Unexpected result for test case: %s", tc.name) + }) + } +} + +func TestApplyFilters(t *testing.T) { + pid, err := peer.Decode("12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn") + require.NoError(t, err) + + tests := []struct { + name string + provider *types.PeerRecord + filterAddrs []string + filterProtocols []string + expected *types.PeerRecord + }{ + { + name: "No filters", + provider: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + filterAddrs: []string{}, + filterProtocols: []string{}, + expected: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + }, + { + name: "Protocol filter", + provider: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001"), + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/quic-v1"), + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001/ws"), + mustMultiaddr(t, "/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + filterAddrs: []string{}, + filterProtocols: []string{"transport-ipfs-gateway-http", "transport-bitswap"}, + expected: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001"), + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/quic-v1"), + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001/ws"), + mustMultiaddr(t, "/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + }, + { + name: "Address filter", + provider: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001"), + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/quic-v1"), + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001/ws"), + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/webrtc-direct/certhash/uEiCZqN653gMqxrWNmYuNg7Emwb-wvtsuzGE3XD6rypViZA"), + mustMultiaddr(t, "/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + filterAddrs: []string{"webtransport", "wss", "webrtc-direct", "!p2p-circuit"}, + filterProtocols: []string{"transport-ipfs-gateway-http", "transport-bitswap"}, + expected: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/webrtc-direct/certhash/uEiCZqN653gMqxrWNmYuNg7Emwb-wvtsuzGE3XD6rypViZA"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + }, + { + name: "Unknown protocol filter", + provider: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + }, + filterAddrs: []string{}, + filterProtocols: []string{"unknown"}, + expected: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := applyFilters(tt.provider, tt.filterAddrs, tt.filterProtocols) + assert.Equal(t, tt.expected, result) + }) + } +} + +func mustMultiaddr(t *testing.T, s string) types.Multiaddr { + addr, err := multiaddr.NewMultiaddr(s) + if err != nil { + t.Fatalf("Failed to create multiaddr: %v", err) + } + return types.Multiaddr{Multiaddr: addr} +} diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 1e1a84770..3cdcefee0 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -194,6 +194,11 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { return } + // Parse query parameters + query := httpReq.URL.Query() + filterAddrs := parseFilter(query.Get("filter-addrs")) + filterProtocols := parseFilter(query.Get("filter-protocols")) + mediaType, err := s.detectResponseType(httpReq) if err != nil { writeErr(w, "FindProviders", http.StatusBadRequest, err) @@ -201,7 +206,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { } var ( - handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[types.Record]) + handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) recordsLimit int ) @@ -224,13 +229,14 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { } } - handlerFunc(w, provIter) + handlerFunc(w, provIter, filterAddrs, filterProtocols) } -func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record]) { +func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) { defer provIter.Close() - providers, err := iter.ReadAllResults(provIter) + filteredIter := applyFiltersToIter(provIter, filterAddrs, filterProtocols) + providers, err := iter.ReadAllResults(filteredIter) if err != nil { writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return @@ -240,9 +246,10 @@ func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIt Providers: providers, }) } +func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) { + filteredIter := applyFiltersToIter(provIter, filterAddrs, filterProtocols) -func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record]) { - writeResultsIterNDJSON(w, provIter) + writeResultsIterNDJSON(w, filteredIter) } func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { @@ -277,6 +284,10 @@ func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { return } + query := r.URL.Query() + filterAddrs := parseFilter(query.Get("filter-addrs")) + filterProtocols := parseFilter(query.Get("filter-protocols")) + mediaType, err := s.detectResponseType(r) if err != nil { writeErr(w, "FindPeers", http.StatusBadRequest, err) @@ -284,7 +295,7 @@ func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { } var ( - handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[*types.PeerRecord]) + handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[*types.PeerRecord], filterAddrs, filterProtocols []string) recordsLimit int ) @@ -307,7 +318,7 @@ func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { } } - handlerFunc(w, provIter) + handlerFunc(w, provIter, filterAddrs, filterProtocols) } func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { @@ -369,10 +380,33 @@ func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { writeJSONResult(w, "Provide", resp) } -func (s *server) findPeersJSON(w http.ResponseWriter, peersIter iter.ResultIter[*types.PeerRecord]) { +func (s *server) findPeersJSON(w http.ResponseWriter, peersIter iter.ResultIter[*types.PeerRecord], filterAddrs, filterProtocols []string) { defer peersIter.Close() - peers, err := iter.ReadAllResults(peersIter) + // Convert PeerRecord to Record so that we can reuse the filtering logic from findProviders + mappedIter := iter.Map(peersIter, func(v iter.Result[*types.PeerRecord]) iter.Result[types.Record] { + if v.Err != nil || v.Val == nil { + return iter.Result[types.Record]{Err: v.Err} + } + + var record types.Record = v.Val + return iter.Result[types.Record]{Val: record} + }) + + filteredIter := applyFiltersToIter(mappedIter, filterAddrs, filterProtocols) + + // Convert Record back to PeerRecord 🙃 + finalIter := iter.Map(filteredIter, func(v iter.Result[types.Record]) iter.Result[*types.PeerRecord] { + if v.Err != nil || v.Val == nil { + return iter.Result[*types.PeerRecord]{Err: v.Err} + } + + var record *types.PeerRecord = v.Val.(*types.PeerRecord) + return iter.Result[*types.PeerRecord]{Val: record} + }) + + peers, err := iter.ReadAllResults(finalIter) + if err != nil { writeErr(w, "FindPeers", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return @@ -383,8 +417,19 @@ func (s *server) findPeersJSON(w http.ResponseWriter, peersIter iter.ResultIter[ }) } -func (s *server) findPeersNDJSON(w http.ResponseWriter, peersIter iter.ResultIter[*types.PeerRecord]) { - writeResultsIterNDJSON(w, peersIter) +func (s *server) findPeersNDJSON(w http.ResponseWriter, peersIter iter.ResultIter[*types.PeerRecord], filterAddrs, filterProtocols []string) { + // Convert PeerRecord to Record so that we can reuse the filtering logic from findProviders + mappedIter := iter.Map(peersIter, func(v iter.Result[*types.PeerRecord]) iter.Result[types.Record] { + if v.Err != nil || v.Val == nil { + return iter.Result[types.Record]{Err: v.Err} + } + + var record types.Record = v.Val + return iter.Result[types.Record]{Val: record} + }) + + filteredIter := applyFiltersToIter(mappedIter, filterAddrs, filterProtocols) + writeResultsIterNDJSON(w, filteredIter) } func (s *server) GetIPNS(w http.ResponseWriter, r *http.Request) { @@ -572,7 +617,7 @@ func logErr(method, msg string, err error) { logger.Infow(msg, "Method", method, "Error", err) } -func writeResultsIterNDJSON[T any](w http.ResponseWriter, resultIter iter.ResultIter[T]) { +func writeResultsIterNDJSON[T types.Record](w http.ResponseWriter, resultIter iter.ResultIter[T]) { defer resultIter.Close() w.Header().Set("Content-Type", mediaTypeNDJSON) diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index 3f4e7906a..772f79999 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/rand" + "fmt" "io" "net/http" "net/http/httptest" @@ -22,6 +23,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/routing" b58 "github.com/mr-tron/base58/base58" + "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -93,6 +95,13 @@ func TestProviders(t *testing.T) { pid2Str := "12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz" cidStr := "bafkreifjjcie6lypi6ny7amxnfftagclbuxndqonfipmb64f2km2devei4" + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + addr2, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/udp/4001/quic-v1") + addr3, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001/ws") + addr4, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit") + addr5, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit") + addr6, _ := multiaddr.NewMultiaddr("/ip4/8.8.8.8/udp/4001/quic-v1/webtransport") + pid, err := peer.Decode(pidStr) require.NoError(t, err) pid2, err := peer.Decode(pid2Str) @@ -101,7 +110,7 @@ func TestProviders(t *testing.T) { cid, err := cid.Decode(cidStr) require.NoError(t, err) - runTest := func(t *testing.T, contentType string, empty bool, expectedStream bool, expectedBody string) { + runTest := func(t *testing.T, contentType string, filterAddrs, filterProtocols string, empty bool, expectedStream bool, expectedBody string) { t.Parallel() var results *iter.SliceIter[iter.Result[types.Record]] @@ -114,16 +123,22 @@ func TestProviders(t *testing.T) { Schema: types.SchemaPeer, ID: &pid, Protocols: []string{"transport-bitswap"}, + Addrs: []types.Multiaddr{ + {Multiaddr: addr1}, + {Multiaddr: addr2}, + {Multiaddr: addr3}, + {Multiaddr: addr4}, + {Multiaddr: addr5}, + {Multiaddr: addr6}, + }, + }}, + {Val: &types.PeerRecord{ + Schema: types.SchemaPeer, + ID: &pid2, + Protocols: []string{"transport-ipfs-gateway-http"}, Addrs: []types.Multiaddr{}, }}, - //lint:ignore SA1019 // ignore staticcheck - {Val: &types.BitswapRecord{ - //lint:ignore SA1019 // ignore staticcheck - Schema: types.SchemaBitswap, - ID: &pid2, - Protocol: "transport-bitswap", - Addrs: []types.Multiaddr{}, - }}}, + }, ) } @@ -136,7 +151,17 @@ func TestProviders(t *testing.T) { limit = DefaultStreamingRecordsLimit } router.On("FindProviders", mock.Anything, cid, limit).Return(results, nil) - urlStr := serverAddr + "/routing/v1/providers/" + cidStr + + urlStr := fmt.Sprintf("%s/routing/v1/providers/%s", serverAddr, cidStr) + if filterAddrs != "" || filterProtocols != "" { + urlStr += "?" + if filterAddrs != "" { + urlStr = fmt.Sprintf("%s&filter-addrs=%s", urlStr, filterAddrs) + } + if filterProtocols != "" { + urlStr = fmt.Sprintf("%s&filter-protocols=%s", urlStr, filterProtocols) + } + } req, err := http.NewRequest(http.MethodGet, urlStr, nil) require.NoError(t, err) @@ -174,29 +199,55 @@ func TestProviders(t *testing.T) { } t.Run("JSON Response", func(t *testing.T) { - runTest(t, mediaTypeJSON, false, false, `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}]}`) + runTest(t, mediaTypeJSON, "", "", false, false, `{"Providers":[{"Addrs":["/ip4/127.0.0.1/tcp/4001","/ip4/127.0.0.1/udp/4001/quic-v1","/ip4/127.0.0.1/tcp/4001/ws","/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}]}`) + }) + + t.Run("JSON Response with addr filtering including unknown", func(t *testing.T) { + runTest(t, mediaTypeJSON, "webtransport,!p2p-circuit,unknown", "", false, false, `{"Providers":[{"Addrs":["/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}]}`) + }) + + t.Run("JSON Response with addr filtering", func(t *testing.T) { + runTest(t, mediaTypeJSON, "webtransport,!p2p-circuit", "", false, false, `{"Providers":[{"Addrs":["/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}]}`) + }) + + t.Run("JSON Response with protocol and addr filtering", func(t *testing.T) { + runTest(t, mediaTypeJSON, "quic-v1", "transport-bitswap", false, false, + `{"Providers":[{"Addrs":["/ip4/127.0.0.1/udp/4001/quic-v1","/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}]}`) + }) + + t.Run("JSON Response with protocol filtering", func(t *testing.T) { + runTest(t, mediaTypeJSON, "", "transport-ipfs-gateway-http", false, false, + `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}]}`) }) t.Run("Empty JSON Response", func(t *testing.T) { - runTest(t, mediaTypeJSON, true, false, `{"Providers":null}`) + runTest(t, mediaTypeJSON, "", "", true, false, `{"Providers":null}`) }) t.Run("Wildcard Accept header defaults to JSON Response", func(t *testing.T) { accept := "text/html,*/*" - runTest(t, accept, true, false, `{"Providers":null}`) + runTest(t, accept, "", "", true, false, `{"Providers":null}`) }) t.Run("Missing Accept header defaults to JSON Response", func(t *testing.T) { accept := "" - runTest(t, accept, true, false, `{"Providers":null}`) + runTest(t, accept, "", "", true, false, `{"Providers":null}`) }) t.Run("NDJSON Response", func(t *testing.T) { - runTest(t, mediaTypeNDJSON, false, true, `{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}`+"\n") + runTest(t, mediaTypeNDJSON, "", "", false, true, `{"Addrs":["/ip4/127.0.0.1/tcp/4001","/ip4/127.0.0.1/udp/4001/quic-v1","/ip4/127.0.0.1/tcp/4001/ws","/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}`+"\n") + }) + + t.Run("NDJSON Response with addr filtering", func(t *testing.T) { + runTest(t, mediaTypeNDJSON, "webtransport,!p2p-circuit,unknown", "", false, true, `{"Addrs":["/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}`+"\n") + }) + + t.Run("NDJSON Response with addr filtering", func(t *testing.T) { + runTest(t, mediaTypeNDJSON, "webtransport,!p2p-circuit,unknown", "", false, true, `{"Addrs":["/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}`+"\n") }) t.Run("Empty NDJSON Response", func(t *testing.T) { - runTest(t, mediaTypeNDJSON, true, true, "") + runTest(t, mediaTypeNDJSON, "", "", true, true, "") }) t.Run("404 when router returns routing.ErrNotFound", func(t *testing.T) { @@ -217,10 +268,21 @@ func TestProviders(t *testing.T) { } func TestPeers(t *testing.T) { - makeRequest := func(t *testing.T, router *mockContentRouter, contentType, arg string) *http.Response { + makeRequest := func(t *testing.T, router *mockContentRouter, contentType, arg, filterAddrs, filterProtocols string) *http.Response { server := httptest.NewServer(Handler(router)) t.Cleanup(server.Close) - req, err := http.NewRequest(http.MethodGet, "http://"+server.Listener.Addr().String()+"/routing/v1/peers/"+arg, nil) + + urlStr := fmt.Sprintf("http://%s/routing/v1/peers/%s", server.Listener.Addr().String(), arg) + if filterAddrs != "" || filterProtocols != "" { + urlStr += "?" + if filterAddrs != "" { + urlStr = fmt.Sprintf("%s&filter-addrs=%s", urlStr, filterAddrs) + } + if filterProtocols != "" { + urlStr = fmt.Sprintf("%s&filter-protocols=%s", urlStr, filterProtocols) + } + } + req, err := http.NewRequest(http.MethodGet, urlStr, nil) require.NoError(t, err) if contentType != "" { req.Header.Set("Accept", contentType) @@ -234,7 +296,7 @@ func TestPeers(t *testing.T) { t.Parallel() router := &mockContentRouter{} - resp := makeRequest(t, router, mediaTypeJSON, "nonpeerid") + resp := makeRequest(t, router, mediaTypeJSON, "nonpeerid", "", "") require.Equal(t, 400, resp.StatusCode) }) @@ -247,7 +309,7 @@ func TestPeers(t *testing.T) { router := &mockContentRouter{} router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) - resp := makeRequest(t, router, mediaTypeJSON, peer.ToCid(pid).String()) + resp := makeRequest(t, router, mediaTypeJSON, peer.ToCid(pid).String(), "", "") require.Equal(t, 404, resp.StatusCode) require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type")) @@ -267,7 +329,7 @@ func TestPeers(t *testing.T) { router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) // Simulate request with Accept header that includes wildcard match - resp := makeRequest(t, router, "text/html,*/*", peer.ToCid(pid).String()) + resp := makeRequest(t, router, "text/html,*/*", peer.ToCid(pid).String(), "", "") // Expect response to default to application/json require.Equal(t, 404, resp.StatusCode) @@ -285,7 +347,7 @@ func TestPeers(t *testing.T) { router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) // Simulate request without Accept header - resp := makeRequest(t, router, "", peer.ToCid(pid).String()) + resp := makeRequest(t, router, "", peer.ToCid(pid).String(), "", "") // Expect response to default to application/json require.Equal(t, 404, resp.StatusCode) @@ -301,7 +363,7 @@ func TestPeers(t *testing.T) { router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(nil, routing.ErrNotFound) // Simulate request without Accept header - resp := makeRequest(t, router, "", peer.ToCid(pid).String()) + resp := makeRequest(t, router, "", peer.ToCid(pid).String(), "", "") // Expect response to default to application/json require.Equal(t, 404, resp.StatusCode) @@ -331,7 +393,7 @@ func TestPeers(t *testing.T) { router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) libp2pKeyCID := peer.ToCid(pid).String() - resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID) + resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID, "", "") require.Equal(t, 200, resp.StatusCode) require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type")) @@ -347,6 +409,110 @@ func TestPeers(t *testing.T) { require.Equal(t, expectedBody, string(body)) }) + t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (JSON) with filter-addrs", func(t *testing.T) { + t.Parallel() + + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + addr2, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/udp/4001/quic-v1") + addr3, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001/ws") + addr4, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit") + addr5, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu") + _, pid := makeEd25519PeerID(t) + _, pid2 := makeEd25519PeerID(t) + results := iter.FromSlice([]iter.Result[*types.PeerRecord]{ + {Val: &types.PeerRecord{ + Schema: types.SchemaPeer, + ID: &pid, + Protocols: []string{"transport-bitswap", "transport-foo"}, + Addrs: []types.Multiaddr{ + {Multiaddr: addr1}, + {Multiaddr: addr2}, + {Multiaddr: addr3}, + {Multiaddr: addr4}, + }, + }}, + {Val: &types.PeerRecord{ + Schema: types.SchemaPeer, + ID: &pid2, + Protocols: []string{"transport-foo"}, + Addrs: []types.Multiaddr{ + {Multiaddr: addr5}, + }, + }}, + }) + + router := &mockContentRouter{} + router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) + + libp2pKeyCID := peer.ToCid(pid).String() + resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID, "tcp", "") + require.Equal(t, 200, resp.StatusCode) + + require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type")) + require.Equal(t, "Accept", resp.Header.Get("Vary")) + require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control")) + + requireCloseToNow(t, resp.Header.Get("Last-Modified")) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + expectedBody := `{"Peers":[{"Addrs":["/ip4/127.0.0.1/tcp/4001","/ip4/127.0.0.1/tcp/4001/ws"],"ID":"` + pid.String() + `","Protocols":["transport-bitswap","transport-foo"],"Schema":"peer"}]}` + require.Equal(t, expectedBody, string(body)) + }) + + t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (JSON) with filter-protocols", func(t *testing.T) { + t.Parallel() + + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + addr2, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/udp/4001/quic-v1") + addr3, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001/ws") + addr4, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit") + addr5, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu") + _, pid := makeEd25519PeerID(t) + _, pid2 := makeEd25519PeerID(t) + results := iter.FromSlice([]iter.Result[*types.PeerRecord]{ + {Val: &types.PeerRecord{ + Schema: types.SchemaPeer, + ID: &pid, + Protocols: []string{"transport-bitswap", "transport-foo"}, + Addrs: []types.Multiaddr{ + {Multiaddr: addr1}, + {Multiaddr: addr2}, + {Multiaddr: addr3}, + {Multiaddr: addr4}, + }, + }}, + {Val: &types.PeerRecord{ + Schema: types.SchemaPeer, + ID: &pid2, + Protocols: []string{"transport-foo"}, + Addrs: []types.Multiaddr{ + {Multiaddr: addr5}, + }, + }}, + }) + + router := &mockContentRouter{} + router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) + + libp2pKeyCID := peer.ToCid(pid).String() + resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID, "", "transport-bitswap") + require.Equal(t, 200, resp.StatusCode) + + require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type")) + require.Equal(t, "Accept", resp.Header.Get("Vary")) + require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control")) + + requireCloseToNow(t, resp.Header.Get("Last-Modified")) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + expectedBody := `{"Peers":[{"Addrs":["/ip4/127.0.0.1/tcp/4001","/ip4/127.0.0.1/udp/4001/quic-v1","/ip4/127.0.0.1/tcp/4001/ws","/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"],"ID":"` + pid.String() + `","Protocols":["transport-bitswap","transport-foo"],"Schema":"peer"}]}` + require.Equal(t, expectedBody, string(body)) + }) + t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 404 with correct body and headers (No Results, NDJSON)", func(t *testing.T) { t.Parallel() @@ -356,7 +522,7 @@ func TestPeers(t *testing.T) { router := &mockContentRouter{} router.On("FindPeers", mock.Anything, pid, DefaultStreamingRecordsLimit).Return(results, nil) - resp := makeRequest(t, router, mediaTypeNDJSON, peer.ToCid(pid).String()) + resp := makeRequest(t, router, mediaTypeNDJSON, peer.ToCid(pid).String(), "", "") require.Equal(t, 404, resp.StatusCode) require.Equal(t, mediaTypeNDJSON, resp.Header.Get("Content-Type")) @@ -389,7 +555,7 @@ func TestPeers(t *testing.T) { router.On("FindPeers", mock.Anything, pid, DefaultStreamingRecordsLimit).Return(results, nil) libp2pKeyCID := peer.ToCid(pid).String() - resp := makeRequest(t, router, mediaTypeNDJSON, libp2pKeyCID) + resp := makeRequest(t, router, mediaTypeNDJSON, libp2pKeyCID, "", "") require.Equal(t, 200, resp.StatusCode) require.Equal(t, mediaTypeNDJSON, resp.Header.Get("Content-Type")) @@ -451,7 +617,7 @@ func TestPeers(t *testing.T) { router := &mockContentRouter{} router.On("FindPeers", mock.Anything, pid, DefaultStreamingRecordsLimit).Return(iter.FromSlice(results), nil) - resp := makeRequest(t, router, mediaTypeNDJSON, peerIDStr) + resp := makeRequest(t, router, mediaTypeNDJSON, peerIDStr, "", "") require.Equal(t, 200, resp.StatusCode) require.Equal(t, mediaTypeNDJSON, resp.Header.Get("Content-Type")) @@ -471,7 +637,7 @@ func TestPeers(t *testing.T) { router := &mockContentRouter{} router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(iter.FromSlice(results), nil) - resp := makeRequest(t, router, mediaTypeJSON, peerIDStr) + resp := makeRequest(t, router, mediaTypeJSON, peerIDStr, "", "") require.Equal(t, 200, resp.StatusCode) require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type")) diff --git a/routing/http/types/iter/filter.go b/routing/http/types/iter/filter.go new file mode 100644 index 000000000..628997811 --- /dev/null +++ b/routing/http/types/iter/filter.go @@ -0,0 +1,43 @@ +package iter + +// Filter returns an iterator that filters out values that don't satisfy the predicate f. +func Filter[T any](iter Iter[T], f func(t T) bool) *FilterIter[T] { + return &FilterIter[T]{iter: iter, f: f} +} + +type FilterIter[T any] struct { + iter Iter[T] + f func(T) bool + + done bool + val T +} + +func (f *FilterIter[T]) Next() bool { + if f.done { + return false + } + + ok := f.iter.Next() + f.done = !ok + + if f.done { + return false + } + + f.val = f.iter.Val() + + if f.f(f.val) { + return true + } + + return f.Next() +} + +func (f *FilterIter[T]) Val() T { + return f.val +} + +func (f *FilterIter[T]) Close() error { + return f.iter.Close() +} diff --git a/routing/http/types/iter/filter_test.go b/routing/http/types/iter/filter_test.go new file mode 100644 index 000000000..6d170285e --- /dev/null +++ b/routing/http/types/iter/filter_test.go @@ -0,0 +1,41 @@ +package iter + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilter(t *testing.T) { + for _, c := range []struct { + input Iter[int] + f func(int) bool + expResults []int + }{ + { + input: FromSlice([]int{1, 2, 3, 4}), + f: func(i int) bool { return i%2 == 0 }, + expResults: []int{2, 4}, + }, + { + input: FromSlice([]int{}), + f: func(i int) bool { return i%2 == 0 }, + expResults: nil, + }, + { + input: FromSlice([]int{1, 3, 5, 100}), + f: func(i int) bool { return i > 2 }, + expResults: []int{3, 5, 100}, + }, + } { + t.Run(fmt.Sprintf("%v", c.input), func(t *testing.T) { + iter := Filter(c.input, c.f) + var res []int + for iter.Next() { + res = append(res, iter.Val()) + } + assert.Equal(t, c.expResults, res) + }) + } +} diff --git a/routing/http/types/record_peer.go b/routing/http/types/record_peer.go index 76bd810e0..cb4a04fca 100644 --- a/routing/http/types/record_peer.go +++ b/routing/http/types/record_peer.go @@ -79,3 +79,13 @@ func (pr PeerRecord) MarshalJSON() ([]byte, error) { return drjson.MarshalJSONBytes(m) } + +func FromBitswapRecord(br *BitswapRecord) *PeerRecord { + return &PeerRecord{ + Schema: SchemaPeer, + ID: br.ID, + Addrs: br.Addrs, + Protocols: []string{br.Protocol}, + Extra: map[string]json.RawMessage{}, + } +}