Skip to content

Commit

Permalink
discojs: simplify types
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Sep 24, 2024
1 parent 95a8210 commit 0ba102d
Show file tree
Hide file tree
Showing 42 changed files with 186 additions and 191 deletions.
4 changes: 2 additions & 2 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { parse } from 'ts-command-line-args'
import { Map, Set } from 'immutable'

import type { TaskProvider } from '@epfml/discojs'
import type { DataType, TaskProvider } from "@epfml/discojs";
import { defaultTasks } from '@epfml/discojs'

interface BenchmarkArguments {
provider: TaskProvider
provider: TaskProvider<DataType>
numberOfUsers: number
epochs: number
roundDuration: number
Expand Down
2 changes: 1 addition & 1 deletion cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
}

async function runUser<D extends DataType>(
task: Task,
task: Task<D>,
url: URL,
data: Dataset<Raw[D]>,
): Promise<List<RoundLogs>> {
Expand Down
32 changes: 19 additions & 13 deletions discojs-node/src/model_loader.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import fs from 'node:fs/promises'
import { serialization, models } from '@epfml/discojs'
import fs from "node:fs/promises";

export async function saveModelToDisk(model: models.Model, modelFolder: string, modelFileName: string): Promise<void> {
try {
await fs.access(modelFolder)
} catch {
await fs.mkdir(modelFolder)
}
const encoded = await serialization.model.encode(model)
await fs.writeFile(`${modelFolder}/${modelFileName}`, encoded)
import type { models, DataType } from "@epfml/discojs";
import { serialization } from "@epfml/discojs";

export async function saveModelToDisk(
model: models.Model<DataType>,
modelFolder: string,
modelFileName: string,
): Promise<void> {
const encoded = await serialization.model.encode(model);

await fs.mkdir(modelFolder, { recursive: true });
await fs.writeFile(`${modelFolder}/${modelFileName}`, encoded);
}

export async function loadModelFromDisk(modelPath: string): Promise<models.Model> {
const content = await fs.readFile(modelPath)
return await serialization.model.decode(content) as models.GPT
export async function loadModelFromDisk(
modelPath: string,
): Promise<models.Model<DataType>> {
const content = await fs.readFile(modelPath);

return await serialization.model.decode(content);
}
9 changes: 6 additions & 3 deletions discojs/src/aggregator/get.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { Task } from '../index.js'
import type { DataType, Task } from '../index.js'
import { aggregator } from '../index.js'

type AggregatorOptions = Partial<{
scheme: Task['trainingInformation']['scheme'], // if undefined, fallback on task.trainingInformation.scheme
scheme: Task<DataType>["trainingInformation"]["scheme"]; // if undefined, fallback on task.trainingInformation.scheme
roundCutOff: number, // MeanAggregator
threshold: number, // MeanAggregator
thresholdType: 'relative' | 'absolute', // MeanAggregator
Expand All @@ -26,7 +26,10 @@ type AggregatorOptions = Partial<{
* @param options Options passed down to the aggregator's constructor
* @returns The aggregator
*/
export function getAggregator(task: Task, options: AggregatorOptions = {}): aggregator.Aggregator {
export function getAggregator(
task: Task<DataType>,
options: AggregatorOptions = {},
): aggregator.Aggregator {
const aggregatorType = task.trainingInformation.aggregator ?? 'mean'
const scheme = options.scheme ?? task.trainingInformation.scheme

Expand Down
14 changes: 10 additions & 4 deletions discojs/src/client/client.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import axios from 'axios'

import type { Model, Task, WeightsContainer, RoundStatus } from '../index.js'
import type {
DataType,
Model,
RoundStatus,
Task,
WeightsContainer,
} from "../index.js";
import { serialization } from '../index.js'
import type { NodeID } from './types.js'
import type { EventConnection } from './event_connection.js'
Expand All @@ -27,7 +33,7 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{

constructor (
public readonly url: URL, // The network server's URL to connect to
public readonly task: Task, // The client's corresponding task
public readonly task: Task<DataType>, // The client's corresponding task
public readonly aggregator: Aggregator,
) {
super()
Expand All @@ -38,7 +44,7 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
* This method is overriden by the federated and decentralized clients
* By default, it fetches and returns the server's base model
*/
async connect(): Promise<Model> {
async connect(): Promise<Model<DataType>> {
return this.getLatestModel()
}

Expand All @@ -51,7 +57,7 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
* Fetches the latest model available on the network's server, for the adequate task.
* @returns The latest model
*/
async getLatestModel (): Promise<Model> {
async getLatestModel (): Promise<Model<DataType>> {
const url = new URL('', this.url.href)
if (!url.pathname.endsWith('/')) {
url.pathname += '/'
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/client/decentralized/decentralized_client.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import createDebug from "debug";
import { Map, Set } from 'immutable'

import type { Model, WeightsContainer } from "../../index.js";
import type { DataType, Model, WeightsContainer } from "../../index.js";
import { serialization } from "../../index.js";
import { Client, type NodeID } from '../index.js'
import { type, type ClientConnected } from '../messages.js'
Expand Down Expand Up @@ -38,7 +38,7 @@ export class DecentralizedClient extends Client {
* create peer-to-peer WebRTC connections with peers. The server is used to exchange
* peers network information.
*/
async connect(): Promise<Model> {
async connect(): Promise<Model<DataType>> {
const model = await super.connect() // Get the server base model
const serverURL = new URL('', this.url.href)
switch (this.url.protocol) {
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/client/federated/federated_client.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import createDebug from "debug";

import { serialization } from "../../index.js";
import type { Model, RoundStatus, WeightsContainer } from "../../index.js";
import type { DataType, Model, RoundStatus, WeightsContainer } from "../../index.js";
import { Client } from "../client.js";
import { type, type ClientConnected } from "../messages.js";
import {
Expand Down Expand Up @@ -75,7 +75,7 @@ export class FederatedClient extends Client {
* as well as the latest training information: latest global model, current round and
* whether we are waiting for more participants.
*/
async connect(): Promise<Model> {
async connect(): Promise<Model<DataType>> {
const model = await super.connect() // Get the server base model

const serverURL = new URL("", this.url.href);
Expand Down
10 changes: 7 additions & 3 deletions discojs/src/client/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { Task } from '../index.js'
import type { DataType, Task } from '../index.js'
import { client as clients, type aggregator } from '../index.js'

// Time to wait for the others in milliseconds.
Expand All @@ -10,8 +10,12 @@ export async function timeout (ms = MAX_WAIT_PER_ROUND, errorMsg: string = 'time
})
}

export function getClient(trainingScheme: Required<Task['trainingInformation']['scheme']>,
serverURL: URL, task: Task, aggregator: aggregator.Aggregator): clients.Client {
export function getClient(
trainingScheme: Task<DataType>["trainingInformation"]["scheme"],
serverURL: URL,
task: Task<DataType>,
aggregator: aggregator.Aggregator,
): clients.Client {

switch (trainingScheme) {
case 'decentralized':
Expand Down
4 changes: 1 addition & 3 deletions discojs/src/models/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ import type { BatchLogs, EpochLogs } from "./logs.js";
* Allow for various implementation of models (various train function, tensor-library, ...)
**/
// TODO make it typesafe: same shape of data/input/weights
export abstract class Model<D extends DataType = DataType>
implements Disposable
{
export abstract class Model<D extends DataType> implements Disposable {
// TODO don't allow external access but upgrade train to return weights on every epoch
/** Return training state */
abstract get weights(): WeightsContainer;
Expand Down
6 changes: 3 additions & 3 deletions discojs/src/models/tfjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import { BatchLogs } from './index.js'
import { Model } from './index.js'
import { EpochLogs } from './logs.js'

type Serialized<D extends DataType = DataType> = [D, tf.io.ModelArtifacts]
type Serialized<D extends DataType> = [D, tf.io.ModelArtifacts];

/** TensorFlow JavaScript model with standard training */
export class TFJS<D extends DataType = DataType> extends Model<D> {
export class TFJS<D extends DataType> extends Model<D> {
/** Wrap the given trainable model */
constructor (
public readonly datatype: D,
Expand Down Expand Up @@ -168,7 +168,7 @@ export class TFJS<D extends DataType = DataType> extends Model<D> {
return ret
}

static async deserialize<D extends DataType = DataType>([
static async deserialize<D extends DataType>([
datatype,
artifacts,
]: Serialized<D>): Promise<TFJS<D>> {
Expand Down
8 changes: 5 additions & 3 deletions discojs/src/serialization/model.spec.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { assert, expect } from 'chai'
import * as tf from '@tensorflow/tfjs'

import type { Model } from '../index.js'
import type { DataType, Model } from "../index.js";
import type { GPTConfig } from '../models/index.js'
import { serialization, models } from '../index.js'

async function getRawWeights (model: Model): Promise<Array<[number, Float32Array]>> {
async function getRawWeights(
model: Model<DataType>,
): Promise<Array<[number, Float32Array]>> {
return Array.from(
(await Promise.all(
model.weights.weights.map(async (w) => await w.data<'float32'>()))
Expand Down Expand Up @@ -33,7 +35,7 @@ describe('serialization', () => {
const decoded = await serialization.model.decode(encoded)

expect(decoded).to.be.an.instanceof(models.TFJS);
expect((decoded as models.TFJS).datatype).to.equal("image")
expect((decoded as models.TFJS<DataType>).datatype).to.equal("image")
assert.sameDeepOrderedMembers(
await getRawWeights(model),
await getRawWeights(decoded)
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/serialization/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export function isEncoded (raw: unknown): raw is Encoded {
return raw instanceof Uint8Array
}

export async function encode(model: Model): Promise<Encoded> {
export async function encode(model: Model<DataType>): Promise<Encoded> {
let encoded;
switch (true) {
case model instanceof models.TFJS: {
Expand All @@ -39,7 +39,7 @@ export async function encode(model: Model): Promise<Encoded> {
return new Uint8Array(encoded);
}

export async function decode (encoded: unknown): Promise<Model> {
export async function decode(encoded: unknown): Promise<Model<DataType>> {
if (!isEncoded(encoded)) {
throw new Error("Invalid encoding, raw encoding isn't an instance of Uint8Array")
}
Expand Down
14 changes: 8 additions & 6 deletions discojs/src/task/task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { isTrainingInformation, type TrainingInformation } from './training_info

export type TaskID = string

export interface Task<D extends DataType = DataType> {
export interface Task<D extends DataType> {
id: TaskID
displayInformation: DisplayInformation
trainingInformation: TrainingInformation<D>
Expand All @@ -14,13 +14,13 @@ export function isTaskID (obj: unknown): obj is TaskID {
return typeof obj === 'string'
}

export function isTask (raw: unknown): raw is Task {
export function isTask (raw: unknown): raw is Task<DataType> {
if (typeof raw !== 'object' || raw === null) {
return false
}

const { id, displayInformation, trainingInformation }:
Partial<Record<keyof Task, unknown>> = raw
Partial<Record<keyof Task<DataType>, unknown>> = raw

if (!isTaskID(id) ||
!isDisplayInformation(displayInformation) ||
Expand All @@ -29,9 +29,11 @@ export function isTask (raw: unknown): raw is Task {
return false
}

const repack = { id, displayInformation, trainingInformation }
const _correct: Task = repack
const _total: Record<keyof Task, unknown> = repack
const _: Task<DataType> = {
id,
displayInformation,
trainingInformation,
} satisfies Record<keyof Task<DataType>, unknown>;

return true
}
12 changes: 7 additions & 5 deletions discojs/src/task/task_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import axios from 'axios'
import createDebug from "debug";
import { Map } from 'immutable'

import type { Model } from '../index.js'
import type { DataType, Model } from '../index.js'
import { serialization } from '../index.js'

import type { Task, TaskID } from './task.js'
Expand All @@ -12,10 +12,10 @@ const debug = createDebug("discojs:task:handlers");

const TASK_ENDPOINT = 'tasks'

export async function pushTask (
export async function pushTask<D extends DataType>(
url: URL,
task: Task,
model: Model
task: Task<D>,
model: Model<D>,
): Promise<void> {
await axios.post(
url.href + TASK_ENDPOINT,
Expand All @@ -27,7 +27,9 @@ export async function pushTask (
)
}

export async function fetchTasks (url: URL): Promise<Map<TaskID, Task>> {
export async function fetchTasks(
url: URL,
): Promise<Map<TaskID, Task<DataType>>> {
const response = await axios.get(new URL(TASK_ENDPOINT, url).href)
const tasks: unknown = response.data

Expand Down
2 changes: 1 addition & 1 deletion discojs/src/task/task_provider.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { DataType, Model, Task } from "../index.js";

export interface TaskProvider<D extends DataType = DataType> {
export interface TaskProvider<D extends DataType> {
getTask(): Task<D>;
// Create the corresponding model ready for training (compiled)
getModel(): Promise<Model<D>>;
Expand Down
16 changes: 9 additions & 7 deletions discojs/src/task/training_information.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ interface Privacy {
noiseScale?: number;
}

export type TrainingInformation<D extends DataType = DataType> = {
export type TrainingInformation<D extends DataType> = {
// epochs: number of epochs to run training for
epochs: number;
// roundDuration: number of epochs between each weight sharing round.
Expand Down Expand Up @@ -100,7 +100,7 @@ function isPrivacy(raw: unknown): raw is Privacy {

export function isTrainingInformation(
raw: unknown,
): raw is TrainingInformation {
): raw is TrainingInformation<DataType> {
if (typeof raw !== "object" || raw === null) {
return false;
}
Expand All @@ -118,7 +118,7 @@ export function isTrainingInformation(
scheme,
validationSplit,
tensorBackend,
}: Partial<Record<keyof TrainingInformation, unknown>> = raw;
}: Partial<Record<keyof TrainingInformation<DataType>, unknown>> = raw;

if (
typeof epochs !== "number" ||
Expand Down Expand Up @@ -178,7 +178,7 @@ export function isTrainingInformation(
case "image": {
type ImageOnly = Omit<
TrainingInformation<"image">,
keyof TrainingInformation
keyof TrainingInformation<DataType>
>;

const { LABEL_LIST, IMAGE_W, IMAGE_H }: Partial<ImageOnly> = raw;
Expand Down Expand Up @@ -206,7 +206,7 @@ export function isTrainingInformation(
case "tabular": {
type TabularOnly = Omit<
TrainingInformation<"tabular">,
keyof TrainingInformation
keyof TrainingInformation<DataType>
>;

const { inputColumns, outputColumn }: Partial<TabularOnly> = raw;
Expand All @@ -233,8 +233,10 @@ export function isTrainingInformation(
const {
maxSequenceLength,
tokenizer,
}: Partial<Omit<TrainingInformation<"text">, keyof TrainingInformation>> =
raw;
}: Partial<
Omit<TrainingInformation<"text">,
keyof TrainingInformation<DataType>>
> = raw;

if (
(typeof tokenizer !== "string" &&
Expand Down
Loading

0 comments on commit 0ba102d

Please sign in to comment.