Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(community): add ByteDance Doubao Embeddings #7450

Merged
merged 12 commits into from
Jan 10, 2025
Merged
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
---
sidebar_class_name: node-only
---

# ByteDance Doubao

The `ByteDanceDoubaoEmbeddings` class uses the ByteDance Doubao API to generate embeddings for a given text.

## Setup

You'll need to sign up for an ByteDance API key and set it as an environment variable named `ARK_API_KEY`.
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved

Then, you'll need to install the [`@langchain/community`](https://www.npmjs.com/package/@langchain/community) package:

import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";

<IntegrationInstallTooltip></IntegrationInstallTooltip>

```bash npm2yarn
npm install @langchain/community @langchain/core
```

## Usage

import CodeBlock from "@theme/CodeBlock";
import ByteDanceDoubaoExample from "@examples/embeddings/bytedance_doubao.ts";

<CodeBlock language="typescript">{ByteDanceDoubaoExample}</CodeBlock>

## Related

- Embedding model [conceptual guide](/docs/concepts/embedding_models)
- Embedding model [how-to guides](/docs/how_to/#embedding-models)
3 changes: 2 additions & 1 deletion examples/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ HANA_HOST=HANA_DB_ADDRESS
HANA_PORT=HANA_DB_PORT
HANA_UID=HANA_DB_USER
HANA_PWD=HANA_DB_PASSWORD
ARK_API_KEY=ADD_YOURS_HERE # https://console.volcengine.com/
JIRA_HOST=ADD_YOURS_HERE
JIRA_USERNAME=ADD_YOURS_HERE
JIRA_ACCESS_TOKEN=ADD_YOURS_HERE
JIRA_PROJECT_KEY=ADD_YOURS_HERE
JIRA_PROJECT_KEY=ADD_YOURS_HERE
9 changes: 9 additions & 0 deletions examples/src/embeddings/bytedance_doubao.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { ByteDanceDoubaoEmbeddings } from "@langchain/community/embeddings/bytedance_doubao";

const model = new ByteDanceDoubaoEmbeddings({
modelName: 'ep-xxx-xxx'
});
const res = await model.embedQuery(
"What would be a good company name a company that makes colorful socks?"
);
console.log({ res });
4 changes: 4 additions & 0 deletions libs/langchain-community/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ embeddings/bedrock.cjs
embeddings/bedrock.js
embeddings/bedrock.d.ts
embeddings/bedrock.d.cts
embeddings/bytedance_doubao.cjs
embeddings/bytedance_doubao.js
embeddings/bytedance_doubao.d.ts
embeddings/bytedance_doubao.d.cts
embeddings/cloudflare_workersai.cjs
embeddings/cloudflare_workersai.js
embeddings/cloudflare_workersai.d.ts
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-community/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export const config = {
"embeddings/alibaba_tongyi": "embeddings/alibaba_tongyi",
"embeddings/baidu_qianfan": "embeddings/baidu_qianfan",
"embeddings/bedrock": "embeddings/bedrock",
"embeddings/bytedance_doubao": "embeddings/bytedance_doubao",
"embeddings/cloudflare_workersai": "embeddings/cloudflare_workersai",
"embeddings/cohere": "embeddings/cohere",
"embeddings/deepinfra": "embeddings/deepinfra",
Expand Down
13 changes: 13 additions & 0 deletions libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,15 @@
"import": "./embeddings/bedrock.js",
"require": "./embeddings/bedrock.cjs"
},
"./embeddings/bytedance_doubao": {
"types": {
"import": "./embeddings/bytedance_doubao.d.ts",
"require": "./embeddings/bytedance_doubao.d.cts",
"default": "./embeddings/bytedance_doubao.d.ts"
},
"import": "./embeddings/bytedance_doubao.js",
"require": "./embeddings/bytedance_doubao.cjs"
},
"./embeddings/cloudflare_workersai": {
"types": {
"import": "./embeddings/cloudflare_workersai.d.ts",
Expand Down Expand Up @@ -3332,6 +3341,10 @@
"embeddings/bedrock.js",
"embeddings/bedrock.d.ts",
"embeddings/bedrock.d.cts",
"embeddings/bytedance_doubao.cjs",
"embeddings/bytedance_doubao.js",
"embeddings/bytedance_doubao.d.ts",
"embeddings/bytedance_doubao.d.cts",
"embeddings/cloudflare_workersai.cjs",
"embeddings/cloudflare_workersai.js",
"embeddings/cloudflare_workersai.d.ts",
Expand Down
175 changes: 175 additions & 0 deletions libs/langchain-community/src/embeddings/bytedance_doubao.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";

export interface ByteDanceDoubaoEmbeddingsParams extends EmbeddingsParams {
/** Model name to use */
modelName: string;
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Timeout to use when making requests to ByteDanceDoubao.
*/
timeout?: number;

/**
* The maximum number of documents to embed in a single request. This is
* limited by the ByteDanceDoubao API to a maximum of 2048.
*/
batchSize?: number;

/**
* Whether to strip new lines from the input text.
*/
stripNewLines?: boolean;
}

interface EmbeddingCreateParams {
model: ByteDanceDoubaoEmbeddingsParams["modelName"];
input: string[];
encoding_format?: "float";
}

interface EmbeddingResponse {
data: {
index: number;
embedding: number[];
}[];

usage: {
prompt_tokens: number;
total_tokens: number;
};

id: string;
}

interface EmbeddingErrorResponse {
type: string;
code: string;
param: string;
message: string;
}

export class ByteDanceDoubaoEmbeddings
extends Embeddings
implements ByteDanceDoubaoEmbeddingsParams {
modelName: ByteDanceDoubaoEmbeddingsParams["modelName"] = "";
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved

batchSize = 24;

stripNewLines = true;

apiKey: string;

constructor(
fields?: Partial<ByteDanceDoubaoEmbeddingsParams> & {
verbose?: boolean;
apiKey?: string;
}
) {
const fieldsWithDefaults = { maxConcurrency: 2, ...fields };
super(fieldsWithDefaults);

const apiKey =
fieldsWithDefaults?.apiKey ?? getEnvironmentVariable("ARK_API_KEY");

if (!apiKey) throw new Error("ByteDanceDoubao API key not found");

this.apiKey = apiKey;

this.modelName = fieldsWithDefaults?.modelName ?? this.modelName;
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
this.stripNewLines =
fieldsWithDefaults?.stripNewLines ?? this.stripNewLines;
}

/**
* Method to generate embeddings for an array of documents. Splits the
* documents into batches and makes requests to the ByteDanceDoubao API to generate
* embeddings.
* @param texts Array of documents to generate embeddings for.
* @returns Promise that resolves to a 2D array of embeddings for each document.
*/
async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize
);
const batchRequests = batches.map((batch) => {
const params = this.getParams(batch);

return this.embeddingWithRetry(params);
});

const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];

for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
const batchResponse = batchResponses[i] || [];
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse[j]);
}
}

return embeddings;
}

/**
* Method to generate an embedding for a single document. Calls the
* embeddingWithRetry method with the document as the input.
* @param text Document to generate an embedding for.
* @returns Promise that resolves to an embedding for the document.
*/
async embedQuery(text: string): Promise<number[]> {
const params = this.getParams([
this.stripNewLines ? text.replace(/\n/g, " ") : text,
]);

const embeddings = (await this.embeddingWithRetry(params)) || [[]];
return embeddings[0];
}

/**
* Method to generate an embedding params.
* @param texts Array of documents to generate embeddings for.
* @returns an embedding params.
*/
private getParams(
texts: EmbeddingCreateParams["input"]
): EmbeddingCreateParams {
return {
model: this.modelName,
input: texts,
};
}

/**
* Private method to make a request to the OpenAI API to generate
* embeddings. Handles the retry logic and returns the response from the
* API.
* @param request Request to send to the OpenAI API.
* @returns Promise that resolves to the response from the API.
*/
private async embeddingWithRetry(body: EmbeddingCreateParams) {
return fetch("https://ark.cn-beijing.volces.com/api/v3/embeddings", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify(body),
}).then(async (response) => {
const embeddingData: EmbeddingResponse | EmbeddingErrorResponse =
await response.json();

if ("code" in embeddingData && embeddingData.code) {
throw new Error(`${embeddingData.code}: ${embeddingData.message}`);
}

return (embeddingData as EmbeddingResponse).data.map(
({ embedding }) => embedding
);
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { test, expect } from "@jest/globals";
import { ByteDanceDoubaoEmbeddings } from "../bytedance_doubao.js";

const modelName = 'ep-xxx-xxx';
test.skip("Test ByteDanceDoubaoEmbeddings.embedQuery", async () => {
const embeddings = new ByteDanceDoubaoEmbeddings({
modelName,
});
const res = await embeddings.embedQuery("Hello world");
expect(typeof res[0]).toBe("number");
});

test.skip("Test ByteDanceDoubaoEmbeddings.embedDocuments", async () => {
const embeddings = new ByteDanceDoubaoEmbeddings({
modelName,
});
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
expect(res).toHaveLength(2);
expect(typeof res[0][0]).toBe("number");
expect(typeof res[1][0]).toBe("number");
});

test.skip("Test ByteDanceDoubaoEmbeddings concurrency", async () => {
const embeddings = new ByteDanceDoubaoEmbeddings({
modelName,
batchSize: 1,
});
const res = await embeddings.embedDocuments([
"Hello world",
"Bye bye",
"Hello world",
"Bye bye",
"Hello world",
"Bye bye",
]);
expect(res).toHaveLength(6);
expect(res.find((embedding) => typeof embedding[0] !== "number")).toBe(
undefined
);
});
1 change: 1 addition & 0 deletions libs/langchain-community/src/load/import_map.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export * as agents__toolkits__base from "../agents/toolkits/base.js";
export * as agents__toolkits__connery from "../agents/toolkits/connery/index.js";
export * as embeddings__alibaba_tongyi from "../embeddings/alibaba_tongyi.js";
export * as embeddings__baidu_qianfan from "../embeddings/baidu_qianfan.js";
export * as embeddings__bytedance_doubao from "../embeddings/bytedance_doubao.js";
export * as embeddings__deepinfra from "../embeddings/deepinfra.js";
export * as embeddings__fireworks from "../embeddings/fireworks.js";
export * as embeddings__minimax from "../embeddings/minimax.js";
Expand Down