diff --git a/internal/core/plugin_manager/aws_manager/full_duplex_simulator.go b/internal/core/plugin_manager/aws_manager/full_duplex_simulator.go index 6497596..47c0264 100644 --- a/internal/core/plugin_manager/aws_manager/full_duplex_simulator.go +++ b/internal/core/plugin_manager/aws_manager/full_duplex_simulator.go @@ -3,6 +3,7 @@ package aws_manager import ( "context" "errors" + "fmt" "io" "net" "net/http" @@ -27,8 +28,10 @@ type FullDuplexSimulator struct { baseurl *url.URL // single connection max alive time - sending_connection_max_alive_time time.Duration - receiving_connection_max_alive_time time.Duration + sending_connection_max_alive_time time.Duration + receiving_connection_max_alive_time time.Duration + target_sending_connection_max_alive_time time.Duration + target_receiving_connection_max_alive_time time.Duration // how many transactions are alive alive_transactions int32 @@ -40,7 +43,8 @@ type FullDuplexSimulator struct { connection_restarts int32 // sent bytes - sent_bytes int64 + sent_bytes int64 + current_request_sent_bytes int32 // received bytes received_bytes int64 @@ -60,6 +64,10 @@ type FullDuplexSimulator struct { // max retries max_retries int + // max sending single request sending bytes + max_sending_bytes int32 + // max receiving single request receiving bytes + max_receiving_bytes int32 // request id request_id string @@ -70,6 +78,7 @@ type FullDuplexSimulator struct { // is sending connection alive sending_connection_alive int32 sending_routine_lock sync.Mutex + sending_lock sync.Mutex virtual_sending_connection_alive int32 // receiving routine lock @@ -87,22 +96,92 @@ type FullDuplexSimulator struct { client *http.Client } +type FullDuplexSimulatorOption struct { + // MaxRetries, default 10 + MaxRetries int + // SendingConnectionMaxAliveTime, default 60s + SendingConnectionMaxAliveTime time.Duration + // TargetSendingConnectionMaxAliveTime, default 80s + TargetSendingConnectionMaxAliveTime time.Duration + // ReceivingConnectionMaxAliveTime, default 80s + ReceivingConnectionMaxAliveTime time.Duration + // TargetReceivingConnectionMaxAliveTime, default 60s + TargetReceivingConnectionMaxAliveTime time.Duration + // MaxSingleRequestSendingBytes, default 5 * 1024 * 1024 + MaxSingleRequestSendingBytes int32 + // MaxSingleRequestReceivingBytes, default 5 * 1024 * 1024 + MaxSingleRequestReceivingBytes int32 +} + +func (opt *FullDuplexSimulatorOption) defaultOption() error { + if opt.MaxRetries == 0 { + opt.MaxRetries = 10 + } + + if opt.SendingConnectionMaxAliveTime == 0 { + opt.SendingConnectionMaxAliveTime = 60 * time.Second + } + + if opt.ReceivingConnectionMaxAliveTime == 0 { + opt.ReceivingConnectionMaxAliveTime = 80 * time.Second + } + + if opt.TargetSendingConnectionMaxAliveTime == 0 { + opt.TargetSendingConnectionMaxAliveTime = 80 * time.Second + } + + if opt.TargetReceivingConnectionMaxAliveTime == 0 { + opt.TargetReceivingConnectionMaxAliveTime = 60 * time.Second + } + + if opt.MaxSingleRequestSendingBytes == 0 { + opt.MaxSingleRequestSendingBytes = 5 * 1024 * 1024 + } + + if opt.MaxSingleRequestReceivingBytes == 0 { + opt.MaxSingleRequestReceivingBytes = 5 * 1024 * 1024 + } + + // target receiving connection max alive time should be larger than receiving connection max alive time + if opt.TargetReceivingConnectionMaxAliveTime < opt.ReceivingConnectionMaxAliveTime { + return errors.New("target receiving connection max alive time should be larger than receiving connection max alive time") + } + + // sending connection max alive time should be larger than target sending connection max alive time + if opt.SendingConnectionMaxAliveTime < opt.TargetSendingConnectionMaxAliveTime { + return errors.New("sending connection max alive time should be larger than target sending connection max alive time") + } + + return nil +} + func NewFullDuplexSimulator( baseurl string, - sending_connection_max_alive_time time.Duration, - receiving_connection_max_alive_time time.Duration, + opt *FullDuplexSimulatorOption, ) (*FullDuplexSimulator, error) { u, err := url.Parse(baseurl) if err != nil { return nil, err } + if opt == nil { + opt = &FullDuplexSimulatorOption{} + } + + if err := opt.defaultOption(); err != nil { + return nil, err + } + return &FullDuplexSimulator{ - baseurl: u, - sending_connection_max_alive_time: sending_connection_max_alive_time, - receiving_connection_max_alive_time: receiving_connection_max_alive_time, - max_retries: 10, - request_id: strings.RandomString(32), + baseurl: u, + sending_connection_max_alive_time: opt.SendingConnectionMaxAliveTime, + target_sending_connection_max_alive_time: opt.TargetSendingConnectionMaxAliveTime, + receiving_connection_max_alive_time: opt.ReceivingConnectionMaxAliveTime, + target_receiving_connection_max_alive_time: opt.TargetReceivingConnectionMaxAliveTime, + max_sending_bytes: opt.MaxSingleRequestSendingBytes, + max_receiving_bytes: opt.MaxSingleRequestReceivingBytes, + max_retries: opt.MaxRetries, + request_id: strings.RandomString(32), // using keep alive to reduce the connection reset client: &http.Client{ @@ -117,21 +196,52 @@ func NewFullDuplexSimulator( }, nil } -// send data to server +// send data to server, it's thread-safe func (s *FullDuplexSimulator) Send(data []byte, timeout ...time.Duration) error { + s.sending_lock.Lock() + defer s.sending_lock.Unlock() + + // split data into max 1024 bytes + for len(data) > 0 { + chunk := data + if len(chunk) > 1024 { + chunk = chunk[:1024] + } + + data = data[len(chunk):] + if err := s.send(chunk, timeout...); err != nil { + return err + } + } + + return nil +} + +func (s *FullDuplexSimulator) send(data []byte, timeout ...time.Duration) error { + started := time.Now() + timeout_duration := time.Second * 10 if len(timeout) > 0 { timeout_duration = timeout[0] } - started := time.Now() - for time.Since(started) < timeout_duration { if atomic.LoadInt32(&s.sending_connection_alive) != 1 { time.Sleep(time.Millisecond * 50) continue } + if atomic.AddInt32(&s.current_request_sent_bytes, int32(len(data))) > s.max_sending_bytes { + // reached max sending bytes, close current connection, and start a new one + s.sending_pipe_lock.Lock() + if s.sending_pipeline != nil { + s.sending_pipeline.Close() + } + s.sending_pipe_lock.Unlock() + atomic.StoreInt32(&s.current_request_sent_bytes, 0) + continue + } + s.sending_pipe_lock.Lock() writer := s.sending_pipeline if writer == nil { @@ -146,13 +256,18 @@ func (s *FullDuplexSimulator) Send(data []byte, timeout ...time.Duration) error continue } else { atomic.AddInt64(&s.sent_bytes, int64(n)) + atomic.AddInt32(&s.current_request_sent_bytes, int32(n)) } s.sending_pipe_lock.Unlock() - return nil + break } - return errors.New("send data timeout") + if time.Since(started) > timeout_duration { + return errors.New("send data timeout") + } + + return nil } func (s *FullDuplexSimulator) On(f func(data []byte)) { @@ -224,6 +339,7 @@ func (s *FullDuplexSimulator) startSendingConnection(routine_id string) error { req.Header.Set("Content-Type", "octet-stream") req.Header.Set("Connection", "keep-alive") req.Header.Set("x-dify-plugin-request-id", s.request_id) + req.Header.Set("x-dify-plugin-max-alive-time", fmt.Sprintf("%d", s.target_receiving_connection_max_alive_time.Milliseconds())) routine.Submit(func() { s.sendingConnectionRoutine(req, routine_id) @@ -379,6 +495,7 @@ func (s *FullDuplexSimulator) receivingConnectionRoutine(routine_id string) { req.Header.Set("Content-Type", "octet-stream") req.Header.Set("Connection", "keep-alive") req.Header.Set("x-dify-plugin-request-id", s.request_id) + req.Header.Set("x-dify-plugin-max-alive-time", fmt.Sprintf("%d", s.target_sending_connection_max_alive_time.Milliseconds())) ctx, cancel := context.WithCancel(context.Background()) req = req.Clone(ctx) diff --git a/internal/core/plugin_manager/aws_manager/full_duplex_simulator_test.go b/internal/core/plugin_manager/aws_manager/full_duplex_simulator_test.go index e2d0916..eb322ca 100644 --- a/internal/core/plugin_manager/aws_manager/full_duplex_simulator_test.go +++ b/internal/core/plugin_manager/aws_manager/full_duplex_simulator_test.go @@ -45,7 +45,7 @@ func (s *S) Stop() { s.srv.Close() } -func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error) { +func server() (*S, error) { port, err := network.GetRandomPort() if err != nil { return nil, err @@ -75,6 +75,7 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error) // fmt.Println("new send request") id := c.Request.Header.Get("x-dify-plugin-request-id") + max_alive_time := c.Request.Header.Get("x-dify-plugin-max-alive-time") s.current_send_request_id = id var ch chan []byte @@ -88,7 +89,12 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error) } s.data_mu.Unlock() - time.AfterFunc(send_timeout, func() { + timeout, err := strconv.ParseInt(max_alive_time, 10, 64) + if err != nil { + timeout = 60 + } + + time.AfterFunc(time.Millisecond*time.Duration(timeout), func() { c.Request.Body.Close() }) @@ -118,6 +124,7 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error) // fmt.Println("new recv request") id := ctx.Request.Header.Get("x-dify-plugin-request-id") + max_alive_time := ctx.Request.Header.Get("x-dify-plugin-max-alive-time") s.current_recv_request_id = id var ch chan []byte @@ -137,7 +144,12 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error) ctx.Writer.Write([]byte("pong\n")) ctx.Writer.Flush() - timer := time.NewTimer(recv_timeout) + timeout, err := strconv.ParseInt(max_alive_time, 10, 64) + if err != nil { + timeout = 60 + } + + timer := time.NewTimer(time.Millisecond * time.Duration(timeout)) for { select { @@ -167,7 +179,7 @@ func TestFullDuplexSimulator_SingleSendAndReceive(t *testing.T) { log.SetShowLog(false) defer log.SetShowLog(true) - srv, err := server(time.Second*100, time.Second*100) + srv, err := server() if err != nil { t.Fatal(err) } @@ -175,7 +187,16 @@ func TestFullDuplexSimulator_SingleSendAndReceive(t *testing.T) { time.Sleep(time.Second) - simulator, err := NewFullDuplexSimulator(srv.url, time.Second*100, time.Second*100) + simulator, err := NewFullDuplexSimulator( + srv.url, &FullDuplexSimulatorOption{ + SendingConnectionMaxAliveTime: time.Second * 100, + ReceivingConnectionMaxAliveTime: time.Second * 100, + TargetSendingConnectionMaxAliveTime: time.Second * 99, + TargetReceivingConnectionMaxAliveTime: time.Second * 101, + MaxSingleRequestSendingBytes: 1024 * 1024, + MaxSingleRequestReceivingBytes: 1024 * 1024, + }, + ) if err != nil { t.Fatal(err) } @@ -222,7 +243,7 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) { go func() { defer wg.Done() - srv, err := server(time.Millisecond*700, time.Second*10) + srv, err := server() if err != nil { t.Fatal(err) } @@ -230,7 +251,16 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) { time.Sleep(time.Second) - simulator, err := NewFullDuplexSimulator(srv.url, time.Millisecond*700, time.Second*10) + simulator, err := NewFullDuplexSimulator( + srv.url, &FullDuplexSimulatorOption{ + SendingConnectionMaxAliveTime: time.Millisecond * 700, + TargetSendingConnectionMaxAliveTime: time.Millisecond * 700, + ReceivingConnectionMaxAliveTime: time.Millisecond * 10000, + TargetReceivingConnectionMaxAliveTime: time.Millisecond * 10000, + MaxSingleRequestSendingBytes: 1024 * 1024, + MaxSingleRequestReceivingBytes: 1024 * 1024, + }, + ) if err != nil { t.Fatal(err) } @@ -266,12 +296,15 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) { if l != 3000*5 { sent, received, restarts := simulator.GetStats() t.Errorf(fmt.Sprintf("expected: %d, actual: %d, sent: %d, received: %d, restarts: %d", 3000*5, l, sent, received, restarts)) + server_recv_count := srv.recv_count + server_send_count := srv.send_count + t.Errorf(fmt.Sprintf("server recv count: %d, server send count: %d", server_recv_count, server_send_count)) // to find which one is missing - for i := 0; i < 3000; i++ { - if !strings.Contains(recved.String(), fmt.Sprintf("%05d", i)) { - t.Errorf(fmt.Sprintf("missing: %d", i)) - } - } + // for i := 0; i < 3000; i++ { + // if !strings.Contains(recved.String(), fmt.Sprintf("%05d", i)) { + // t.Errorf(fmt.Sprintf("missing: %d", i)) + // } + // } } }() } @@ -295,7 +328,7 @@ func TestFullDuplexSimulator_MultipleTransactions(t *testing.T) { go func() { defer w.Done() - srv, err := server(time.Millisecond*700, time.Second*10) + srv, err := server() if err != nil { t.Fatal(err) } @@ -303,7 +336,16 @@ func TestFullDuplexSimulator_MultipleTransactions(t *testing.T) { time.Sleep(time.Second) - simulator, err := NewFullDuplexSimulator(srv.url, time.Millisecond*700, time.Second*10) + simulator, err := NewFullDuplexSimulator( + srv.url, &FullDuplexSimulatorOption{ + SendingConnectionMaxAliveTime: time.Millisecond * 700, + TargetSendingConnectionMaxAliveTime: time.Millisecond * 700, + ReceivingConnectionMaxAliveTime: time.Millisecond * 1000, + TargetReceivingConnectionMaxAliveTime: time.Millisecond * 1000, + MaxSingleRequestSendingBytes: 1024 * 1024, + MaxSingleRequestReceivingBytes: 1024 * 1024, + }, + ) if err != nil { t.Fatal(err) } @@ -402,3 +444,55 @@ func TestFullDuplexSimulator_MultipleTransactions(t *testing.T) { w.Wait() } + +func TestFullDuplexSimulator_SendLargeData(t *testing.T) { + log.SetShowLog(false) + defer log.SetShowLog(true) + + srv, err := server() + if err != nil { + t.Fatal(err) + } + defer srv.Stop() + + time.Sleep(time.Second) + + l := 0 + + simulator, err := NewFullDuplexSimulator( + srv.url, &FullDuplexSimulatorOption{ + SendingConnectionMaxAliveTime: time.Millisecond * 700, + TargetSendingConnectionMaxAliveTime: time.Millisecond * 700, + ReceivingConnectionMaxAliveTime: time.Millisecond * 1000, + TargetReceivingConnectionMaxAliveTime: time.Millisecond * 1000, + MaxSingleRequestSendingBytes: 5 * 1024 * 1024, + MaxSingleRequestReceivingBytes: 5 * 1024 * 1024, + }, + ) + + if err != nil { + t.Fatal(err) + } + + simulator.On(func(data []byte) { + l += len(data) + }) + + done, err := simulator.StartTransaction() + if err != nil { + t.Fatal(err) + } + defer done() + + for i := 0; i < 300; i++ { // 300MB, this process should be done in 20 seconds + if err := simulator.Send([]byte(strings.Repeat("a", 1024*1024))); err != nil { + t.Fatal(err) + } + } + + time.Sleep(time.Second * 1) + + if l != 300*1024*1024 { // 300MB + t.Fatal(fmt.Sprintf("expected: %d, actual: %d", 300*1024*1024, l)) + } +}