diff --git a/cli/src/benchmark_gpt.ts b/cli/src/benchmark_gpt.ts index 4a2b20c70..c0662fe77 100644 --- a/cli/src/benchmark_gpt.ts +++ b/cli/src/benchmark_gpt.ts @@ -3,6 +3,7 @@ import type { Task } from '@epfml/discojs' import { fetchTasks, data, models } from '@epfml/discojs' import { NodeTextLoader, loadModelFromDisk } from '@epfml/discojs-node' import { startServer } from 'server' +import { get_return_value } from './utils.js'; interface CLIArguments{ modelType?: string; // 'gpt-nano', 'gpt-micro', 'gpt-mini', 'gpt2' @@ -80,8 +81,9 @@ async function main(args: Required): Promise { console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`) let epochTime = performance.now() - const logGenerator = model.train(preprocessedDataset, undefined, epoch) - for await (const logs of logGenerator) { + const rounds = model.train(preprocessedDataset, undefined, epoch) + for (const round of rounds) { + const logs = await get_return_value(round) epochTime = (performance.now() - epochTime) const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epoch) console.log(`\t\tTraining time: ${msPerToken.toFixed(2)} ms/token
${logs.peakMemory.toFixed(2)} GB`) diff --git a/cli/src/utils.ts b/cli/src/utils.ts new file mode 100644 index 000000000..da1bb729e --- /dev/null +++ b/cli/src/utils.ts @@ -0,0 +1,9 @@ +export async function get_return_value( + iter: AsyncIterator, +): Promise { + for (;;) { + const v = await iter.next(); + if (!v.done) continue; + return v.value; + } +} diff --git a/discojs/src/default_tasks/wikitext.ts b/discojs/src/default_tasks/wikitext.ts index 19b69a3af..e559a47dd 100644 --- a/discojs/src/default_tasks/wikitext.ts +++ b/discojs/src/default_tasks/wikitext.ts @@ -20,7 +20,7 @@ export const wikitext: TaskProvider = { modelID: 'wikitext-103-raw-model', preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding], scheme: 'federated', - epochs: 5, + epochs: 10, // Unused by wikitext because data already comes split // But if set to 0 then the webapp doesn't display the validation metrics validationSplit: 0.1, diff --git a/discojs/src/models/gpt/index.ts b/discojs/src/models/gpt/index.ts index b9753ff97..32a8bbe13 100644 --- a/discojs/src/models/gpt/index.ts +++ b/discojs/src/models/gpt/index.ts @@ -2,15 +2,17 @@ * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement **/ +import { List, Map } from 'immutable'; import * as tf from '@tensorflow/tfjs' import { PreTrainedTokenizer } from '@xenova/transformers'; import { WeightsContainer } from '../../index.js' import type { Dataset } from '../../dataset/index.js' -import { Model } from '../model.js' +import { BatchLogs, EpochLogs, Model } from "../index.js"; +import type { Prediction, Sample } from '../model.js' + import { GPTForCausalLM } from './model.js' -import type { EpochLogs, Prediction, Sample } from '../model.js' import type { GPTConfig } from './config.js' export type GPTSerialization = { @@ -35,54 +37,108 @@ export class GPT extends Model { * @param epochs the number of passes of the training dataset * @param tracker */ - override async *train( + override *train( trainingData: Dataset, validationData?: Dataset, epochs = 1, - ): AsyncGenerator { - this.model.compile() + ): Generator, void> { + this.model.compile(); + + for (let epoch = 0; epoch < epochs; epoch++) + yield async function* (this: GPT) { + let batchesLogs = List(); + + const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator + for (let batchNumber = 0; true; batchNumber++) { + const iteration = await batches.next(); + if (iteration.done) break; + const batch = iteration.value; + + const batchLogs = { + batch: batchNumber, + ...(await this.#runBatch(batch)), + }; + + yield batchLogs; + batchesLogs = batchesLogs.push(batchLogs); + } + + const validation = + validationData && (await this.#evaluate(validationData)); + + return new EpochLogs(epoch, batchesLogs, validation); + }.bind(this)(); + } + + async #runBatch( + batch: tf.TensorContainer, + ): Promise> { let logs: tf.Logs | undefined; - const trainingArgs: tf.ModelFitDatasetArgs = { - epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop - validationData, - callbacks: { onEpochEnd: (_, cur) => { logs = cur }}, + await this.model.fitDataset(tf.data.array([batch]), { + epochs: 1, + callbacks: { + onEpochEnd: (_, cur) => { + logs = cur; + }, + }, + }); + if (logs === undefined) throw new Error("batch didn't gave any logs"); + + const { loss, acc: accuracy } = logs; + if (loss === undefined || isNaN(loss)) + throw new Error("training loss is undefined or NaN"); + + return { + accuracy, + loss, + memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024, }; - for (let epoch = 0; epoch < epochs; epoch++) { - await this.model.fitDataset(trainingData, trainingArgs); - if (logs === undefined) { - throw new Error("Epoch didn't gave any logs"); - } - const { loss, val_acc, val_loss, peakMemory } = logs; - if (loss === undefined || isNaN(loss)) { - throw new Error("Training loss is undefined or nan"); - } - const structuredLogs: EpochLogs = { - epoch, - peakMemory, - training: { - loss: logs.loss, - accuracy: logs.acc - } - } + } - if (validationData !== undefined) { - if(val_loss === undefined || isNaN(val_loss) || - val_acc === undefined || isNaN(val_acc)) { - throw new Error("Validation accuracy or loss is undefined or nan"); + async #evaluate( + dataset: Dataset, + ): Promise> { + const evaluation = await this.model.evaluateDataset( + dataset.map((t) => { + switch (t) { + case null: + case undefined: + throw new Error("nullish value in dataset"); + default: + return t as Exclude; } - structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss} - } - yield structuredLogs - } + }), + ); + const metricToValue = Map( + List(this.model.metricsNames).zip( + Array.isArray(evaluation) + ? List(await Promise.all(evaluation.map((t) => t.data()))) + : List.of(await evaluation.data()), + ), + ).map((values) => { + if (values.length !== 1) throw new Error("more than one metric value"); + return values[0]; + }); + + const [accuracy, loss] = [ + metricToValue.get("acc"), + metricToValue.get("loss"), + ]; + if (accuracy === undefined || loss === undefined) + throw new Error("some needed metrics are missing"); + + return { accuracy, loss }; } - override predict (input: Sample): Promise { - const ret = this.model.predict(input) + override predict(input: Sample): Promise { + const ret = this.model.predict(input); if (Array.isArray(ret)) { - throw new Error('prediction yield many Tensors but should have only returned one') + throw new Error( + "prediction yield many Tensors but should have only returned one", + ); } - return Promise.resolve(ret) + return Promise.resolve(ret); } async generate(input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise { diff --git a/discojs/src/models/index.ts b/discojs/src/models/index.ts index aefb9b5fa..6e384b352 100644 --- a/discojs/src/models/index.ts +++ b/discojs/src/models/index.ts @@ -1,4 +1,5 @@ -export { EpochLogs, Model } from './model.js' +export { Model } from './model.js' +export { BatchLogs, EpochLogs } from "./logs.js"; export { GPT } from './gpt/index.js' export { GPTConfig } from './gpt/config.js' diff --git a/discojs/src/models/logs.ts b/discojs/src/models/logs.ts new file mode 100644 index 000000000..f993d9928 --- /dev/null +++ b/discojs/src/models/logs.ts @@ -0,0 +1,34 @@ +import { List } from "immutable"; + +export interface BatchLogs { + batch: number; // first batch is zero + accuracy: number; + loss: number; + memoryUsage: number; // GB +} + +export class EpochLogs { + public readonly batches: List; + + constructor( + public readonly epoch: number, // first epoch is zero + batches: Iterable, + public readonly validation?: Record<"accuracy" | "loss", number>, + ) { + this.batches = List(batches); + } + + get training(): Record<"accuracy" | "loss", number> { + return this.batches.reduce( + (acc, batch) => ({ + accuracy: acc.accuracy + batch.accuracy, + loss: acc.loss + batch.loss, + }), + { loss: 0, accuracy: 0 }, + ); + } + + get peakMemory(): number { + return this.batches.map((batch) => batch.memoryUsage).max() ?? 0; + } +} diff --git a/discojs/src/models/model.ts b/discojs/src/models/model.ts index 60b61764c..61b8f2f1c 100644 --- a/discojs/src/models/model.ts +++ b/discojs/src/models/model.ts @@ -3,18 +3,7 @@ import type tf from "@tensorflow/tfjs"; import type { WeightsContainer } from "../index.js"; import type { Dataset } from "../dataset/index.js"; -export interface EpochLogs { - epoch: number; // first epoch is zero - training: { - loss: number, - accuracy?: number - }; - validation?: { - loss: number, - accuracy: number - }; - peakMemory: number; -} +import type { BatchLogs, EpochLogs } from "./logs.js"; // TODO still bound to tfjs export type Prediction = tf.Tensor; @@ -26,7 +15,7 @@ export type Sample = tf.Tensor; * Allow for various implementation of models (various train function, tensor-library, ...) **/ // TODO make it typesafe: same shape of data/input/weights -export abstract class Model implements Disposable{ +export abstract class Model implements Disposable { // TODO don't allow external access but upgrade train to return weights on every epoch /** Return training state */ abstract get weights(): WeightsContainer; @@ -46,13 +35,12 @@ export abstract class Model implements Disposable{ trainingData: Dataset, validationData?: Dataset, epochs?: number, - ): AsyncGenerator; + ): Generator, void>; /** Predict likely values */ // TODO extract in separated TrainedModel? abstract predict(input: Sample): Promise; - /** * This method is automatically called to cleanup the memory occupied by the model * when leaving the definition scope if the instance has been defined with the `using` keyword. diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts index 8646c34cd..06289daf3 100644 --- a/discojs/src/models/tfjs.ts +++ b/discojs/src/models/tfjs.ts @@ -1,9 +1,11 @@ +import { List, Map } from 'immutable' import * as tf from '@tensorflow/tfjs' import { WeightsContainer } from '../index.js' -import { Model } from './index.js' -import type { EpochLogs, Prediction, Sample } from './model.js' +import type { BatchLogs } from './index.js' +import { EpochLogs, Model } from './index.js' +import type { Prediction, Sample } from './model.js' import type { Dataset } from '../dataset/index.js' /** TensorFlow JavaScript model with standard training */ @@ -27,55 +29,96 @@ export class TFJS extends Model { this.model.setWeights(ws.weights) } - override async *train( + override *train( trainingData: Dataset, validationData?: Dataset, epochs = 1, - ): AsyncGenerator { + ): Generator, void> { for (let epoch = 0; epoch < epochs; epoch++) { - let logs: tf.Logs | undefined; - let peakMemory = 0 - await this.model.fitDataset(trainingData, { - epochs: 1, - validationData, - callbacks: { - onBatchEnd: (_) => { - const currentMemory = tf.memory().numBytes / 1024 / 1024 / 1024 // GB - if (currentMemory > peakMemory) { - peakMemory = currentMemory - } - }, - onEpochEnd: (_, cur) => { logs = cur } + yield async function* (this: TFJS) { + let batchesLogs = List(); + + const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator + for (let batchNumber = 0; true; batchNumber++) { + const iteration = await batches.next(); + if (iteration.done) break; + const batch = iteration.value; + + const batchLogs = { + batch: batchNumber, + ...(await this.#runBatch(batch)), + }; + + yield batchLogs; + batchesLogs = batchesLogs.push(batchLogs); + } + + const validation = + validationData && (await this.#evaluate(validationData)); + + return new EpochLogs(epoch, batchesLogs, validation); + }.bind(this)() + } + } + + async #runBatch( + batch: tf.TensorContainer, + ): Promise> { + let logs: tf.Logs | undefined; + await this.model.fitDataset(tf.data.array([batch]), { + epochs: 1, + callbacks: { + onEpochEnd: (_, cur) => { + logs = cur; }, - }); + }, + }); + if (logs === undefined) throw new Error("batch didn't gave any logs"); + + const { loss, acc: accuracy } = logs; + if (loss === undefined || isNaN(loss)) + throw new Error("training loss is undefined or NaN"); + + return { + accuracy, + loss, + memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024, + }; + } - if (logs === undefined) { - throw new Error("Epoch didn't gave any logs"); - } - const { loss, acc, val_acc, val_loss } = logs; - if (loss === undefined || isNaN(loss) || acc === undefined || isNaN(acc)) { - throw new Error("Training loss is undefined or nan"); - } - const structuredLogs: EpochLogs = { - epoch, - peakMemory, - training: { - loss: logs.loss, - accuracy: logs.acc, - } - } - if (validationData !== undefined) { - if(val_loss === undefined || isNaN(val_loss) || - val_acc === undefined || isNaN(val_acc)) { - throw new Error("Invalid validation logs"); + async #evaluate( + dataset: Dataset, + ): Promise> { + const evaluation = await this.model.evaluateDataset( + dataset.map((t) => { + switch (t) { + case null: + case undefined: + throw new Error("nullish value in dataset"); + default: + return t as Exclude; } - structuredLogs.validation = { - accuracy: logs.val_acc, - loss: logs.val_loss - } - } - yield structuredLogs - } + }), + ); + const metricToValue = Map( + List(this.model.metricsNames).zip( + Array.isArray(evaluation) + ? List(await Promise.all(evaluation.map((t) => t.data()))) + : List.of(await evaluation.data()), + ), + ).map((values) => { + if (values.length !== 1) throw new Error("more than one metric value"); + return values[0]; + }); + + const [accuracy, loss] = [ + metricToValue.get("acc"), + metricToValue.get("loss"), + ]; + if (accuracy === undefined || loss === undefined) + throw new Error("some needed metrics are missing"); + + return { accuracy, loss }; } override predict (input: Sample): Promise { diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index d57baa4d3..7c8b9d759 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -2,6 +2,7 @@ import type { data, Logger, Memory, Task, TrainingInformation } from '../index.j import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js' import type { Aggregator } from '../aggregator/index.js' import { MeanAggregator } from '../aggregator/mean.js' +import * as async_iterator from '../utils/async_iterator.js' import type { RoundLogs, Trainer } from './trainer/trainer.js' import { TrainerBuilder } from './trainer/trainer_builder.js' @@ -94,7 +95,9 @@ export class Disco { await this.client.connect(); const trainer = await this.trainer; - for await (const roundLogs of trainer.fitModel(trainData.dataset, validationData.dataset)) { + for await (const round of trainer.fitModel(trainData.dataset, validationData.dataset)) { + const roundLogs = await async_iterator.get_return_value(round) + let msg = `Round: ${roundLogs.round}\n` for (const epochLogs of roundLogs.epochs.values()) { msg += ` Epoch: ${epochLogs.epoch}\n` diff --git a/discojs/src/training/trainer/trainer.ts b/discojs/src/training/trainer/trainer.ts index ed1f5bd5b..824b872f7 100644 --- a/discojs/src/training/trainer/trainer.ts +++ b/discojs/src/training/trainer/trainer.ts @@ -3,7 +3,8 @@ import { List } from "immutable"; import type { Model, Task } from "../../index.js"; -import { EpochLogs } from "../../models/model.js"; +import { BatchLogs, EpochLogs } from "../../models/index.js"; +import * as async_iterator from "../../utils/async_iterator.js"; export interface RoundLogs { round: number; @@ -22,7 +23,10 @@ export abstract class Trainer { readonly #roundDuration: number; readonly #epochs: number; - private training?: AsyncGenerator; + private training?: AsyncGenerator< + AsyncGenerator>, + void + >; constructor( task: Task, @@ -30,6 +34,9 @@ export abstract class Trainer { ) { this.#roundDuration = task.trainingInformation.roundDuration; this.#epochs = task.trainingInformation.epochs; + + if (Number.isInteger(this.#epochs / this.#roundDuration)) + throw new Error(`round duration doesn't divide epochs: ${this.#epochs}/${this.#epochs} == ${this.#epochs/this.#roundDuration}`) } protected abstract onRoundBegin(round: number): Promise; @@ -49,34 +56,54 @@ export abstract class Trainer { async *fitModel( dataset: tf.data.Dataset, valDataset: tf.data.Dataset, - ): AsyncGenerator { - if (this.training !== undefined) { + ): AsyncGenerator< + AsyncGenerator, RoundLogs>, + void + > { + if (this.training !== undefined) throw new Error( "training already running, cancel it before launching a new one", ); + + try { + this.training = this.#runRounds(dataset, valDataset); + yield *this.training + } finally { + this.training = undefined; } + } - await this.onRoundBegin(0); + async *#runRounds( + dataset: tf.data.Dataset, + valDataset: tf.data.Dataset, + ): AsyncGenerator< + AsyncGenerator, RoundLogs>, + void + > { + const totalRound = Math.trunc(this.#epochs / this.#roundDuration); + for (let round = 0; round < totalRound; round++) { + await this.onRoundBegin(round); - this.training = this.model.train(dataset, valDataset, this.#epochs); + const training = this.model.train(dataset, valDataset, this.#epochs); - for await (const logs of this.training) { - // for now, round (sharing on network) == epoch (full pass over local data) - yield { - round: logs.epoch, - epochs: List.of(logs), - }; + yield (async function* () { + let epochs = List(); - if (logs.epoch % this.#roundDuration === 0) { - const round = Math.trunc(logs.epoch / this.#roundDuration); - await this.onRoundEnd(round); - await this.onRoundBegin(round); - } - } + for await (const batch of training) { + const [gen, ret] = async_iterator.split(batch); + + yield (async function* () { + yield* gen; + return await ret; + })(); - const round = Math.trunc(this.#epochs / this.#roundDuration); - await this.onRoundEnd(round); + epochs = epochs.push(await ret); + } - this.training = undefined; + return { round, epochs }; + })(); + + await this.onRoundEnd(round); + } } } diff --git a/discojs/src/utils/async_iterator.ts b/discojs/src/utils/async_iterator.ts new file mode 100644 index 000000000..95dd0c7f2 --- /dev/null +++ b/discojs/src/utils/async_iterator.ts @@ -0,0 +1,55 @@ +export async function get_return_value( + iter: AsyncIterator, +): Promise { + for (;;) { + const v = await iter.next(); + if (!v.done) continue; + return v.value; + } +} + +// `Promise.withResolvers` not widely deployed +function PromiseWithResolvers(): [ + Promise, + (_: T) => void, + (_: unknown) => void, +] { + let resolve: (_: T) => void, reject: (_: unknown) => void; + resolve = reject = () => { + // should not happen as Promise are run on creation + throw new Error("race condition triggered"); + }; + + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + + return [promise, resolve, reject]; +} + +/** + * Split yields from return value + * + * You need to consume the iterator to resolve the returned value + **/ +export function split( + iter: AsyncIterator, +): [AsyncGenerator, Promise] { + const [returnPromise, returnResolve, _] = PromiseWithResolvers(); + + return [ + (async function* () { + while (true) { + const v = await iter.next(); + if (!v.done) { + yield v.value; + continue; + } + + returnResolve(v.value); + } + })(), + returnPromise, + ]; +}