diff --git a/.github/workflows/lint-test-build.yml b/.github/workflows/lint-test-build.yml index 479df3a6a..30371dc10 100644 --- a/.github/workflows/lint-test-build.yml +++ b/.github/workflows/lint-test-build.yml @@ -335,7 +335,7 @@ jobs: cache: npm - run: npm ci - run: npm --workspace={discojs,discojs-node,server} run build - - run: npm --workspace=cli start -- -t cifar10 -u 3 -e 1 + - run: npm --workspace=cli start -- -t cifar10 -u 3 -e 1 -r 1 test-docs-examples: needs: [build-lib, build-lib-node, build-server, download-datasets] diff --git a/cli/src/benchmark_gpt.ts b/cli/src/benchmark_gpt.ts index 4a2b20c70..f748ddbf1 100644 --- a/cli/src/benchmark_gpt.ts +++ b/cli/src/benchmark_gpt.ts @@ -1,6 +1,6 @@ import { parse } from 'ts-command-line-args'; import type { Task } from '@epfml/discojs' -import { fetchTasks, data, models } from '@epfml/discojs' +import { fetchTasks, data, models, async_iterator } from '@epfml/discojs' import { NodeTextLoader, loadModelFromDisk } from '@epfml/discojs-node' import { startServer } from 'server' @@ -57,7 +57,7 @@ async function main(args: Required): Promise { */ if (!benchmarkInference) { // Benchmark parameters - const epoch = 1 + const epochsCount = 1 const iterationsPerEpoch = 10 const config: models.GPTConfig = { @@ -80,10 +80,10 @@ async function main(args: Required): Promise { console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`) let epochTime = performance.now() - const logGenerator = model.train(preprocessedDataset, undefined, epoch) - for await (const logs of logGenerator) { + for (let epochsCounter = 0; epochsCounter < epochsCount; epochsCounter++) { + const [_, logs] = await async_iterator.gather(model.train(preprocessedDataset)) epochTime = (performance.now() - epochTime) - const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epoch) + const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epochsCounter) console.log(`\t\tTraining time: ${msPerToken.toFixed(2)} ms/token
${logs.peakMemory.toFixed(2)} GB`) } diff --git a/cli/src/cli.ts b/cli/src/cli.ts index e49f2f5a4..128605fc3 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -2,7 +2,7 @@ import { List, Range } from 'immutable' import fs from 'node:fs/promises' import type { data, RoundLogs, Task } from '@epfml/discojs' -import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs' +import { Disco, aggregator as aggregators, async_iterator, client as clients } from '@epfml/discojs' import { startServer } from 'server' import { getTaskData } from './data.js' @@ -23,7 +23,12 @@ async function runUser( const disco = new Disco(task, { scheme: "federated", client }); let logs = List(); - for await (const round of disco.fit(data)) logs = logs.push(round); + for await (const round of disco.fit(data)) { + const [roundGen, roundLogs] = async_iterator.split(round) + for await (const epoch of roundGen) + for await (const _ of epoch); + logs = logs.push(await roundLogs); + } await disco.close(); return logs; diff --git a/discojs/src/default_tasks/wikitext.ts b/discojs/src/default_tasks/wikitext.ts index 85291f473..244a6baba 100644 --- a/discojs/src/default_tasks/wikitext.ts +++ b/discojs/src/default_tasks/wikitext.ts @@ -22,7 +22,7 @@ export const wikitext: TaskProvider = { modelID: 'llm-raw-model', preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding], scheme: 'federated', - epochs: 5, + epochs: 6, // Unused by wikitext because data already comes split // But if set to 0 then the webapp doesn't display the validation metrics validationSplit: 0.1, diff --git a/discojs/src/index.ts b/discojs/src/index.ts index 278471c56..9d44c5798 100644 --- a/discojs/src/index.ts +++ b/discojs/src/index.ts @@ -13,10 +13,11 @@ export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemo export { Disco, RoundLogs } from './training/index.js' export { Validator } from './validation/index.js' -export { Model, EpochLogs } from './models/index.js' +export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js' export * as models from './models/index.js' export * from './task/index.js' export * as defaultTasks from './default_tasks/index.js' export * from './types.js' +export * as async_iterator from "./utils/async_iterator.js" diff --git a/discojs/src/models/gpt/evaluate.ts b/discojs/src/models/gpt/evaluate.ts index 1fedd420a..0165c2018 100644 --- a/discojs/src/models/gpt/evaluate.ts +++ b/discojs/src/models/gpt/evaluate.ts @@ -9,7 +9,7 @@ export default async function evaluate ( model: tf.LayersModel, dataset: tf.data.Dataset, maxEvalBatches: number -): Promise> { +): Promise> { let datasetSize = 0 let totalLoss = 0 const acc: [number, number] = [0, 0] @@ -53,7 +53,6 @@ export default async function evaluate ( return { val_loss: loss, val_perplexity: Math.exp(loss), - acc: acc[0] / acc[1], val_acc: acc[0] / acc[1] } } diff --git a/discojs/src/models/gpt/gpt.spec.ts b/discojs/src/models/gpt/gpt.spec.ts index 669ceb9a2..c4b55649d 100644 --- a/discojs/src/models/gpt/gpt.spec.ts +++ b/discojs/src/models/gpt/gpt.spec.ts @@ -36,8 +36,8 @@ describe('gpt-tfjs', function() { }).repeat().batch(64) const model = new GPT(config) - const logGenerator = model.train(tokenDataset, undefined, 5) // 5 epochs - for await (const _ of logGenerator); // Await the end of training + for (let i = 0; i < 5; i++) + for await (const _ of model.train(tokenDataset, undefined)); const generation = await model.generate("Lorem ipsum dolor", tokenizer, 1) expect(generation).equal(data) // Assert that the model completes 'Lorem ipsum dolor' with 'sit' }) diff --git a/discojs/src/models/gpt/index.ts b/discojs/src/models/gpt/index.ts index b9753ff97..8f7f7b2b0 100644 --- a/discojs/src/models/gpt/index.ts +++ b/discojs/src/models/gpt/index.ts @@ -8,10 +8,13 @@ import { PreTrainedTokenizer } from '@xenova/transformers'; import { WeightsContainer } from '../../index.js' import type { Dataset } from '../../dataset/index.js' -import { Model } from '../model.js' +import { BatchLogs, Model, EpochLogs } from "../index.js"; +import type { Prediction, Sample } from '../model.js' + import { GPTForCausalLM } from './model.js' -import type { EpochLogs, Prediction, Sample } from '../model.js' -import type { GPTConfig } from './config.js' +import { DEFAULT_CONFIG, type GPTConfig } from './config.js' +import evaluate from './evaluate.js'; +import { List } from 'immutable'; export type GPTSerialization = { weights: WeightsContainer @@ -21,9 +24,13 @@ export type GPTSerialization = { export class GPT extends Model { private readonly model: GPTForCausalLM + readonly #maxBatchCount: number + constructor (partialConfig?: GPTConfig, layersModel?: tf.LayersModel) { super() + this.model = new GPTForCausalLM(partialConfig, layersModel) + this.#maxBatchCount = partialConfig?.maxIter ?? DEFAULT_CONFIG.maxIter } /** @@ -38,51 +45,90 @@ export class GPT extends Model { override async *train( trainingData: Dataset, validationData?: Dataset, - epochs = 1, - ): AsyncGenerator { - this.model.compile() + ): AsyncGenerator { + this.model.compile(); + + const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator + let batchesLogs = List(); + for ( + let batchNumber = 0; + batchNumber < this.#maxBatchCount; + batchNumber++ + ) { + const iteration = await batches.next(); + if (iteration.done) break; + const batch = iteration.value; + + const batchLogs = await this.#runBatch(batch); + tf.dispose(batch); + + yield batchLogs; + batchesLogs = batchesLogs.push(batchLogs); + } + + const validation = validationData && (await this.#evaluate(validationData)); + return new EpochLogs(batchesLogs, validation); + } + + async #runBatch( + batch: tf.TensorContainer, + ): Promise { let logs: tf.Logs | undefined; - const trainingArgs: tf.ModelFitDatasetArgs = { - epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop - validationData, - callbacks: { onEpochEnd: (_, cur) => { logs = cur }}, + await this.model.fitDataset(tf.data.array([batch]), { + epochs: 1, + verbose: 0, // don't pollute + callbacks: { + onEpochEnd: (_, cur) => { + logs = cur; + }, + }, + }); + if (logs === undefined) throw new Error("batch didn't gave any logs"); + + const { loss, acc: accuracy } = logs; + if (loss === undefined || isNaN(loss)) + throw new Error("training loss is undefined or NaN"); + + return { + accuracy, + loss, + memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024, }; - for (let epoch = 0; epoch < epochs; epoch++) { - await this.model.fitDataset(trainingData, trainingArgs); - if (logs === undefined) { - throw new Error("Epoch didn't gave any logs"); - } - const { loss, val_acc, val_loss, peakMemory } = logs; - if (loss === undefined || isNaN(loss)) { - throw new Error("Training loss is undefined or nan"); - } - const structuredLogs: EpochLogs = { - epoch, - peakMemory, - training: { - loss: logs.loss, - accuracy: logs.acc - } - } + } - if (validationData !== undefined) { - if(val_loss === undefined || isNaN(val_loss) || - val_acc === undefined || isNaN(val_acc)) { - throw new Error("Validation accuracy or loss is undefined or nan"); + async #evaluate( + dataset: Dataset, + ): Promise> { + const evaluation = await evaluate( + this.model, + dataset.map((t) => { + switch (t) { + case null: + case undefined: + throw new Error("nullish value in dataset"); + default: + // TODO unsafe cast + return t as { xs: tf.Tensor2D; ys: tf.Tensor3D }; } - structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss} - } - yield structuredLogs - } + }), + this.config.maxEvalBatches, + ); + + return { + accuracy: evaluation.val_acc, + loss: evaluation.val_loss, + }; } - override predict (input: Sample): Promise { - const ret = this.model.predict(input) + override predict(input: Sample): Promise { + const ret = this.model.predict(input); if (Array.isArray(ret)) { - throw new Error('prediction yield many Tensors but should have only returned one') + throw new Error( + "prediction yield many Tensors but should have only returned one", + ); } - return Promise.resolve(ret) + return Promise.resolve(ret); } async generate(input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise { diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts index 959468b57..113c7f798 100644 --- a/discojs/src/models/gpt/model.ts +++ b/discojs/src/models/gpt/model.ts @@ -57,6 +57,7 @@ class GPTModel extends tf.LayersModel { await callbacks.onTrainBegin?.() for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) { + let accuracyFraction: [number, number] = [0, 0]; let averageLoss = 0 let peakMemory = 0 let iteration = 1 @@ -69,18 +70,34 @@ class GPTModel extends tf.LayersModel { let weightUpdateTime = performance.now() await callbacks.onEpochBegin?.(epoch) const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D } - const lossFn: () => tf.Scalar = () => { + + // TODO include as a tensor inside the model + const accTensor = tf.tidy(() => { const logits = this.apply(xs) - if (Array.isArray(logits)) { + if (Array.isArray(logits)) throw new Error('model outputs too many tensor') - } - if (logits instanceof tf.SymbolicTensor) { + if (logits instanceof tf.SymbolicTensor) throw new Error('model outputs symbolic tensor') - } - return tf.losses.softmaxCrossEntropy(ys, logits) - } + return tf.metrics.categoricalAccuracy(ys, logits) + }) + const accSize = accTensor.shape.reduce((l, r) => l * r, 1) + const accSumTensor = accTensor.sum() + const accSum = await accSumTensor.array() + tf.dispose(accSumTensor) + if (typeof accSum !== 'number') + throw new Error('got multiple accuracy sum') + accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize]; + tf.dispose([accTensor]) + const lossTensor = tf.tidy(() => { - const { grads, value: lossTensor } = this.optimizer.computeGradients(lossFn) + const { grads, value: lossTensor } = this.optimizer.computeGradients(() => { + const logits = this.apply(xs) + if (Array.isArray(logits)) + throw new Error('model outputs too many tensor') + if (logits instanceof tf.SymbolicTensor) + throw new Error('model outputs symbolic tensor') + return tf.losses.softmaxCrossEntropy(ys, logits) + }) const gradsClipped = clipByGlobalNormObj(grads, 1) this.optimizer.applyGradients(gradsClipped) return lossTensor @@ -89,6 +106,7 @@ class GPTModel extends tf.LayersModel { const loss = await lossTensor.array() averageLoss += loss weightUpdateTime = performance.now() - weightUpdateTime + tf.dispose([xs, ys, lossTensor]) if ( @@ -122,6 +140,7 @@ class GPTModel extends tf.LayersModel { } let logs: tf.Logs = { 'loss': averageLoss / iteration, + 'acc': accuracyFraction[0] / accuracyFraction[1], 'peakMemory': peakMemory } if (evalDataset !== undefined) { diff --git a/discojs/src/models/index.ts b/discojs/src/models/index.ts index aefb9b5fa..267e41bec 100644 --- a/discojs/src/models/index.ts +++ b/discojs/src/models/index.ts @@ -1,4 +1,5 @@ -export { EpochLogs, Model } from './model.js' +export { Model } from './model.js' +export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js"; export { GPT } from './gpt/index.js' export { GPTConfig } from './gpt/config.js' diff --git a/discojs/src/models/logs.ts b/discojs/src/models/logs.ts new file mode 100644 index 000000000..dc7d92497 --- /dev/null +++ b/discojs/src/models/logs.ts @@ -0,0 +1,42 @@ +import { List } from "immutable"; + +export interface ValidationMetrics { + accuracy: number; + loss: number; +} + +export interface BatchLogs { + accuracy: number; + loss: number; + memoryUsage: number; // GB +} + +export class EpochLogs { + public readonly batches: List; + + constructor( + batches: Iterable, + public readonly validation?: ValidationMetrics, + ) { + this.batches = List(batches); + } + + get training(): Record<"accuracy" | "loss", number> { + const sum = this.batches.reduce( + (acc, batch) => ({ + accuracy: acc.accuracy + batch.accuracy, + loss: acc.loss + batch.loss, + }), + { loss: 0, accuracy: 0 }, + ); + + return { + accuracy: sum.accuracy / this.batches.size, + loss: sum.loss / this.batches.size, + }; + } + + get peakMemory(): number { + return this.batches.map((batch) => batch.memoryUsage).max() ?? 0; + } +} diff --git a/discojs/src/models/model.ts b/discojs/src/models/model.ts index 60b61764c..bf2f2ce8a 100644 --- a/discojs/src/models/model.ts +++ b/discojs/src/models/model.ts @@ -3,18 +3,7 @@ import type tf from "@tensorflow/tfjs"; import type { WeightsContainer } from "../index.js"; import type { Dataset } from "../dataset/index.js"; -export interface EpochLogs { - epoch: number; // first epoch is zero - training: { - loss: number, - accuracy?: number - }; - validation?: { - loss: number, - accuracy: number - }; - peakMemory: number; -} +import type { BatchLogs, EpochLogs } from "./logs.js"; // TODO still bound to tfjs export type Prediction = tf.Tensor; @@ -26,7 +15,7 @@ export type Sample = tf.Tensor; * Allow for various implementation of models (various train function, tensor-library, ...) **/ // TODO make it typesafe: same shape of data/input/weights -export abstract class Model implements Disposable{ +export abstract class Model implements Disposable { // TODO don't allow external access but upgrade train to return weights on every epoch /** Return training state */ abstract get weights(): WeightsContainer; @@ -45,14 +34,12 @@ export abstract class Model implements Disposable{ abstract train( trainingData: Dataset, validationData?: Dataset, - epochs?: number, - ): AsyncGenerator; + ): AsyncGenerator; /** Predict likely values */ // TODO extract in separated TrainedModel? abstract predict(input: Sample): Promise; - /** * This method is automatically called to cleanup the memory occupied by the model * when leaving the definition scope if the instance has been defined with the `using` keyword. diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts index 8646c34cd..f653869dc 100644 --- a/discojs/src/models/tfjs.ts +++ b/discojs/src/models/tfjs.ts @@ -1,10 +1,12 @@ +import { List, Map } from 'immutable' import * as tf from '@tensorflow/tfjs' import { WeightsContainer } from '../index.js' +import type { Dataset } from '../dataset/index.js' +import { BatchLogs, EpochLogs } from './index.js' import { Model } from './index.js' -import type { EpochLogs, Prediction, Sample } from './model.js' -import type { Dataset } from '../dataset/index.js' +import type { Prediction, Sample } from './model.js' /** TensorFlow JavaScript model with standard training */ export class TFJS extends Model { @@ -30,52 +32,87 @@ export class TFJS extends Model { override async *train( trainingData: Dataset, validationData?: Dataset, - epochs = 1, - ): AsyncGenerator { - for (let epoch = 0; epoch < epochs; epoch++) { - let logs: tf.Logs | undefined; - let peakMemory = 0 - await this.model.fitDataset(trainingData, { - epochs: 1, - validationData, - callbacks: { - onBatchEnd: (_) => { - const currentMemory = tf.memory().numBytes / 1024 / 1024 / 1024 // GB - if (currentMemory > peakMemory) { - peakMemory = currentMemory - } - }, - onEpochEnd: (_, cur) => { logs = cur } + ): AsyncGenerator { + const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator + let batchesLogs = List(); + for (let batchNumber = 0; true; batchNumber++) { + const iteration = await batches.next(); + if (iteration.done) break; + const batch = iteration.value; + + const batchLogs = { + batch: batchNumber, + ...(await this.#runBatch(batch)), + }; + tf.dispose(batch); + + yield batchLogs; + batchesLogs = batchesLogs.push(batchLogs); + } + + const validation = validationData && (await this.#evaluate(validationData)); + return new EpochLogs(batchesLogs, validation); + } + + async #runBatch( + batch: tf.TensorContainer, + ): Promise> { + let logs: tf.Logs | undefined; + await this.model.fitDataset(tf.data.array([batch]), { + epochs: 1, + verbose: 0, // don't pollute + callbacks: { + onEpochEnd: (_, cur) => { + logs = cur; }, - }); + }, + }); + if (logs === undefined) throw new Error("batch didn't gave any logs"); + + const { loss, acc: accuracy } = logs; + if (loss === undefined || isNaN(loss)) + throw new Error("training loss is undefined or NaN"); + + return { + accuracy, + loss, + memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024, + }; + } - if (logs === undefined) { - throw new Error("Epoch didn't gave any logs"); - } - const { loss, acc, val_acc, val_loss } = logs; - if (loss === undefined || isNaN(loss) || acc === undefined || isNaN(acc)) { - throw new Error("Training loss is undefined or nan"); - } - const structuredLogs: EpochLogs = { - epoch, - peakMemory, - training: { - loss: logs.loss, - accuracy: logs.acc, - } - } - if (validationData !== undefined) { - if(val_loss === undefined || isNaN(val_loss) || - val_acc === undefined || isNaN(val_acc)) { - throw new Error("Invalid validation logs"); + async #evaluate( + dataset: Dataset, + ): Promise> { + const evaluation = await this.model.evaluateDataset( + dataset.map((t) => { + switch (t) { + case null: + case undefined: + throw new Error("nullish value in dataset"); + default: + return t as Exclude; } - structuredLogs.validation = { - accuracy: logs.val_acc, - loss: logs.val_loss - } - } - yield structuredLogs - } + }), + ); + const metricToValue = Map( + List(this.model.metricsNames).zip( + Array.isArray(evaluation) + ? List(await Promise.all(evaluation.map((t) => t.data()))) + : List.of(await evaluation.data()), + ), + ).map((values) => { + if (values.length !== 1) throw new Error("more than one metric value"); + return values[0]; + }); + + const [accuracy, loss] = [ + metricToValue.get("acc"), + metricToValue.get("loss"), + ]; + if (accuracy === undefined || loss === undefined) + throw new Error("some needed metrics are missing"); + + return { accuracy, loss }; } override predict (input: Sample): Promise { diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index d57baa4d3..fdcf12265 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -1,7 +1,10 @@ -import type { data, Logger, Memory, Task, TrainingInformation } from '../index.js' +import { List } from 'immutable' + +import { BatchLogs, data, EpochLogs, Logger, Memory, Task, TrainingInformation } from '../index.js' import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js' import type { Aggregator } from '../aggregator/index.js' import { MeanAggregator } from '../aggregator/mean.js' +import { enumerate, split } from '../utils/async_iterator.js' import type { RoundLogs, Trainer } from './trainer/trainer.js' import { TrainerBuilder } from './trainer/trainer_builder.js' @@ -85,7 +88,14 @@ export class Disco { * @param dataTuple The data tuple */ // TODO RoundLogs should contain number of participants but Trainer doesn't need client - async *fit(dataTuple: data.DataSplit): AsyncGenerator { + async *fit( + dataTuple: data.DataSplit, + ): AsyncGenerator< + AsyncGenerator< + AsyncGenerator, + RoundLogs & { participants: number } + > + > { this.logger.success("Training started."); const trainData = dataTuple.train.preprocess().batch(); @@ -94,25 +104,39 @@ export class Disco { await this.client.connect(); const trainer = await this.trainer; - for await (const roundLogs of trainer.fitModel(trainData.dataset, validationData.dataset)) { - let msg = `Round: ${roundLogs.round}\n` - for (const epochLogs of roundLogs.epochs.values()) { - msg += ` Epoch: ${epochLogs.epoch}\n` - msg += ` Training loss: ${epochLogs.training.loss}\n` - if (epochLogs.training.accuracy !== undefined) { - msg += ` Training accuracy: ${epochLogs.training.accuracy}\n` - } - if (epochLogs.validation !== undefined) { - msg += ` Validation loss: ${epochLogs.validation.loss}\n` - msg += ` Validation accuracy: ${epochLogs.validation.accuracy}\n` + for await (const [round, epochs] of enumerate( + trainer.fitModel(trainData.dataset, validationData.dataset), + )) { + yield async function* (this: Disco) { + let epochsLogs = List(); + for await (const [epoch, batches] of enumerate(epochs)) { + const [gen, returnedEpochLogs] = split(batches); + + yield gen; + const epochLogs = await returnedEpochLogs; + epochsLogs = epochsLogs.push(epochLogs); + + this.logger.success( + [ + `Round: ${round}`, + ` Epoch: ${epoch}`, + ` Training loss: ${epochLogs.training.loss}`, + ` Training accuracy: ${epochLogs.training.accuracy}`, + epochLogs.validation !== undefined + ? ` Validation loss: ${epochLogs.validation.loss}` + : "", + epochLogs.validation !== undefined + ? ` Validation accuracy: ${epochLogs.validation.accuracy}` + : "", + ].join("\n"), + ); } - } - this.logger.success(msg) - yield { - ...roundLogs, - participants: this.client.nodes.size + 1 // add ourself - } + return { + epochs: epochsLogs, + participants: this.client.nodes.size + 1, // add ourself + }; + }.bind(this)(); } this.logger.success("Training finished."); diff --git a/discojs/src/training/trainer/trainer.ts b/discojs/src/training/trainer/trainer.ts index ed1f5bd5b..79ae937a1 100644 --- a/discojs/src/training/trainer/trainer.ts +++ b/discojs/src/training/trainer/trainer.ts @@ -2,11 +2,10 @@ import type tf from "@tensorflow/tfjs"; import { List } from "immutable"; import type { Model, Task } from "../../index.js"; - -import { EpochLogs } from "../../models/model.js"; +import * as async_iterator from "../../utils/async_iterator.js"; +import { BatchLogs, EpochLogs } from "../../models/index.js"; export interface RoundLogs { - round: number; epochs: List; } @@ -22,7 +21,10 @@ export abstract class Trainer { readonly #roundDuration: number; readonly #epochs: number; - private training?: AsyncGenerator; + private training?: AsyncGenerator< + AsyncGenerator, RoundLogs>, + void + >; constructor( task: Task, @@ -30,6 +32,9 @@ export abstract class Trainer { ) { this.#roundDuration = task.trainingInformation.roundDuration; this.#epochs = task.trainingInformation.epochs; + + if (!Number.isInteger(this.#epochs / this.#roundDuration)) + throw new Error(`round duration doesn't divide epochs`); } protected abstract onRoundBegin(round: number): Promise; @@ -49,34 +54,52 @@ export abstract class Trainer { async *fitModel( dataset: tf.data.Dataset, valDataset: tf.data.Dataset, - ): AsyncGenerator { - if (this.training !== undefined) { + ): AsyncGenerator< + AsyncGenerator, RoundLogs>, + void + > { + if (this.training !== undefined) throw new Error( "training already running, cancel it before launching a new one", ); - } - await this.onRoundBegin(0); + try { + this.training = this.#runRounds(dataset, valDataset); + yield* this.training; + } finally { + this.training = undefined; + } + } - this.training = this.model.train(dataset, valDataset, this.#epochs); + async *#runRounds( + dataset: tf.data.Dataset, + valDataset: tf.data.Dataset, + ): AsyncGenerator< + AsyncGenerator, RoundLogs>, + void + > { + const totalRound = Math.trunc(this.#epochs / this.#roundDuration); + for (let round = 0; round < totalRound; round++) { + await this.onRoundBegin(round); + yield this.#runRound(dataset, valDataset); + await this.onRoundEnd(round); + } + } - for await (const logs of this.training) { - // for now, round (sharing on network) == epoch (full pass over local data) - yield { - round: logs.epoch, - epochs: List.of(logs), - }; + async *#runRound( + dataset: tf.data.Dataset, + valDataset: tf.data.Dataset, + ): AsyncGenerator, RoundLogs> { + let epochsLogs = List(); + for (let epoch = 0; epoch < this.#roundDuration; epoch++) { + const [gen, epochLogs] = async_iterator.split( + this.model.train(dataset, valDataset), + ); - if (logs.epoch % this.#roundDuration === 0) { - const round = Math.trunc(logs.epoch / this.#roundDuration); - await this.onRoundEnd(round); - await this.onRoundBegin(round); - } + yield gen; + epochsLogs = epochsLogs.push(await epochLogs); } - const round = Math.trunc(this.#epochs / this.#roundDuration); - await this.onRoundEnd(round); - - this.training = undefined; + return { epochs: epochsLogs }; } } diff --git a/discojs/src/utils/async_iterator.spec.ts b/discojs/src/utils/async_iterator.spec.ts new file mode 100644 index 000000000..af4a34602 --- /dev/null +++ b/discojs/src/utils/async_iterator.spec.ts @@ -0,0 +1,63 @@ +import { expect } from "chai"; + +import { split, gather } from "./async_iterator.js"; + +// Array.fromAsync not yet widely used (2024) +async function arrayFromAsync(iter: AsyncIterable): Promise { + const ret: T[] = []; + for await (const e of iter) ret.push(e); + return ret; +} + +describe("gather", () => { + it("returns generator value", async () => { + const [yielded, returned] = await gather( + // eslint-disable-next-line @typescript-eslint/require-await + (async function* () { + yield "yield"; + return "return"; + })(), + ); + + expect(yielded.toArray()).to.have.same.ordered.members(["yield"]); + expect(returned).to.equals("return"); + }); +}); + +describe("split", () => { + it("returns both iterator and return value", async () => { + const [gen, ret] = split( + // eslint-disable-next-line @typescript-eslint/require-await + (async function* () { + yield "yield"; + return "return"; + })(), + ); + + expect(await arrayFromAsync(gen)).to.have.same.ordered.members(["yield"]); + expect(await ret).to.equals("return"); + }); + + it("throws returned when iterator throws", async () => { + const [gen, ret] = split( + // eslint-disable-next-line require-yield, @typescript-eslint/require-await + (async function* () { + throw new Error(); + })(), + ); + + try { + for await (const _ of gen); + } catch { + // expected + } + + try { + await ret; + } catch { + return; // all good + } + + expect(false, "should have thrown").to.be.true; + }); +}); diff --git a/discojs/src/utils/async_iterator.ts b/discojs/src/utils/async_iterator.ts new file mode 100644 index 000000000..0a5e6e676 --- /dev/null +++ b/discojs/src/utils/async_iterator.ts @@ -0,0 +1,79 @@ +import { List } from "immutable"; + +// `Promise.withResolvers` not widely deployed +function PromiseWithResolvers(): [ + Promise, + (_: T) => void, + (_: unknown) => void, +] { + let resolve: (_: T) => void, reject: (_: unknown) => void; + resolve = reject = () => { + // should not happen as Promise are run on creation + throw new Error("race condition triggered"); + }; + + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + + return [promise, resolve, reject]; +} + +/** + * Split yields from return value + * + * You need to consume the iterator to resolve the returned value + **/ +export function split( + iter: AsyncIterator, +): [AsyncGenerator, Promise] { + const [returnPromise, returnResolve, returnReject] = + PromiseWithResolvers(); + + return [ + (async function* () { + try { + while (true) { + const v = await iter.next(); + if (!v.done) { + yield v.value; + continue; + } + + returnResolve(v.value); + return v.value; + } + } catch (e) { + returnReject(e); + throw e; + } + })(), + returnPromise, + ]; +} + +/** Zip iterator with a infinite counter */ +export function enumerate( + iter: AsyncIterator | Iterator, +): AsyncGenerator<[number, T], U> { + return (async function* () { + for (let i = 0; ; i++) { + const v = await iter.next(); + if (v.done) return v.value; + yield [i, v.value]; + } + })(); +} + +/** Run the whole iterator to get yielded & returned */ +export async function gather( + iter: AsyncIterator, +): Promise<[List, U]> { + let elems = List(); + for (;;) { + const v = await iter.next(); + if (v.done) return [elems, v.value]; + elems = elems.push(v.value); + } +} diff --git a/docs/examples/README.md b/docs/examples/README.md index 2abbc7df8..9d5b0cffc 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -22,7 +22,9 @@ As you can see in `training.ts` a client is represented by a `Disco` object: ```js const disco = new Disco(task, { url, scheme: "federated" }); -await disco.fit(dataset); // Start training on the dataset +for await (const round of disco.fit(dataset)) + for await (const epoch of round) + for await (const batch of epoch); await disco.close(); ``` diff --git a/docs/examples/training.ts b/docs/examples/training.ts index 585685368..d2ab3936d 100644 --- a/docs/examples/training.ts +++ b/docs/examples/training.ts @@ -12,9 +12,13 @@ import { startServer } from 'server' async function runUser (url: URL, task: Task, dataset: data.DataSplit): Promise { // Create Disco object associated with the server url, the training scheme const disco = new Disco(task, { url, scheme: 'federated' }) - for await (const _ of disco.fit(dataset)); // Start training on the dataset - // Stop training and disconnect from the remote server + // Run training on the dataset + for await (const round of disco.fit(dataset)) + for await (const epoch of round) + for await (const _ of epoch); + + // Disconnect from the remote server await disco.close() } diff --git a/docs/examples/wikitext.ts b/docs/examples/wikitext.ts index b3b15e6ae..5367a9180 100644 --- a/docs/examples/wikitext.ts +++ b/docs/examples/wikitext.ts @@ -29,7 +29,9 @@ async function main(): Promise { const aggregator = new aggregators.MeanAggregator() const client = new clients.federated.FederatedClient(url, task, aggregator) const disco = new Disco(task, { scheme: 'federated', client, aggregator }) - for await (const _ of disco.fit(dataset)); + for await (const round of disco.fit(dataset)) + for await (const epoch of round) + for await (const _ of epoch); // Get the model and complete the prompt if (aggregator.model === undefined) { diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index bbe789da4..63255a96a 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -7,7 +7,8 @@ import { assert, expect } from 'chai' import type { RoundLogs, WeightsContainer } from '@epfml/discojs' import { Disco, client as clients, data, - aggregator as aggregators, defaultTasks + aggregator as aggregators, defaultTasks, + async_iterator } from '@epfml/discojs' import { NodeImageLoader, NodeTabularLoader, NodeTextLoader } from '@epfml/discojs-node' @@ -52,7 +53,7 @@ describe("end-to-end federated", function () { const files = [DATASET_DIR + 'titanic_train.csv'] const titanicTask = defaultTasks.titanic.getTask() - titanicTask.trainingInformation.epochs = 5 + titanicTask.trainingInformation.epochs = titanicTask.trainingInformation.roundDuration = 5 const data = await (new NodeTabularLoader(titanicTask, ',').loadAll( files, { @@ -68,7 +69,10 @@ describe("end-to-end federated", function () { let logs = List() for await (const round of disco.fit(data)) { - logs = logs.push(round) + const [roundGen, roundLogs] = async_iterator.split(round) + for await (const epoch of roundGen) + for await (const _ of epoch); + logs = logs.push(await roundLogs) } await disco.close() @@ -99,7 +103,10 @@ describe("end-to-end federated", function () { let logs = List() for await (const round of disco.fit(dataSplit)) { - logs = logs.push(round) + const [roundGen, roundLogs] = async_iterator.split(round) + for await (const epoch of roundGen) + for await (const _ of epoch); + logs = logs.push(await roundLogs) } await disco.close() @@ -120,7 +127,8 @@ describe("end-to-end federated", function () { const negativeLabels = files[1].map(_ => 'COVID-Negative') const labels = positiveLabels.concat(negativeLabels) const lusCovidTask = defaultTasks.lusCovid.getTask() - lusCovidTask.trainingInformation.epochs = 15 + lusCovidTask.trainingInformation.epochs = 16 + lusCovidTask.trainingInformation.roundDuration = 4 const data = await new NodeImageLoader(lusCovidTask) .loadAll(files.flat(), { labels, channels: 3 }) @@ -131,7 +139,10 @@ describe("end-to-end federated", function () { let logs = List() for await (const round of disco.fit(data)) { - logs = logs.push(round) + const [roundGen, roundLogs] = async_iterator.split(round) + for await (const epoch of roundGen) + for await (const _ of epoch); + logs = logs.push(await roundLogs) } await disco.close() diff --git a/webapp/cypress/e2e/library.cy.ts b/webapp/cypress/e2e/library.cy.ts index ea4b204c0..5931d132d 100644 --- a/webapp/cypress/e2e/library.cy.ts +++ b/webapp/cypress/e2e/library.cy.ts @@ -19,7 +19,7 @@ describe("model library", () => { function setupForTask(taskProvider: TaskProvider): void { cy.intercept({ hostname: "server", pathname: "/tasks" }, (req) => { const task = taskProvider.getTask(); - task.trainingInformation.epochs = 3; + task.trainingInformation.epochs = task.trainingInformation.roundDuration = 3; req.reply([task]); }); @@ -103,7 +103,7 @@ describe("model library", () => { cy.contains("button", "next").click(); cy.contains("button", "train alone").click(); - cy.contains("h6", "current round") + cy.contains("h6", "current epoch") .next({ timeout: 10_000 }) .should("have.text", "3"); cy.contains("button", "next").click(); diff --git a/webapp/cypress/e2e/training.cy.ts b/webapp/cypress/e2e/training.cy.ts index fa4df1338..484720558 100644 --- a/webapp/cypress/e2e/training.cy.ts +++ b/webapp/cypress/e2e/training.cy.ts @@ -57,8 +57,8 @@ describe("training page", () => { cy.contains("button", "train alone").click(); cy.contains("h6", "current round") - .next({ timeout: 30_000 }) - .should("have.text", "20"); + .next({ timeout: 40_000 }) + .should("have.text", "2"); cy.contains("button", "next").click(); cy.contains("button", "test model").click(); diff --git a/webapp/src/charts.ts b/webapp/src/charts.ts deleted file mode 100644 index 4e83049e3..000000000 --- a/webapp/src/charts.ts +++ /dev/null @@ -1,73 +0,0 @@ -const chartOptions = { - chart: { - id: 'realtime', - width: 'auto', - height: 'auto', - // type: 'area', - animations: { - enabled: true, - easing: 'linear', - dynamicAnimation: { - speed: 1000 - } - }, - toolbar: { - show: false - }, - zoom: { - enabled: false - } - }, - dataLabels: { - enabled: false - }, - colors: [ - '#6096BA' - ], - fill: { - colors: ['#E2E8F0'], - type: 'solid', - opacity: 0.6 - }, - stroke: { - curve: 'smooth' - }, - markers: { - size: 0.5 - }, - grid: { - xaxis: { - lines: { - show: false - } - }, - yaxis: { - lines: { - show: false - } - } - }, - yaxis: { - max: 100, - min: 0, - labels: { - show: true, - formatter: function (value: number) { - return value.toFixed(0); - } - } - }, - xaxis: { - labels: { - show: false - } - }, - legend: { - show: false - }, - tooltip: { - enabled: true - } -} - -export { chartOptions } diff --git a/webapp/src/components/training/Trainer.vue b/webapp/src/components/training/Trainer.vue index 1f7099980..5ad48e7b6 100644 --- a/webapp/src/components/training/Trainer.vue +++ b/webapp/src/components/training/Trainer.vue @@ -30,7 +30,9 @@
@@ -42,12 +44,8 @@ import { List } from "immutable"; import { ref, computed } from "vue"; -import type { RoundLogs, Task } from "@epfml/discojs"; -import { - data, - EmptyMemory, - Disco, -} from "@epfml/discojs"; +import type { BatchLogs, EpochLogs, RoundLogs, Task } from "@epfml/discojs"; +import { async_iterator, data, EmptyMemory, Disco } from "@epfml/discojs"; import { IndexedDB } from "@epfml/discojs-web"; import { getClient } from '@/clients' @@ -69,19 +67,40 @@ const props = defineProps<{ const displayModelCaching = ref(true) const trainingGenerator = - ref>(); -const logs = ref(List()); + ref< + AsyncGenerator< + AsyncGenerator< + AsyncGenerator, + RoundLogs & { participants: number } + > + > + >(); +const roundGenerator = + ref< + AsyncGenerator< + AsyncGenerator, + RoundLogs & { participants: number } + > + >(); +const epochGenerator = ref>(); +const roundsLogs = ref(List()); +const epochsOfRoundLogs = ref(List()); +const batchesOfEpochLogs = ref(List()); const messages = ref(List()); const hasValidationData = computed( () => props.task.trainingInformation.validationSplit > 0, ); +const stopper = new Error("stop training") + async function startTraining(distributed: boolean): Promise { // Reset training information before starting a new training trainingGenerator.value = undefined - logs.value = List() - messages.value = List() + roundsLogs.value = List() + epochsOfRoundLogs.value = List() + batchesOfEpochLogs.value = List() + messages.value = List() let dataset: data.DataSplit; try { @@ -130,16 +149,29 @@ async function startTraining(distributed: boolean): Promise { try { displayModelCaching.value = false // hide model caching buttons during training trainingGenerator.value = disco.fit(dataset); - logs.value = List(); - for await (const roundLogs of trainingGenerator.value) - logs.value = logs.value.push(roundLogs); - if (trainingGenerator.value === undefined) { - toaster.info("Training stopped"); - return; + roundsLogs.value = List() + for await (const round of trainingGenerator.value) { + const [roundGen, roundLogs] = async_iterator.split(round) + + roundGenerator.value = roundGen + epochsOfRoundLogs.value = List() + for await (const epoch of roundGenerator.value) { + const [epochGen, epochLogs] = async_iterator.split(epoch) + + epochGenerator.value = epochGen + batchesOfEpochLogs.value = List() + for await (const batch of epochGenerator.value) + batchesOfEpochLogs.value = batchesOfEpochLogs.value.push(batch); + epochsOfRoundLogs.value = epochsOfRoundLogs.value.push(await epochLogs) + } + roundsLogs.value = roundsLogs.value.push(await roundLogs) } } catch (e) { - if (e instanceof Error && e.message.includes("greater than WebGL maximum on this browser")) { + if (e === stopper) { + toaster.info("Training stopped"); + return + } else if (e instanceof Error && e.message.includes("greater than WebGL maximum on this browser")) { toaster.error("Unfortunately your browser doesn't support training this task.
If you are on Firefox try using Chrome instead.") } else if (e instanceof Error && e.message.includes("loss is undefined or nan")) { toaster.error("Training is not converging. Data potentially needs better preprocessing.") @@ -157,10 +189,13 @@ async function startTraining(distributed: boolean): Promise { } async function stopTraining(): Promise { - const generator = trainingGenerator.value; - if (generator === undefined) return; - + trainingGenerator.value?.throw(stopper); trainingGenerator.value = undefined; - generator.return(); + + roundGenerator.value?.throw(stopper); + roundGenerator.value = undefined; + + epochGenerator.value?.throw(stopper); + epochGenerator.value = undefined; } diff --git a/webapp/src/components/training/TrainingInformation.vue b/webapp/src/components/training/TrainingInformation.vue index 04b19b852..1ca56d9cc 100644 --- a/webapp/src/components/training/TrainingInformation.vue +++ b/webapp/src/components/training/TrainingInformation.vue @@ -4,7 +4,21 @@
+ + + + + + @@ -34,12 +48,10 @@ - +