Skip to content

Commit

Permalink
Add check for request structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Joanna Grycz committed Sep 24, 2024
1 parent e5cc8c6 commit 3026609
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion generative-ai/snippets/test/gemma2Prediction.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,47 @@ const tpuResponse =
'The sky appears blue due to a phenomenon called **Rayleigh scattering**.';

describe('Gemma2 predictions', async () => {
const gemma2Endpoint =
'projects/your-project-id/locations/your-vertex-endpoint-region/endpoints/your-vertex-endpoint-id';
const configValues = {
maxOutputTokens: {kind: 'numberValue', numberValue: 1024},
temperature: {kind: 'numberValue', numberValue: 0.9},
topP: {kind: 'numberValue', numberValue: 1},
topK: {kind: 'numberValue', numberValue: 1},
};
const prompt = 'Why is the sky blue?';
const predictionServiceClientMock = {
predict: sinon.stub().resolves([]),
};

afterEach(() => {
sinon.restore();
sinon.reset();
});

it('should run interference with GPU', async () => {
const expectedGpuRequest = {
endpoint: gemma2Endpoint,
instances: [
{
kind: 'structValue',
structValue: {
fields: {
inputs: {
kind: 'stringValue',
stringValue: prompt,
},
parameters: {
kind: 'structValue',
structValue: {
fields: configValues,
},
},
},
},
},
],
};

predictionServiceClientMock.predict.resolves([
{
predictions: [
Expand All @@ -59,9 +91,30 @@ describe('Gemma2 predictions', async () => {
const output = await gemma2PredictGpu(predictionServiceClientMock);

expect(output).include('Rayleigh scattering');
expect(predictionServiceClientMock.predict.calledOnce).to.be.true;
expect(predictionServiceClientMock.predict.calledWith(expectedGpuRequest))
.to.be.true;
});

it('should run interference with TPU', async () => {
const expectedTpuRequest = {
endpoint: gemma2Endpoint,
instances: [
{
kind: 'structValue',
structValue: {
fields: {
...configValues,
prompt: {
kind: 'stringValue',
stringValue: prompt,
},
},
},
},
],
};

predictionServiceClientMock.predict.resolves([
{
predictions: [
Expand All @@ -75,5 +128,8 @@ describe('Gemma2 predictions', async () => {
const output = await gemma2PredictTpu(predictionServiceClientMock);

expect(output).include('Rayleigh scattering');
expect(predictionServiceClientMock.predict.calledOnce).to.be.true;
expect(predictionServiceClientMock.predict.calledWith(expectedTpuRequest))
.to.be.true;
});
});

0 comments on commit 3026609

Please sign in to comment.