Skip to content

Commit

Permalink
server/tests: add wikitext
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Mar 7, 2024
1 parent 239f911 commit 79c9878
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion server/tests/e2e/federated.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
Disco, TrainingSchemes, client as clients,
aggregator as aggregators, informant, defaultTasks
} from '@epfml/discojs-core'
import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node'
import { NodeImageLoader, NodeTabularLoader, NodeTextLoader } from '@epfml/discojs-node'

import { startServer } from '../../src'

Expand Down Expand Up @@ -81,6 +81,35 @@ describe('end-to-end federated', function () {
return aggregator.model.weights
}

async function wikitextUser (): Promise<WeightsContainer> {
const task = defaultTasks.wikitext.getTask()
console.log('>>', { task })
const data = await (new NodeTextLoader(task).loadAll(['../datasets/wikitext/wiki.train.tokens']))
console.log('~~', { data })

const aggregator = new aggregators.MeanAggregator()
const client = new clients.federated.FederatedClient(url, task, aggregator)
const trainingInformant = new informant.FederatedInformant(task, 10)
const disco = new Disco(task, { scheme: SCHEME, client, aggregator, informant: trainingInformant })

await disco.fit(data)
await disco.close()

assert(
trainingInformant.trainingAccuracy() > 0.6,
`expected training accuracy greater than 0.6 but got ${trainingInformant.trainingAccuracy()}`
)
assert(
trainingInformant.validationAccuracy() > 0.6,
`expected validation accuracy greater than 0.6 but got ${trainingInformant.validationAccuracy()}`
)

if (aggregator.model === undefined) {
throw new Error('model was not set')
}
return aggregator.model.weights
}

it('two cifar10 users reach consensus', async () => {
const [m1, m2] = await Promise.all([cifar10user(), cifar10user()])
assert.isTrue(m1.equals(m2))
Expand All @@ -90,4 +119,9 @@ describe('end-to-end federated', function () {
const [m1, m2] = await Promise.all([titanicUser(), titanicUser()])
assert.isTrue(m1.equals(m2))
})

it('two wikitext users reach consensus', async () => {
const [m1, m2] = await Promise.all([wikitextUser(), wikitextUser()])
assert.isTrue(m1.equals(m2))
})
})

0 comments on commit 79c9878

Please sign in to comment.