Skip to content

Commit

Permalink
more robust approach
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Jan 14, 2025
1 parent 9a09820 commit f272605
Showing 1 changed file with 40 additions and 22 deletions.
62 changes: 40 additions & 22 deletions core/http/endpoints/openai/realtime.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package openai

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"strings"
"sync"
"time"

"github.com/go-audio/wav"

"github.com/go-audio/audio"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
Expand Down Expand Up @@ -488,21 +492,8 @@ func updateSession(session *Session, update *Session, cl *config.BackendConfigLo
}

const (
minMicVolume = 450
sendToVADDelay = time.Second
)

type VADState int

const (
StateSilence VADState = iota
StateSpeaking
)

const (
// tune these thresholds to taste
SpeechFramesThreshold = 3 // must see X consecutive speech results to confirm "start"
SilenceFramesThreshold = 5 // must see X consecutive silence results to confirm "end"
sendToVADDelay = 2 * time.Second
silenceThreshold = 2 * time.Second
)

// handleVAD is a goroutine that listens for audio data from the client,
Expand Down Expand Up @@ -534,14 +525,18 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
copy(allAudio, session.InputAudioBuffer)
session.AudioBufferLock.Unlock()

// 2) If there's no audio at all, just continue
if len(allAudio) == 0 {
// 2) If there's no audio at all, or just too small samples, just continue
if len(allAudio) == 0 || len(allAudio) < 32000 {
continue
}

// 3) Run VAD on the entire audio so far
segments, err := runVAD(vadContext, session, allAudio)
if err != nil {
if err.Error() == "unexpected speech end" {
log.Debug().Msg("VAD cancelled")
continue
}
log.Error().Msgf("failed to process audio: %s", err.Error())
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
// handle or log error, continue
Expand All @@ -550,7 +545,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio

segCount := len(segments)

if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > silenceThreshold {
// no speech detected, and we haven't seen a new segment in > 1s
// clean up input
session.AudioBufferLock.Lock()
Expand All @@ -569,8 +564,11 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
}

// 5) If speaking, but we haven't seen a new segment in > 1s => finalize
if speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
if speaking && time.Since(timeOfLastNewSeg) > sendToVADDelay {
log.Debug().Msgf("Detected end of speech segment")
session.AudioBufferLock.Lock()
session.InputAudioBuffer = nil
session.AudioBufferLock.Unlock()
// user has presumably stopped talking
commitUtterance(allAudio, cfg, evaluator, session, conv, c)
// reset state
Expand Down Expand Up @@ -608,18 +606,38 @@ func commitUtterance(utt []byte, cfg *config.BackendConfig, evaluator *templates
Item: item,
})

// Optionally trigger the response generation
// save chunk to disk
f, err := os.CreateTemp("", "audio-*.wav")
if err != nil {
log.Error().Msgf("failed to create temp file: %s", err.Error())
return
}
defer f.Close()
//defer os.Remove(f.Name())
log.Debug().Msgf("Writing to %s\n", f.Name())

f.Write(utt)

Check warning

Code scanning / gosec

Errors unhandled. Warning

Errors unhandled.
f.Sync()

Check warning

Code scanning / gosec

Errors unhandled. Warning

Errors unhandled.

// trigger the response generation
generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage)
}

// runVAD is a helper that calls your model's VAD method, returning
// runVAD is a helper that calls the model's VAD method, returning
// true if it detects speech, false if it detects silence
func runVAD(ctx context.Context, session *Session, chunk []byte) ([]*proto.VADSegment, error) {

adata := sound.BytesToInt16sLE(chunk)

// Resample from 24kHz to 16kHz
// adata = sound.ResampleInt16(adata, 24000, 16000)
adata = sound.ResampleInt16(adata, 24000, 16000)

dec := wav.NewDecoder(bytes.NewReader(chunk))
dur, err := dec.Duration()
if err != nil {
fmt.Printf("failed to get duration: %s\n", err)
}
fmt.Printf("duration: %s\n", dur)

soundIntBuffer := &audio.IntBuffer{
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
Expand Down

0 comments on commit f272605

Please sign in to comment.