diff --git a/ai-platform/snippets/create-batch-embedding.js b/ai-platform/snippets/create-batch-embedding.js new file mode 100644 index 00000000000..258e7f7f0f8 --- /dev/null +++ b/ai-platform/snippets/create-batch-embedding.js @@ -0,0 +1,105 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +'use strict'; + +async function main(projectId, inputUri, outputUri, jobName) { + // [START generativeaionvertexai_embedding_batch] + // Imports the aiplatform library + const aiplatformLib = require('@google-cloud/aiplatform'); + const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1; + + /** + * TODO(developer): Uncomment/update these variables before running the sample. + */ + // projectId = 'YOUR_PROJECT_ID'; + + // Optional: URI of the input dataset. + // Could be a BigQuery table or a Google Cloud Storage file. + // E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]" + // inputUri = + // 'gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl'; + + // Optional: URI where the output will be stored. + // Could be a BigQuery table or a Google Cloud Storage file. + // E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]" + // outputUri = 'gs://your_backet/embedding_batch_output'; + + // The name of the job + // jobName = `Batch embedding job: ${new Date().getMilliseconds()}`; + + const textEmbeddingModel = 'text-embedding-005'; + const location = 'us-central1'; + + // Configure the parent resource + const parent = `projects/${projectId}/locations/${location}`; + const modelName = `projects/${projectId}/locations/${location}/publishers/google/models/${textEmbeddingModel}`; + + // Specifies the location of the api endpoint + const clientOptions = { + apiEndpoint: `${location}-aiplatform.googleapis.com`, + }; + + // Instantiates a client + const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions); + + // Generates embeddings from text using batch processing. + // Read more: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/batch-prediction-genai-embeddings + async function callBatchEmbedding() { + const gcsSource = new aiplatform.GcsSource({ + uris: [inputUri], + }); + + const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({ + gcsSource, + instancesFormat: 'jsonl', + }); + + const gcsDestination = new aiplatform.GcsDestination({ + outputUriPrefix: outputUri, + }); + + const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({ + gcsDestination, + predictionsFormat: 'jsonl', + }); + + const batchPredictionJob = new aiplatform.BatchPredictionJob({ + displayName: jobName, + model: modelName, + inputConfig, + outputConfig, + }); + + const request = { + parent, + batchPredictionJob, + }; + + // Create batch prediction job request + const [response] = await jobServiceClient.createBatchPredictionJob(request); + + console.log('Raw response: ', JSON.stringify(response, null, 2)); + } + + await callBatchEmbedding(); + // [END generativeaionvertexai_embedding_batch] +} + +main(...process.argv.slice(2)).catch(err => { + console.error(err.message); + process.exitCode = 1; +}); diff --git a/ai-platform/snippets/test/create-batch-embedding.test.js b/ai-platform/snippets/test/create-batch-embedding.test.js new file mode 100644 index 00000000000..dd482b4e091 --- /dev/null +++ b/ai-platform/snippets/test/create-batch-embedding.test.js @@ -0,0 +1,84 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +'use strict'; + +const {assert} = require('chai'); +const {after, before, describe, it} = require('mocha'); +const uuid = require('uuid').v4; +const cp = require('child_process'); +const {JobServiceClient} = require('@google-cloud/aiplatform'); +const {Storage} = require('@google-cloud/storage'); + +const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'}); + +describe('Batch embedding', async () => { + const displayName = `batch-embedding-job-${uuid()}`; + const location = 'us-central1'; + const inputUri = + 'gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl'; + let outputUri = 'gs://ucaip-samples-test-output/'; + const jobServiceClient = new JobServiceClient({ + apiEndpoint: `${location}-aiplatform.googleapis.com`, + }); + const projectId = process.env.CAIP_PROJECT_ID; + const storage = new Storage({ + projectId, + }); + let batchPredictionJobId; + let bucket; + + before(async () => { + const bucketName = `test-bucket-${uuid()}`; + // Create a Google Cloud Storage bucket for UsageReports + [bucket] = await storage.createBucket(bucketName); + outputUri = `gs://${bucketName}/embedding_batch_output`; + }); + + after(async () => { + // Delete job + const name = jobServiceClient.batchPredictionJobPath( + projectId, + location, + batchPredictionJobId + ); + + const cancelRequest = { + name, + }; + + jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => { + const deleteRequest = { + name, + }; + + return jobServiceClient.deleteBatchPredictionJob(deleteRequest); + }); + // Delete the Google Cloud Storage bucket created for usage reports. + await bucket.delete(); + }); + + it('should create batch prediction job', async () => { + const response = execSync( + `node ./create-batch-embedding.js ${projectId} ${inputUri} ${outputUri} ${displayName}` + ); + + assert.match(response, new RegExp(displayName)); + batchPredictionJobId = response + .split(`/locations/${location}/batchPredictionJobs/`)[1] + .split('\n')[0]; + }); +});