Skip to content

Commit

Permalink
feat: implement maxEdgeSilence for ssml
Browse files Browse the repository at this point in the history
  • Loading branch information
airenas committed Oct 31, 2024
1 parent a1f37e7 commit 2fab707
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 23 deletions.
51 changes: 29 additions & 22 deletions internal/pkg/processor/joinAudio.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func join(parts []*synthesizer.TTSDataPart, suffix []byte, maxEdgeSilenceMilis i
endDurHops := part.DefaultSilence / 2
if maxEdgeSilenceMilis > -1 {
endDurHops = min(endDurHops, toHops(maxEdgeSilenceMilis, part.Step, res.sampleRate()))
}
}
_, endSil, _ = calcPauseWithEnds(0, endSil, endDurHops)
}
} else if i < len(parts)-1 {
Expand Down Expand Up @@ -160,14 +160,6 @@ func join(parts []*synthesizer.TTSDataPart, suffix []byte, maxEdgeSilenceMilis i
return bufRes.String(), float64(res.size) / float64(bitsRate), res.sampleRate(), nil
}

func toHops(maxEdgeSilenceMilis int64, step int, sampleRate uint32) int {
if step == 0 {
return 0
}
res := int(float64(maxEdgeSilenceMilis) * float64(sampleRate) / 1000 / float64(step))
return res
}

func calculateDurations(aLen int, samplesPerSec uint32) time.Duration {
if samplesPerSec == 0 {
return 0
Expand Down Expand Up @@ -269,7 +261,7 @@ func (p *joinSSMLAudio) Process(data *synthesizer.TTSData) error {
}
}
}
data.Audio, data.AudioLenSeconds, data.SampleRate, err = joinSSML(data, suffix)
data.Audio, data.AudioLenSeconds, data.SampleRate, err = joinSSML(data, suffix, data.Input.MaxEdgeSilenceMillis)
if err != nil {
return errors.Wrap(err, "can't join audio")
}
Expand All @@ -289,17 +281,19 @@ type nextWriteData struct {
durationAdd time.Duration // time to add to the part
}

func joinSSML(data *synthesizer.TTSData, suffix []byte) (string, float64 /*sampleRate*/, uint32, error) {
func joinSSML(data *synthesizer.TTSData, suffix []byte, maxEdgeSilenceMillis int64) (string, float64 /*len*/, uint32 /*sampleRate*/, error) {
res := &wavWriter{}
wd := &nextWriteData{}
wd.pause = time.Duration(0)
writeF := func(part *synthesizer.TTSDataPart) error {
first := true
writeF := func(part *synthesizer.TTSDataPart, last bool) error {
step, defaultSil, pause := 0, 0, time.Duration(0)
endSil, startSil := 0, 0
var decoded []byte
var err error
if part != nil {
decoded, err = base64.StdEncoding.DecodeString(part.Audio)
res.init(decoded)
if err != nil {
return err
}
Expand All @@ -310,19 +304,24 @@ func joinSSML(data *synthesizer.TTSData, suffix []byte) (string, float64 /*sampl
res.header = wav.TakeHeader(decoded)
}
startSil = getStartSilSize(part.TranscribedSymbols, part.Durations)
// endSil = getEndSilSize(part.TranscribedSymbols, part.Durations)
step = part.Step
defaultSil = part.DefaultSilence
if first && maxEdgeSilenceMillis > -1 {
defaultSil = min(defaultSil, toHops(maxEdgeSilenceMillis, part.Step, res.sampleRate()))
}
first = false
}
if wd.part != nil {
endSil = getEndSilSize(wd.part.TranscribedSymbols, wd.part.Durations)
// startSil = getStartSilSize(wd.part.TranscribedSymbols, wd.part.Durations)
step = wd.part.Step
defaultSil = wd.part.DefaultSilence
if last && maxEdgeSilenceMillis > -1 {
defaultSil = min(defaultSil, toHops(maxEdgeSilenceMillis, wd.part.Step, res.sampleRate()))
}
}
pause = 0
if wd.isPause {
startSil, endSil, pause = fixPause(startSil, endSil, wd.pause, step)
startSil, endSil, pause = fixPause(startSil, endSil, wd.pause, step, res.sampleRate())
} else {
startSil, endSil, _ = calcPauseWithEnds(startSil, endSil, defaultSil)
}
Expand Down Expand Up @@ -366,14 +365,14 @@ func joinSSML(data *synthesizer.TTSData, suffix []byte) (string, float64 /*sampl
wd.isPause = true
case synthesizer.SSMLText:
for _, part := range dp.Parts {
err := writeF(part)
err := writeF(part, false /*last*/)
if err != nil {
return "", 0, 0, err
}
}
}
}
if err := writeF(nil); err != nil {
if err := writeF(nil, suffix == nil /*last*/); err != nil {
return "", 0, 0, err
}
if res.size == 0 {
Expand All @@ -400,7 +399,7 @@ func joinSSML(data *synthesizer.TTSData, suffix []byte) (string, float64 /*sampl
if bitsRate == 0 {
return "", 0, 0, errors.New("can't extract bits rate from header")
}
return bufRes.String(), float64(res.size) / float64(bitsRate), wav.GetSampleRate(res.header), nil
return bufRes.String(), float64(res.size) / float64(bitsRate), res.sampleRate(), nil
}

func calcPauseWithEnds(s1, s2, pause int) (int, int, int) {
Expand All @@ -415,21 +414,29 @@ func calcPauseWithEnds(s1, s2, pause int) (int, int, int) {
return s1 - max(pause-r, 0), s2 - r, 0
}

func fixPause(s1, s2 int, pause time.Duration, step int) (int, int, time.Duration) {
func fixPause(s1, s2 int, pause time.Duration, step int, sampleRate uint32) (int, int, time.Duration) {
if step == 0 {
return s1, s2, pause
}
millisInHops := float64(step) / float64(22.050)
pauseHops := int(math.Round(float64(pause.Milliseconds()) / millisInHops))
millisInHops := 1000 * float64(step) / float64(sampleRate)
pauseHops := toHops(pause.Milliseconds(), step, sampleRate)
r1, r2, rp := calcPauseWithEnds(s1, s2, pauseHops)
return r1, r2, time.Millisecond * time.Duration(int(float64(rp)*millisInHops))
}

func toHops(millis int64, step int, sampleRate uint32) int {
if step == 0 {
return 0
}
res := int(math.Round(float64(millis) * float64(sampleRate) / 1000 / float64(step)))
return res
}

func appendPause(res *wavWriter, pause time.Duration) error {
if res.header == nil {
return errors.New("no wav data before pause")
}
c, err := writePause(&res.buf, wav.GetSampleRate(res.header), wav.GetBitsPerSample(res.header), pause)
c, err := writePause(&res.buf, res.sampleRate(), res.bitsPerSample(), pause)
if err != nil {
return err
}
Expand Down
29 changes: 28 additions & 1 deletion internal/pkg/processor/joinAudio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func Test_fixPause(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, got2 := fixPause(tt.args.s1, tt.args.s2, tt.args.pause, tt.args.step)
got, got1, got2 := fixPause(tt.args.s1, tt.args.s2, tt.args.pause, tt.args.step, 22050)
if got != tt.want {
t.Errorf("fixPause() got = %v, want %v", got, tt.want)
}
Expand Down Expand Up @@ -466,3 +466,30 @@ func Test_isSil(t *testing.T) {
})
}
}

func Test_toHops(t *testing.T) {
type args struct {
millis int64
step int
sampleRate uint32
}
tests := []struct {
name string
args args
want int
}{
{name: "1s", args: args{millis: 1000, step: 256, sampleRate: 22050}, want: 86},
{name: "0.5s", args: args{millis: 500, step: 256, sampleRate: 22050}, want: 43},
{name: "0s", args: args{millis: 0, step: 256, sampleRate: 22050}, want: 0},
{name: "10ms", args: args{millis: 10, step: 256, sampleRate: 22050}, want: 1},
{name: "40ms round", args: args{millis: 40, step: 256, sampleRate: 22050}, want: 3},
{name: "30ms round", args: args{millis: 30, step: 256, sampleRate: 22050}, want: 3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := toHops(tt.args.millis, tt.args.step, tt.args.sampleRate); got != tt.want {
t.Errorf("toHops() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 2fab707

Please sign in to comment.