From 4ffc4931be90ddbb344ee38d148dbc653e35c32d Mon Sep 17 00:00:00 2001 From: Sergio Prada Date: Tue, 14 Nov 2023 09:00:16 -0500 Subject: [PATCH] Handle azure long term memory + session tag --- README.md | 3 ++- src/long_term_memory.rs | 2 +- src/models.rs | 59 ++++++++++++++++++++++++++++++++--------- src/redis_utils.rs | 8 ++++-- 4 files changed, 55 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 9976283..cbbc972 100644 --- a/README.md +++ b/README.md @@ -101,9 +101,10 @@ Searches are segmented (filtered) by the session id provided automatically. ### Azure deployment -*NOTE: `MOTORHEAD_LONG_TERM_MEMORY=true` won't work with Azure*. Additional Environment Variables are required for Azure deployments: +Additional Environment Variables are required for Azure deployments: - `AZURE_DEPLOYMENT_ID` +- `AZURE_DEPLOYMENT_ID_ADA` - `AZURE_API_BASE` - `AZURE_API_KEY` diff --git a/src/long_term_memory.rs b/src/long_term_memory.rs index 1a52443..a3f1c9b 100644 --- a/src/long_term_memory.rs +++ b/src/long_term_memory.rs @@ -53,7 +53,7 @@ pub async fn search_messages( let response = openai_client.create_embedding(vec![query]).await?; let embeddings = response[0].clone(); let vector = encode(embeddings); - let query = format!("@session:{}=>[KNN 10 @vector $V AS dist]", session_id); + let query = format!("@session:{{{}}}=>[KNN 10 @vector $V AS dist]", session_id); let values: Vec = redis::cmd("FT.SEARCH") .arg("motorhead") diff --git a/src/models.rs b/src/models.rs index 2b1eed8..130276d 100644 --- a/src/models.rs +++ b/src/models.rs @@ -29,17 +29,36 @@ impl Manager for OpenAIClientManager { let openai_client = match ( env::var("AZURE_API_KEY"), env::var("AZURE_DEPLOYMENT_ID"), + env::var("AZURE_DEPLOYMENT_ID_ADA"), env::var("AZURE_API_BASE"), ) { - (Ok(azure_api_key), Ok(azure_deployment_id), Ok(azure_api_base)) => { + ( + Ok(azure_api_key), + Ok(azure_deployment_id), + Ok(azure_deployment_id_ada), + Ok(azure_api_base), + ) => { let config = AzureConfig::new() - .with_api_base(azure_api_base) - .with_api_key(azure_api_key) + .with_api_base(&azure_api_base) + .with_api_key(&azure_api_key) .with_deployment_id(azure_deployment_id) .with_api_version("2023-05-15"); - AnyOpenAIClient::Azure(Client::with_config(config)) + + let config_ada = AzureConfig::new() + .with_api_base(&azure_api_base) + .with_api_key(&azure_api_key) + .with_deployment_id(azure_deployment_id_ada) + .with_api_version("2023-05-15"); + + AnyOpenAIClient::Azure { + embedding_client: Client::with_config(config_ada), + completion_client: Client::with_config(config), + } } - _ => AnyOpenAIClient::OpenAI(Client::new()), + _ => AnyOpenAIClient::OpenAI { + embedding_client: Client::new(), + completion_client: Client::new(), + }, }; Ok(openai_client) } @@ -50,8 +69,14 @@ impl Manager for OpenAIClientManager { } pub enum AnyOpenAIClient { - Azure(Client), - OpenAI(Client), + Azure { + embedding_client: Client, + completion_client: Client, + }, + OpenAI { + embedding_client: Client, + completion_client: Client, + }, } impl AnyOpenAIClient { @@ -70,8 +95,12 @@ impl AnyOpenAIClient { .build()?; match self { - AnyOpenAIClient::Azure(client) => client.chat().create(request).await, - AnyOpenAIClient::OpenAI(client) => client.chat().create(request).await, + AnyOpenAIClient::Azure { + completion_client, .. + } => completion_client.chat().create(request).await, + AnyOpenAIClient::OpenAI { + completion_client, .. + } => completion_client.chat().create(request).await, } } @@ -80,13 +109,15 @@ impl AnyOpenAIClient { query_vec: Vec, ) -> Result>, OpenAIError> { match self { - AnyOpenAIClient::OpenAI(client) => { + AnyOpenAIClient::OpenAI { + embedding_client, .. + } => { let request = CreateEmbeddingRequestArgs::default() .model("text-embedding-ada-002") .input(query_vec) .build()?; - let response = client.embeddings().create(request).await?; + let response = embedding_client.embeddings().create(request).await?; let embeddings: Vec<_> = response .data .iter() @@ -95,7 +126,9 @@ impl AnyOpenAIClient { Ok(embeddings) } - AnyOpenAIClient::Azure(client) => { + AnyOpenAIClient::Azure { + embedding_client, .. + } => { let tasks: Vec<_> = query_vec .into_iter() .map(|query| async { @@ -104,7 +137,7 @@ impl AnyOpenAIClient { .input(vec![query]) .build()?; - client.embeddings().create(request).await + embedding_client.embeddings().create(request).await }) .collect(); diff --git a/src/redis_utils.rs b/src/redis_utils.rs index cf3b727..2f0b71d 100644 --- a/src/redis_utils.rs +++ b/src/redis_utils.rs @@ -11,7 +11,11 @@ pub fn ensure_redisearch_index( let index_info: Result = redis::cmd("FT.INFO").arg(index_name).query(&mut con); if let Err(err) = index_info { - if err.to_string().to_lowercase().contains("unknown: index name") { + if err + .to_string() + .to_lowercase() + .contains("unknown: index name") + { redis::cmd("FT.CREATE") .arg(index_name) .arg("ON") @@ -21,7 +25,7 @@ pub fn ensure_redisearch_index( .arg("motorhead:") .arg("SCHEMA") .arg("session") - .arg("TEXT") + .arg("TAG") .arg("content") .arg("TEXT") .arg("role")