diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index 0bd13442a..bdc85071d 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -103,7 +103,7 @@ func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder }, needValidate, nil } -func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) { +func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) { for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index 37a154bc2..b3e0adb30 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -55,6 +55,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/server/binding" + "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" @@ -63,6 +64,7 @@ import ( "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/cloudwego/hertz/pkg/protocol/suite" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -854,3 +856,176 @@ func TestCustomValidator(t *testing.T) { }) performRequest(e, "GET", "/validate?a=2") } + +var errTestDeregsitry = fmt.Errorf("test deregsitry error") + +type mockDeregsitryErr struct{} + +var _ registry.Registry = &mockDeregsitryErr{} + +func (e mockDeregsitryErr) Register(*registry.Info) error { + return nil +} + +func (e mockDeregsitryErr) Deregister(*registry.Info) error { + return errTestDeregsitry +} + +func TestEngineShutdown(t *testing.T) { + defaultTransporter = standard.NewTransporter + mockCtxCallback := func(ctx context.Context) {} + // Test case 1: serve not running error + engine := NewEngine(config.NewOptions(nil)) + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) + defer cancel1() + err := engine.Shutdown(ctx1) + assert.DeepEqual(t, errStatusNotRunning, err) + + // Test case 2: serve successfully running and shutdown + engine = NewEngine(config.NewOptions(nil)) + engine.OnShutdown = []CtxCallback{mockCtxCallback} + go func() { + engine.Run() + }() + // wait for engine to start + time.Sleep(100 * time.Millisecond) + + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + defer cancel2() + err = engine.Shutdown(ctx2) + assert.Nil(t, err) + assert.DeepEqual(t, statusClosed, atomic.LoadUint32(&engine.status)) + + // Test case 3: serve successfully running and shutdown with deregistry error + engine = NewEngine(config.NewOptions(nil)) + engine.OnShutdown = []CtxCallback{mockCtxCallback} + engine.options.Registry = &mockDeregsitryErr{} + go func() { + engine.Run() + }() + // wait for engine to start + time.Sleep(100 * time.Millisecond) + + ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second) + defer cancel3() + err = engine.Shutdown(ctx3) + assert.DeepEqual(t, errTestDeregsitry, err) + assert.DeepEqual(t, statusShutdown, atomic.LoadUint32(&engine.status)) +} + +type mockStreamer struct{} + +type mockProtocolServer struct{} + +func (s *mockStreamer) Serve(c context.Context, conn network.StreamConn) error { + return nil +} + +func (s *mockProtocolServer) Serve(c context.Context, conn network.Conn) error { + return nil +} + +type mockStreamConn struct { + network.StreamConn + version string +} + +var _ network.StreamConn = &mockStreamConn{} + +func (m *mockStreamConn) GetVersion() uint32 { + return network.Version1 +} + +func TestEngineServeStream(t *testing.T) { + engine := &Engine{ + options: &config.Options{ + ALPN: true, + TLS: &tls.Config{}, + }, + protocolStreamServers: map[string]protocol.StreamServer{ + suite.HTTP3: &mockStreamer{}, + }, + } + + // Test ALPN path + conn := &mockStreamConn{version: suite.HTTP3} + err := engine.ServeStream(context.Background(), conn) + assert.Nil(t, err) + + // Test default path + engine.options.ALPN = false + conn = &mockStreamConn{} + err = engine.ServeStream(context.Background(), conn) + assert.Nil(t, err) + + // Test unsupported protocol + engine.protocolStreamServers = map[string]protocol.StreamServer{} + conn = &mockStreamConn{} + err = engine.ServeStream(context.Background(), conn) + assert.DeepEqual(t, errs.ErrNotSupportProtocol, err) +} + +func TestEngineServe(t *testing.T) { + engine := NewEngine(config.NewOptions(nil)) + engine.protocolServers[suite.HTTP1] = &mockProtocolServer{} + engine.protocolServers[suite.HTTP2] = &mockProtocolServer{} + + // test H2C path + ctx := context.Background() + conn := mock.NewConn("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + engine.options.H2C = true + err := engine.Serve(ctx, conn) + assert.Nil(t, err) + + // test ALPN path + ctx = context.Background() + conn = mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + engine.options.H2C = false + engine.options.ALPN = true + engine.options.TLS = &tls.Config{} + err = engine.Serve(ctx, conn) + assert.Nil(t, err) + + // test HTTP1 path + engine.options.ALPN = false + err = engine.Serve(ctx, conn) + assert.Nil(t, err) +} + +func TestOndata(t *testing.T) { + ctx := context.Background() + engine := NewEngine(config.NewOptions(nil)) + + // test stream conn + streamConn := &mockStreamConn{version: suite.HTTP3} + engine.protocolStreamServers[suite.HTTP3] = &mockStreamer{} + err := engine.onData(ctx, streamConn) + assert.Nil(t, err) + + // test conn + conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + engine.protocolServers[suite.HTTP1] = &mockProtocolServer{} + err = engine.onData(ctx, conn) + assert.Nil(t, err) +} + +func TestAcquireHijackConn(t *testing.T) { + engine := &Engine{ + NoHijackConnPool: false, + } + // test conn pool + conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + hijackConn := engine.acquireHijackConn(conn) + assert.NotNil(t, hijackConn) + assert.NotNil(t, hijackConn.Conn) + assert.DeepEqual(t, engine, hijackConn.e) + assert.DeepEqual(t, conn, hijackConn.Conn) + + // test no conn pool + engine.NoHijackConnPool = true + hijackConn = engine.acquireHijackConn(conn) + assert.NotNil(t, hijackConn) + assert.NotNil(t, hijackConn.Conn) + assert.DeepEqual(t, engine, hijackConn.e) + assert.DeepEqual(t, conn, hijackConn.Conn) +}