diff --git a/key/key.go b/key/key.go index 9a8808f..d75a503 100644 --- a/key/key.go +++ b/key/key.go @@ -3,12 +3,16 @@ package key import ( "bytes" "encoding/hex" + "errors" "fmt" "math" "github.com/plprobelab/go-kademlia/kad" ) +// ErrInvalidDataLength is the error returned when attempting to construct a key from binary data of the wrong length. +var ErrInvalidDataLength = errors.New("invalid data length") + const bitPanicMsg = "bit index out of range" // Key256 is a 256-bit Kademlia key. @@ -21,7 +25,7 @@ var _ kad.Key[Key256] = Key256{} // NewKey256 returns a 256-bit Kademlia key whose bits are set from the supplied bytes. func NewKey256(data []byte) Key256 { if len(data) != 32 { - panic("invalid data length for key") + panic(ErrInvalidDataLength) } var b [32]byte copy(b[:], data) @@ -86,7 +90,15 @@ func (k Key256) CommonPrefixLength(o Key256) int { // Compare compares the numeric value of the key with another key of the same type. func (k Key256) Compare(o Key256) int { - return bytes.Compare(k.b[:], o.b[:]) + if k.b != nil && o.b != nil { + return bytes.Compare(k.b[:], o.b[:]) + } + + var zero [32]byte + if k.b == nil { + return bytes.Compare(zero[:], o.b[:]) + } + return bytes.Compare(zero[:], k.b[:]) } // HexString returns a string containing the hexadecimal representation of the key. @@ -97,6 +109,16 @@ func (k Key256) HexString() string { return hex.EncodeToString(k.b[:]) } +// MarshalBinary marshals the key into a byte slice. +// The bytes may be passed to NewKey256 to construct a new key with the same value. +func (k Key256) MarshalBinary() ([]byte, error) { + buf := make([]byte, 32) + if k.b != nil { + copy(buf, (*k.b)[:]) + } + return buf, nil +} + // Key32 is a 32-bit Kademlia key, suitable for testing and simulation of small networks. type Key32 uint32 diff --git a/key/key_test.go b/key/key_test.go index ad2d6ed..86b6baa 100644 --- a/key/key_test.go +++ b/key/key_test.go @@ -34,6 +34,8 @@ func TestKey256(t *testing.T) { } tester.RunTests(t) + + testBinaryMarshaler(t, tester.KeyX, NewKey256) } func TestKey32(t *testing.T) { @@ -79,6 +81,7 @@ func TestBitStrKey7(t *testing.T) { tester.RunTests(t) } +// KeyTester tests a kad.Key's implementation type KeyTester[K kad.Key[K]] struct { // Key 0 is zero Key0 K @@ -127,6 +130,12 @@ func (kt *KeyTester[K]) TestXor(t *testing.T) { xored = kt.Key1.Xor(kt.Key2) require.Equal(t, kt.Key1xor2, xored) + + var empty K // zero value of key + xored = kt.Key0.Xor(empty) + require.Equal(t, kt.Key0, xored) + xored = empty.Xor(kt.Key0) + require.Equal(t, kt.Key0, xored) } func (kt *KeyTester[K]) TestCommonPrefixLength(t *testing.T) { @@ -141,6 +150,12 @@ func (kt *KeyTester[K]) TestCommonPrefixLength(t *testing.T) { cpl = kt.Key0.CommonPrefixLength(kt.Key010) require.Equal(t, 1, cpl) + + var empty K // zero value of key + cpl = kt.Key0.CommonPrefixLength(empty) + require.Equal(t, kt.Key0.BitLen(), cpl) + cpl = empty.CommonPrefixLength(kt.Key0) + require.Equal(t, kt.Key0.BitLen(), cpl) } func (kt *KeyTester[K]) TestCompare(t *testing.T) { @@ -167,6 +182,12 @@ func (kt *KeyTester[K]) TestCompare(t *testing.T) { res = kt.Key1.Compare(kt.Key2) require.Equal(t, -1, res) + + var empty K // zero value of key + res = kt.Key0.Compare(empty) + require.Equal(t, 0, res) + res = empty.Compare(kt.Key0) + require.Equal(t, 0, res) } func (kt *KeyTester[K]) TestBit(t *testing.T) { @@ -184,6 +205,11 @@ func (kt *KeyTester[K]) TestBit(t *testing.T) { } require.Equal(t, uint(1), kt.Key2.Bit(kt.Key2.BitLen()-2), fmt.Sprintf("Key1.Bit(%d)=%d", kt.Key2.BitLen()-2, kt.Key2.BitLen()-2)) require.Equal(t, uint(0), kt.Key2.Bit(kt.Key2.BitLen()-1), fmt.Sprintf("Key1.Bit(%d)=%d", kt.Key2.BitLen()-2, kt.Key2.BitLen()-1)) + + var empty K // zero value of key + for i := 0; i < empty.BitLen(); i++ { + require.Equal(t, uint(0), empty.Bit(i), fmt.Sprintf("empty.Bit(%d)=%d", i, kt.Key0.Bit(i))) + } } func (kt *KeyTester[K]) TestBitString(t *testing.T) { @@ -230,6 +256,21 @@ func (kt *KeyTester[K]) TestHexString(t *testing.T) { } } +// testBinaryMarshaler tests the behaviour of a kad.Key implementation that also implements the BinaryMarshaler interface +func testBinaryMarshaler[K interface { + kad.Key[K] + MarshalBinary() ([]byte, error) +}](t *testing.T, k K, newFunc func([]byte) K, +) { + b, err := k.MarshalBinary() + require.NoError(t, err) + + other := newFunc(b) + + res := k.Compare(other) + require.Equal(t, 0, res) +} + // BitStrKey is a key represented by a string of 1's and 0's type BitStrKey string @@ -253,6 +294,12 @@ func (k BitStrKey) Bit(i int) uint { func (k BitStrKey) Xor(o BitStrKey) BitStrKey { if len(k) != len(o) { + if len(k) == 0 && o.isZero() { + return BitStrKey(o) + } + if len(o) == 0 && k.isZero() { + return BitStrKey(k) + } panic("BitStrKey: other key has different length") } buf := make([]byte, len(k)) @@ -268,6 +315,12 @@ func (k BitStrKey) Xor(o BitStrKey) BitStrKey { func (k BitStrKey) CommonPrefixLength(o BitStrKey) int { if len(k) != len(o) { + if len(k) == 0 && o.isZero() { + return len(o) + } + if len(o) == 0 && k.isZero() { + return len(k) + } panic("BitStrKey: other key has different length") } for i := 0; i < len(k); i++ { @@ -280,6 +333,12 @@ func (k BitStrKey) CommonPrefixLength(o BitStrKey) int { func (k BitStrKey) Compare(o BitStrKey) int { if len(k) != len(o) { + if len(k) == 0 && o.isZero() { + return 0 + } + if len(o) == 0 && k.isZero() { + return 0 + } panic("BitStrKey: other key has different length") } for i := 0; i < len(k); i++ { @@ -292,3 +351,12 @@ func (k BitStrKey) Compare(o BitStrKey) int { } return 0 } + +func (k BitStrKey) isZero() bool { + for i := 0; i < len(k); i++ { + if k[i] != '0' { + return false + } + } + return true +}