Skip to content

Commit

Permalink
Derived value of crossEncodingEnabled based on enableEmbeddingModelsV…
Browse files Browse the repository at this point in the history
…iaSagemaker config.
  • Loading branch information
azaylamba committed Feb 9, 2024
1 parent cf0dfc1 commit 13ce71e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 10 deletions.
3 changes: 2 additions & 1 deletion bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export function getConfig(): SystemConfig {
},
rag: {
enabled: false,
enableEmbeddingModelsViaSagemaker: false,
engines: {
aurora: {
enabled: false,
Expand Down Expand Up @@ -61,14 +62,14 @@ export function getConfig(): SystemConfig {
dimensions: 1536,
},
],
crossEncodingEnabled: false,
crossEncoderModels: [
{
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
},
],
crossEncodingEnabled: false,
},
};
}
Expand Down
37 changes: 31 additions & 6 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ async function processCreateOptions(options: any): Promise<void> {
message: "Do you want to enable RAG",
initial: options.enableRag || false,
},
{
type: "confirm",
name: "enableEmbeddingModelsViaSagemaker",
message: "Do you want to enable embedding models via SageMaker?",
initial: options.enableEmbeddingModelsViaSagemaker || false,
skip(): boolean {
return !(this as any).state.answers.enableRag;
},
},
{
type: "multiselect",
name: "ragsToEnable",
Expand Down Expand Up @@ -349,10 +358,13 @@ async function processCreateOptions(options: any): Promise<void> {
}
: undefined,
llms: {
enableSagemakerModels: answers.enableSagemakerModels,
sagemaker: answers.sagemakerModels,
},
rag: {
enabled: answers.enableRag,
enableEmbeddingModelsViaSagemaker:
answers.enableEmbeddingModelsViaSagemaker,
engines: {
aurora: {
enabled: answers.ragsToEnable.includes("aurora"),
Expand All @@ -367,6 +379,7 @@ async function processCreateOptions(options: any): Promise<void> {
enterprise: false,
},
},
crossEncodingEnabled: answers.enableEmbeddingModelsViaSagemaker,
embeddingsModels: [{}],
crossEncoderModels: [{}],
},
Expand All @@ -377,12 +390,24 @@ async function processCreateOptions(options: any): Promise<void> {
models.defaultEmbedding = embeddingModels[0].name;
}

config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.embeddingsModels = embeddingModels;
if (answers.enableEmbeddingModelsViaSagemaker && answers.enableSagemakerModels) {
config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
} else {
config.rag.crossEncoderModels[0] = {
provider: "None",
name: "None",
default: true,
};
}
if (!config.rag.enableEmbeddingModelsViaSagemaker) {
config.rag.embeddingsModels = embeddingModels.filter(model => model.provider !== "sagemaker");
} else {
config.rag.embeddingsModels = embeddingModels;
}
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
m.default = true;
Expand Down
3 changes: 1 addition & 2 deletions lib/aws-genai-llm-chatbot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
api: chatBotApi,
chatbotFilesBucket: chatBotApi.filesBucket,
crossEncodersEnabled: props.config.rag.crossEncodingEnabled,
sagemakerEmbeddingsEnabled:
typeof ragEngines?.sageMakerRagModels?.model !== "undefined",
sagemakerEmbeddingsEnabled: props.config.rag.enableEmbeddingModelsViaSagemaker,
});

/**
Expand Down
3 changes: 2 additions & 1 deletion lib/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ export interface SystemConfig {
};
rag: {
enabled: boolean;
enableEmbeddingModelsViaSagemaker: boolean;
engines: {
aurora: {
enabled: boolean;
Expand All @@ -113,12 +114,12 @@ export interface SystemConfig {
dimensions: number;
default?: boolean;
}[];
crossEncodingEnabled: boolean;
crossEncoderModels: {
provider: ModelProvider;
name: string;
default?: boolean;
}[];
crossEncodingEnabled: boolean;
};
}

Expand Down

0 comments on commit 13ce71e

Please sign in to comment.