diff --git a/compression.go b/compression.go index 73a9f13..6cbca86 100644 --- a/compression.go +++ b/compression.go @@ -10,13 +10,15 @@ import ( "reflect" ) +var _ variableReadWriter = (*compressionReadWriter)(nil) + type compressionReadWriter struct { variable handler readWriter level int } -func (c *compressionReadWriter) readVariable(r io.Reader, v reflect.Value) error { +func (c *compressionReadWriter) readVariable(r io.Reader, v reflect.Value) (err error) { lb := make([]byte, 4) if _, err := io.ReadFull(r, lb); err != nil { return err @@ -34,7 +36,9 @@ func (c *compressionReadWriter) readVariable(r io.Reader, v reflect.Value) error } z := flate.NewReader(bytes.NewBuffer(cb)) - defer z.Close() + defer func() { + _ = z.Close() // Memory buffer, can never error + }() return handleVariableReader(z, c.handler, v) } @@ -47,12 +51,8 @@ func (c *compressionReadWriter) writeVariable(w io.Writer, v reflect.Value) erro return err } - if err = handleVariableWriter(z, c.handler, v); err != nil { - return err - } - if err = z.Close(); err != nil { - return err - } + _ = handleVariableWriter(z, c.handler, v) // As we are using a memory buffer, these two calls can never err + _ = z.Close() lb := make([]byte, 4) binary.BigEndian.PutUint32(lb, uint32(b.Len())) @@ -66,10 +66,8 @@ func (c *compressionReadWriter) writeVariable(w io.Writer, v reflect.Value) erro return nil } -func (c *compressionReadWriter) vLength(v reflect.Value) (int, error) { +func (c *compressionReadWriter) vLength(v reflect.Value) int { var b bytes.Buffer - if err := c.writeVariable(&b, v); err != nil { - return 0, err - } - return b.Len(), nil + _ = c.writeVariable(&b, v) + return b.Len() } diff --git a/ikea_test.go b/ikea_test.go deleted file mode 100644 index f2b1a20..0000000 --- a/ikea_test.go +++ /dev/null @@ -1,305 +0,0 @@ -package ikea - -import ( - "bytes" - "encoding/hex" - "fmt" - "io" - "math/rand" - "os" - "reflect" - "strconv" - "strings" - "testing" -) - -/* Tests */ - -func TestBool(t *testing.T) { - i := rand.Int()%2 == 1 - typeTest(t, "TestBool", &i, i) -} - -func TestByte(t *testing.T) { - i := byte(rand.Int() & 0xFF) - typeTest(t, "TestByte", &i, i) -} - -func TestUint8(t *testing.T) { - i := uint8(rand.Int() & 0xFF) - typeTest(t, "TestUint8", &i, i) -} - -func TestUint16(t *testing.T) { - i := uint16(rand.Int() & 0xFFFF) - typeTest(t, "TestUint16", &i, i) -} - -func TestUint32(t *testing.T) { - i := rand.Uint32() - typeTest(t, "TestUint32", &i, i) -} - -func TestUint64(t *testing.T) { - i := rand.Uint64() - typeTest(t, "TestUint64", &i, i) -} - -func TestInt8(t *testing.T) { - i := int8(rand.Int() & 0xFF) - typeTest(t, "TestInt8", &i, i) -} - -func TestInt16(t *testing.T) { - i := int16(rand.Int() & 0xFFFF) - typeTest(t, "TestInt16", &i, i) -} - -func TestInt32(t *testing.T) { - i := rand.Int31() - typeTest(t, "TestInt32", &i, i) -} - -func TestInt64(t *testing.T) { - i := rand.Int63() - typeTest(t, "TestInt64", &i, i) -} - -func TestFloat32(t *testing.T) { - i := rand.Float32() - typeTest(t, "TestFloat32", &i, i) -} - -func TestFloat64(t *testing.T) { - i := rand.Float64() - typeTest(t, "TestFloat32", &i, i) -} - -func TestString(t *testing.T) { - const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - - b := make([]byte, rand.Intn(30)) - for i := range b { - b[i] = letters[rand.Intn(len(letters))] - } - s := string(b) - - typeTest(t, "TestString", &s, s) -} - -func typeTest(t *testing.T, typ string, value, compare interface{}) { - var b bytes.Buffer - - if err := Pack(&b, value); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Failing %s, could not write value: %s\n", typ, err.Error()) - t.FailNow() - } - - target := reflect.New(reflect.TypeOf(value).Elem()) - if err := Unpack(&b, target.Interface()); err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Failing %s, could not read value: %s\n", typ, err.Error()) - t.FailNow() - } - - dereference := target.Elem().Interface() - if dereference != compare { - _, _ = fmt.Fprintf(os.Stderr, "Failing %s, %T value %+v does not match original %T %+v\n", typ, dereference, dereference, compare, compare) - t.FailNow() - } -} - -func TestOutput(t *testing.T) { - buf := new(bytes.Buffer) - if err := Pack(buf, source); err != nil { - t.Error(err) - return - } - - result := buf.Bytes() - if len(result) != len(testData) { - fmt.Printf("Failing TestWrite, result \"%s\" length (%d) does not match test data length (%d)\n", hex.EncodeToString(result), len(result), len(testData)) - t.FailNow() - } - - // In struct creation we've added 0x4242 padding to we can split out the map. - // We need to do this because golangs map iteration is always random. - // So we'll deal with that one later in this test. - originalParts := strings.Split(hex.EncodeToString(testData), "4242") - resultParts := strings.Split(hex.EncodeToString(result), "4242") - - if originalParts[0] != resultParts[0] { - fmt.Printf("Failing TestWrite, hex output \"%s\" does not match test data slice\n", hex.EncodeToString(result)) - t.FailNow() - } - - // Instead, we treat the map a bit differently, we put it back into a buffer - buf.Reset() - if data, err := hex.DecodeString(resultParts[1]); err != nil { - fmt.Printf("Failing TestWrite, hex output \"%s\" is not a valid hex string: %s\n", resultParts[1], err.Error()) - t.FailNow() - } else { - buf.Write(data) - } - // Unpack it - var test map[string]string - if err := Unpack(buf, &test); err != nil { - fmt.Printf("Failing TestWrite, could not unpack map: %s\n", err.Error()) - t.FailNow() - } - - // And then compare it using DeepEqual - if !reflect.DeepEqual(source.TestMap, test) { - fmt.Printf("Failing TestWrite, resulting map is not equal\n") - t.FailNow() - } -} - -func TestCompleteRead(t *testing.T) { - buf := new(bytes.Buffer) - buf.Write(testData) - - tst := new(testStruct) - if err := Unpack(buf, tst); err != nil { - t.Error(err) - return - } - - compare(t, "TestBool", tst.TestBool, source.TestBool) - compare(t, "TestByte", tst.TestByte, source.TestByte) - compare(t, "TestUint8", tst.TestUint8, source.TestUint8) - compare(t, "TestUint16", tst.TestUint16, source.TestUint16) - compare(t, "TestUint32", tst.TestUint32, source.TestUint32) - compare(t, "TestUint64", tst.TestUint64, source.TestUint64) - compare(t, "TestInt8", tst.TestInt8, source.TestInt8) - compare(t, "TestInt16", tst.TestInt16, source.TestInt16) - compare(t, "TestInt32", tst.TestInt32, source.TestInt32) - compare(t, "TestInt64", tst.TestInt64, source.TestInt64) - compare(t, "TestFloat32", tst.TestFloat32, source.TestFloat32) - compare(t, "TestFloat64", tst.TestFloat64, source.TestFloat64) - compare(t, "TestString", tst.TestString, source.TestString) - compare(t, "TestSubStruct", tst.TestSubStruct.A, source.TestSubStruct.A) - compare(t, "TestInterface", tst.TestInterface.A, source.TestInterface.A) - compare(t, "TestFixedPtr", *tst.TestFixedPtr, *source.TestFixedPtr) - compare(t, "TestVariablePtr", *tst.TestVariablePtr, *source.TestVariablePtr) - for i := range source.TestSlice { - compare(t, "TestSlice["+strconv.Itoa(i)+"]", tst.TestSlice[i], source.TestSlice[i]) - } - for i := range source.TestCompression { - compare(t, "TestCompression["+strconv.Itoa(i)+"]", tst.TestCompression[i], source.TestCompression[i]) - } - for k := range source.TestMap { - compare(t, "TestMap["+k+"]", tst.TestMap[k], source.TestMap[k]) - } -} - -func TestLen(t *testing.T) { - if l, err := Len(source); err != nil { - t.Error(err) - } else if l != len(testData) { - fmt.Printf("Failing TestLen, Len reported an incorrect value %d, should be %d", l, len(testData)) - } -} - -func compare(t *testing.T, field string, value1, value2 interface{}) { - if value1 != value2 { - fmt.Printf("Failing TestCompleteRead, decoded data field '%s' with value '%v' does not match '%v'", field, value1, value2) - t.FailNow() - } -} - -/* Test types */ - -type testStruct struct { - TestBool bool - TestByte byte - TestUint8 uint8 - TestUint16 uint16 - TestUint32 uint32 - TestUint64 uint64 - TestInt8 int8 - TestInt16 int16 - TestInt32 int32 - TestInt64 int64 - TestFloat32 float32 - TestFloat64 float64 - TestString string - TestSubStruct testSubStruct - TestInterface testInterface - TestFixedPtr *uint8 - TestVariablePtr *string - TestSlice []byte - TestCompression []byte `ikea:"compress:9"` - Padding uint16 // Maps randomise iteration order, we can't verify this string, so we split using this - TestMap map[string]string -} - -type testSubStruct struct { - A byte -} - -type testInterface struct { - A int64 -} - -func (t *testInterface) Unpack(r io.Reader) error { - var temp int64 - if err := Unpack(r, &temp); err != nil { - return err - } - - t.A = temp - 10 - - return nil -} - -func (t *testInterface) Pack(w io.Writer) error { - return Pack(w, t.A+10) -} - -/* Test data */ -var source = &testStruct{ - TestBool: true, - TestByte: 0x11, - TestUint8: 0x88, - TestUint16: 0x1616, - TestUint32: 0x32323232, - TestUint64: 0x6464646464646464, - TestInt8: 0x12, - TestInt16: 0x1234, - TestInt32: 0x12345678, - TestInt64: 0x1234567812345678, - TestFloat32: 0.12345678, - TestFloat64: 0.12345678901234567890, - TestString: "amazing serialization lib", - TestSubStruct: testSubStruct{A: 0x42}, - TestInterface: testInterface{A: 0x24}, - TestFixedPtr: makeIntPtr(), - TestVariablePtr: makeStringPtr(), - TestSlice: make([]byte, 100), - TestCompression: make([]byte, 10000), - Padding: 0x4242, - TestMap: map[string]string{ - "keynr1": "valuenr1", - "anotherkey": "anothervalue", - }, -} -var testData, _ = hex.DecodeString("01118816163232323264646464646464641212341234567812345678123456783dfcd6e93fbf9add3746f65f00000019616d617a696e672073657269616c697a6174696f6e206c696242000000000000002e660000000e49276d206120706f696e74657221000000640001081b407dd85802dbeb38c69dc23c1044dee55f51c1b63646ec3016a4e1d380ed2223f6a32f9ffa478aca0e5ab526b15e323367d48173b03f25680f1f9e9304f56f7611451992b78e1d697953fc7cd7153a4d5455565d7095d22ead5731418d1cf21800000024ecc0811000000803c021649147fe5079ecfe939d03000000000000000080021f0000ffff424200000002000000066b65796e72310000000876616c75656e72310000000a616e6f746865726b65790000000c616e6f7468657276616c7565") - -func init() { - for i := range source.TestSlice { - source.TestSlice[i] = byte((i * i * i) % 0xFF) - } - for i := range source.TestCompression { - source.TestCompression[i] = 0x42 - } -} - -func makeIntPtr() *uint8 { - i := uint8(0x66) - return &i -} - -func makeStringPtr() *string { - s := "I'm a pointer!" - return &s -} diff --git a/interface.go b/interface.go index 41e45ca..4a1cd2c 100644 --- a/interface.go +++ b/interface.go @@ -21,6 +21,8 @@ type Packer interface { Pack(w io.Writer) error } +var _ variableReadWriter = (*customReadWriter)(nil) + type customReadWriter struct { variable fallback readWriter @@ -46,10 +48,8 @@ func (c *customReadWriter) writeVariable(w io.Writer, v reflect.Value) error { return err } -func (c *customReadWriter) vLength(v reflect.Value) (int, error) { +func (c *customReadWriter) vLength(v reflect.Value) int { var b bytes.Buffer - if err := c.writeVariable(&b, v); err != nil { - return 0, err - } - return b.Len(), nil + _ = c.writeVariable(&b, v) + return b.Len() } diff --git a/map.go b/map.go index 4696e27..b01c52b 100644 --- a/map.go +++ b/map.go @@ -32,6 +32,8 @@ func getMapHandlerFromType(t reflect.Type) readWriter { return info } +var _ variableReadWriter = (*mapReadWriter)(nil) + type mapReadWriter struct { variable @@ -93,25 +95,19 @@ func (s *mapReadWriter) writeVariable(w io.Writer, v reflect.Value) error { return nil } -func (s *mapReadWriter) vLength(v reflect.Value) (int, error) { +func (s *mapReadWriter) vLength(v reflect.Value) int { size := 4 for _, key := range v.MapKeys() { val := v.MapIndex(key) - l, err := handleVariableLength(s.keyHandler, key) - if err != nil { - return 0, err - } + l := handleVariableLength(s.keyHandler, key) size += l - l, err = handleVariableLength(s.valueHandler, val) - if err != nil { - return 0, err - } + l = handleVariableLength(s.valueHandler, val) size += l } - return size, nil + return size } diff --git a/pointer.go b/pointer.go index 5f9db57..bf34792 100644 --- a/pointer.go +++ b/pointer.go @@ -10,6 +10,9 @@ func getPointerHandlerFromType(t reflect.Type) readWriter { return &pointerWrapper{getTypeHandler(e), e} } +var _ fixedReadWriter = (*pointerWrapper)(nil) +var _ variableReadWriter = (*pointerWrapper)(nil) + type pointerWrapper struct { readWriter typ reflect.Type @@ -19,9 +22,9 @@ func (p *pointerWrapper) isFixed() bool { return p.readWriter.isFixed() } -func (p *pointerWrapper) vLength(v reflect.Value) (int, error) { +func (p *pointerWrapper) vLength(v reflect.Value) int { if v.IsNil() { - v = reflect.New(p.typ) + panic("Attempting to get Len of nil value") } return p.readWriter.(variableReadWriter).vLength(v.Elem()) } @@ -35,7 +38,7 @@ func (p *pointerWrapper) readVariable(r io.Reader, v reflect.Value) error { func (p *pointerWrapper) writeVariable(w io.Writer, v reflect.Value) error { if v.IsNil() { - v.Set(reflect.New(p.typ)) + panic("Attempting to marshal nil value") } return p.readWriter.(variableReadWriter).writeVariable(w, v.Elem()) } @@ -53,7 +56,7 @@ func (p *pointerWrapper) readFixed(b []byte, v reflect.Value) { func (p *pointerWrapper) writeFixed(b []byte, v reflect.Value) { if v.IsNil() { - v.Set(reflect.New(p.typ)) + panic("Attempting to marshal nil value") } p.readWriter.(fixedReadWriter).writeFixed(b, v.Elem()) } diff --git a/primitives.go b/primitives.go index bdf22d3..40db3f9 100644 --- a/primitives.go +++ b/primitives.go @@ -6,6 +6,8 @@ import ( "reflect" ) +var _ fixedReadWriter = (*primitiveReadWriter)(nil) + type primitiveReadWriter struct { fixed diff --git a/public.go b/public.go index f345cbb..6e3018b 100644 --- a/public.go +++ b/public.go @@ -28,7 +28,7 @@ func Pack(w io.Writer, data interface{}) error { } // Len will return the amount of bytes Pack will use. -func Len(data interface{}) (int, error) { +func Len(data interface{}) int { v := reflect.Indirect(reflect.ValueOf(data)) h := getTypeHandler(v.Type()) diff --git a/slice.go b/slice.go index 3aa822f..ed83260 100644 --- a/slice.go +++ b/slice.go @@ -30,6 +30,8 @@ func getSliceHandlerFromType(t reflect.Type) readWriter { return info } +var _ variableReadWriter = (*sliceReadWriter)(nil) + type sliceReadWriter struct { variable typ reflect.Type @@ -106,20 +108,17 @@ func (s *sliceReadWriter) writeVariable(w io.Writer, v reflect.Value) error { return nil } -func (s *sliceReadWriter) vLength(v reflect.Value) (int, error) { +func (s *sliceReadWriter) vLength(v reflect.Value) int { if s.handler.isFixed() { - return 4 + (v.Len() * s.handler.(fixedReadWriter).length()), nil + return 4 + (v.Len() * s.handler.(fixedReadWriter).length()) } // variable size := 4 h := s.handler.(variableReadWriter) for i := 0; i < v.Len(); i++ { - l, err := h.vLength(v.Index(i)) - if err != nil { - return 0, err - } + l := h.vLength(v.Index(i)) size += l } - return size, nil + return size } diff --git a/strings.go b/strings.go index 0119355..29ec141 100644 --- a/strings.go +++ b/strings.go @@ -12,6 +12,8 @@ import ( var stringTypeHandler = new(stringReadWriter) +var _ variableReadWriter = (*stringReadWriter)(nil) + type stringReadWriter struct { variable } @@ -55,6 +57,6 @@ func (s *stringReadWriter) writeVariable(w io.Writer, v reflect.Value) error { return nil } -func (s *stringReadWriter) vLength(v reflect.Value) (int, error) { - return 4 + v.Len(), nil +func (s *stringReadWriter) vLength(v reflect.Value) int { + return 4 + v.Len() } diff --git a/struct.go b/struct.go index 9bfcb4a..b8e02ff 100644 --- a/struct.go +++ b/struct.go @@ -37,19 +37,19 @@ func getStructHandlerFromType(t reflect.Type) readWriter { hasPacker = interfaceTest.Implements(packerInterface) ) if hasUnpacker && hasPacker { - ret.readWriter = &customReadWriter{fallback: nil} + ret.r = &customReadWriter{fallback: nil} } else if hasUnpacker || hasPacker { - ret.readWriter = &customReadWriter{fallback: scanStruct(t)} + ret.r = &customReadWriter{fallback: scanStruct(t)} } else { - ret.readWriter = scanStruct(t) + ret.r = scanStruct(t) } // Replace the original with the direct version (major performance boost) structIndexLock.Lock() - structIndex[t.String()] = ret.readWriter + structIndex[t.String()] = ret.r structIndexLock.Unlock() - return ret.readWriter + return ret.r } func scanStruct(t reflect.Type) readWriter { @@ -97,30 +97,28 @@ func scanStruct(t reflect.Type) readWriter { return &variableStructReadWriter{handlers: handlers} } +// var _ fixedReadWriter = (*structWrapper)(nil) +var _ variableReadWriter = (*structWrapper)(nil) + type structWrapper struct { sync.Mutex - readWriter -} - -func (s *structWrapper) isFixed() bool { - s.Lock() - defer s.Unlock() - - return s.readWriter.isFixed() + variable + r readWriter } -func (s *structWrapper) vLength(v reflect.Value) (int, error) { - return s.readWriter.(variableReadWriter).vLength(v) +func (s *structWrapper) vLength(v reflect.Value) int { + return s.r.(variableReadWriter).vLength(v) } func (s *structWrapper) readVariable(r io.Reader, v reflect.Value) error { - return s.readWriter.(variableReadWriter).readVariable(r, v) + return s.r.(variableReadWriter).readVariable(r, v) } func (s *structWrapper) writeVariable(w io.Writer, v reflect.Value) error { - return s.readWriter.(variableReadWriter).writeVariable(w, v) + return s.r.(variableReadWriter).writeVariable(w, v) } +/* func (s *structWrapper) length() int { return s.readWriter.(fixedReadWriter).length() } @@ -131,7 +129,9 @@ func (s *structWrapper) readFixed(b []byte, v reflect.Value) { func (s *structWrapper) writeFixed(b []byte, v reflect.Value) { s.readWriter.(fixedReadWriter).writeFixed(b, v) -} +}*/ + +var _ fixedReadWriter = (*fixedStructReadWriter)(nil) type fixedStructReadWriter struct { fixed @@ -162,6 +162,8 @@ func (s *fixedStructReadWriter) writeFixed(data []byte, v reflect.Value) { } } +var _ variableReadWriter = (*variableStructReadWriter)(nil) + type variableStructReadWriter struct { variable @@ -188,16 +190,12 @@ func (h *variableStructReadWriter) writeVariable(w io.Writer, v reflect.Value) e return nil } -func (h *variableStructReadWriter) vLength(v reflect.Value) (int, error) { +func (h *variableStructReadWriter) vLength(v reflect.Value) int { size := 0 for i, handler := range h.handlers { - l, err := handleVariableLength(handler, v.Field(i)) - if err != nil { - return 0, err - } - size += l + size += handleVariableLength(handler, v.Field(i)) } - return size, nil + return size } diff --git a/tests/errors_test.go b/tests/errors_test.go new file mode 100644 index 0000000..9d0da98 --- /dev/null +++ b/tests/errors_test.go @@ -0,0 +1,219 @@ +package tests + +import ( + "bytes" + "errors" + "math" + "testing" + + "github.com/ikkerens/ikeapack" +) + +func TestReadPointer(t *testing.T) { + defer func() { + if recover() == nil { + t.FailNow() + } + }() + + var i int + _ = ikea.Unpack(nil, i) +} + +func TestUseInt(t *testing.T) { + defer func() { + if recover() == nil { + t.FailNow() + } + }() + + var i int + _ = ikea.Unpack(nil, &i) // int is not supported +} + +func TestUseUint(t *testing.T) { + defer func() { + if recover() == nil { + t.FailNow() + } + }() + + var ui uint + _ = ikea.Unpack(nil, &ui) // uint is not supported +} + +func TestUnsupportedType(t *testing.T) { + defer func() { + if recover() == nil { + t.FailNow() + } + }() + + var ui complex64 + _ = ikea.Unpack(nil, &ui) // uint is not supported +} + +func TestVariableLengthOverflow(t *testing.T) { + overflow := new(bytes.Buffer) + _ = ikea.Pack(overflow, uint32(math.MaxInt32+1)) + var t1 struct { + Data struct{} `ikea:"compress:9"` + } + if err := ikea.Unpack(overflow, &t1); err == nil { + t.FailNow() + } + + overflow.Reset() + _ = ikea.Pack(overflow, uint32(math.MaxInt32+1)) + var t2 map[string]struct{} + if err := ikea.Unpack(overflow, &t2); err == nil { + t.FailNow() + } + + overflow.Reset() + _ = ikea.Pack(overflow, uint32(math.MaxInt32+1)) + var t3 []struct{} + if err := ikea.Unpack(overflow, &t3); err == nil { + t.FailNow() + } + + overflow.Reset() + _ = ikea.Pack(overflow, uint32(math.MaxInt32+1)) + var t4 string + if err := ikea.Unpack(overflow, &t4); err == nil { + t.FailNow() + } +} + +func TestCompressionInitError(t *testing.T) { + s1 := struct { + Data []byte `ikea:"compress:10"` + }{make([]byte, 10)} + if err := ikea.Pack(new(bytes.Buffer), &s1); err == nil { + t.Fail() + } + + s2 := struct { + Data []byte `ikea:"compress:a"` + }{make([]byte, 10)} + defer func() { + if recover() == nil { + t.FailNow() + } + }() + _ = ikea.Pack(new(bytes.Buffer), &s2) +} + +func TestInvalidUTF8(t *testing.T) { + var invalid string + b := bytes.NewBuffer([]byte{0x00, 0x00, 0x00, 0x01, 0xF1}) + if ikea.Unpack(b, &invalid) == nil { + t.FailNow() + } +} + +func TestPackFixedNil(t *testing.T) { + defer func() { + if recover() == nil { + t.FailNow() + } + }() + + s := struct { + A *uint32 + }{} + _ = ikea.Pack(new(bytes.Buffer), &s) +} + +func TestPackVariableNil(t *testing.T) { + defer func() { + if recover() == nil { + t.FailNow() + } + }() + + s := struct { + A *string + }{} + _ = ikea.Pack(new(bytes.Buffer), &s) +} + +func TestLenFixedNil(t *testing.T) { + s := struct { + A *uint32 + }{} + ikea.Len(&s) // Unlike all other nil values, this should succeed +} + +func TestLenVariableNil(t *testing.T) { + defer func() { + if recover() == nil { + t.FailNow() + } + }() + + s := struct { + A *string + }{} + ikea.Len(&s) +} + +func TestReadErrors(t *testing.T) { + e := new(errorStream) + + // Reading error tests + tst := new(testStruct) + err := errors.New("start of errors") + for err != nil { + err = ikea.Unpack(e, tst) + if err == nil && e.pass != len(testData) { + t.FailNow() + } + e.Reset() + } +} + +func TestWriteErrors(t *testing.T) { + e := new(errorStream) + + // Reading error tests + err := errors.New("start of errors") + for err != nil { + err = ikea.Pack(e, source) + if err == nil && e.pass != len(testData) { + t.FailNow() + } + e.Reset() + } +} + +type errorStream struct { + pointer int + pass int +} + +func (s *errorStream) Read(p []byte) (n int, err error) { + if s.pointer+len(p) > s.pass { + s.pointer += len(p) + return 0, errors.New("test error") + } + + copy(p, testData[s.pointer:s.pointer+len(p)]) + s.pointer += len(p) + return len(p), nil +} + +func (s *errorStream) Write(p []byte) (n int, err error) { + if s.pointer+len(p) > s.pass { + s.pointer += len(p) + return 0, errors.New("test error") + } + + s.pointer += len(p) + return len(p), nil +} + +func (s *errorStream) Reset() { + s.pass = s.pointer + s.pointer = 0 +} diff --git a/tests/full_test.go b/tests/full_test.go new file mode 100644 index 0000000..48934c3 --- /dev/null +++ b/tests/full_test.go @@ -0,0 +1,129 @@ +package tests + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "testing" + + "github.com/ikkerens/ikeapack" +) + +/* Tests */ + +func TestOutput(t *testing.T) { + buf := new(bytes.Buffer) + if err := ikea.Pack(buf, source); err != nil { + t.Error(err) + return + } + + result := buf.Bytes() + if len(result) != len(testData) { + fmt.Printf("Failing TestWrite, result \"%s\" length (%d) does not match test data length (%d)\n", hex.EncodeToString(result), len(result), len(testData)) + t.FailNow() + } + + // In struct creation we've added 0x4242 padding to we can split out the map. + // We need to do this because golangs map iteration is always random. + // So we'll deal with that one later in this test. + originalParts := strings.Split(hex.EncodeToString(testData), "4242") + resultParts := strings.Split(hex.EncodeToString(result), "4242") + + if originalParts[0] != resultParts[0] { + fmt.Printf("Failing TestWrite, hex output \"%s\" does not match test data slice\n", hex.EncodeToString(result)) + t.FailNow() + } + + // Instead, we treat the map a bit differently, we put it back into a buffer + buf.Reset() + if data, err := hex.DecodeString(resultParts[1]); err != nil { + fmt.Printf("Failing TestWrite, hex output \"%s\" is not a valid hex string: %s\n", resultParts[1], err.Error()) + t.FailNow() + } else { + buf.Write(data) + } + // Unpack it + var test map[string]string + if err := ikea.Unpack(buf, &test); err != nil { + fmt.Printf("Failing TestWrite, could not unpack map: %s\n", err.Error()) + t.FailNow() + } + + // And then compare it using DeepEqual + if !reflect.DeepEqual(source.TestMap, test) { + fmt.Printf("Failing TestWrite, resulting map is not equal\n") + t.FailNow() + } +} + +func TestCompleteRead(t *testing.T) { + buf := new(bytes.Buffer) + buf.Write(testData) + + tst := new(testStruct) + if err := ikea.Unpack(buf, tst); err != nil { + t.Error(err) + return + } + + compare(t, "TestBool", tst.TestBool, source.TestBool) + compare(t, "TestByte", tst.TestByte, source.TestByte) + compare(t, "TestUint8", tst.TestUint8, source.TestUint8) + compare(t, "TestUint16", tst.TestUint16, source.TestUint16) + compare(t, "TestUint32", tst.TestUint32, source.TestUint32) + compare(t, "TestUint64", tst.TestUint64, source.TestUint64) + compare(t, "TestInt8", tst.TestInt8, source.TestInt8) + compare(t, "TestInt16", tst.TestInt16, source.TestInt16) + compare(t, "TestInt32", tst.TestInt32, source.TestInt32) + compare(t, "TestInt64", tst.TestInt64, source.TestInt64) + compare(t, "TestFloat32", tst.TestFloat32, source.TestFloat32) + compare(t, "TestFloat64", tst.TestFloat64, source.TestFloat64) + compare(t, "TestString", tst.TestString, source.TestString) + compare(t, "TestSubStruct", tst.TestSubStruct.A, source.TestSubStruct.A) + compare(t, "TestInterface", tst.TestInterface.A, source.TestInterface.A) + compare(t, "TestFixedPtr", *tst.TestFixedPtr, *source.TestFixedPtr) + compare(t, "TestVariablePtr", *tst.TestVariablePtr, *source.TestVariablePtr) + for i := range source.TestSlice { + compare(t, "TestSlice["+strconv.Itoa(i)+"]", tst.TestSlice[i], source.TestSlice[i]) + } + for i := range source.TestCompression { + compare(t, "TestCompression["+strconv.Itoa(i)+"]", tst.TestCompression[i], source.TestCompression[i]) + } + for k := range source.TestMap { + compare(t, "TestMap["+k+"]", tst.TestMap[k], source.TestMap[k]) + } +} + +func TestLen(t *testing.T) { + if l := ikea.Len(source); l != len(testData) { + fmt.Printf("Failing TestLen, Len reported an incorrect value %d, should be %d", l, len(testData)) + t.FailNow() + } +} + +func compare(t *testing.T, field string, value1, value2 interface{}) { + if value1 != value2 { + t.Errorf("Failing TestCompleteRead, decoded data field '%s' with value '%v' does not match '%v'", field, value1, value2) + } +} + +type testPackerOnly struct { + A uint8 +} + +func (p *testPackerOnly) Pack(w io.Writer) error { + return ikea.Pack(w, &p.A) +} + +type testUnpackerOnly struct { + A uint8 +} + +func (p *testUnpackerOnly) Unpack(r io.Reader) error { + return ikea.Unpack(r, &p.A) +} diff --git a/tests/primitives_test.go b/tests/primitives_test.go new file mode 100644 index 0000000..e4a1b67 --- /dev/null +++ b/tests/primitives_test.go @@ -0,0 +1,105 @@ +package tests + +import ( + "bytes" + "fmt" + "math/rand" + "os" + "reflect" + "testing" + + "github.com/ikkerens/ikeapack" +) + +func TestBool(t *testing.T) { + i := rand.Int()%2 == 1 + typeTest(t, "TestBool", &i, i) +} + +func TestByte(t *testing.T) { + i := byte(rand.Int() & 0xFF) + typeTest(t, "TestByte", &i, i) +} + +func TestUint8(t *testing.T) { + i := uint8(rand.Int() & 0xFF) + typeTest(t, "TestUint8", &i, i) +} + +func TestUint16(t *testing.T) { + i := uint16(rand.Int() & 0xFFFF) + typeTest(t, "TestUint16", &i, i) +} + +func TestUint32(t *testing.T) { + i := rand.Uint32() + typeTest(t, "TestUint32", &i, i) +} + +func TestUint64(t *testing.T) { + i := rand.Uint64() + typeTest(t, "TestUint64", &i, i) +} + +func TestInt8(t *testing.T) { + i := int8(rand.Int() & 0xFF) + typeTest(t, "TestInt8", &i, i) +} + +func TestInt16(t *testing.T) { + i := int16(rand.Int() & 0xFFFF) + typeTest(t, "TestInt16", &i, i) +} + +func TestInt32(t *testing.T) { + i := rand.Int31() + typeTest(t, "TestInt32", &i, i) +} + +func TestInt64(t *testing.T) { + i := rand.Int63() + typeTest(t, "TestInt64", &i, i) +} + +func TestFloat32(t *testing.T) { + i := rand.Float32() + typeTest(t, "TestFloat32", &i, i) +} + +func TestFloat64(t *testing.T) { + i := rand.Float64() + typeTest(t, "TestFloat32", &i, i) +} + +func TestString(t *testing.T) { + const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + b := make([]byte, rand.Intn(30)) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + s := string(b) + + typeTest(t, "TestString", &s, s) +} + +func typeTest(t *testing.T, typ string, value, compare interface{}) { + var b bytes.Buffer + + if err := ikea.Pack(&b, value); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Failing %s, could not write value: %s\n", typ, err.Error()) + t.FailNow() + } + + target := reflect.New(reflect.TypeOf(value).Elem()) + if err := ikea.Unpack(&b, target.Interface()); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Failing %s, could not read value: %s\n", typ, err.Error()) + t.FailNow() + } + + dereference := target.Elem().Interface() + if dereference != compare { + _, _ = fmt.Fprintf(os.Stderr, "Failing %s, %T value %+v does not match original %T %+v\n", typ, dereference, dereference, compare, compare) + t.FailNow() + } +} diff --git a/tests/testdata_test.go b/tests/testdata_test.go new file mode 100644 index 0000000..e404109 --- /dev/null +++ b/tests/testdata_test.go @@ -0,0 +1,112 @@ +package tests + +import ( + "encoding/hex" + "io" + + "github.com/ikkerens/ikeapack" +) + +/* Test types */ + +type testStruct struct { + TestBool bool + TestByte byte + TestUint8 uint8 + TestUint16 uint16 + TestUint32 uint32 + TestUint64 uint64 + TestInt8 int8 + TestInt16 int16 + TestInt32 int32 + TestInt64 int64 + TestFloat32 float32 + TestFloat64 float64 + TestString string + TestSubStruct testSubStruct + TestPackerOnly testPackerOnly + TestUnpackerOnly testUnpackerOnly + TestInterface testInterface + TestFixedPtr *uint8 + TestVariablePtr *string + TestSlice []byte + TestVariableSlice []string + TestCompression []byte `ikea:"compress:9"` + Padding uint16 // Maps randomise iteration order, we can't verify this string, so we split using this + TestMap map[string]string +} + +type testSubStruct struct { + A byte + B []testSubStruct +} + +type testInterface struct { + A int64 +} + +func (t *testInterface) Unpack(r io.Reader) error { + var temp int64 + if err := ikea.Unpack(r, &temp); err != nil { + return err + } + + t.A = temp - 10 + + return nil +} + +func (t *testInterface) Pack(w io.Writer) error { + return ikea.Pack(w, t.A+10) +} + +/* Test data */ +var source = &testStruct{ + TestBool: true, + TestByte: 0x11, + TestUint8: 0x88, + TestUint16: 0x1616, + TestUint32: 0x32323232, + TestUint64: 0x6464646464646464, + TestInt8: 0x12, + TestInt16: 0x1234, + TestInt32: 0x12345678, + TestInt64: 0x1234567812345678, + TestFloat32: 0.12345678, + TestFloat64: 0.12345678901234567890, + TestString: "amazing serialization lib", + TestSubStruct: testSubStruct{A: 0x42, B: []testSubStruct{{A: 0x24}, {A: 0x42}}}, + TestPackerOnly: testPackerOnly{A: 0x24}, + TestUnpackerOnly: testUnpackerOnly{A: 0x42}, + TestInterface: testInterface{A: 0x24}, + TestFixedPtr: makeIntPtr(), + TestVariablePtr: makeStringPtr(), + TestSlice: make([]byte, 100), + TestVariableSlice: []string{"a", "bc", "def"}, + TestCompression: make([]byte, 10000), + Padding: 0x4242, + TestMap: map[string]string{ + "keynr1": "valuenr1", + "anotherkey": "anothervalue", + }, +} +var testData, _ = hex.DecodeString("01118816163232323264646464646464641212341234567812345678123456783dfcd6e93fbf9add3746f65f00000019616d617a696e672073657269616c697a6174696f6e206c69624200000002240000000042000000002442000000000000002e660000000e49276d206120706f696e74657221000000640001081b407dd85802dbeb38c69dc23c1044dee55f51c1b63646ec3016a4e1d380ed2223f6a32f9ffa478aca0e5ab526b15e323367d48173b03f25680f1f9e9304f56f7611451992b78e1d697953fc7cd7153a4d5455565d7095d22ead5731418d1cf2180000000300000001610000000262630000000364656600000024ecc0811000000803c021649147fe5079ecfe939d03000000000000000080021f0000ffff424200000002000000066b65796e72310000000876616c75656e72310000000a616e6f746865726b65790000000c616e6f7468657276616c7565") + +func init() { + for i := range source.TestSlice { + source.TestSlice[i] = byte((i * i * i) % 0xFF) + } + for i := range source.TestCompression { + source.TestCompression[i] = 0x42 + } +} + +func makeIntPtr() *uint8 { + i := uint8(0x66) + return &i +} + +func makeStringPtr() *string { + s := "I'm a pointer!" + return &s +} diff --git a/types.go b/types.go index ebf2244..435b018 100644 --- a/types.go +++ b/types.go @@ -23,7 +23,7 @@ type fixedReadWriter interface { type variableReadWriter interface { readWriter - vLength(reflect.Value) (int, error) + vLength(reflect.Value) int readVariable(io.Reader, reflect.Value) error diff --git a/variablehandlers.go b/variablehandlers.go index c02bb11..b79dc93 100644 --- a/variablehandlers.go +++ b/variablehandlers.go @@ -43,13 +43,11 @@ func handleVariableWriter(w io.Writer, h readWriter, v reflect.Value) error { return nil } -func handleVariableLength(h readWriter, v reflect.Value) (int, error) { +func handleVariableLength(h readWriter, v reflect.Value) int { if h.isFixed() { - hl := h.(fixedReadWriter) - return hl.length(), nil + return h.(fixedReadWriter).length() } // variable - hl := h.(variableReadWriter) - return hl.vLength(v) + return h.(variableReadWriter).vLength(v) }