diff --git a/format/plain/writer.go b/format/plain/writer.go index 49f417e..deeb5a1 100644 --- a/format/plain/writer.go +++ b/format/plain/writer.go @@ -42,13 +42,6 @@ func NewWriter(meta *model.Meta) (*Writer, error) { meta: meta, } - ret.buffer = bytes.NewBuffer([]byte{}) - ret.iw = ret.buffer - - if err := ret.Header(); err != nil { - return nil, err - } - return ret, nil } @@ -76,7 +69,12 @@ func (w *Writer) SetOption(option interface{}) error { // Insert adds the given IP information into the writer. func (w *Writer) Insert(info *model.IPInfo) error { if w.iw == nil { - return errors.ErrNilWriter + w.buffer = bytes.NewBuffer([]byte{}) + w.iw = w.buffer + + if err := w.Header(); err != nil { + return err + } } for _, ipNet := range info.IPNet.IPNets() { diff --git a/format/zxinc/reader.go b/format/zxinc/reader.go index 7af7c82..55e718d 100644 --- a/format/zxinc/reader.go +++ b/format/zxinc/reader.go @@ -45,7 +45,7 @@ func NewReader(file string) (*Reader, error) { meta := &model.Meta{ MetaVersion: model.MetaVersion, Format: DBFormat, - IPVersion: model.IPv4, + IPVersion: model.IPv6, Fields: FullFields, } meta.AddCommonFieldAlias(CommonFieldsAlias) diff --git a/ipnet/ip.go b/ipnet/ip.go index 860eb9c..23fe064 100644 --- a/ipnet/ip.go +++ b/ipnet/ip.go @@ -78,32 +78,46 @@ func Uint64ToIP2(high, low uint64) net.IP { } } -// adjustIP modifies the IP by the given delta, which can be positive or negative. -func adjustIP(ip net.IP, delta int) net.IP { - ipVals := make([]byte, len(ip)) - copy(ipVals, ip) - i := len(ipVals) - 1 - for i >= 0 { - sum := int(ipVals[i]) + delta - if 0 <= sum && sum <= 255 { - ipVals[i] = byte(sum) +// PrevIP returns the IP immediately before the given IP. +func PrevIP(ip net.IP) net.IP { + res := make(net.IP, len(ip)) + copy(res, ip) + + // Ensure it's a pure IPv4 or IPv6 + if ip.To4() != nil { + ip = ip.To4() + res = res[len(res)-4:] + } + + for i := len(ip) - 1; i >= 0; i-- { + if res[i] > 0 { + res[i]-- break } - // Adjust the next byte and continue - delta, ipVals[i] = sum/256, byte(sum%256) - i-- + res[i] = 0xff } - return net.IP(ipVals) -} - -// PrevIP returns the IP immediately before the given IP. -func PrevIP(ip net.IP) net.IP { - return adjustIP(ip, -1) + return res } // NextIP returns the IP immediately after the given IP. func NextIP(ip net.IP) net.IP { - return adjustIP(ip, 1) + res := make(net.IP, len(ip)) + copy(res, ip) + + // Ensure it's a pure IPv4 or IPv6 + if ip.To4() != nil { + ip = ip.To4() + res = res[len(res)-4:] + } + + for i := len(ip) - 1; i >= 0; i-- { + if res[i] < 0xff { + res[i]++ + break + } + res[i] = 0 + } + return res } // IPLess compares two IPs and returns true if the first IP is less than the second. diff --git a/ipnet/ip_test.go b/ipnet/ip_test.go index 62ffb70..18d1107 100644 --- a/ipnet/ip_test.go +++ b/ipnet/ip_test.go @@ -18,6 +18,7 @@ package ipnet import ( "math/rand" + "net" "testing" "github.com/stretchr/testify/assert" @@ -43,3 +44,43 @@ func TestIPv4StrToUint32(t *testing.T) { i = IPv4StrToUint32("fake") ast.Equal(uint32(0), i) } + +func TestPrevIP(t *testing.T) { + tests := []struct { + input net.IP + expected net.IP + }{ + {net.ParseIP("192.168.1.1"), net.ParseIP("192.168.1.0")}, + {net.ParseIP("192.168.1.0"), net.ParseIP("192.168.0.255")}, + {net.ParseIP("0.0.0.0"), net.ParseIP("255.255.255.255")}, + {net.ParseIP("::1"), net.ParseIP("::")}, + {net.ParseIP("::"), net.ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")}, + } + + for _, test := range tests { + output := PrevIP(test.input) + if !output.Equal(test.expected) { + t.Errorf("For input %v, expected %v, but got %v", test.input, test.expected, output) + } + } +} + +func TestNextIP(t *testing.T) { + tests := []struct { + input net.IP + expected net.IP + }{ + {net.ParseIP("192.168.1.0"), net.ParseIP("192.168.1.1")}, + {net.ParseIP("192.168.0.255"), net.ParseIP("192.168.1.0")}, + {net.ParseIP("255.255.255.255"), net.ParseIP("0.0.0.0")}, + {net.ParseIP("::"), net.ParseIP("::1")}, + {net.ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"), net.ParseIP("::")}, + } + + for _, test := range tests { + output := NextIP(test.input) + if !output.Equal(test.expected) { + t.Errorf("For input %v, expected %v, but got %v", test.input, test.expected, output) + } + } +}