From 4a8619965d90ee6490b04e3536ac632930b7f901 Mon Sep 17 00:00:00 2001 From: tharvik Date: Tue, 2 Jul 2024 13:37:29 +0200 Subject: [PATCH] discojs/training/disco: add train helpers --- docs/examples/training.ts | 4 +--- docs/examples/wikitext.ts | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/examples/training.ts b/docs/examples/training.ts index d2ab3936d..615039821 100644 --- a/docs/examples/training.ts +++ b/docs/examples/training.ts @@ -14,9 +14,7 @@ async function runUser (url: URL, task: Task, dataset: data.DataSplit): Promise< const disco = new Disco(task, { url, scheme: 'federated' }) // Run training on the dataset - for await (const round of disco.fit(dataset)) - for await (const epoch of round) - for await (const _ of epoch); + await disco.trainFully(dataset); // Disconnect from the remote server await disco.close() diff --git a/docs/examples/wikitext.ts b/docs/examples/wikitext.ts index b2361382b..0f3ae4145 100644 --- a/docs/examples/wikitext.ts +++ b/docs/examples/wikitext.ts @@ -29,9 +29,7 @@ async function main(): Promise { const aggregator = new aggregators.MeanAggregator() const client = new clients.federated.FederatedClient(url, task, aggregator) const disco = new Disco(task, { scheme: 'federated', client, aggregator }) - for await (const round of disco.fit(dataset)) - for await (const epoch of round) - for await (const _ of epoch); + await disco.trainFully(dataset); // Get the model and complete the prompt if (aggregator.model === undefined) {