diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 21b12f2bea55..c841a3e4f631 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -7,6 +7,7 @@ import ( "fmt" "strings" "sync" + "time" "github.com/go-audio/audio" "github.com/gofiber/websocket/v2" @@ -187,7 +188,6 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app log.Error().Msgf("read: %s", err.Error()) break } - log.Printf("recv: %s", msg) // Parse the incoming message var incomingMsg IncomingMessage @@ -199,6 +199,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app switch incomingMsg.Type { case "session.update": + log.Printf("recv: %s", msg) + // Update session configurations var sessionUpdate Session if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil { @@ -258,6 +260,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app session.AudioBufferLock.Unlock() case "input_audio_buffer.commit": + log.Printf("recv: %s", msg) + // Commit the audio buffer to the conversation as a new item item := &Item{ ID: generateItemID(), @@ -290,6 +294,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app }) case "conversation.item.create": + log.Printf("recv: %s", msg) + // Handle creating new conversation items var item Item if err := json.Unmarshal(incomingMsg.Item, &item); err != nil { @@ -315,10 +321,14 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app }) case "conversation.item.delete": + log.Printf("recv: %s", msg) + // Handle deleting conversation items // Implement deletion logic as needed case "response.create": + log.Printf("recv: %s", msg) + // Handle generating a response var responseCreate ResponseCreate if len(incomingMsg.Response) > 0 { @@ -342,6 +352,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app }() case "conversation.item.update": + log.Printf("recv: %s", msg) + // Handle function_call_output from the client var item Item if err := json.Unmarshal(incomingMsg.Item, &item); err != nil { @@ -366,6 +378,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app }) case "response.cancel": + log.Printf("recv: %s", msg) + // Handle cancellation of ongoing responses // Implement cancellation logic as needed @@ -443,12 +457,19 @@ func updateSession(session *Session, update *Session, cl *config.BackendConfigLo return nil } +const ( + minMicVolume = 450 + sendToVADDelay = time.Second + maxWhisperSegmentDuration = time.Second * 25 +) + // Placeholder function to handle VAD (Voice Activity Detection) // https://github.com/snakers4/silero-vad/tree/master/examples/go // XXX: use session.ModelInterface for VAD or hook directly VAD runtime here? func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) { vadContext, cancel := context.WithCancel(context.Background()) + //var startListening time.Time go func() { <-done @@ -466,7 +487,7 @@ func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn, default: // Check if there's audio data to process session.AudioBufferLock.Lock() - if len(session.InputAudioBuffer) > 0 { + if len(session.InputAudioBuffer) > 16000 { adata := sound.BytesToInt16sLE(session.InputAudioBuffer) @@ -475,37 +496,77 @@ func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn, } soundIntBuffer.Data = sound.ConvertInt16ToInt(adata) + /* if len(adata) < 16000 { + log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer)) + session.AudioBufferLock.Unlock() + continue + } */ + + float32Data := soundIntBuffer.AsFloat32Buffer().Data + resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{ - Audio: soundIntBuffer.AsFloat32Buffer().Data, + Audio: float32Data, }) if err != nil { log.Error().Msgf("failed to process audio: %s", err.Error()) - sendError(c, "processing_error", "Failed to process audio", "", "") + sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") session.AudioBufferLock.Unlock() continue } speechStart, speechEnd := float32(0), float32(0) + + /* + volume := sound.CalculateRMS16(adata) + if volume > minMicVolume { + startListening = time.Now() + } + + if time.Since(startListening) < sendToVADDelay && time.Since(startListening) < maxWhisperSegmentDuration { + log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer)) + + session.AudioBufferLock.Unlock() + log.Debug().Msg("speech is ongoing") + + continue + } + */ + + if len(resp.Segments) == 0 { + log.Debug().Msg("VAD detected no speech activity") + log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer)) + + session.InputAudioBuffer = nil + log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer)) + + session.AudioBufferLock.Unlock() + continue + } + + log.Debug().Msgf("VAD detected %d segments", len(resp.Segments)) + log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer)) + + speechStart = resp.Segments[0].Start + log.Debug().Msgf("speech starts at %0.2fs", speechStart) + for _, s := range resp.Segments { - log.Debug().Msgf("speech starts at %0.2fs", s.Start) - speechStart = s.Start if s.End > 0 { log.Debug().Msgf("speech ends at %0.2fs", s.End) speechEnd = s.End - } else { - continue } } - if speechEnd == 0 && speechStart != 0 { + if speechEnd == 0 { + log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer)) + session.AudioBufferLock.Unlock() - log.Debug().Msg("speech is ongoing") + log.Debug().Msg("speech is ongoing, no end found ?") continue } // Handle when input is too long without a voice activity (reset the buffer) if speechStart == 0 && speechEnd == 0 { - log.Debug().Msg("VAD detected no speech activity") + // log.Debug().Msg("VAD detected no speech activity") session.InputAudioBuffer = nil session.AudioBufferLock.Unlock() continue diff --git a/go.mod b/go.mod index 898d72717691..23eb998e2a66 100644 --- a/go.mod +++ b/go.mod @@ -104,6 +104,7 @@ require ( github.com/labstack/gommon v0.4.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e // indirect github.com/pion/datachannel v1.5.8 // indirect github.com/pion/dtls/v2 v2.2.12 // indirect github.com/pion/ice/v2 v2.3.34 // indirect diff --git a/pkg/sound/float32.go b/pkg/sound/float32.go index 8909bb2869cc..f42a04e53abb 100644 --- a/pkg/sound/float32.go +++ b/pkg/sound/float32.go @@ -5,14 +5,6 @@ import ( "math" ) -func BytesToFloat32Array(aBytes []byte) []float32 { - aArr := make([]float32, 3) - for i := 0; i < 3; i++ { - aArr[i] = BytesFloat32(aBytes[i*4:]) - } - return aArr -} - func BytesFloat32(bytes []byte) float32 { bits := binary.LittleEndian.Uint32(bytes) float := math.Float32frombits(bits) diff --git a/pkg/sound/int16.go b/pkg/sound/int16.go index 55e1c2f160ac..237c805ce5b5 100644 --- a/pkg/sound/int16.go +++ b/pkg/sound/int16.go @@ -1,5 +1,7 @@ package sound +import "math" + /* MIT License @@ -8,6 +10,17 @@ Copyright (c) 2024 Xbozon */ +// calculateRMS16 calculates the root mean square of the audio buffer for int16 samples. +func CalculateRMS16(buffer []int16) float64 { + var sumSquares float64 + for _, sample := range buffer { + val := float64(sample) // Convert int16 to float64 for calculation + sumSquares += val * val + } + meanSquares := sumSquares / float64(len(buffer)) + return math.Sqrt(meanSquares) +} + func ResampleInt16(input []int16, inputRate, outputRate int) []int16 { // Calculate the resampling ratio ratio := float64(inputRate) / float64(outputRate)