From 13ce71e52c9d575a389dd7730fc60fa835c0bb44 Mon Sep 17 00:00:00 2001 From: Ajay Lamba Date: Sat, 10 Feb 2024 00:00:16 +0530 Subject: [PATCH] Derived value of crossEncodingEnabled based on enableEmbeddingModelsViaSagemaker config. --- bin/config.ts | 3 ++- cli/magic-config.ts | 37 +++++++++++++++++++++++++----- lib/aws-genai-llm-chatbot-stack.ts | 3 +-- lib/shared/types.ts | 3 ++- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/bin/config.ts b/bin/config.ts index 46f673338..d64d05eeb 100644 --- a/bin/config.ts +++ b/bin/config.ts @@ -25,6 +25,7 @@ export function getConfig(): SystemConfig { }, rag: { enabled: false, + enableEmbeddingModelsViaSagemaker: false, engines: { aurora: { enabled: false, @@ -61,6 +62,7 @@ export function getConfig(): SystemConfig { dimensions: 1536, }, ], + crossEncodingEnabled: false, crossEncoderModels: [ { provider: "sagemaker", @@ -68,7 +70,6 @@ export function getConfig(): SystemConfig { default: true, }, ], - crossEncodingEnabled: false, }, }; } diff --git a/cli/magic-config.ts b/cli/magic-config.ts index cadcb091c..8f0177cd8 100644 --- a/cli/magic-config.ts +++ b/cli/magic-config.ts @@ -203,6 +203,15 @@ async function processCreateOptions(options: any): Promise { message: "Do you want to enable RAG", initial: options.enableRag || false, }, + { + type: "confirm", + name: "enableEmbeddingModelsViaSagemaker", + message: "Do you want to enable embedding models via SageMaker?", + initial: options.enableEmbeddingModelsViaSagemaker || false, + skip(): boolean { + return !(this as any).state.answers.enableRag; + }, + }, { type: "multiselect", name: "ragsToEnable", @@ -349,10 +358,13 @@ async function processCreateOptions(options: any): Promise { } : undefined, llms: { + enableSagemakerModels: answers.enableSagemakerModels, sagemaker: answers.sagemakerModels, }, rag: { enabled: answers.enableRag, + enableEmbeddingModelsViaSagemaker: + answers.enableEmbeddingModelsViaSagemaker, engines: { aurora: { enabled: answers.ragsToEnable.includes("aurora"), @@ -367,6 +379,7 @@ async function processCreateOptions(options: any): Promise { enterprise: false, }, }, + crossEncodingEnabled: answers.enableEmbeddingModelsViaSagemaker, embeddingsModels: [{}], crossEncoderModels: [{}], }, @@ -377,12 +390,24 @@ async function processCreateOptions(options: any): Promise { models.defaultEmbedding = embeddingModels[0].name; } - config.rag.crossEncoderModels[0] = { - provider: "sagemaker", - name: "cross-encoder/ms-marco-MiniLM-L-12-v2", - default: true, - }; - config.rag.embeddingsModels = embeddingModels; + if (answers.enableEmbeddingModelsViaSagemaker && answers.enableSagemakerModels) { + config.rag.crossEncoderModels[0] = { + provider: "sagemaker", + name: "cross-encoder/ms-marco-MiniLM-L-12-v2", + default: true, + }; + } else { + config.rag.crossEncoderModels[0] = { + provider: "None", + name: "None", + default: true, + }; + } + if (!config.rag.enableEmbeddingModelsViaSagemaker) { + config.rag.embeddingsModels = embeddingModels.filter(model => model.provider !== "sagemaker"); + } else { + config.rag.embeddingsModels = embeddingModels; + } config.rag.embeddingsModels.forEach((m: any) => { if (m.name === models.defaultEmbedding) { m.default = true; diff --git a/lib/aws-genai-llm-chatbot-stack.ts b/lib/aws-genai-llm-chatbot-stack.ts index 219aaa015..01f3e70a4 100644 --- a/lib/aws-genai-llm-chatbot-stack.ts +++ b/lib/aws-genai-llm-chatbot-stack.ts @@ -152,8 +152,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { api: chatBotApi, chatbotFilesBucket: chatBotApi.filesBucket, crossEncodersEnabled: props.config.rag.crossEncodingEnabled, - sagemakerEmbeddingsEnabled: - typeof ragEngines?.sageMakerRagModels?.model !== "undefined", + sagemakerEmbeddingsEnabled: props.config.rag.enableEmbeddingModelsViaSagemaker, }); /** diff --git a/lib/shared/types.ts b/lib/shared/types.ts index 780246dd8..231ee39f0 100644 --- a/lib/shared/types.ts +++ b/lib/shared/types.ts @@ -88,6 +88,7 @@ export interface SystemConfig { }; rag: { enabled: boolean; + enableEmbeddingModelsViaSagemaker: boolean; engines: { aurora: { enabled: boolean; @@ -113,12 +114,12 @@ export interface SystemConfig { dimensions: number; default?: boolean; }[]; + crossEncodingEnabled: boolean; crossEncoderModels: { provider: ModelProvider; name: string; default?: boolean; }[]; - crossEncodingEnabled: boolean; }; }