From aa39c7a9e0874c3e822a981408704049667580d3 Mon Sep 17 00:00:00 2001 From: tharvik Date: Thu, 7 Mar 2024 14:24:15 +0100 Subject: [PATCH] server/tests: add wikitext --- .../src/default_tasks/wikitext.ts | 2 +- server/tests/e2e/federated.spec.ts | 37 +++++++++++++++++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/discojs/discojs-core/src/default_tasks/wikitext.ts b/discojs/discojs-core/src/default_tasks/wikitext.ts index 0b8800780..7abe83900 100644 --- a/discojs/discojs-core/src/default_tasks/wikitext.ts +++ b/discojs/discojs-core/src/default_tasks/wikitext.ts @@ -20,7 +20,7 @@ export const wikitext: TaskProvider = { dataType: 'text', modelID: 'wikitext-103-raw-model', validationSplit: 0.2, // TODO: is this used somewhere? because train, eval and test are already split in dataset - epochs: 10_000, + epochs: 10, // constructing a batch is taken care automatically in the dataset to make things faster // so we fake a batch size of 1 batchSize: 1, diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index a2ae07415..bb6287e2a 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -2,21 +2,21 @@ import fs from 'node:fs/promises' import path from 'node:path' import type { Server } from 'node:http' import { Range } from 'immutable' -import { assert } from 'chai' +import { assert, expect } from 'chai' import type { WeightsContainer } from '@epfml/discojs-core' import { - Disco, TrainingSchemes, client as clients, + Disco, TrainingSchemes, client as clients, data, aggregator as aggregators, informant, defaultTasks } from '@epfml/discojs-core' -import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node' +import { NodeImageLoader, NodeTabularLoader, NodeTextLoader } from '@epfml/discojs-node' import { startServer } from '../../src' const SCHEME = TrainingSchemes.FEDERATED describe('end-to-end federated', function () { - this.timeout(120_000) + this.timeout(100_000) let server: Server let url: URL @@ -81,13 +81,42 @@ describe('end-to-end federated', function () { return aggregator.model.weights } + async function wikitextUser (): Promise { + const task = defaultTasks.wikitext.getTask() + const loader = new NodeTextLoader(task) + const dataSplit: data.DataSplit = { + train: await data.TextData.init((await loader.load('../datasets/wikitext/wiki.train.tokens')), task), + validation: await data.TextData.init(await loader.load('../datasets/wikitext/wiki.valid.tokens'), task) + } + + const aggregator = new aggregators.MeanAggregator() + const client = new clients.federated.FederatedClient(url, task, aggregator) + const trainingInformant = new informant.FederatedInformant(task, 10) + const disco = new Disco(task, { scheme: SCHEME, client, aggregator, informant: trainingInformant }) + + await disco.fit(dataSplit) + await disco.close() + + expect(trainingInformant.losses.first()).to.be.above(trainingInformant.losses.last()) + } + it('two cifar10 users reach consensus', async () => { + this.timeout(90_000) + const [m1, m2] = await Promise.all([cifar10user(), cifar10user()]) assert.isTrue(m1.equals(m2)) }) it('two titanic users reach consensus', async () => { + this.timeout(30_000) + const [m1, m2] = await Promise.all([titanicUser(), titanicUser()]) assert.isTrue(m1.equals(m2)) }) + + it('trains wikitext', async () => { + this.timeout(120_000) + + await wikitextUser() + }) })