Skip to content

Commit

Permalink
feat(js): OpenAI embeddings Instrumentation (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeldking authored Jan 5, 2024
1 parent b6fb536 commit 18a7df2
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 6 deletions.
6 changes: 6 additions & 0 deletions js/.changeset/old-apples-attack.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@arizeai/openinference-instrumentation-openai": patch
"@arizeai/openinference-semantic-conventions": patch
---

Add OpenAI Embeddings sementic attributes and instrumentation
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
module.exports = {
preset: "ts-jest",
testEnvironment: "node",
prettierPath: null,
};
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ import {
ChatCompletionCreateParamsBase,
} from "openai/resources/chat/completions";
import { Stream } from "openai/streaming";
import {
CreateEmbeddingResponse,
EmbeddingCreateParams,
} from "openai/resources";

const MODULE_NAME = "openai";

Expand Down Expand Up @@ -80,7 +84,7 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
const span = instrumentation.tracer.startSpan(
`OpenAI Chat Completions`,
{
kind: SpanKind.CLIENT,
kind: SpanKind.INTERNAL,
attributes: {
[SemanticConventions.OPENINFERENCE_SPAN_KIND]:
OpenInferenceSpanKind.LLM,
Expand All @@ -106,6 +110,11 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
// Push the error to the span
if (error) {
span.recordException(error);
span.setStatus({
code: SpanStatusCode.ERROR,
message: error.message,
});
span.end();
}
},
);
Expand All @@ -115,6 +124,12 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
span.setAttributes({
[SemanticConventions.OUTPUT_VALUE]: JSON.stringify(result),
[SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.JSON,
// Override the model from the value sent by the server
[SemanticConventions.LLM_MODEL_NAME]: isChatCompletionResponse(
result,
)
? result.model
: body.model,
...getLLMOutputMessagesAttributes(result),
...getUsageAttributes(result),
});
Expand All @@ -127,6 +142,75 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
};
},
);

// Patch embeddings
type EmbeddingsCreateType =
typeof module.OpenAI.Embeddings.prototype.create;
this._wrap(
module.OpenAI.Embeddings.prototype,
"create",
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(original: EmbeddingsCreateType): any => {
return function patchedEmbeddingCreate(
this: unknown,
...args: Parameters<typeof module.OpenAI.Embeddings.prototype.create>
) {
const body = args[0];
const { input } = body;
const isStringInput = typeof input == "string";
const span = instrumentation.tracer.startSpan(`OpenAI Embeddings`, {
kind: SpanKind.INTERNAL,
attributes: {
[SemanticConventions.OPENINFERENCE_SPAN_KIND]:
OpenInferenceSpanKind.EMBEDDING,
[SemanticConventions.EMBEDDING_MODEL_NAME]: body.model,
[SemanticConventions.INPUT_VALUE]: isStringInput
? input
: JSON.stringify(input),
[SemanticConventions.INPUT_MIME_TYPE]: isStringInput
? MimeType.TEXT
: MimeType.JSON,
...getEmbeddingTextAttributes(body),
},
});
const execContext = trace.setSpan(context.active(), span);
const execPromise = safeExecuteInTheMiddle<
ReturnType<EmbeddingsCreateType>
>(
() => {
return context.with(execContext, () => {
return original.apply(this, args);
});
},
(error) => {
// Push the error to the span
if (error) {
span.recordException(error);
span.setStatus({
code: SpanStatusCode.ERROR,
message: error.message,
});
span.end();
}
},
);
const wrappedPromise = execPromise.then((result) => {
if (result) {
// Record the results
span.setAttributes({
// Do not record the output data as it can be large
...getEmbeddingEmbeddingsAttributes(result),
});
}
span.setStatus({ code: SpanStatusCode.OK });
span.end();
return result;
});
return context.bind(execContext, wrappedPromise);
};
},
);

module.openInferencePatched = true;
return module;
}
Expand All @@ -136,9 +220,19 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
private unpatch(moduleExports: typeof openai, moduleVersion?: string) {
diag.debug(`Removing patch for ${MODULE_NAME}@${moduleVersion}`);
this._unwrap(moduleExports.OpenAI.Chat.Completions.prototype, "create");
this._unwrap(moduleExports.OpenAI.Embeddings.prototype, "create");
}
}

/**
* type-guard that checks if the response is a chat completion response
*/
function isChatCompletionResponse(
response: Stream<ChatCompletionChunk> | ChatCompletion,
): response is ChatCompletion {
return "choices" in response;
}

/**
* Converts the body of the request to LLM input messages
*/
Expand Down Expand Up @@ -204,3 +298,43 @@ function getLLMOutputMessagesAttributes(
}
return {};
}

/**
* Converts the embedding result payload to embedding attributes
*/
function getEmbeddingTextAttributes(
request: EmbeddingCreateParams,
): Attributes {
if (typeof request.input == "string") {
return {
[`${SemanticConventions.EMBEDDING_EMBEDDINGS}.0.${SemanticConventions.EMBEDDING_TEXT}`]:
request.input,
};
} else if (
Array.isArray(request.input) &&
request.input.length > 0 &&
typeof request.input[0] == "string"
) {
return request.input.reduce((acc, input, index) => {
const index_prefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}`;
acc[`${index_prefix}.${SemanticConventions.EMBEDDING_TEXT}`] = input;
return acc;
}, {} as Attributes);
}
// Ignore other cases where input is a number or an array of numbers
return {};
}

/**
* Converts the embedding result payload to embedding attributes
*/
function getEmbeddingEmbeddingsAttributes(
response: CreateEmbeddingResponse,
): Attributes {
return response.data.reduce((acc, embedding, index) => {
const index_prefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}`;
acc[`${index_prefix}.${SemanticConventions.EMBEDDING_VECTOR}`] =
embedding.embedding;
return acc;
}, {} as Attributes);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,22 @@ describe("OpenAIInstrumentation", () => {

instrumentation.setTracerProvider(tracerProvider);
tracerProvider.addSpanProcessor(new SimpleSpanProcessor(memoryExporter));
// @ts-expect-error the moduleExports property is private. This is needed to make the test work with auto-mocking
instrumentation._modules[0].moduleExports = OpenAI;

beforeEach(() => {
// @ts-expect-error the moduleExports property is private. This is needed to make the test work with auto-mocking
instrumentation._modules[0].moduleExports = OpenAI;
beforeAll(() => {
instrumentation.enable();
openai = new OpenAI.OpenAI({
apiKey: `fake-api-key`,
});
});
afterAll(() => {
instrumentation.disable();
});
beforeEach(() => {
memoryExporter.reset();
});
afterEach(() => {
instrumentation.disable();
jest.clearAllMocks();
});
it("is patched", () => {
Expand Down Expand Up @@ -85,7 +89,7 @@ describe("OpenAIInstrumentation", () => {
"llm.input_messages.0.message.content": "Say this is a test",
"llm.input_messages.0.message.role": "user",
"llm.invocation_parameters": "{"model":"gpt-3.5-turbo"}",
"llm.model_name": "gpt-3.5-turbo",
"llm.model_name": "gpt-3.5-turbo-0613",
"llm.output_messages.0.message.content": "This is a test.",
"llm.output_messages.0.message.role": "assistant",
"llm.token_count.completion": 5,
Expand All @@ -97,4 +101,39 @@ describe("OpenAIInstrumentation", () => {
}
`);
});
it("creates a span for embedding create", async () => {
const response = {
object: "list",
data: [{ object: "embedding", index: 0, embedding: [1, 2, 3] }],
};
// Mock out the embedding create endpoint
jest.spyOn(openai, "post").mockImplementation(
// @ts-expect-error the response type is not correct - this is just for testing
async (): Promise<unknown> => {
return response;
},
);
await openai.embeddings.create({
input: "A happy moment",
model: "text-embedding-ada-002",
});
const spans = memoryExporter.getFinishedSpans();
expect(spans.length).toBe(1);
const span = spans[0];
expect(span.name).toBe("OpenAI Embeddings");
expect(span.attributes).toMatchInlineSnapshot(`
{
"embedding.embeddings.0.embedding.text": "A happy moment",
"embedding.embeddings.0.embedding.vector": [
1,
2,
3,
],
"embedding.model_name": "text-embedding-ada-002",
"input.mime_type": "text/plain",
"input.value": "A happy moment",
"openinference.span.kind": "embedding",
}
`);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ export const EMBEDDING_MODEL_NAME =
export const EMBEDDING_VECTOR =
`${SemanticAttributePrefixes.embedding}.${EmbeddingAttributePostfixes.vector}` as const;

/**
* The embedding list root
*/
export const EMBEDDING_EMBEDDINGS =
`${SemanticAttributePrefixes.embedding}.${EmbeddingAttributePostfixes.embeddings}` as const;

export const SemanticConventions = {
INPUT_VALUE,
INPUT_MIME_TYPE,
Expand All @@ -234,6 +240,7 @@ export const SemanticConventions = {
DOCUMENT_CONTENT,
DOCUMENT_SCORE,
DOCUMENT_METADATA,
EMBEDDING_EMBEDDINGS,
EMBEDDING_TEXT,
EMBEDDING_MODEL_NAME,
EMBEDDING_VECTOR,
Expand Down

0 comments on commit 18a7df2

Please sign in to comment.