Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: framework agnostic preprocessing #781

Merged
merged 31 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a512f5a
server/tests: bump timeout
tharvik Sep 11, 2024
0968620
server/test/status: avoid arbitrary wait
tharvik Oct 22, 2024
c19b42c
discojs/processing: expand
tharvik Aug 23, 2024
8266afc
discojs/dataset: allow subclass
tharvik Sep 4, 2024
33e0c7c
discojs/dataset: add cache
tharvik Sep 4, 2024
76ea663
discojs/dataset: resolve batch in parallel
tharvik Sep 4, 2024
7433b0b
discojs/model: generic on datatype
tharvik Aug 29, 2024
ce187ba
discojs/disco: option to keep preprocessed
tharvik Sep 5, 2024
a52865b
discojs/types: use better names
tharvik Sep 10, 2024
977b66f
discojs/validator: flatten
tharvik Sep 10, 2024
35d09c5
discojs/processing: add validator specifics
tharvik Sep 10, 2024
e09b6ff
discojs/types/text: single output token
tharvik Sep 10, 2024
654c4b6
*: bump deps
tharvik Sep 13, 2024
160f0e1
discojs/model: generic predict
tharvik Sep 11, 2024
2cb4864
discojs/tabular: single output
tharvik Sep 13, 2024
b73d764
discojs: rm old processing
tharvik Sep 13, 2024
f02b7e8
discojs/task: generic on datatype
tharvik Sep 13, 2024
707447f
discojs: simplify types
tharvik Sep 23, 2024
e457cd7
webapp/tsconfig: add cypress
tharvik Sep 30, 2024
32715f2
discojs/dataset/image: documente
tharvik Oct 1, 2024
d16c283
*: fix typo "excepted" -> "expected"
tharvik Oct 3, 2024
225bd9a
discojs/dataset: clearer `this` binding
tharvik Oct 3, 2024
e9e46eb
server/tasks: drop remote model loading
tharvik Oct 4, 2024
815614e
discojs/model/tfjs: drop text support
tharvik Oct 4, 2024
2ec74df
discojs/dataset: add unbatch
tharvik Oct 4, 2024
98a64c2
discojs/validator: output Inferred
tharvik Oct 4, 2024
6a86f83
discojs/model/gpt: commonize predict
tharvik Oct 4, 2024
3f53cdb
discojs/disco: document preprocessOnce
tharvik Oct 23, 2024
4f0bc55
discojs/types: add DataFormat namespace
tharvik Oct 23, 2024
d4b1c25
webapp/testing: fix correct color
tharvik Oct 28, 2024
f680523
webapp/training: nicer graphs
tharvik Oct 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { parse } from 'ts-command-line-args'
import { Map, Set } from 'immutable'

import type { TaskProvider } from '@epfml/discojs'
import type { DataType, TaskProvider } from "@epfml/discojs";
import { defaultTasks } from '@epfml/discojs'

interface BenchmarkArguments {
provider: TaskProvider
provider: TaskProvider<DataType>
numberOfUsers: number
epochs: number
roundDuration: number
Expand Down Expand Up @@ -37,7 +37,7 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
)

const supportedTasks = Map(
Set.of(
Set.of<TaskProvider<"image"> | TaskProvider<"tabular">>(
defaultTasks.cifar10,
defaultTasks.lusCovid,
defaultTasks.simpleFace,
Expand Down Expand Up @@ -67,6 +67,6 @@ export const args: BenchmarkArguments = {

return task;
},
getModel: provider.getModel,
getModel: () => provider.getModel(),
tharvik marked this conversation as resolved.
Show resolved Hide resolved
},
};
56 changes: 16 additions & 40 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { List } from "immutable";
import { parse } from "ts-command-line-args";
import * as tf from "@tensorflow/tfjs"
import { AutoTokenizer } from "@xenova/transformers";

import { fetchTasks, models, async_iterator, defaultTasks, processing } from "@epfml/discojs";
import { fetchTasks, models, async_iterator, defaultTasks, processing, Task } from "@epfml/discojs";
import { loadModelFromDisk, loadText } from '@epfml/discojs-node'

import { Server } from "server";
Expand Down Expand Up @@ -41,15 +41,6 @@ const args = { ...defaultArgs, ...parsedArgs }
* Benchmark results are reported in https://github.com/epfml/disco/pull/659
*/

function intoTFGenerator<T extends tf.TensorContainer>(
iter: AsyncIterable<T>,
): tf.data.Dataset<T> {
// @ts-expect-error generator
return tf.data.generator(async function* () {
yield* iter;
});
}

async function main(args: Required<CLIArguments>): Promise<void> {
const { inference: benchmarkInference, modelType,
contextLength, batchSize, modelPath } = args
Expand All @@ -59,7 +50,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {

// Fetch the wikitext task from the server
const tasks = await fetchTasks(url)
const task = tasks.get('llm_task')
const task = tasks.get('llm_task') as Task<'text'> | undefined
if (task === undefined) { throw new Error('task not found') }

const tokenizerName = task.trainingInformation.tokenizer
Expand Down Expand Up @@ -89,32 +80,10 @@ async function main(args: Required<CLIArguments>): Promise<void> {
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')

const maxLength = task.trainingInformation.maxSequenceLength ?? (tokenizer.model_max_length as number) + 1
// TODO will be easier when preproccessing is redone
const preprocessedDataset = intoTFGenerator(
dataset
.map((line) =>
processing.tokenizeAndLeftPad(line, tokenizer, maxLength),
)
.batch(batchSize)
.map((batch) =>
tf.tidy(() => ({
xs: tf.tensor2d(
batch.map((tokens) => tokens.slice(0, -1)).toArray(),
),
ys: tf.stack(
batch
.map(
(tokens) =>
tf.oneHot(
tokens.slice(1),
tokenizer.model.vocab.length + 1,
) as tf.Tensor2D,
)
.toArray(),
) as tf.Tensor3D,
})),
),
);
const preprocessedDataset = dataset
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, maxLength))
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
.batch(batchSize);

// Init and train the model
const model = new models.GPT(config)
Expand Down Expand Up @@ -143,13 +112,20 @@ async function main(args: Required<CLIArguments>): Promise<void> {
const iterations = 10
console.log("Generating", nbNewTokens, "new tokens")

let tokens = List(
(tokenizer(prompt, { return_tensor: false }) as { input_ids: number[] })
.input_ids,
);

let inferenceTime = 0
for (let i = 0; i < iterations; i++) {
const timeStart = performance.now()
const _ = await model.generate(prompt, tokenizer, nbNewTokens)
for (let n = 0; n < nbNewTokens; n++) {
const next: number = (await model.predict(List.of(tokens))).first();
tokens = tokens.push(next)
}
inferenceTime += performance.now() - timeStart
}
// Overall average includes tokenization, token sampling and de-tokenization
console.log(`Inference time: ${(inferenceTime/ nbNewTokens / iterations).toFixed(2)} ms/token`)
}
await new Promise((resolve, reject) => {
Expand Down
20 changes: 15 additions & 5 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@ import "@tensorflow/tfjs-node"
import { List, Range } from 'immutable'
import fs from 'node:fs/promises'

import type { RoundLogs, Task, TaskProvider, TypedLabeledDataset } from '@epfml/discojs'
import type {
Dataset,
DataFormat,
DataType,
RoundLogs,
Task,
TaskProvider,
} from "@epfml/discojs";
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'
import { Server } from 'server'

Expand All @@ -18,10 +25,10 @@ async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
return ret;
}

async function runUser(
task: Task,
async function runUser<D extends DataType>(
task: Task<D>,
url: URL,
data: TypedLabeledDataset,
data: Dataset<DataFormat.Raw[D]>,
): Promise<List<RoundLogs>> {
const trainingScheme = task.trainingInformation.scheme
const aggregator = aggregators.getAggregator(task)
Expand All @@ -34,7 +41,10 @@ async function runUser(
return logs;
}

async function main (provider: TaskProvider, numberOfUsers: number): Promise<void> {
async function main<D extends DataType>(
provider: TaskProvider<D>,
numberOfUsers: number,
): Promise<void> {
const task = provider.getTask()
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })
Expand Down
36 changes: 20 additions & 16 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import path from "node:path";

import type { Dataset, Image, Task, TypedLabeledDataset } from "@epfml/discojs";
import type {
Dataset,
DataFormat,
DataType,
Image,
Task,
} from "@epfml/discojs";
import { loadCSV, loadImagesInDir } from "@epfml/discojs-node";
import { Repeat } from "immutable";

async function loadSimpleFaceData(): Promise<Dataset<[Image, string]>> {
async function loadSimpleFaceData(): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "simple_face");

const [adults, childs]: Dataset<[Image, string]>[] = [
Expand All @@ -15,7 +21,7 @@ async function loadSimpleFaceData(): Promise<Dataset<[Image, string]>> {
return adults.chain(childs);
}

async function loadLusCovidData(): Promise<Dataset<[Image, string]>> {
async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "lus_covid");

const [positive, negative]: Dataset<[Image, string]>[] = [
Expand All @@ -30,24 +36,22 @@ async function loadLusCovidData(): Promise<Dataset<[Image, string]>> {
return positive.chain(negative);
}

export async function getTaskData(task: Task): Promise<TypedLabeledDataset> {
export async function getTaskData<D extends DataType>(
task: Task<D>,
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (task.id) {
case "simple_face":
return ["image", await loadSimpleFaceData()];
return (await loadSimpleFaceData()) as Dataset<DataFormat.Raw[D]>;
case "titanic":
return [
"tabular",
loadCSV(path.join("..", "datasets", "titanic_train.csv")),
];
return loadCSV(
path.join("..", "datasets", "titanic_train.csv"),
) as Dataset<DataFormat.Raw[D]>;
case "cifar10":
return [
"image",
(await loadImagesInDir(path.join("..", "datasets", "CIFAR10"))).zip(
Repeat("cat"),
),
];
return (
await loadImagesInDir(path.join("..", "datasets", "CIFAR10"))
).zip(Repeat("cat")) as Dataset<DataFormat.Raw[D]>;
case "lus_covid":
return ["image", await loadLusCovidData()];
return (await loadLusCovidData()) as Dataset<DataFormat.Raw[D]>;
default:
throw new Error(`Data loader for ${task.id} not implemented.`);
}
Expand Down
2 changes: 1 addition & 1 deletion discojs-node/src/loaders/csv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export function load(path: string): Dataset<Partial<Record<string, string>>> {

for await (const row of stream) {
if (!isRecordOfString(row))
throw new Error("excepted object of string to string");
throw new Error("expected object of string to string");
yield row;
}
});
Expand Down
32 changes: 19 additions & 13 deletions discojs-node/src/model_loader.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import fs from 'node:fs/promises'
import { serialization, models } from '@epfml/discojs'
import fs from "node:fs/promises";

export async function saveModelToDisk(model: models.Model, modelFolder: string, modelFileName: string): Promise<void> {
try {
await fs.access(modelFolder)
} catch {
await fs.mkdir(modelFolder)
}
const encoded = await serialization.model.encode(model)
await fs.writeFile(`${modelFolder}/${modelFileName}`, encoded)
import type { models, DataType } from "@epfml/discojs";
import { serialization } from "@epfml/discojs";

export async function saveModelToDisk(
model: models.Model<DataType>,
modelFolder: string,
modelFileName: string,
): Promise<void> {
const encoded = await serialization.model.encode(model);

await fs.mkdir(modelFolder, { recursive: true });
await fs.writeFile(`${modelFolder}/${modelFileName}`, encoded);
}

export async function loadModelFromDisk(modelPath: string): Promise<models.Model> {
const content = await fs.readFile(modelPath)
return await serialization.model.decode(content) as models.GPT
export async function loadModelFromDisk(
modelPath: string,
): Promise<models.Model<DataType>> {
const content = await fs.readFile(modelPath);

return await serialization.model.decode(content);
}
2 changes: 1 addition & 1 deletion discojs-web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@
"@types/papaparse": "5",
"jsdom": "25",
"nodemon": "3",
"vitest": "1"
"vitest": "2"
}
}
2 changes: 1 addition & 1 deletion discojs-web/src/loaders/csv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export function load(file: File): Dataset<Partial<Record<string, string>>> {

const rows = results.data.map((row) => {
if (!isRecordOfString(row))
throw new Error("excepted object of string to string");
throw new Error("expected object of string to string");

return row;
});
Expand Down
4 changes: 3 additions & 1 deletion discojs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
},
"homepage": "https://github.com/epfml/disco#readme",
"dependencies": {
"@jimp/core": "1",
"@jimp/plugin-resize": "1",
"@msgpack/msgpack": "^3.0.0-beta2",
"@tensorflow/tfjs": "4",
"@xenova/transformers": "2",
Expand All @@ -31,7 +33,7 @@
},
"devDependencies": {
"@tensorflow/tfjs-node": "4",
"@types/chai": "4",
"@types/chai": "5",
"@types/mocha": "10",
"@types/simple-peer": "9",
"chai": "5",
Expand Down
9 changes: 6 additions & 3 deletions discojs/src/aggregator/get.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { Task } from '../index.js'
import type { DataType, Task } from '../index.js'
import { aggregator } from '../index.js'

type AggregatorOptions = Partial<{
scheme: Task['trainingInformation']['scheme'], // if undefined, fallback on task.trainingInformation.scheme
scheme: Task<DataType>["trainingInformation"]["scheme"]; // if undefined, fallback on task.trainingInformation.scheme
roundCutOff: number, // MeanAggregator
threshold: number, // MeanAggregator
thresholdType: 'relative' | 'absolute', // MeanAggregator
Expand All @@ -26,7 +26,10 @@ type AggregatorOptions = Partial<{
* @param options Options passed down to the aggregator's constructor
* @returns The aggregator
*/
export function getAggregator(task: Task, options: AggregatorOptions = {}): aggregator.Aggregator {
export function getAggregator(
task: Task<DataType>,
options: AggregatorOptions = {},
): aggregator.Aggregator {
const aggregationStrategy = task.trainingInformation.aggregationStrategy ?? 'mean'
const scheme = options.scheme ?? task.trainingInformation.scheme

Expand Down
14 changes: 10 additions & 4 deletions discojs/src/client/client.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import createDebug from "debug";

import type { Model, Task, WeightsContainer, RoundStatus } from '../index.js'
import type {
DataType,
Model,
RoundStatus,
Task,
WeightsContainer,
} from "../index.js";
import { serialization } from '../index.js'
import type { NodeID } from './types.js'
import type { EventConnection } from './event_connection.js'
Expand Down Expand Up @@ -38,7 +44,7 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{

constructor (
public readonly url: URL, // The network server's URL to connect to
public readonly task: Task, // The client's corresponding task
public readonly task: Task<DataType>, // The client's corresponding task
public readonly aggregator: Aggregator,
) {
super()
Expand All @@ -61,7 +67,7 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
* This method is overriden by the federated and decentralized clients
* By default, it fetches and returns the server's base model
*/
async connect(): Promise<Model> {
async connect(): Promise<Model<DataType>> {
return this.getLatestModel()
}

Expand Down Expand Up @@ -164,7 +170,7 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
* Fetches the latest model available on the network's server, for the adequate task.
* @returns The latest model
*/
async getLatestModel (): Promise<Model> {
async getLatestModel (): Promise<Model<DataType>> {
const url = new URL('', this.url.href)
if (!url.pathname.endsWith('/')) {
url.pathname += '/'
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/client/decentralized/decentralized_client.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import createDebug from "debug";
import { Map, Set } from 'immutable'

import type { Model, WeightsContainer } from "../../index.js";
import type { DataType, Model, WeightsContainer } from "../../index.js";
import { serialization } from "../../index.js";
import { Client, shortenId } from '../client.js'
import { type NodeID } from '../index.js'
Expand Down Expand Up @@ -44,7 +44,7 @@ export class DecentralizedClient extends Client {
* create peer-to-peer WebRTC connections with peers. The server is used to exchange
* peers network information.
*/
override async connect(): Promise<Model> {
override async connect(): Promise<Model<DataType>> {
const model = await super.connect() // Get the server base model
const serverURL = new URL('', this.url.href)
switch (this.url.protocol) {
Expand Down
Loading