Skip to content

Commit

Permalink
discojs/model: expose batch generator
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Jun 21, 2024
1 parent 76befd7 commit c86b033
Show file tree
Hide file tree
Showing 11 changed files with 341 additions and 123 deletions.
6 changes: 4 additions & 2 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { Task } from '@epfml/discojs'
import { fetchTasks, data, models } from '@epfml/discojs'
import { NodeTextLoader, loadModelFromDisk } from '@epfml/discojs-node'
import { startServer } from 'server'
import { get_return_value } from './utils.js';

interface CLIArguments{
modelType?: string; // 'gpt-nano', 'gpt-micro', 'gpt-mini', 'gpt2'
Expand Down Expand Up @@ -80,8 +81,9 @@ async function main(args: Required<CLIArguments>): Promise<void> {
console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`)

let epochTime = performance.now()
const logGenerator = model.train(preprocessedDataset, undefined, epoch)
for await (const logs of logGenerator) {
const rounds = model.train(preprocessedDataset, undefined, epoch)
for (const round of rounds) {
const logs = await get_return_value(round)
epochTime = (performance.now() - epochTime)
const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epoch)
console.log(`\t\tTraining time: ${msPerToken.toFixed(2)} ms/token <br> ${logs.peakMemory.toFixed(2)} GB`)
Expand Down
9 changes: 9 additions & 0 deletions cli/src/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export async function get_return_value<T>(
iter: AsyncIterator<unknown, T>,
): Promise<T> {
for (;;) {
const v = await iter.next();
if (!v.done) continue;
return v.value;
}
}
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/wikitext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export const wikitext: TaskProvider = {
modelID: 'wikitext-103-raw-model',
preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
scheme: 'federated',
epochs: 5,
epochs: 10,
// Unused by wikitext because data already comes split
// But if set to 0 then the webapp doesn't display the validation metrics
validationSplit: 0.1,
Expand Down
132 changes: 94 additions & 38 deletions discojs/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
**/

import { List, Map } from 'immutable';
import * as tf from '@tensorflow/tfjs'
import { PreTrainedTokenizer } from '@xenova/transformers';

import { WeightsContainer } from '../../index.js'
import type { Dataset } from '../../dataset/index.js'

import { Model } from '../model.js'
import { BatchLogs, EpochLogs, Model } from "../index.js";
import type { Prediction, Sample } from '../model.js'

import { GPTForCausalLM } from './model.js'
import type { EpochLogs, Prediction, Sample } from '../model.js'
import type { GPTConfig } from './config.js'

export type GPTSerialization = {
Expand All @@ -35,54 +37,108 @@ export class GPT extends Model {
* @param epochs the number of passes of the training dataset
* @param tracker
*/
override async *train(
override *train(
trainingData: Dataset,
validationData?: Dataset,
epochs = 1,
): AsyncGenerator<EpochLogs, void> {
this.model.compile()
): Generator<AsyncGenerator<BatchLogs, EpochLogs>, void> {
this.model.compile();

for (let epoch = 0; epoch < epochs; epoch++)
yield async function* (this: GPT) {
let batchesLogs = List<BatchLogs>();

const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
for (let batchNumber = 0; true; batchNumber++) {
const iteration = await batches.next();
if (iteration.done) break;
const batch = iteration.value;

const batchLogs = {
batch: batchNumber,
...(await this.#runBatch(batch)),
};

yield batchLogs;
batchesLogs = batchesLogs.push(batchLogs);
}

const validation =
validationData && (await this.#evaluate(validationData));

return new EpochLogs(epoch, batchesLogs, validation);
}.bind(this)();
}

async #runBatch(
batch: tf.TensorContainer,
): Promise<Omit<BatchLogs, "batch">> {
let logs: tf.Logs | undefined;
const trainingArgs: tf.ModelFitDatasetArgs<tf.TensorContainer> = {
epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop
validationData,
callbacks: { onEpochEnd: (_, cur) => { logs = cur }},
await this.model.fitDataset(tf.data.array([batch]), {
epochs: 1,
callbacks: {
onEpochEnd: (_, cur) => {
logs = cur;
},
},
});
if (logs === undefined) throw new Error("batch didn't gave any logs");

const { loss, acc: accuracy } = logs;
if (loss === undefined || isNaN(loss))
throw new Error("training loss is undefined or NaN");

return {
accuracy,
loss,
memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
};
for (let epoch = 0; epoch < epochs; epoch++) {
await this.model.fitDataset(trainingData, trainingArgs);
if (logs === undefined) {
throw new Error("Epoch didn't gave any logs");
}
const { loss, val_acc, val_loss, peakMemory } = logs;
if (loss === undefined || isNaN(loss)) {
throw new Error("Training loss is undefined or nan");
}
const structuredLogs: EpochLogs = {
epoch,
peakMemory,
training: {
loss: logs.loss,
accuracy: logs.acc
}
}
}

if (validationData !== undefined) {
if(val_loss === undefined || isNaN(val_loss) ||
val_acc === undefined || isNaN(val_acc)) {
throw new Error("Validation accuracy or loss is undefined or nan");
async #evaluate(
dataset: Dataset,
): Promise<Record<"accuracy" | "loss", number>> {
const evaluation = await this.model.evaluateDataset(
dataset.map((t) => {
switch (t) {
case null:
case undefined:
throw new Error("nullish value in dataset");
default:
return t as Exclude<tf.TensorContainer, void>;
}
structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss}
}
yield structuredLogs
}
}),
);
const metricToValue = Map(
List(this.model.metricsNames).zip(
Array.isArray(evaluation)
? List(await Promise.all(evaluation.map((t) => t.data())))
: List.of(await evaluation.data()),
),
).map((values) => {
if (values.length !== 1) throw new Error("more than one metric value");
return values[0];
});

const [accuracy, loss] = [
metricToValue.get("acc"),
metricToValue.get("loss"),
];
if (accuracy === undefined || loss === undefined)
throw new Error("some needed metrics are missing");

return { accuracy, loss };
}

override predict (input: Sample): Promise<Prediction> {
const ret = this.model.predict(input)
override predict(input: Sample): Promise<Prediction> {
const ret = this.model.predict(input);
if (Array.isArray(ret)) {
throw new Error('prediction yield many Tensors but should have only returned one')
throw new Error(
"prediction yield many Tensors but should have only returned one",
);
}

return Promise.resolve(ret)
return Promise.resolve(ret);
}

async generate(input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise<string> {
Expand Down
3 changes: 2 additions & 1 deletion discojs/src/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export { EpochLogs, Model } from './model.js'
export { Model } from './model.js'
export { BatchLogs, EpochLogs } from "./logs.js";

export { GPT } from './gpt/index.js'
export { GPTConfig } from './gpt/config.js'
Expand Down
34 changes: 34 additions & 0 deletions discojs/src/models/logs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { List } from "immutable";

export interface BatchLogs {
batch: number; // first batch is zero
accuracy: number;
loss: number;
memoryUsage: number; // GB
}

export class EpochLogs {
public readonly batches: List<BatchLogs>;

constructor(
public readonly epoch: number, // first epoch is zero
batches: Iterable<BatchLogs>,
public readonly validation?: Record<"accuracy" | "loss", number>,
) {
this.batches = List(batches);
}

get training(): Record<"accuracy" | "loss", number> {
return this.batches.reduce(
(acc, batch) => ({
accuracy: acc.accuracy + batch.accuracy,
loss: acc.loss + batch.loss,
}),
{ loss: 0, accuracy: 0 },
);
}

get peakMemory(): number {
return this.batches.map((batch) => batch.memoryUsage).max() ?? 0;
}
}
18 changes: 3 additions & 15 deletions discojs/src/models/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,7 @@ import type tf from "@tensorflow/tfjs";
import type { WeightsContainer } from "../index.js";
import type { Dataset } from "../dataset/index.js";

export interface EpochLogs {
epoch: number; // first epoch is zero
training: {
loss: number,
accuracy?: number
};
validation?: {
loss: number,
accuracy: number
};
peakMemory: number;
}
import type { BatchLogs, EpochLogs } from "./logs.js";

// TODO still bound to tfjs
export type Prediction = tf.Tensor;
Expand All @@ -26,7 +15,7 @@ export type Sample = tf.Tensor;
* Allow for various implementation of models (various train function, tensor-library, ...)
**/
// TODO make it typesafe: same shape of data/input/weights
export abstract class Model implements Disposable{
export abstract class Model implements Disposable {
// TODO don't allow external access but upgrade train to return weights on every epoch
/** Return training state */
abstract get weights(): WeightsContainer;
Expand All @@ -46,13 +35,12 @@ export abstract class Model implements Disposable{
trainingData: Dataset,
validationData?: Dataset,
epochs?: number,
): AsyncGenerator<EpochLogs, void>;
): Generator<AsyncGenerator<BatchLogs, EpochLogs>, void>;

/** Predict likely values */
// TODO extract in separated TrainedModel?
abstract predict(input: Sample): Promise<Prediction>;


/**
* This method is automatically called to cleanup the memory occupied by the model
* when leaving the definition scope if the instance has been defined with the `using` keyword.
Expand Down
Loading

0 comments on commit c86b033

Please sign in to comment.