diff --git a/traverse.go b/traverse.go index 657e2c4..90073e2 100644 --- a/traverse.go +++ b/traverse.go @@ -95,6 +95,13 @@ func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) * } pointer, bit := r.traverseTree(ip, 0, uint(prefixLength)) + + // We could skip this when bit >= prefixLength if we assume that the network + // passed in is in canonical form. However, given that this may not be the + // case, it is safest to always take the mask. If this is hot code at some + // point, we could eliminate the allocation of the net.IPMask by zeroing + // out the bits in ip directly. + ip = ip.Mask(net.CIDRMask(bit, len(ip)*8)) networks.nodes = []netNode{ { ip: ip, diff --git a/traverse_test.go b/traverse_test.go index 62e3547..00edfce 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -3,6 +3,8 @@ package maxminddb import ( "fmt" "net" + "strconv" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -71,6 +73,8 @@ var tests = []networkTest{ }, }, { + // This is intentionally in non-canonical form to test + // that we handle it correctly. Network: "1.1.1.1/30", Database: "ipv4", Expected: []string{ @@ -78,6 +82,13 @@ var tests = []networkTest{ "1.1.1.2/31", }, }, + { + Network: "1.1.1.2/31", + Database: "ipv4", + Expected: []string{ + "1.1.1.2/31", + }, + }, { Network: "1.1.1.1/32", Database: "ipv4", @@ -85,6 +96,27 @@ var tests = []networkTest{ "1.1.1.1/32", }, }, + { + Network: "1.1.1.2/32", + Database: "ipv4", + Expected: []string{ + "1.1.1.2/31", + }, + }, + { + Network: "1.1.1.3/32", + Database: "ipv4", + Expected: []string{ + "1.1.1.2/31", + }, + }, + { + Network: "1.1.1.19/32", + Database: "ipv4", + Expected: []string{ + "1.1.1.16/28", + }, + }, { Network: "255.255.255.0/24", Database: "ipv4", @@ -234,28 +266,51 @@ var tests = []networkTest{ func TestNetworksWithin(t *testing.T) { for _, v := range tests { for _, recordSize := range []uint{24, 28, 32} { - fileName := testFile(fmt.Sprintf("MaxMind-DB-test-%s-%d.mmdb", v.Database, recordSize)) - reader, err := Open(fileName) - require.NoError(t, err, "unexpected error while opening database: %v", err) + name := fmt.Sprintf( + "%s-%d: %s, options: %v", + v.Database, + recordSize, + v.Network, + len(v.Options) != 0, + ) + t.Run(name, func(t *testing.T) { + fileName := testFile(fmt.Sprintf("MaxMind-DB-test-%s-%d.mmdb", v.Database, recordSize)) + reader, err := Open(fileName) + require.NoError(t, err, "unexpected error while opening database: %v", err) - _, network, err := net.ParseCIDR(v.Network) - require.NoError(t, err) - n := reader.NetworksWithin(network, v.Options...) - var innerIPs []string + // We are purposely not using net.ParseCIDR so that we can pass in + // values that aren't in canonical form. + parts := strings.Split(v.Network, "/") + ip := net.ParseIP(parts[0]) + if v := ip.To4(); v != nil { + ip = v + } + prefixLength, err := strconv.Atoi(parts[1]) + require.NoError(t, err) + mask := net.CIDRMask(prefixLength, len(ip)*8) + network := &net.IPNet{ + IP: ip, + Mask: mask, + } - for n.Next() { - record := struct { - IP string `maxminddb:"ip"` - }{} - network, err := n.Network(&record) require.NoError(t, err) - innerIPs = append(innerIPs, network.String()) - } + n := reader.NetworksWithin(network, v.Options...) + var innerIPs []string - assert.Equal(t, v.Expected, innerIPs) - require.NoError(t, n.Err()) + for n.Next() { + record := struct { + IP string `maxminddb:"ip"` + }{} + network, err := n.Network(&record) + require.NoError(t, err) + innerIPs = append(innerIPs, network.String()) + } - require.NoError(t, reader.Close()) + assert.Equal(t, v.Expected, innerIPs) + require.NoError(t, n.Err()) + + require.NoError(t, reader.Close()) + }) } } }