Skip to content

Commit

Permalink
feat: generativeaionvertexai_embedding_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
gryczj committed Nov 28, 2024
1 parent 0469e1f commit 6fdd9f8
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 0 deletions.
105 changes: 105 additions & 0 deletions ai-platform/snippets/create-batch-embedding.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

async function main(projectId, inputUri, outputUri, jobName) {
// [START generativeaionvertexai_embedding_batch]
// Imports the aiplatform library
const aiplatformLib = require('@google-cloud/aiplatform');
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;

/**
* TODO(developer): Uncomment/update these variables before running the sample.
*/
// projectId = 'YOUR_PROJECT_ID';

// Optional: URI of the input dataset.
// Could be a BigQuery table or a Google Cloud Storage file.
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
// inputUri =
// 'gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl';

// Optional: URI where the output will be stored.
// Could be a BigQuery table or a Google Cloud Storage file.
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
// outputUri = 'gs://your_backet/embedding_batch_output';

// The name of the job
// jobName = `Batch embedding job: ${new Date().getMilliseconds()}`;

const textEmbeddingModel = 'text-embedding-005';
const location = 'us-central1';

// Configure the parent resource
const parent = `projects/${projectId}/locations/${location}`;
const modelName = `projects/${projectId}/locations/${location}/publishers/google/models/${textEmbeddingModel}`;

// Specifies the location of the api endpoint
const clientOptions = {
apiEndpoint: `${location}-aiplatform.googleapis.com`,
};

// Instantiates a client
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);

// Generates embeddings from text using batch processing.
// Read more: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/batch-prediction-genai-embeddings
async function callBatchEmbedding() {
const gcsSource = new aiplatform.GcsSource({
uris: [inputUri],
});

const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
gcsSource,
instancesFormat: 'jsonl',
});

const gcsDestination = new aiplatform.GcsDestination({
outputUriPrefix: outputUri,
});

const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
gcsDestination,
predictionsFormat: 'jsonl',
});

const batchPredictionJob = new aiplatform.BatchPredictionJob({
displayName: jobName,
model: modelName,
inputConfig,
outputConfig,
});

const request = {
parent,
batchPredictionJob,
};

// Create batch prediction job request
const [response] = await jobServiceClient.createBatchPredictionJob(request);

console.log('Raw response: ', JSON.stringify(response, null, 2));
}

await callBatchEmbedding();
// [END generativeaionvertexai_embedding_batch]
}

main(...process.argv.slice(2)).catch(err => {
console.error(err.message);
process.exitCode = 1;
});
84 changes: 84 additions & 0 deletions ai-platform/snippets/test/create-batch-embedding.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

const {assert} = require('chai');
const {after, before, describe, it} = require('mocha');
const uuid = require('uuid').v4;
const cp = require('child_process');
const {JobServiceClient} = require('@google-cloud/aiplatform');
const {Storage} = require('@google-cloud/storage');

const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});

describe('Batch embedding', async () => {
const displayName = `batch-embedding-job-${uuid()}`;
const location = 'us-central1';
const inputUri =
'gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl';
let outputUri = 'gs://ucaip-samples-test-output/';
const jobServiceClient = new JobServiceClient({
apiEndpoint: `${location}-aiplatform.googleapis.com`,
});
const projectId = process.env.CAIP_PROJECT_ID;
const storage = new Storage({
projectId,
});
let batchPredictionJobId;
let bucket;

before(async () => {
const bucketName = `test-bucket-${uuid()}`;
// Create a Google Cloud Storage bucket for UsageReports
[bucket] = await storage.createBucket(bucketName);
outputUri = `gs://${bucketName}/embedding_batch_output`;
});

after(async () => {
// Delete job
const name = jobServiceClient.batchPredictionJobPath(
projectId,
location,
batchPredictionJobId
);

const cancelRequest = {
name,
};

jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => {
const deleteRequest = {
name,
};

return jobServiceClient.deleteBatchPredictionJob(deleteRequest);
});
// Delete the Google Cloud Storage bucket created for usage reports.
await bucket.delete();
});

it('should create batch prediction job', async () => {
const response = execSync(
`node ./create-batch-embedding.js ${projectId} ${inputUri} ${outputUri} ${displayName}`
);

assert.match(response, new RegExp(displayName));
batchPredictionJobId = response
.split(`/locations/${location}/batchPredictionJobs/`)[1]
.split('\n')[0];
});
});

0 comments on commit 6fdd9f8

Please sign in to comment.