Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: prerelease fixes #688

Merged
merged 18 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint-test-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ jobs:
cache: npm
- run: npm ci
- run: npm --workspace={discojs,discojs-node,server} run build
- run: npm --workspace=cli start -- -t cifar10 -u 3 -e 1
- run: npm --workspace=cli start -- -t cifar10 -u 3 -e 1 -r 1

test-docs-examples:
needs: [build-lib, build-lib-node, build-server, download-datasets]
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ ___
DISCO aims to enable open-access and easy-use distributed training which is
- :tornado: efficient ([R1](https://github.com/epfml/powergossip), [R2](https://github.com/epfml/ChocoSGD))
- :lock: privacy-preserving ([R3](https://eprint.iacr.org/2017/281.pdf), [R4](https://arxiv.org/abs/2006.04747))
- :hammer_and_wrench: fault-tolerant and dynamic over time ([R5](https://arxiv.org/abs/1910.12308))
- :ninja: robust to malicious actors and data poisoning ([R6](https://arxiv.org/abs/2012.10333), [R7](https://arxiv.org/abs/2006.09365))
- :apple: :banana: interpretable in imperfectly interoperable data distributions ([R8](https://arxiv.org/abs/2107.06580))
- :mirror: personalizable ([R9](https://arxiv.org/abs/2103.00710))
- :hammer_and_wrench: fault-tolerant and dynamic over time ([R5](https://arxiv.org/abs/2106.06639), [R6](https://arxiv.org/abs/2206.08307))
- :ninja: robust to malicious actors and data poisoning ([R7](https://arxiv.org/abs/2012.10333), [R8](https://arxiv.org/abs/2006.09365))
- :apple: :banana: interpretable in imperfectly interoperable data distributions ([R9](https://arxiv.org/abs/2107.06580))
- :mirror: personalizable ([R10](https://arxiv.org/abs/2103.00710))
- :carrot: fairly incentivize participation


Expand Down
12 changes: 6 additions & 6 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { parse } from 'ts-command-line-args';
import type { Task } from '@epfml/discojs'
import { fetchTasks, data, models } from '@epfml/discojs'
import { fetchTasks, data, models, async_iterator } from '@epfml/discojs'
import { NodeTextLoader, loadModelFromDisk } from '@epfml/discojs-node'
import { startServer } from 'server'

Expand Down Expand Up @@ -49,15 +49,15 @@ async function main(args: Required<CLIArguments>): Promise<void> {

// Fetch the wikitext task from the server
const tasks = await fetchTasks(url)
const task = tasks.get('wikitext-103')
const task = tasks.get('llm_task')
if (task === undefined) { throw new Error('task not found') }

/**
* Training benchmark
*/
if (!benchmarkInference) {
// Benchmark parameters
const epoch = 1
const epochsCount = 1
const iterationsPerEpoch = 10

const config: models.GPTConfig = {
Expand All @@ -80,10 +80,10 @@ async function main(args: Required<CLIArguments>): Promise<void> {
console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`)

let epochTime = performance.now()
const logGenerator = model.train(preprocessedDataset, undefined, epoch)
for await (const logs of logGenerator) {
for (let epochsCounter = 1; epochsCounter <= epochsCount; epochsCounter++) {
const [_, logs] = await async_iterator.gather(model.train(preprocessedDataset))
epochTime = (performance.now() - epochTime)
const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epoch)
const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epochsCounter)
console.log(`\t\tTraining time: ${msPerToken.toFixed(2)} ms/token <br> ${logs.peakMemory.toFixed(2)} GB`)
}

Expand Down
11 changes: 8 additions & 3 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ import { startServer } from 'server'
import { getTaskData } from './data.js'
import { args } from './args.js'

// Array.fromAsync not yet widely used (2024)
async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
const ret: T[] = [];
for await (const e of iter) ret.push(e);
return ret;
}

async function runUser(
task: Task,
url: URL,
Expand All @@ -22,9 +29,7 @@ async function runUser(
// force the federated scheme
const disco = new Disco(task, { scheme: "federated", client });

let logs = List<RoundLogs>();
for await (const round of disco.fit(data)) logs = logs.push(round);

const logs = List(await arrayFromAsync(disco.trainByRound(data)));
await disco.close();
return logs;
}
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/wikitext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export const wikitext: TaskProvider = {
modelID: 'llm-raw-model',
preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
scheme: 'federated',
epochs: 5,
epochs: 6,
// Unused by wikitext because data already comes split
// But if set to 0 then the webapp doesn't display the validation metrics
validationSplit: 0.1,
Expand Down
3 changes: 2 additions & 1 deletion discojs/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemo
export { Disco, RoundLogs } from './training/index.js'
export { Validator } from './validation/index.js'

export { Model, EpochLogs } from './models/index.js'
export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js'
export * as models from './models/index.js'

export * from './task/index.js'
export * as defaultTasks from './default_tasks/index.js'

export * from './types.js'
export * as async_iterator from "./utils/async_iterator.js"
3 changes: 1 addition & 2 deletions discojs/src/models/gpt/evaluate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export default async function evaluate (
model: tf.LayersModel,
dataset: tf.data.Dataset<DataPoint>,
maxEvalBatches: number
): Promise<Record<'acc' | 'val_acc' | 'val_loss' | 'val_perplexity', number>> {
): Promise<Record<'val_acc' | 'val_loss' | 'val_perplexity', number>> {
let datasetSize = 0
let totalLoss = 0
const acc: [number, number] = [0, 0]
Expand Down Expand Up @@ -53,7 +53,6 @@ export default async function evaluate (
return {
val_loss: loss,
val_perplexity: Math.exp(loss),
acc: acc[0] / acc[1],
val_acc: acc[0] / acc[1]
}
}
4 changes: 2 additions & 2 deletions discojs/src/models/gpt/gpt.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ describe('gpt-tfjs', function() {
}).repeat().batch(64)

const model = new GPT(config)
const logGenerator = model.train(tokenDataset, undefined, 5) // 5 epochs
for await (const _ of logGenerator); // Await the end of training
for (let i = 0; i < 5; i++)
for await (const _ of model.train(tokenDataset, undefined));
const generation = await model.generate("Lorem ipsum dolor", tokenizer, 1)
expect(generation).equal(data) // Assert that the model completes 'Lorem ipsum dolor' with 'sit'
})
Expand Down
124 changes: 85 additions & 39 deletions discojs/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import { PreTrainedTokenizer } from '@xenova/transformers';
import { WeightsContainer } from '../../index.js'
import type { Dataset } from '../../dataset/index.js'

import { Model } from '../model.js'
import { BatchLogs, Model, EpochLogs } from "../index.js";
import type { Prediction, Sample } from '../model.js'

import { GPTForCausalLM } from './model.js'
import type { EpochLogs, Prediction, Sample } from '../model.js'
import type { GPTConfig } from './config.js'
import { DEFAULT_CONFIG, type GPTConfig } from './config.js'
import evaluate from './evaluate.js';
import { List } from 'immutable';

export type GPTSerialization = {
weights: WeightsContainer
Expand All @@ -21,9 +24,13 @@ export type GPTSerialization = {
export class GPT extends Model {
private readonly model: GPTForCausalLM

readonly #maxBatchCount: number

constructor (partialConfig?: GPTConfig, layersModel?: tf.LayersModel) {
super()

this.model = new GPTForCausalLM(partialConfig, layersModel)
this.#maxBatchCount = partialConfig?.maxIter ?? DEFAULT_CONFIG.maxIter
}

/**
Expand All @@ -38,51 +45,90 @@ export class GPT extends Model {
override async *train(
trainingData: Dataset,
validationData?: Dataset,
epochs = 1,
): AsyncGenerator<EpochLogs, void> {
this.model.compile()
): AsyncGenerator<BatchLogs, EpochLogs> {
this.model.compile();

const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
let batchesLogs = List<BatchLogs>();
for (
let batchNumber = 0;
batchNumber < this.#maxBatchCount;
batchNumber++
) {
const iteration = await batches.next();
if (iteration.done) break;
const batch = iteration.value;

const batchLogs = await this.#runBatch(batch);
tf.dispose(batch);

yield batchLogs;
batchesLogs = batchesLogs.push(batchLogs);
}

const validation = validationData && (await this.#evaluate(validationData));
tharvik marked this conversation as resolved.
Show resolved Hide resolved
return new EpochLogs(batchesLogs, validation);
}

async #runBatch(
batch: tf.TensorContainer,
): Promise<BatchLogs> {
let logs: tf.Logs | undefined;
const trainingArgs: tf.ModelFitDatasetArgs<tf.TensorContainer> = {
epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop
validationData,
callbacks: { onEpochEnd: (_, cur) => { logs = cur }},
await this.model.fitDataset(tf.data.array([batch]), {
epochs: 1,
verbose: 0, // don't pollute
callbacks: {
onEpochEnd: (_, cur) => {
logs = cur;
},
},
});
if (logs === undefined) throw new Error("batch didn't gave any logs");

const { loss, acc: accuracy } = logs;
if (loss === undefined || isNaN(loss))
throw new Error("training loss is undefined or NaN");

return {
accuracy,
loss,
memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
tharvik marked this conversation as resolved.
Show resolved Hide resolved
};
for (let epoch = 0; epoch < epochs; epoch++) {
await this.model.fitDataset(trainingData, trainingArgs);
if (logs === undefined) {
throw new Error("Epoch didn't gave any logs");
}
const { loss, val_acc, val_loss, peakMemory } = logs;
if (loss === undefined || isNaN(loss)) {
throw new Error("Training loss is undefined or nan");
}
const structuredLogs: EpochLogs = {
epoch,
peakMemory,
training: {
loss: logs.loss,
accuracy: logs.acc
}
}
}

if (validationData !== undefined) {
if(val_loss === undefined || isNaN(val_loss) ||
val_acc === undefined || isNaN(val_acc)) {
throw new Error("Validation accuracy or loss is undefined or nan");
async #evaluate(
dataset: Dataset,
): Promise<Record<"accuracy" | "loss", number>> {
const evaluation = await evaluate(
this.model,
dataset.map((t) => {
switch (t) {
case null:
case undefined:
throw new Error("nullish value in dataset");
default:
// TODO unsafe cast
return t as { xs: tf.Tensor2D; ys: tf.Tensor3D };
}
structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss}
}
yield structuredLogs
}
}),
this.config.maxEvalBatches,
);

return {
accuracy: evaluation.val_acc,
loss: evaluation.val_loss,
};
}

override predict (input: Sample): Promise<Prediction> {
const ret = this.model.predict(input)
override predict(input: Sample): Promise<Prediction> {
const ret = this.model.predict(input);
if (Array.isArray(ret)) {
throw new Error('prediction yield many Tensors but should have only returned one')
throw new Error(
"prediction yield many Tensors but should have only returned one",
);
}

return Promise.resolve(ret)
return Promise.resolve(ret);
}

async generate(input: string, tokenizer: PreTrainedTokenizer, newTokens: number = 10): Promise<string> {
Expand Down
46 changes: 31 additions & 15 deletions discojs/src/models/gpt/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,48 @@ class GPTModel extends tf.LayersModel {
await callbacks.onTrainBegin?.()

for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) {
tharvik marked this conversation as resolved.
Show resolved Hide resolved
let accuracyFraction: [number, number] = [0, 0];
let averageLoss = 0
let peakMemory = 0
let iteration = 1
const iterator = await dataset.iterator()
let preprocessingTime = performance.now()
let next = await iterator.next()
preprocessingTime = performance.now() - preprocessingTime

while (next.done !== true && iteration <= this.config.maxIter) {
let weightUpdateTime = performance.now()
await callbacks.onEpochBegin?.(epoch)
const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D }
const lossFn: () => tf.Scalar = () => {

let preprocessingTime = performance.now()
await Promise.all([xs.data(), ys.data()])
preprocessingTime = performance.now() - preprocessingTime

// TODO include as a tensor inside the model
const accTensor = tf.tidy(() => {
const logits = this.apply(xs)
if (Array.isArray(logits)) {
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
}
if (logits instanceof tf.SymbolicTensor) {
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
}
return tf.losses.softmaxCrossEntropy(ys, logits)
}
return tf.metrics.categoricalAccuracy(ys, logits)
})
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
const accSumTensor = accTensor.sum()
const accSum = await accSumTensor.array()
tf.dispose(accSumTensor)
if (typeof accSum !== 'number')
throw new Error('got multiple accuracy sum')
accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize];
tf.dispose([accTensor])

const lossTensor = tf.tidy(() => {
const { grads, value: lossTensor } = this.optimizer.computeGradients(lossFn)
const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
const logits = this.apply(xs)
tharvik marked this conversation as resolved.
Show resolved Hide resolved
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
return tf.losses.softmaxCrossEntropy(ys, logits)
})
const gradsClipped = clipByGlobalNormObj(grads, 1)
this.optimizer.applyGradients(gradsClipped)
return lossTensor
Expand All @@ -89,6 +107,7 @@ class GPTModel extends tf.LayersModel {
const loss = await lossTensor.array()
averageLoss += loss
weightUpdateTime = performance.now() - weightUpdateTime

tf.dispose([xs, ys, lossTensor])

if (
Expand All @@ -100,9 +119,6 @@ class GPTModel extends tf.LayersModel {
console.log(iterationLogs)
}
const memory = tf.memory().numBytes / 1024 / 1024 / 1024
if (memory > peakMemory) {
peakMemory = memory
}
console.log(
`Epoch: ${epoch}`,
`\tStep: ${iteration} / ${this.config.maxIter}`,
Expand All @@ -122,7 +138,7 @@ class GPTModel extends tf.LayersModel {
}
let logs: tf.Logs = {
'loss': averageLoss / iteration,
'peakMemory': peakMemory
'acc': accuracyFraction[0] / accuracyFraction[1],
}
if (evalDataset !== undefined) {
logs = { ...logs, ...await evaluate(this, evalDataset, this.config.maxEvalBatches) }
Expand Down
3 changes: 2 additions & 1 deletion discojs/src/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export { EpochLogs, Model } from './model.js'
export { Model } from './model.js'
export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js";

export { GPT } from './gpt/index.js'
export { GPTConfig } from './gpt/config.js'
Expand Down
Loading