Skip to content

Commit

Permalink
server/tests: add wikitext
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Mar 11, 2024
1 parent ec5d1df commit aa39c7a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/default_tasks/wikitext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export const wikitext: TaskProvider = {
dataType: 'text',
modelID: 'wikitext-103-raw-model',
validationSplit: 0.2, // TODO: is this used somewhere? because train, eval and test are already split in dataset
epochs: 10_000,
epochs: 10,
// constructing a batch is taken care automatically in the dataset to make things faster
// so we fake a batch size of 1
batchSize: 1,
Expand Down
37 changes: 33 additions & 4 deletions server/tests/e2e/federated.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ import fs from 'node:fs/promises'
import path from 'node:path'
import type { Server } from 'node:http'
import { Range } from 'immutable'
import { assert } from 'chai'
import { assert, expect } from 'chai'

import type { WeightsContainer } from '@epfml/discojs-core'
import {
Disco, TrainingSchemes, client as clients,
Disco, TrainingSchemes, client as clients, data,
aggregator as aggregators, informant, defaultTasks
} from '@epfml/discojs-core'
import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node'
import { NodeImageLoader, NodeTabularLoader, NodeTextLoader } from '@epfml/discojs-node'

import { startServer } from '../../src'

const SCHEME = TrainingSchemes.FEDERATED

describe('end-to-end federated', function () {
this.timeout(120_000)
this.timeout(100_000)

let server: Server
let url: URL
Expand Down Expand Up @@ -81,13 +81,42 @@ describe('end-to-end federated', function () {
return aggregator.model.weights
}

async function wikitextUser (): Promise<void> {
const task = defaultTasks.wikitext.getTask()
const loader = new NodeTextLoader(task)
const dataSplit: data.DataSplit = {
train: await data.TextData.init((await loader.load('../datasets/wikitext/wiki.train.tokens')), task),
validation: await data.TextData.init(await loader.load('../datasets/wikitext/wiki.valid.tokens'), task)
}

const aggregator = new aggregators.MeanAggregator()
const client = new clients.federated.FederatedClient(url, task, aggregator)
const trainingInformant = new informant.FederatedInformant(task, 10)
const disco = new Disco(task, { scheme: SCHEME, client, aggregator, informant: trainingInformant })

await disco.fit(dataSplit)
await disco.close()

expect(trainingInformant.losses.first()).to.be.above(trainingInformant.losses.last())
}

it('two cifar10 users reach consensus', async () => {
this.timeout(90_000)

const [m1, m2] = await Promise.all([cifar10user(), cifar10user()])
assert.isTrue(m1.equals(m2))
})

it('two titanic users reach consensus', async () => {
this.timeout(30_000)

const [m1, m2] = await Promise.all([titanicUser(), titanicUser()])
assert.isTrue(m1.equals(m2))
})

it('trains wikitext', async () => {
this.timeout(120_000)

await wikitextUser()
})
})

0 comments on commit aa39c7a

Please sign in to comment.