Skip to content

Commit

Permalink
feat(transcription): extract run concept in own class an run more bench
Browse files Browse the repository at this point in the history
  • Loading branch information
lutangar committed May 2, 2024
1 parent 22e97df commit c5eb336
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 51 deletions.
53 changes: 31 additions & 22 deletions packages/tests/src/transcription/benchmark.spec.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { createLogger } from 'winston'
import short, { UUID } from 'short-uuid'
import { performance, PerformanceObserver } from 'node:perf_hooks'
// import { CpuInfo, CpuUsage } from 'node:os'
import { rm, mkdir } from 'node:fs/promises'
Expand All @@ -8,12 +9,9 @@ import {
transcriberFactory,
TranscriptFile,
TranscriptFileEvaluator,
TranscriptionEngine
TranscriptionEngine, TranscriptionRun
} from '@peertube/peertube-transcription'

const WER_TOLERANCE = 0.01
const CER_TOLERANCE = 0.001

interface TestResult {
uuid: string
WER: number
Expand All @@ -29,10 +27,12 @@ interface TestResult {
// memoryUsages: Record<number, MemoryUsage> // https://nodejs.org/docs/latest-v18.x/api/process.html#processmemoryusage
}

const benchmarkReducer = (benchmark: Record<string, Partial<TestResult>> = {}, engineName: string, testResult: Partial<TestResult>) => ({
type Benchmark = Record<UUID, Partial<TestResult>>

const benchmarkReducer = (benchmark: Benchmark = {}, uuid: string, testResult: Partial<TestResult>) => ({
...benchmark,
[engineName]: {
...benchmark[engineName],
[uuid]: {
...benchmark[uuid],
...testResult
}
})
Expand All @@ -57,6 +57,10 @@ describe('Transcribers benchmark', function () {
'whisper-ctranslate2',
'whisper-timestamped'
]
const models = [
'tiny',
'small'
]

const transcriptDirectory = buildAbsoluteFixturePath('transcription/benchmark/')
const mediaFilePath = buildAbsoluteFixturePath('transcription/videos/communiquer-lors-dune-classe-transplantee.mp4')
Expand All @@ -75,10 +79,9 @@ describe('Transcribers benchmark', function () {
items
.getEntries()
.forEach((entry) => {
const engineName = transcribers.find(transcriberName => entry.name.includes(transcriberName))
const { uuid } = TranscriptionRun.extractFromId(entry.name)

benchmark = benchmarkReducer(benchmark, engineName, {
uuid: entry.name,
benchmark = benchmarkReducer(benchmark, uuid, {
duration: entry.duration
})
})
Expand All @@ -87,23 +90,29 @@ describe('Transcribers benchmark', function () {
})

transcribers.forEach(function (transcriberName) {
it(`Run ${transcriberName} transcriber benchmark without issue`, async function () {
this.timeout(45000)
describe(`Creates a ${transcriberName} transcriber for the benchmark`, function () {
const transcriber = transcriberFactory.createFromEngineName(
transcriberName,
createLogger(),
transcriptDirectory
)
const model = { name: 'tiny' }
const transcriptFile = await transcriber.transcribe(mediaFilePath, model, 'fr', 'txt')
const evaluator = new TranscriptFileEvaluator(referenceTranscriptFile, transcriptFile)
await new Promise(resolve => setTimeout(resolve, 1))

benchmark = benchmarkReducer(benchmark, transcriberName, {
engine: transcriber.engine,
WER: await evaluator.wer(),
CER: await evaluator.cer(),
model: model.name

models.forEach((modelName) => {
it(`Run ${transcriberName} transcriber benchmark with ${modelName} model`, async function () {
this.timeout(1000000)
const model = { name: modelName }
const uuid = short.uuid()
const transcriptFile = await transcriber.transcribe(mediaFilePath, model, 'fr', 'txt', uuid)
const evaluator = new TranscriptFileEvaluator(referenceTranscriptFile, transcriptFile)
await new Promise(resolve => setTimeout(resolve, 1))

benchmark = benchmarkReducer(benchmark, uuid, {
engine: transcriber.engine,
WER: await evaluator.wer(),
CER: await evaluator.cer(),
model: model.name
})
})
})
})
})
Expand Down
13 changes: 10 additions & 3 deletions packages/tests/src/transcription/transcription-run.spec.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
/* eslint-disable @typescript-eslint/no-unused-expressions */
import { expect } from 'chai'
import { TranscriptionRun } from '@peertube/peertube-transcription'
import { UUID } from 'short-uuid'

describe('Transcription run', function () {
const supposedlyValidIds = [
'a44521d0-0fb8-4ade-8002-3385545c3318_openai-whisper_tiny',
'a44521d0-0fb8-4ade-8002-3385545c3318_openai-whisper_openai/tiny'
'a44521d0-0fb8-4ade-8002-3385545c3318_openai-whisper_openai/tiny',
'0f229848-b709-4373-a49c-80dcc0d39e2a_whisper-ctranslate2_tiny'
]

it(`matches the list of supposedly valid ids`, function () {
supposedlyValidIds.forEach((id) => {
expect(id.match(TranscriptionRun.RUN_ID_MASK)).to.be.ok
expect(TranscriptionRun.extractFromId(id)).to.be.ok
})
})

Expand All @@ -31,17 +34,21 @@ describe('Transcription run', function () {
})

it(`extracts information from a run id`, function () {
// Because it's a "Branded primitive"
// https://github.com/microsoft/TypeScript/wiki/FAQ#can-i-make-a-type-alias-nominal
const expectedUuid = 'a44521d0-0fb8-4ade-8002-3385545c3318' as UUID
const runId = TranscriptionRun.createId({
name: 'engine-name',
binary: '/bin/engine-name',
requirements: [],
type: 'binary',
supportedModelFormats: []
}, { name: 'openai/tiny' })
}, { name: 'openai/tiny' }, expectedUuid)

expect(runId.match(TranscriptionRun.RUN_ID_MASK)).to.be.ok

const { engineName, modelName } = TranscriptionRun.extractFromId(runId)
const { uuid, engineName, modelName } = TranscriptionRun.extractFromId(runId)
expect(uuid).to.equals(expectedUuid)
expect(engineName).to.equals('engine-name')
expect(modelName).to.equals('openai/tiny')

Expand Down
13 changes: 9 additions & 4 deletions packages/transcription/src/abstract-transcriber.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { createLogger, Logger } from 'winston'
import short, { UUID } from 'short-uuid'
import { join } from 'node:path'
import { existsSync } from 'node:fs'
import { PerformanceObserver } from 'node:perf_hooks'
import { createLogger, Logger } from 'winston'
import { root } from '@peertube/peertube-node-utils'
import { TranscriptionEngine } from './transcription-engine.js'
import { TranscriptionModel } from './transcription-model.js'
Expand Down Expand Up @@ -29,8 +30,11 @@ export abstract class AbstractTranscriber {
this.performanceObserver = performanceObserver
}

startRun (model: TranscriptionModel) {
this.run = new TranscriptionRun(this.engine, model, this.logger)
createRun (model: TranscriptionModel, uuid = short.uuid()) {
this.run = new TranscriptionRun(this.engine, model, this.logger, uuid)
}

startRun () {
this.run.start()
}

Expand All @@ -55,6 +59,7 @@ export abstract class AbstractTranscriber {
mediaFilePath: string,
model: TranscriptionModel,
language: string,
format: TranscriptFormat
format: TranscriptFormat,
runId: UUID
): Promise<TranscriptFile>
}
29 changes: 16 additions & 13 deletions packages/transcription/src/transcription-run.ts
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
import short from 'short-uuid'
import short, { UUID } from 'short-uuid'
import { createLogger, Logger } from 'winston'
import { TranscriptionModel } from './transcription-model.js'
import { TranscriptionEngine } from './transcription-engine.js'

export class TranscriptionRun {
id: string
uuid: UUID
engine: TranscriptionEngine
model: TranscriptionModel
logger: Logger

static RUN_ID_MASK = /^([a-zA-Z0-9-]+)_([a-zA-Z0-9-]+)_([a-zA-Z0-9-/]+)/gm
static RUN_ID_MASK = /^([a-z0-9-]+)_([a-z0-9-]+)_([a-z0-9-/]+)/i

constructor (engine: TranscriptionEngine, model: TranscriptionModel, logger = createLogger()) {
constructor (engine: TranscriptionEngine, model: TranscriptionModel, logger = createLogger(), uuid?: UUID) {
this.uuid = uuid
this.engine = engine
this.model = model

this.id = TranscriptionRun.createId(engine, model)
this.logger = logger
}

static createId (engine: TranscriptionEngine, model: TranscriptionModel) {
return `${short.uuid()}_${engine.name}_${model.name}`
static createId (engine: TranscriptionEngine, model: TranscriptionModel, uuid = short.uuid()) {
return `${uuid}_${engine.name}_${model.name}`
}

static extractFromId (runId: string) {
const [ , id, engineName, modelName ] = TranscriptionRun.RUN_ID_MASK.exec(runId)
return { id, engineName, modelName }
const [ , uuid, engineName, modelName ] = TranscriptionRun.RUN_ID_MASK.exec(runId)
return { uuid, engineName, modelName }
}

get runId () {
return TranscriptionRun.createId(this.engine, this.model, this.uuid)
}

start () {
Expand All @@ -36,7 +39,7 @@ export class TranscriptionRun {
try {
performance.mark(this.getEndPerformanceMarkName())
performance.measure(
this.id,
this.runId,
this.getStartPerformanceMarkName(),
this.getEndPerformanceMarkName()
)
Expand All @@ -46,10 +49,10 @@ export class TranscriptionRun {
}

getStartPerformanceMarkName () {
return `${this.id}-started`
return `${this.runId}-started`
}

getEndPerformanceMarkName () {
return `${this.id}-ended`
return `${this.runId}-ended`
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { $ } from 'execa'
import short, { UUID } from 'short-uuid'
import { join } from 'path'
import { lstat } from 'node:fs/promises'
import { OpenaiTranscriber } from './openai-transcriber.js'
Expand All @@ -7,13 +8,12 @@ import { TranscriptFile, TranscriptFormat } from '../../transcript/index.js'
import { getFileInfo } from '../../file-utils.js'

export class Ctranslate2Transcriber extends OpenaiTranscriber {
public static readonly MODEL_FILENAME = 'model.bin'

async transcribe (
mediaFilePath: string,
model: TranscriptionModel = { name: 'tiny' },
language: string = 'en',
format: TranscriptFormat = 'vtt'
format: TranscriptFormat = 'vtt',
runId: UUID = short.uuid()
): Promise<TranscriptFile> {
// Shall we run the command with `{ shell: true }` to get the same error as in sh ?
// ex: ENOENT => Command not found
Expand All @@ -25,7 +25,8 @@ export class Ctranslate2Transcriber extends OpenaiTranscriber {
}
const modelArgs = model.path ? [ '--model_directory', model.path ] : [ '--model', model.name ]

this.startRun(model)
this.createRun(model, runId)
this.startRun()
await $$`${this.engine.binary} ${[
mediaFilePath,
...modelArgs,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { join } from 'path'
import { $ } from 'execa'
import short, { UUID } from 'short-uuid'
import { TranscriptionModel } from '../../transcription-model.js'
import { TranscriptFile, TranscriptFormat } from '../../transcript/index.js'
import { AbstractTranscriber } from '../../abstract-transcriber.js'
Expand All @@ -10,14 +11,16 @@ export class OpenaiTranscriber extends AbstractTranscriber {
mediaFilePath: string,
model: TranscriptionModel = { name: 'tiny' },
language: string = 'en',
format: TranscriptFormat = 'vtt'
format: TranscriptFormat = 'vtt',
runId: UUID = short.uuid()
): Promise<TranscriptFile> {
// Shall we run the command with `{ shell: true }` to get the same error as in sh ?
// ex: ENOENT => Command not found
const $$ = $({ verbose: true })
const { baseName } = getFileInfo(mediaFilePath)

this.startRun(model)
this.createRun(model, runId)
this.startRun()
await $$`${this.engine.binary} ${[
mediaFilePath,
'--model',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { $ } from 'execa'
import short, { UUID } from 'short-uuid'
import assert from 'node:assert'
import { join } from 'node:path'
import { existsSync } from 'node:fs'
Expand All @@ -13,12 +14,14 @@ export class WhisperTimestampedTranscriber extends OpenaiTranscriber {
mediaFilePath: string,
model: TranscriptionModel,
language: string,
format: TranscriptFormat = 'vtt'
format: TranscriptFormat = 'vtt',
runId: UUID = short.uuid()
): Promise<TranscriptFile> {
const $$ = $({ verbose: true })
const { baseName, name } = getFileInfo(mediaFilePath)

this.startRun(model)
this.createRun(model, runId)
this.startRun()
await $$`${this.engine.binary} ${[
mediaFilePath,
'--model',
Expand All @@ -32,7 +35,8 @@ export class WhisperTimestampedTranscriber extends OpenaiTranscriber {

const internalTranscriptPath = join(this.transcriptDirectory, `${name}.${format}`)
const transcriptPath = join(this.transcriptDirectory, `${baseName}.${format}`)
// Whisper timestamped is supposed to output file with the video file extension ex: video.mp4.vtt
// Whisper timestamped output files with the video file extension by defaults, ex: video.mp4.vtt
// @see https://github.com/linto-ai/whisper-timestamped/issues/189
assert(existsSync(internalTranscriptPath), `${internalTranscriptPath} file doesn't exist.`)
await rename(internalTranscriptPath, transcriptPath)

Expand Down

0 comments on commit c5eb336

Please sign in to comment.