Skip to content

Commit

Permalink
fix(transcription): activate language detection
Browse files Browse the repository at this point in the history
Forbid transcript creation without a language.
Add `languageDetection` flag to an engine and some assertions.

Fix an issue in `whisper-ctranslate2` :
Softcatala/whisper-ctranslate2#93
  • Loading branch information
lutangar committed May 7, 2024
1 parent 67a921f commit 5fffa52
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ describe('Open AI Whisper transcriber', function () {
requirements: [],
type: 'binary',
binary: 'whisper',
supportedModelFormats: [ 'PyTorch' ]
supportedModelFormats: [ 'PyTorch' ],
languageDetection: true
},
createLogger(),
transcriptDirectory
Expand Down Expand Up @@ -76,12 +77,12 @@ You
})

it('May transcribe a media file using a local PyTorch model', async function () {
this.timeout(2 * 1000 * 60)
this.timeout(3 * 1000 * 60)
await transcriber.transcribe({ mediaFilePath: frVideoPath, model: TranscriptionModel.fromPath(buildAbsoluteFixturePath('transcription/models/tiny.pt')), language: 'en' })
})

it('May transcribe a media file in french', async function () {
this.timeout(2 * 1000 * 60)
this.timeout(3 * 1000 * 60)
const transcript = await transcriber.transcribe({ mediaFilePath: frVideoPath, language: 'fr', format: 'txt' })
expect(await transcript.equals(new TranscriptFile({
path: join(transcriptDirectory, 'communiquer-lors-dune-classe-transplantee.txt'),
Expand All @@ -104,8 +105,14 @@ Ensuite, il pourront lire et commenter ce de leurs camarades ou répondre aux co
)
})

it('Guesses the video language if not provided', async function () {
this.timeout(3 * 1000 * 60)
const transcript = await transcriber.transcribe({ mediaFilePath: frVideoPath })
expect(transcript.language).to.equals('fr')
})

it('May transcribe a media file in french with small model', async function () {
this.timeout(5 * 1000 * 60)
this.timeout(6 * 1000 * 60)
const transcript = await transcriber.transcribe({ mediaFilePath: frVideoPath, language: 'fr', format: 'txt', model: new WhisperBuiltinModel('small') })
expect(await transcript.equals(new TranscriptFile({
path: join(transcriptDirectory, 'communiquer-lors-dune-classe-transplantee.txt'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ describe('Linto timestamped Whisper transcriber', function () {
requirements: [],
type: 'binary',
binary: 'whisper_timestamped',
supportedModelFormats: [ 'PyTorch' ]
supportedModelFormats: [ 'PyTorch' ],
languageDetection: true
},
createLogger(),
transcriptDirectory
Expand Down Expand Up @@ -84,6 +85,7 @@ you
})

it('May transcribe a media file using a local PyTorch model file', async function () {
this.timeout(2 * 1000 * 60)
await transcriber.transcribe({ mediaFilePath: frVideoPath, model: TranscriptionModel.fromPath(buildAbsoluteFixturePath('transcription/models/tiny.pt')), language: 'en' })
})

Expand Down Expand Up @@ -124,6 +126,12 @@ Ensuite, il pourront lire et commenter ce de leur camarade, ou répondre au comm
)
})

it('Guesses the video language if not provided', async function () {
this.timeout(2 * 1000 * 60)
const transcript = await transcriber.transcribe({ mediaFilePath: frVideoPath })
expect(transcript.language).to.equals('fr')
})

it('Should produce a text transcript similar to openai-whisper implementation', async function () {
this.timeout(5 * 1000 * 60)
const transcribeArgs: WhisperTranscribeArgs = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ describe('Whisper CTranslate2 transcriber', function () {
requirements: [],
type: 'binary',
binary: 'whisper-ctranslate2',
supportedModelFormats: []
supportedModelFormats: [],
languageDetection: true
},
createLogger(),
transcriptDirectory
Expand All @@ -36,7 +37,7 @@ describe('Whisper CTranslate2 transcriber', function () {

it('Should transcribe a media file and provide a valid path to a transcript file in `vtt` format by default', async function () {
const transcript = await transcriber.transcribe({ mediaFilePath: shortVideoPath, language: 'en' })
expect(await transcript.equals(new TranscriptFile({ path: join(transcriptDirectory, 'video_short.vtt') }))).to.be.true
expect(await transcript.equals(new TranscriptFile({ path: join(transcriptDirectory, 'video_short.vtt'), language: 'en' }))).to.be.true
expect(await readFile(transcript.path, 'utf8')).to.equal(
`WEBVTT
Expand All @@ -51,7 +52,8 @@ You
const transcript = await transcriber.transcribe({ mediaFilePath: shortVideoPath, language: 'en', format: 'srt' })
expect(await transcript.equals(new TranscriptFile({
path: join(transcriptDirectory, 'video_short.srt'),
format: 'srt'
format: 'srt',
language: 'en'
}))).to.be.true

expect(await readFile(transcript.path, 'utf8')).to.equal(
Expand All @@ -67,7 +69,8 @@ You
const transcript = await transcriber.transcribe({ mediaFilePath: shortVideoPath, language: 'en', format: 'txt' })
expect(await transcript.equals(new TranscriptFile({
path: join(transcriptDirectory, 'video_short.txt'),
format: 'txt'
format: 'txt',
language: 'en'
}))).to.be.true

expect(await transcript.read()).to.equal(`You
Expand All @@ -84,7 +87,8 @@ You
})
expect(await transcript.equals(new TranscriptFile({
path: join(transcriptDirectory, 'video_short.txt'),
format: 'txt'
format: 'txt',
language: 'en'
}))).to.be.true

expect(await transcript.read()).to.equal(`You
Expand Down Expand Up @@ -114,6 +118,12 @@ Ensuite, il pourront lire et commenter ce de leur camarade, on répondra au comm
)
})

it('Guesses the video language if not provided', async function () {
this.timeout(2 * 1000 * 60)
const transcript = await transcriber.transcribe({ mediaFilePath: frVideoPath })
expect(transcript.language).to.equals('fr')
})

it('Should produce a text transcript similar to openai-whisper implementation', async function () {
this.timeout(5 * 1000 * 60)
const transcribeArgs: WhisperTranscribeArgs = {
Expand Down
3 changes: 2 additions & 1 deletion packages/transcription/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ import { OpenaiTranscriber } from '@peertube/peertube-transcription'
// create a transcriber powered by OpeanAI Whisper CLI
const transcriber = new OpenaiTranscriber({
name: 'openai-whisper',
binary: 'whisper'
binary: 'whisper',
languageDetection: true,
});

const transcriptFile = await transcriber.transcribe({
Expand Down
11 changes: 4 additions & 7 deletions packages/transcription/src/abstract-transcriber.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { createLogger, Logger } from 'winston'
import short, { SUUID } from 'short-uuid'
import { join } from 'node:path'
import { existsSync } from 'node:fs'
import { PerformanceObserver } from 'node:perf_hooks'
import { root } from '@peertube/peertube-node-utils'
import { TranscriptionEngine } from './transcription-engine.js'
Expand Down Expand Up @@ -51,12 +50,10 @@ export abstract class AbstractTranscriber {
delete this.run
}

detectLanguage () {
return Promise.resolve('')
}

loadModel (model: TranscriptionModel) {
if (existsSync(model.path)) { /* empty */ }
assertLanguageDetectionAvailable (language?: string) {
if (!this.engine.languageDetection && !language) {
throw new Error(`Language detection isn't available in ${this.engine.name}. A language must me provided explicitly.`)
}
}

supports (model: TranscriptionModel) {
Expand Down
13 changes: 0 additions & 13 deletions packages/transcription/src/file-utils.ts

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
export type TranscriptFormat = 'txt' | 'vtt' | 'srt'
export type TranscriptFormat = 'txt' | 'vtt' | 'srt' | 'json'

export type TranscriptFileInterface = { path: string, language?: string, format: TranscriptFormat }
4 changes: 2 additions & 2 deletions packages/transcription/src/transcript/transcript-file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import { TranscriptFileEvaluator } from './transcript-file-evaluator.js'

export class TranscriptFile implements TranscriptFileInterface {
path: string
language: string = 'en'
language: string
format: TranscriptFormat = 'vtt'

constructor ({ path, language = 'en', format = 'vtt' }: { path: string, language?: string, format?: TranscriptFormat }) {
constructor ({ path, language, format = 'vtt' }: { path: string, language: string, format?: TranscriptFormat }) {
statSync(path)

this.path = path
Expand Down
1 change: 1 addition & 0 deletions packages/transcription/src/transcription-engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export class TranscriptionEngine {
license?: string
forgeURL?: string
supportedModelFormats: ModelFormat[]
languageDetection?: true
// There could be a default models.
// There could be a list of default models

Expand Down
19 changes: 6 additions & 13 deletions packages/transcription/src/whisper/engines.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@ export const engines: TranscriptionEngine[] = [
license : 'MIT',
supportedModelFormats: [ 'ONNX' ]
},
// {
// name : 'transformers',
// description : 'High-performance inference of OpenAI\'s Whisper automatic speech recognition model',
// type: 'binary',
// language : 'python',
// requirements : [],
// forgeURL : '',
// license : '',
// supportedModelFormats: [ 'ONNX' ]
// },
{
name: 'openai-whisper',
description: 'High-performance inference of OpenAI\'s Whisper automatic speech recognition model',
Expand All @@ -31,7 +21,8 @@ export const engines: TranscriptionEngine[] = [
binary: 'whisper',
forgeURL: 'https://github.com/openai/whisper',
license: 'MIT',
supportedModelFormats: [ 'PyTorch' ]
supportedModelFormats: [ 'PyTorch' ],
languageDetection: true
},
{
name: 'whisper-ctranslate2',
Expand All @@ -42,7 +33,8 @@ export const engines: TranscriptionEngine[] = [
binary: 'whisper-ctranslate2',
forgeURL: 'https://github.com/openai/whisper',
license: 'MIT',
supportedModelFormats: [ 'CTranslate2' ]
supportedModelFormats: [ 'CTranslate2' ],
languageDetection: true
},
{
name: 'whisper-timestamped',
Expand All @@ -53,6 +45,7 @@ export const engines: TranscriptionEngine[] = [
binary: 'whisper_timestamped',
forgeURL: 'https://github.com/openai/whisper',
license: 'MIT',
supportedModelFormats: [ 'CTranslate2' ]
supportedModelFormats: [ 'CTranslate2' ],
languageDetection: true
}
]
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import { $ } from 'execa'
import short from 'short-uuid'
import { join } from 'path'
import { lstat } from 'node:fs/promises'
import { OpenaiTranscriber, WhisperTranscribeArgs } from './openai-transcriber.js'
import { TranscriptFile } from '../../transcript/index.js'
import { getFileInfo } from '../../file-utils.js'
import { WhisperBuiltinModel } from '../whisper-builtin-model.js'

export class Ctranslate2Transcriber extends OpenaiTranscriber {
Expand All @@ -15,34 +13,35 @@ export class Ctranslate2Transcriber extends OpenaiTranscriber {
format = 'vtt',
runId = short.generate()
}: WhisperTranscribeArgs): Promise<TranscriptFile> {
this.assertLanguageDetectionAvailable(language)

// Shall we run the command with `{ shell: true }` to get the same error as in sh ?
// ex: ENOENT => Command not found
const $$ = $({ verbose: true })
const { baseName } = getFileInfo(mediaFilePath)

if (model.path) {
await lstat(model.path).then(stats => stats.isDirectory())
}

const modelArg = model.path ? [ '--model_directory', model.path ] : [ '--model', model.name ]
const languageArg = language ? [ '--language', language ] : []
const modelArgs = model.path ? [ '--model_directory', model.path ] : [ '--model', model.name ]
const languageArgs = language ? [ '--language', language ] : []

this.createRun(runId)
this.startRun()
await $$`${this.engine.binary} ${[
mediaFilePath,
...modelArg,
...modelArgs,
'--output_format',
format,
'all',
'--output_dir',
this.transcriptDirectory,
...languageArg
...languageArgs
]}`
this.stopRun()

return new TranscriptFile({
language,
path: join(this.transcriptDirectory, `${baseName}.${format}`),
language: language || await this.getDetectedLanguage(mediaFilePath),
path: this.getTranscriptFilePath(mediaFilePath, format),
format
})
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { join } from 'path'
import { $ } from 'execa'
import short from 'short-uuid'
import { TranscriptFile } from '../../transcript/index.js'
import { TranscriptFile, TranscriptFormat } from '../../transcript/index.js'
import { AbstractTranscriber, TranscribeArgs } from '../../abstract-transcriber.js'
import { getFileInfo } from '../../file-utils.js'
import { WhisperBuiltinModel } from '../whisper-builtin-model.js'
import { TranscriptionModel } from '../../transcription-model.js'
import { readFile } from 'node:fs/promises'
import { parse } from 'node:path'

export type WhisperTranscribeArgs = Omit<TranscribeArgs, 'model'> & { model?: TranscriptionModel }

Expand All @@ -17,11 +18,12 @@ export class OpenaiTranscriber extends AbstractTranscriber {
format = 'vtt',
runId = short.generate()
}: WhisperTranscribeArgs): Promise<TranscriptFile> {
this.assertLanguageDetectionAvailable(language)

// Shall we run the command with `{ shell: true }` to get the same error as in sh ?
// ex: ENOENT => Command not found
const $$ = $({ verbose: true })
const { baseName } = getFileInfo(mediaFilePath)
const languageArg = language ? [ '--language', language ] : []
const languageArgs = language ? [ '--language', language ] : []

this.createRun(runId)
this.startRun()
Expand All @@ -30,17 +32,31 @@ export class OpenaiTranscriber extends AbstractTranscriber {
'--model',
model?.path || model.name,
'--output_format',
format,
'all',
'--output_dir',
this.transcriptDirectory,
...languageArg
...languageArgs
]}`
this.stopRun()

return new TranscriptFile({
language,
path: join(this.transcriptDirectory, `${baseName}.${format}`),
language: language || await this.getDetectedLanguage(mediaFilePath),
path: this.getTranscriptFilePath(mediaFilePath, format),
format
})
}

async getDetectedLanguage (mediaFilePath: string) {
const { language } = await this.readJsonTranscriptFile(mediaFilePath)

return language
}

async readJsonTranscriptFile (mediaFilePath: string) {
return JSON.parse(await readFile(this.getTranscriptFilePath(mediaFilePath, 'json'), 'utf8'))
}

getTranscriptFilePath (mediaFilePath: string, format: TranscriptFormat) {
return join(this.transcriptDirectory, `${parse(mediaFilePath).name}.${format}`)
}
}
Loading

0 comments on commit 5fffa52

Please sign in to comment.