From 2c52cb535d138f37a0542140cc5c8dc3d1630ef2 Mon Sep 17 00:00:00 2001 From: Joanna Grycz Date: Wed, 25 Sep 2024 18:47:22 +0200 Subject: [PATCH] feat: batch prediction samples --- ai-platform/snippets/batch-code-predict.js | 113 +++++++++++++++++ ai-platform/snippets/batch-text-predict.js | 115 ++++++++++++++++++ .../snippets/test/batch-code-predict.test.js | 70 +++++++++++ .../snippets/test/batch-text-predict.test.js | 69 +++++++++++ 4 files changed, 367 insertions(+) create mode 100644 ai-platform/snippets/batch-code-predict.js create mode 100644 ai-platform/snippets/batch-text-predict.js create mode 100644 ai-platform/snippets/test/batch-code-predict.test.js create mode 100644 ai-platform/snippets/test/batch-text-predict.test.js diff --git a/ai-platform/snippets/batch-code-predict.js b/ai-platform/snippets/batch-code-predict.js new file mode 100644 index 00000000000..1c25ac84d3d --- /dev/null +++ b/ai-platform/snippets/batch-code-predict.js @@ -0,0 +1,113 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +'use strict'; + +async function main(projectId, inputUri, outputUri, jobDisplayName) { + // [START generativeaionvertexai_batch_code_predict] + // Imports the aiplatform library + const aiplatformLib = require('@google-cloud/aiplatform'); + const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1; + + /** + * TODO(developer): Uncomment/update these variables before running the sample. + */ + // projectId = 'YOUR_PROJECT_ID'; + + // Optional: URI of the input dataset. + // Could be a BigQuery table or a Google Cloud Storage file. + // E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]" + // inputUri = + // 'gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl'; + + // Optional: URI where the output will be stored. + // Could be a BigQuery table or a Google Cloud Storage file. + // E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]" + // outputUri = 'gs://batch-bucket-testing/batch_code_predict_output'; + + // The name of batch prediction job + // jobDisplayName = `Batch code prediction job: ${new Date().getMilliseconds()}`; + + // The name of pre-trained model + const codeModel = 'code-bison'; + const location = 'us-central1'; + + // Construct your modelParameters + const parameters = { + maxOutputTokens: '200', + temperature: '0.2', + }; + const parametersValue = aiplatformLib.helpers.toValue(parameters); + // Configure the parent resource + const parent = `projects/${projectId}/locations/${location}`; + const modelName = `projects/${projectId}/locations/${location}/publishers/google/models/${codeModel}`; + + // Specifies the location of the api endpoint + const clientOptions = { + apiEndpoint: `${location}-aiplatform.googleapis.com`, + }; + + // Instantiates a client + const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions); + + // Perform batch code prediction using a pre-trained code generation model. + // Example of using Google Cloud Storage bucket as the input and output data source + async function callBatchCodePredicton() { + const gcsSource = new aiplatform.GcsSource({ + uris: [inputUri], + }); + + const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({ + gcsSource, + instancesFormat: 'jsonl', + }); + + const gcsDestination = new aiplatform.GcsDestination({ + outputUriPrefix: outputUri, + }); + + const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({ + gcsDestination, + predictionsFormat: 'jsonl', + }); + + const batchPredictionJob = new aiplatform.BatchPredictionJob({ + displayName: jobDisplayName, + model: modelName, + inputConfig, + outputConfig, + modelParameters: parametersValue, + }); + + const request = { + parent, + batchPredictionJob, + }; + + // Create batch prediction job request + const [response] = await jobServiceClient.createBatchPredictionJob(request); + + console.log('Raw response: ', JSON.stringify(response, null, 2)); + } + + await callBatchCodePredicton(); + // [END generativeaionvertexai_batch_code_predict] +} + +main(...process.argv.slice(2)).catch(err => { + console.error(err.message); + process.exitCode = 1; +}); diff --git a/ai-platform/snippets/batch-text-predict.js b/ai-platform/snippets/batch-text-predict.js new file mode 100644 index 00000000000..a2e59eca2f5 --- /dev/null +++ b/ai-platform/snippets/batch-text-predict.js @@ -0,0 +1,115 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +'use strict'; + +async function main(projectId, inputUri, outputUri, jobDisplayName) { + // [START generativeaionvertexai_batch_text_predict] + // Imports the aiplatform library + const aiplatformLib = require('@google-cloud/aiplatform'); + const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1; + + /** + * TODO(developer): Uncomment/update these variables before running the sample. + */ + // projectId = 'YOUR_PROJECT_ID'; + + // Optional: URI of the input dataset. + // Could be a BigQuery table or a Google Cloud Storage file. + // E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]" + // inputUri = + // 'gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl'; + + // Optional: URI where the output will be stored. + // Could be a BigQuery table or a Google Cloud Storage file. + // E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]" + // outputUri = 'gs://batch-bucket-testing/batch_text_predict_output'; + + // The name of batch prediction job + // jobDisplayName = `Batch text prediction job: ${new Date().getMilliseconds()}`; + + // The name of pre-trained model + const textModel = 'text-bison'; + const location = 'us-central1'; + + // Construct your modelParameters + const parameters = { + maxOutputTokens: '200', + temperature: '0.2', + topP: '0.95', + topK: '40', + }; + const parametersValue = aiplatformLib.helpers.toValue(parameters); + // Configure the parent resource + const parent = `projects/${projectId}/locations/${location}`; + const modelName = `projects/${projectId}/locations/${location}/publishers/google/models/${textModel}`; + + // Specifies the location of the api endpoint + const clientOptions = { + apiEndpoint: `${location}-aiplatform.googleapis.com`, + }; + + // Instantiates a client + const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions); + + // Perform batch text prediction using a pre-trained text generation model. + // Example of using Google Cloud Storage bucket as the input and output data source + async function callBatchTextPredicton() { + const gcsSource = new aiplatform.GcsSource({ + uris: [inputUri], + }); + + const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({ + gcsSource, + instancesFormat: 'jsonl', + }); + + const gcsDestination = new aiplatform.GcsDestination({ + outputUriPrefix: outputUri, + }); + + const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({ + gcsDestination, + predictionsFormat: 'jsonl', + }); + + const batchPredictionJob = new aiplatform.BatchPredictionJob({ + displayName: jobDisplayName, + model: modelName, + inputConfig, + outputConfig, + modelParameters: parametersValue, + }); + + const request = { + parent, + batchPredictionJob, + }; + + // Create batch prediction job request + const [response] = await jobServiceClient.createBatchPredictionJob(request); + + console.log('Raw response: ', JSON.stringify(response, null, 2)); + } + + await callBatchTextPredicton(); + // [END generativeaionvertexai_batch_text_predict] +} + +main(...process.argv.slice(2)).catch(err => { + console.error(err.message); + process.exitCode = 1; +}); diff --git a/ai-platform/snippets/test/batch-code-predict.test.js b/ai-platform/snippets/test/batch-code-predict.test.js new file mode 100644 index 00000000000..3cc27712bec --- /dev/null +++ b/ai-platform/snippets/test/batch-code-predict.test.js @@ -0,0 +1,70 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +'use strict'; + +const {assert} = require('chai'); +const {after, describe, it} = require('mocha'); +const uuid = require('uuid').v4; +const cp = require('child_process'); +const {JobServiceClient} = require('@google-cloud/aiplatform'); + +const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'}); + +describe('Batch code predict', async () => { + const displayName = `batch-code-predict-job-${uuid()}`; + const location = 'us-central1'; + const inputUri = + 'gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl'; + const outputUri = 'gs://ucaip-samples-test-output/'; + const jobServiceClient = new JobServiceClient({ + apiEndpoint: `${location}-aiplatform.googleapis.com`, + }); + const projectId = process.env.CAIP_PROJECT_ID; + let batchPredictionJobId; + + after(async () => { + const name = jobServiceClient.batchPredictionJobPath( + projectId, + location, + batchPredictionJobId + ); + + const cancelRequest = { + name, + }; + + jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => { + const deleteRequest = { + name, + }; + + return jobServiceClient.deleteBatchPredictionJob(deleteRequest); + }); + }); + + it('should create job with code prediction', async () => { + const response = execSync( + `node ./batch-code-predict.js ${projectId} ${inputUri} ${outputUri} ${displayName}` + ); + + assert.match(response, new RegExp(displayName)); + + batchPredictionJobId = response + .split('/locations/us-central1/batchPredictionJobs/')[1] + .split('\n')[0]; + }); +}); diff --git a/ai-platform/snippets/test/batch-text-predict.test.js b/ai-platform/snippets/test/batch-text-predict.test.js new file mode 100644 index 00000000000..005bea4a889 --- /dev/null +++ b/ai-platform/snippets/test/batch-text-predict.test.js @@ -0,0 +1,69 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +'use strict'; + +const {assert} = require('chai'); +const {after, describe, it} = require('mocha'); +const uuid = require('uuid').v4; +const cp = require('child_process'); +const {JobServiceClient} = require('@google-cloud/aiplatform'); + +const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'}); + +describe('Batch text predict', async () => { + const displayName = `batch-text-predict-job-${uuid()}`; + const location = 'us-central1'; + const inputUri = + 'gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl'; + const outputUri = 'gs://ucaip-samples-test-output/'; + const jobServiceClient = new JobServiceClient({ + apiEndpoint: `${location}-aiplatform.googleapis.com`, + }); + const projectId = process.env.CAIP_PROJECT_ID; + let batchPredictionJobId; + + after(async () => { + const name = jobServiceClient.batchPredictionJobPath( + projectId, + location, + batchPredictionJobId + ); + + const cancelRequest = { + name, + }; + + jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => { + const deleteRequest = { + name, + }; + + return jobServiceClient.deleteBatchPredictionJob(deleteRequest); + }); + }); + + it('should create job with text prediction', async () => { + const response = execSync( + `node ./batch-text-predict.js ${projectId} ${inputUri} ${outputUri} ${displayName}` + ); + + assert.match(response, new RegExp(displayName)); + batchPredictionJobId = response + .split('/locations/us-central1/batchPredictionJobs/')[1] + .split('\n')[0]; + }); +});