From 18a7df2b4d5c2d5ff19b158bf580226b872b7663 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Fri, 5 Jan 2024 16:05:13 -0700 Subject: [PATCH] feat(js): OpenAI embeddings Instrumentation (#34) --- js/.changeset/old-apples-attack.md | 6 + .../jest.config.js | 1 + .../src/instrumentation.ts | 136 +++++++++++++++++- .../test/openai.test.ts | 49 ++++++- .../src/trace/SemanticConventions.ts | 7 + 5 files changed, 193 insertions(+), 6 deletions(-) create mode 100644 js/.changeset/old-apples-attack.md diff --git a/js/.changeset/old-apples-attack.md b/js/.changeset/old-apples-attack.md new file mode 100644 index 000000000..f32de5cce --- /dev/null +++ b/js/.changeset/old-apples-attack.md @@ -0,0 +1,6 @@ +--- +"@arizeai/openinference-instrumentation-openai": patch +"@arizeai/openinference-semantic-conventions": patch +--- + +Add OpenAI Embeddings sementic attributes and instrumentation diff --git a/js/packages/openinference-instrumentation-openai/jest.config.js b/js/packages/openinference-instrumentation-openai/jest.config.js index 3abcbd946..2fd4f78cf 100644 --- a/js/packages/openinference-instrumentation-openai/jest.config.js +++ b/js/packages/openinference-instrumentation-openai/jest.config.js @@ -2,4 +2,5 @@ module.exports = { preset: "ts-jest", testEnvironment: "node", + prettierPath: null, }; diff --git a/js/packages/openinference-instrumentation-openai/src/instrumentation.ts b/js/packages/openinference-instrumentation-openai/src/instrumentation.ts index 1c7ace0fa..088693da0 100644 --- a/js/packages/openinference-instrumentation-openai/src/instrumentation.ts +++ b/js/packages/openinference-instrumentation-openai/src/instrumentation.ts @@ -26,6 +26,10 @@ import { ChatCompletionCreateParamsBase, } from "openai/resources/chat/completions"; import { Stream } from "openai/streaming"; +import { + CreateEmbeddingResponse, + EmbeddingCreateParams, +} from "openai/resources"; const MODULE_NAME = "openai"; @@ -80,7 +84,7 @@ export class OpenAIInstrumentation extends InstrumentationBase { const span = instrumentation.tracer.startSpan( `OpenAI Chat Completions`, { - kind: SpanKind.CLIENT, + kind: SpanKind.INTERNAL, attributes: { [SemanticConventions.OPENINFERENCE_SPAN_KIND]: OpenInferenceSpanKind.LLM, @@ -106,6 +110,11 @@ export class OpenAIInstrumentation extends InstrumentationBase { // Push the error to the span if (error) { span.recordException(error); + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error.message, + }); + span.end(); } }, ); @@ -115,6 +124,12 @@ export class OpenAIInstrumentation extends InstrumentationBase { span.setAttributes({ [SemanticConventions.OUTPUT_VALUE]: JSON.stringify(result), [SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.JSON, + // Override the model from the value sent by the server + [SemanticConventions.LLM_MODEL_NAME]: isChatCompletionResponse( + result, + ) + ? result.model + : body.model, ...getLLMOutputMessagesAttributes(result), ...getUsageAttributes(result), }); @@ -127,6 +142,75 @@ export class OpenAIInstrumentation extends InstrumentationBase { }; }, ); + + // Patch embeddings + type EmbeddingsCreateType = + typeof module.OpenAI.Embeddings.prototype.create; + this._wrap( + module.OpenAI.Embeddings.prototype, + "create", + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (original: EmbeddingsCreateType): any => { + return function patchedEmbeddingCreate( + this: unknown, + ...args: Parameters + ) { + const body = args[0]; + const { input } = body; + const isStringInput = typeof input == "string"; + const span = instrumentation.tracer.startSpan(`OpenAI Embeddings`, { + kind: SpanKind.INTERNAL, + attributes: { + [SemanticConventions.OPENINFERENCE_SPAN_KIND]: + OpenInferenceSpanKind.EMBEDDING, + [SemanticConventions.EMBEDDING_MODEL_NAME]: body.model, + [SemanticConventions.INPUT_VALUE]: isStringInput + ? input + : JSON.stringify(input), + [SemanticConventions.INPUT_MIME_TYPE]: isStringInput + ? MimeType.TEXT + : MimeType.JSON, + ...getEmbeddingTextAttributes(body), + }, + }); + const execContext = trace.setSpan(context.active(), span); + const execPromise = safeExecuteInTheMiddle< + ReturnType + >( + () => { + return context.with(execContext, () => { + return original.apply(this, args); + }); + }, + (error) => { + // Push the error to the span + if (error) { + span.recordException(error); + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error.message, + }); + span.end(); + } + }, + ); + const wrappedPromise = execPromise.then((result) => { + if (result) { + // Record the results + span.setAttributes({ + // Do not record the output data as it can be large + ...getEmbeddingEmbeddingsAttributes(result), + }); + } + span.setStatus({ code: SpanStatusCode.OK }); + span.end(); + return result; + }); + return context.bind(execContext, wrappedPromise); + }; + }, + ); + module.openInferencePatched = true; return module; } @@ -136,9 +220,19 @@ export class OpenAIInstrumentation extends InstrumentationBase { private unpatch(moduleExports: typeof openai, moduleVersion?: string) { diag.debug(`Removing patch for ${MODULE_NAME}@${moduleVersion}`); this._unwrap(moduleExports.OpenAI.Chat.Completions.prototype, "create"); + this._unwrap(moduleExports.OpenAI.Embeddings.prototype, "create"); } } +/** + * type-guard that checks if the response is a chat completion response + */ +function isChatCompletionResponse( + response: Stream | ChatCompletion, +): response is ChatCompletion { + return "choices" in response; +} + /** * Converts the body of the request to LLM input messages */ @@ -204,3 +298,43 @@ function getLLMOutputMessagesAttributes( } return {}; } + +/** + * Converts the embedding result payload to embedding attributes + */ +function getEmbeddingTextAttributes( + request: EmbeddingCreateParams, +): Attributes { + if (typeof request.input == "string") { + return { + [`${SemanticConventions.EMBEDDING_EMBEDDINGS}.0.${SemanticConventions.EMBEDDING_TEXT}`]: + request.input, + }; + } else if ( + Array.isArray(request.input) && + request.input.length > 0 && + typeof request.input[0] == "string" + ) { + return request.input.reduce((acc, input, index) => { + const index_prefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}`; + acc[`${index_prefix}.${SemanticConventions.EMBEDDING_TEXT}`] = input; + return acc; + }, {} as Attributes); + } + // Ignore other cases where input is a number or an array of numbers + return {}; +} + +/** + * Converts the embedding result payload to embedding attributes + */ +function getEmbeddingEmbeddingsAttributes( + response: CreateEmbeddingResponse, +): Attributes { + return response.data.reduce((acc, embedding, index) => { + const index_prefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}`; + acc[`${index_prefix}.${SemanticConventions.EMBEDDING_VECTOR}`] = + embedding.embedding; + return acc; + }, {} as Attributes); +} diff --git a/js/packages/openinference-instrumentation-openai/test/openai.test.ts b/js/packages/openinference-instrumentation-openai/test/openai.test.ts index 7b73c423e..a768fee77 100644 --- a/js/packages/openinference-instrumentation-openai/test/openai.test.ts +++ b/js/packages/openinference-instrumentation-openai/test/openai.test.ts @@ -21,18 +21,22 @@ describe("OpenAIInstrumentation", () => { instrumentation.setTracerProvider(tracerProvider); tracerProvider.addSpanProcessor(new SimpleSpanProcessor(memoryExporter)); + // @ts-expect-error the moduleExports property is private. This is needed to make the test work with auto-mocking + instrumentation._modules[0].moduleExports = OpenAI; - beforeEach(() => { - // @ts-expect-error the moduleExports property is private. This is needed to make the test work with auto-mocking - instrumentation._modules[0].moduleExports = OpenAI; + beforeAll(() => { instrumentation.enable(); openai = new OpenAI.OpenAI({ apiKey: `fake-api-key`, }); + }); + afterAll(() => { + instrumentation.disable(); + }); + beforeEach(() => { memoryExporter.reset(); }); afterEach(() => { - instrumentation.disable(); jest.clearAllMocks(); }); it("is patched", () => { @@ -85,7 +89,7 @@ describe("OpenAIInstrumentation", () => { "llm.input_messages.0.message.content": "Say this is a test", "llm.input_messages.0.message.role": "user", "llm.invocation_parameters": "{"model":"gpt-3.5-turbo"}", - "llm.model_name": "gpt-3.5-turbo", + "llm.model_name": "gpt-3.5-turbo-0613", "llm.output_messages.0.message.content": "This is a test.", "llm.output_messages.0.message.role": "assistant", "llm.token_count.completion": 5, @@ -97,4 +101,39 @@ describe("OpenAIInstrumentation", () => { } `); }); + it("creates a span for embedding create", async () => { + const response = { + object: "list", + data: [{ object: "embedding", index: 0, embedding: [1, 2, 3] }], + }; + // Mock out the embedding create endpoint + jest.spyOn(openai, "post").mockImplementation( + // @ts-expect-error the response type is not correct - this is just for testing + async (): Promise => { + return response; + }, + ); + await openai.embeddings.create({ + input: "A happy moment", + model: "text-embedding-ada-002", + }); + const spans = memoryExporter.getFinishedSpans(); + expect(spans.length).toBe(1); + const span = spans[0]; + expect(span.name).toBe("OpenAI Embeddings"); + expect(span.attributes).toMatchInlineSnapshot(` + { + "embedding.embeddings.0.embedding.text": "A happy moment", + "embedding.embeddings.0.embedding.vector": [ + 1, + 2, + 3, + ], + "embedding.model_name": "text-embedding-ada-002", + "input.mime_type": "text/plain", + "input.value": "A happy moment", + "openinference.span.kind": "embedding", + } + `); + }); }); diff --git a/js/packages/openinference-semantic-conventions/src/trace/SemanticConventions.ts b/js/packages/openinference-semantic-conventions/src/trace/SemanticConventions.ts index 26e0ea77c..97ee1b62c 100644 --- a/js/packages/openinference-semantic-conventions/src/trace/SemanticConventions.ts +++ b/js/packages/openinference-semantic-conventions/src/trace/SemanticConventions.ts @@ -210,6 +210,12 @@ export const EMBEDDING_MODEL_NAME = export const EMBEDDING_VECTOR = `${SemanticAttributePrefixes.embedding}.${EmbeddingAttributePostfixes.vector}` as const; +/** + * The embedding list root + */ +export const EMBEDDING_EMBEDDINGS = + `${SemanticAttributePrefixes.embedding}.${EmbeddingAttributePostfixes.embeddings}` as const; + export const SemanticConventions = { INPUT_VALUE, INPUT_MIME_TYPE, @@ -234,6 +240,7 @@ export const SemanticConventions = { DOCUMENT_CONTENT, DOCUMENT_SCORE, DOCUMENT_METADATA, + EMBEDDING_EMBEDDINGS, EMBEDDING_TEXT, EMBEDDING_MODEL_NAME, EMBEDDING_VECTOR,