Skip to content

Commit

Permalink
discojs/model: expose batch generator
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Jun 26, 2024
1 parent 8b58308 commit 91a2844
Show file tree
Hide file tree
Showing 28 changed files with 721 additions and 381 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint-test-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ jobs:
cache: npm
- run: npm ci
- run: npm --workspace={discojs,discojs-node,server} run build
- run: npm --workspace=cli start -- -t cifar10 -u 3 -e 1
- run: npm --workspace=cli start -- -t cifar10 -u 3 -e 1 -r 1

test-docs-examples:
needs: [build-lib, build-lib-node, build-server, download-datasets]
Expand Down
10 changes: 5 additions & 5 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { parse } from 'ts-command-line-args';
import type { Task } from '@epfml/discojs'
import { fetchTasks, data, models } from '@epfml/discojs'
import { fetchTasks, data, models, async_iterator } from '@epfml/discojs'
import { NodeTextLoader, loadModelFromDisk } from '@epfml/discojs-node'
import { startServer } from 'server'

Expand Down Expand Up @@ -57,7 +57,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {
*/
if (!benchmarkInference) {
// Benchmark parameters
const epoch = 1
const epochsCount = 1
const iterationsPerEpoch = 10

const config: models.GPTConfig = {
Expand All @@ -80,10 +80,10 @@ async function main(args: Required<CLIArguments>): Promise<void> {
console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`)

let epochTime = performance.now()
const logGenerator = model.train(preprocessedDataset, undefined, epoch)
for await (const logs of logGenerator) {
for (let epochsCounter = 0; epochsCounter < epochsCount; epochsCounter++) {
const [_, logs] = await async_iterator.gather(model.train(preprocessedDataset))
epochTime = (performance.now() - epochTime)
const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epoch)
const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epochsCounter)
console.log(`\t\tTraining time: ${msPerToken.toFixed(2)} ms/token <br> ${logs.peakMemory.toFixed(2)} GB`)
}

Expand Down
9 changes: 7 additions & 2 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { List, Range } from 'immutable'
import fs from 'node:fs/promises'

import type { data, RoundLogs, Task } from '@epfml/discojs'
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'
import { Disco, aggregator as aggregators, async_iterator, client as clients } from '@epfml/discojs'
import { startServer } from 'server'

import { getTaskData } from './data.js'
Expand All @@ -23,7 +23,12 @@ async function runUser(
const disco = new Disco(task, { scheme: "federated", client });

let logs = List<RoundLogs>();
for await (const round of disco.fit(data)) logs = logs.push(round);
for await (const round of disco.fit(data)) {
const [roundGen, roundLogs] = async_iterator.split(round)
for await (const epoch of roundGen)
for await (const _ of epoch);
logs = logs.push(await roundLogs);
}

await disco.close();
return logs;
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/wikitext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export const wikitext: TaskProvider = {
modelID: 'llm-raw-model',
preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
scheme: 'federated',
epochs: 5,
epochs: 6,
// Unused by wikitext because data already comes split
// But if set to 0 then the webapp doesn't display the validation metrics
validationSplit: 0.1,
Expand Down
3 changes: 2 additions & 1 deletion discojs/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemo
export { Disco, RoundLogs } from './training/index.js'
export { Validator } from './validation/index.js'

export { Model, EpochLogs } from './models/index.js'
export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js'
export * as models from './models/index.js'

export * from './task/index.js'
export * as defaultTasks from './default_tasks/index.js'

export * from './types.js'
export * as async_iterator from "./utils/async_iterator.js"
3 changes: 1 addition & 2 deletions discojs/src/models/gpt/evaluate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export default async function evaluate (
model: tf.LayersModel,
dataset: tf.data.Dataset<DataPoint>,
maxEvalBatches: number
): Promise<Record<'acc' | 'val_acc' | 'val_loss' | 'val_perplexity', number>> {
): Promise<Record<'val_acc' | 'val_loss' | 'val_perplexity', number>> {
let datasetSize = 0
let totalLoss = 0
const acc: [number, number] = [0, 0]
Expand Down Expand Up @@ -53,7 +53,6 @@ export default async function evaluate (
return {
val_loss: loss,
val_perplexity: Math.exp(loss),
acc: acc[0] / acc[1],
val_acc: acc[0] / acc[1]
}
}
4 changes: 2 additions & 2 deletions discojs/src/models/gpt/gpt.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ describe('gpt-tfjs', function() {
}).repeat().batch(64)

const model = new GPT(config)
const logGenerator = model.train(tokenDataset, undefined, 5) // 5 epochs
for await (const _ of logGenerator); // Await the end of training
for (let i = 0; i < 5; i++)
for await (const _ of model.train(tokenDataset, undefined));
const generation = await model.generate("Lorem ipsum dolor", tokenizer, 1)
expect(generation).equal(data) // Assert that the model completes 'Lorem ipsum dolor' with 'sit'
})
Expand Down
124 changes: 85 additions & 39 deletions discojs/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import { PreTrainedTokenizer } from '@xenova/transformers';
import { WeightsContainer } from '../../index.js'
import type { Dataset } from '../../dataset/index.js'

import { Model } from '../model.js'
import { BatchLogs, Model, EpochLogs } from "../index.js";
import type { Prediction, Sample } from '../model.js'

import { GPTForCausalLM } from './model.js'
import type { EpochLogs, Prediction, Sample } from '../model.js'
import type { GPTConfig } from './config.js'
import { DEFAULT_CONFIG, type GPTConfig } from './config.js'
import evaluate from './evaluate.js';
import { List } from 'immutable';

export type GPTSerialization = {
weights: WeightsContainer
Expand All @@ -21,9 +24,13 @@ export type GPTSerialization = {
export class GPT extends Model {
private readonly model: GPTForCausalLM

readonly #maxBatchCount: number

constructor (partialConfig?: GPTConfig, layersModel?: tf.LayersModel) {
super()

this.model = new GPTForCausalLM(partialConfig, layersModel)
this.#maxBatchCount = partialConfig?.maxIter ?? DEFAULT_CONFIG.maxIter
}

/**
Expand All @@ -38,51 +45,90 @@ export class GPT extends Model {
override async *train(
trainingData: Dataset,
validationData?: Dataset,
epochs = 1,
): AsyncGenerator<EpochLogs, void> {
this.model.compile()
): AsyncGenerator<BatchLogs, EpochLogs> {
this.model.compile();

const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
let batchesLogs = List<BatchLogs>();
for (
let batchNumber = 0;
batchNumber < this.#maxBatchCount;
batchNumber++
) {
const iteration = await batches.next();
if (iteration.done) break;
const batch = iteration.value;

const batchLogs = await this.#runBatch(batch);
tf.dispose(batch);

yield batchLogs;
batchesLogs = batchesLogs.push(batchLogs);
}

const validation = validationData && (await this.#evaluate(validationData));
return new EpochLogs(batchesLogs, validation);
}

async #runBatch(
batch: tf.TensorContainer,
): Promise<BatchLogs> {
let logs: tf.Logs | undefined;
const trainingArgs: tf.ModelFitDatasetArgs<tf.TensorContainer> = {
epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop
validationData,
callbacks: { onEpochEnd: (_, cur) => { logs = cur }},
await this.model.fitDataset(tf.data.array([batch]), {
epochs: 1,
verbose: 0, // don't pollute
callbacks: {
onEpochEnd: (_, cur) => {
logs = cur;
},
},
});
if (logs === undefined) throw new Error("batch didn't gave any logs");

const { loss, acc: accuracy } = logs;
if (loss === undefined || isNaN(loss))
throw new Error("training loss is undefined or NaN");

return {
accuracy,
loss,
memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
};
for (let epoch = 0; epoch < epochs; epoch++) {
await this.model.fitDataset(trainingData, trainingArgs);
if (logs === undefined) {
throw new Error("Epoch didn't gave any logs");
}
const { loss, val_acc, val_loss, peakMemory } = logs;
if (loss === undefined || isNaN(loss)) {
throw new Error("Training loss is undefined or nan");
}
const structuredLogs: EpochLogs = {
epoch,
peakMemory,
training: {
loss: logs.loss,
accuracy: logs.acc
}
}
}

if (validationData !== undefined) {
if(val_loss === undefined || isNaN(val_loss) ||
val_acc === undefined || isNaN(val_acc)) {
throw new Error("Validation accuracy or loss is undefined or nan");
async #evaluate(
dataset: Dataset,
): Promise<Record<"accuracy" | "loss", number>> {
const evaluation = await evaluate(
this.model,
dataset.map((t) => {
switch (t) {
case null:
case undefined:
throw new Error("nullish value in dataset");
default:
// TODO unsafe cast
return t as { xs: tf.Tensor2D; ys: tf.Tensor3D };
}
structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss}
}
yield structuredLogs
}
}),
this.config.maxEvalBatches,
);

return {
accuracy: evaluation.val_acc,
loss: evaluation.val_loss,
};
}

override predict (input: Sample): Promise<Prediction> {
const ret = this.model.predict(input)
override predict(input: Sample): Promise<Prediction> {
const ret = this.model.predict(input);
if (Array.isArray(ret)) {
throw new Error('prediction yield many Tensors but should have only returned one')
throw new Error(
"prediction yield many Tensors but should have only returned one",
);
}

return Promise.resolve(ret)
return Promise.resolve(ret);
}

async generate(input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise<string> {
Expand Down
35 changes: 27 additions & 8 deletions discojs/src/models/gpt/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class GPTModel extends tf.LayersModel {
await callbacks.onTrainBegin?.()

for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) {
let accuracyFraction: [number, number] = [0, 0];
let averageLoss = 0
let peakMemory = 0
let iteration = 1
Expand All @@ -69,18 +70,34 @@ class GPTModel extends tf.LayersModel {
let weightUpdateTime = performance.now()
await callbacks.onEpochBegin?.(epoch)
const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D }
const lossFn: () => tf.Scalar = () => {

// TODO include as a tensor inside the model
const accTensor = tf.tidy(() => {
const logits = this.apply(xs)
if (Array.isArray(logits)) {
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
}
if (logits instanceof tf.SymbolicTensor) {
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
}
return tf.losses.softmaxCrossEntropy(ys, logits)
}
return tf.metrics.categoricalAccuracy(ys, logits)
})
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
const accSumTensor = accTensor.sum()
const accSum = await accSumTensor.array()
tf.dispose(accSumTensor)
if (typeof accSum !== 'number')
throw new Error('got multiple accuracy sum')
accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize];
tf.dispose([accTensor])

const lossTensor = tf.tidy(() => {
const { grads, value: lossTensor } = this.optimizer.computeGradients(lossFn)
const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
const logits = this.apply(xs)
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
return tf.losses.softmaxCrossEntropy(ys, logits)
})
const gradsClipped = clipByGlobalNormObj(grads, 1)
this.optimizer.applyGradients(gradsClipped)
return lossTensor
Expand All @@ -89,6 +106,7 @@ class GPTModel extends tf.LayersModel {
const loss = await lossTensor.array()
averageLoss += loss
weightUpdateTime = performance.now() - weightUpdateTime

tf.dispose([xs, ys, lossTensor])

if (
Expand Down Expand Up @@ -122,6 +140,7 @@ class GPTModel extends tf.LayersModel {
}
let logs: tf.Logs = {
'loss': averageLoss / iteration,
'acc': accuracyFraction[0] / accuracyFraction[1],
'peakMemory': peakMemory
}
if (evalDataset !== undefined) {
Expand Down
3 changes: 2 additions & 1 deletion discojs/src/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export { EpochLogs, Model } from './model.js'
export { Model } from './model.js'
export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js";

export { GPT } from './gpt/index.js'
export { GPTConfig } from './gpt/config.js'
Expand Down
42 changes: 42 additions & 0 deletions discojs/src/models/logs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { List } from "immutable";

export interface ValidationMetrics {
accuracy: number;
loss: number;
}

export interface BatchLogs {
accuracy: number;
loss: number;
memoryUsage: number; // GB
}

export class EpochLogs {
public readonly batches: List<BatchLogs>;

constructor(
batches: Iterable<BatchLogs>,
public readonly validation?: ValidationMetrics,
) {
this.batches = List(batches);
}

get training(): Record<"accuracy" | "loss", number> {
const sum = this.batches.reduce(
(acc, batch) => ({
accuracy: acc.accuracy + batch.accuracy,
loss: acc.loss + batch.loss,
}),
{ loss: 0, accuracy: 0 },
);

return {
accuracy: sum.accuracy / this.batches.size,
loss: sum.loss / this.batches.size,
};
}

get peakMemory(): number {
return this.batches.map((batch) => batch.memoryUsage).max() ?? 0;
}
}
Loading

0 comments on commit 91a2844

Please sign in to comment.