Skip to content

Commit

Permalink
feat(genai): add Gemini batch prediction samples
Browse files Browse the repository at this point in the history
  • Loading branch information
Sita04 committed Oct 21, 2024
1 parent e587ab9 commit ade584c
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 0 deletions.
96 changes: 96 additions & 0 deletions ai-platform/snippets/batch-prediction/batch-predict-bq.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

async function main(projectId, outputUri) {
// [START generativeaionvertexai_batch_predict_gemini_createjob_bigquery]
// Import the aiplatform library
const aiplatformLib = require('@google-cloud/aiplatform');
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;

/**
* TODO(developer): Uncomment/update these variables before running the sample.
*/
// projectId = 'YOUR_PROJECT_ID';
// URI of the output BigQuery table.
// E.g. "bq://[PROJECT].[DATASET].[TABLE]"
// outputUri = 'bq://projectid.dataset.table';

// URI of the multimodal input BigQuery table.
// E.g. "bq://[PROJECT].[DATASET].[TABLE]"
const inputUri =
'bq://storage-samples.generative_ai.batch_requests_for_multimodal_input';
const location = 'us-central1';
const parent = `projects/${projectId}/locations/${location}`;
const modelName = `${parent}/publishers/google/models/gemini-1.5-flash-002`;

// Specify the location of the api endpoint.
const clientOptions = {
apiEndpoint: `${location}-aiplatform.googleapis.com`,
};

// Instantiate the client.
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);

// Create a Gemini batch prediction job using BigQuery input and output datasets.
async function create_batch_prediction_gemini_bq() {
const bqSource = new aiplatform.BigQuerySource({
uris: [inputUri],
});

const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
bqSource,
instancesFormat: 'bigquery',
});

const bqDestination = new aiplatform.BigQueryDestination({
outputUriPrefix: outputUri,
});

const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
bqDestination,
predictionsFormat: 'bigquery',
});

const batchPredictionJob = new aiplatform.BatchPredictionJob({
displayName: 'Batch predict with Gemini - BigQuery',
model: modelName, // Add model parameters per request in the input BigQuery table.
inputConfig,
outputConfig,
});

const request = {
parent: parent,
batchPredictionJob,
};

// Create batch prediction job request
const [response] = await jobServiceClient.createBatchPredictionJob(request);
console.log('Response name: ', JSON.stringify(response.name, null, 2));
// Example response:
// Response name: projects/<project>/locations/us-central1/batchPredictionJobs/<job-id>
return response.name;
}

await create_batch_prediction_gemini_bq();
// [END generativeaionvertexai_batch_predict_gemini_createjob_bigquery]
}

main(...process.argv.slice(2)).catch(err => {
console.error(err.message);
process.exitCode = 1;
});
99 changes: 99 additions & 0 deletions ai-platform/snippets/batch-prediction/batch-predict-gcs.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

async function main(projectId, outputUri) {
// [START generativeaionvertexai_batch_predict_gemini_createjob_gcs]
// Import the aiplatform library
const aiplatformLib = require('@google-cloud/aiplatform');
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;

/**
* TODO(developer): Uncomment/update these variables before running the sample.
*/
// projectId = 'YOUR_PROJECT_ID';
// URI of the output folder in Google Cloud Storage.
// E.g. "gs://[BUCKET]/[OUTPUT]"
// outputUri = 'gs://my-bucket';

// URI of the input file in Google Cloud Storage.
// E.g. "gs://[BUCKET]/[DATASET].jsonl"
// Or try:
// "gs://cloud-samples-data/generative-ai/batch/gemini_multimodal_batch_predict.jsonl"
// for a batch prediction that uses audio, video, and an image.
const inputUri =
'gs://cloud-samples-data/generative-ai/batch/batch_requests_for_multimodal_input.jsonl';
const location = 'us-central1';
const parent = `projects/${projectId}/locations/${location}`;
const modelName = `${parent}/publishers/google/models/gemini-1.5-flash-002`;

// Specify the location of the api endpoint.
const clientOptions = {
apiEndpoint: `${location}-aiplatform.googleapis.com`,
};

// Instantiate the client.
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);

// Create a Gemini batch prediction job using Google Cloud Storage input and output buckets.
async function create_batch_prediction_gemini_gcs() {
const gcsSource = new aiplatform.GcsSource({
uris: [inputUri],
});

const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
gcsSource,
instancesFormat: 'jsonl',
});

const gcsDestination = new aiplatform.GcsDestination({
outputUriPrefix: outputUri,
});

const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
gcsDestination,
predictionsFormat: 'jsonl',
});

const batchPredictionJob = new aiplatform.BatchPredictionJob({
displayName: 'Batch predict with Gemini - GCS',
model: modelName,
inputConfig,
outputConfig,
});

const request = {
parent: parent,
batchPredictionJob,
};

// Create batch prediction job request
const [response] = await jobServiceClient.createBatchPredictionJob(request);
console.log('Response name: ', JSON.stringify(response.name, null, 2));
// Example response:
// Response name: projects/<project>/locations/us-central1/batchPredictionJobs/<job-id>
return response.name;
}

await create_batch_prediction_gemini_gcs();
// [END generativeaionvertexai_batch_predict_gemini_createjob_gcs]
}

main(...process.argv.slice(2)).catch(err => {
console.error(err.message);
process.exitCode = 1;
});
89 changes: 89 additions & 0 deletions ai-platform/snippets/test/batch-prediction-gemini.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

const {assert} = require('chai');
const {after, describe, it} = require('mocha');
const cp = require('child_process');
const {JobServiceClient} = require('@google-cloud/aiplatform');

const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});

describe('Batch predict with Gemini', async () => {
const projectId = process.env.CAIP_PROJECT_ID;
const outputGCSUri = 'gs://ucaip-samples-test-output/';
const outputBqUri = 'bq://ucaip-sample-tests';
const location = 'us-central1';

const jobServiceClient = new JobServiceClient({
apiEndpoint: `${location}-aiplatform.googleapis.com`,
});
let batchPredictionGcsJobId;
let batchPredictionBqJobId;

after(async () => {
let name = jobServiceClient.batchPredictionJobPath(
projectId,
location,
batchPredictionGcsJobId
);
cancelAndDeleteJob(name);

name = jobServiceClient.batchPredictionJobPath(
projectId,
location,
batchPredictionBqJobId
);
cancelAndDeleteJob(name);

function cancelAndDeleteJob(name) {
const cancelRequest = {
name,
};

jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => {
const deleteRequest = {
name,
};

return jobServiceClient.deleteBatchPredictionJob(deleteRequest);
});
}
});

it('should create Batch prediction Gemini job with GCS ', async () => {
const response = execSync(
`node ./batch-prediction/batch-predict-gcs.js ${projectId} ${outputGCSUri}`
);

assert.match(response, new RegExp('Batch predict with Gemini - GCS'));
batchPredictionGcsJobId = response
.split('/locations/us-central1/batchPredictionJobs/')[1]
.split('\n')[0];
});

it('should create Batch prediction Gemini job with BigQuery', async () => {
const response = execSync(
`node ./batch-prediction/batch-predict-bq.js ${projectId} ${outputBqUri}`
);

assert.match(response, new RegExp('Batch predict with Gemini - BigQuery'));
batchPredictionBqJobId = response
.split('/locations/us-central1/batchPredictionJobs/')[1]
.split('\n')[0];
});
});

0 comments on commit ade584c

Please sign in to comment.