Skip to content

Commit

Permalink
Handle azure long term memory + session tag
Browse files Browse the repository at this point in the history
  • Loading branch information
Czechh committed Nov 14, 2023
1 parent ce4857f commit 4ffc493
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 17 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ Searches are segmented (filtered) by the session id provided automatically.

### Azure deployment

*NOTE: `MOTORHEAD_LONG_TERM_MEMORY=true` won't work with Azure*. Additional Environment Variables are required for Azure deployments:
Additional Environment Variables are required for Azure deployments:

- `AZURE_DEPLOYMENT_ID`
- `AZURE_DEPLOYMENT_ID_ADA`
- `AZURE_API_BASE`
- `AZURE_API_KEY`

Expand Down
2 changes: 1 addition & 1 deletion src/long_term_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub async fn search_messages(
let response = openai_client.create_embedding(vec![query]).await?;
let embeddings = response[0].clone();
let vector = encode(embeddings);
let query = format!("@session:{}=>[KNN 10 @vector $V AS dist]", session_id);
let query = format!("@session:{{{}}}=>[KNN 10 @vector $V AS dist]", session_id);

let values: Vec<Value> = redis::cmd("FT.SEARCH")
.arg("motorhead")
Expand Down
59 changes: 46 additions & 13 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,36 @@ impl Manager for OpenAIClientManager {
let openai_client = match (
env::var("AZURE_API_KEY"),
env::var("AZURE_DEPLOYMENT_ID"),
env::var("AZURE_DEPLOYMENT_ID_ADA"),
env::var("AZURE_API_BASE"),
) {
(Ok(azure_api_key), Ok(azure_deployment_id), Ok(azure_api_base)) => {
(
Ok(azure_api_key),
Ok(azure_deployment_id),
Ok(azure_deployment_id_ada),
Ok(azure_api_base),
) => {
let config = AzureConfig::new()
.with_api_base(azure_api_base)
.with_api_key(azure_api_key)
.with_api_base(&azure_api_base)
.with_api_key(&azure_api_key)
.with_deployment_id(azure_deployment_id)
.with_api_version("2023-05-15");
AnyOpenAIClient::Azure(Client::with_config(config))

let config_ada = AzureConfig::new()
.with_api_base(&azure_api_base)
.with_api_key(&azure_api_key)
.with_deployment_id(azure_deployment_id_ada)
.with_api_version("2023-05-15");

AnyOpenAIClient::Azure {
embedding_client: Client::with_config(config_ada),
completion_client: Client::with_config(config),
}
}
_ => AnyOpenAIClient::OpenAI(Client::new()),
_ => AnyOpenAIClient::OpenAI {
embedding_client: Client::new(),
completion_client: Client::new(),
},
};
Ok(openai_client)
}
Expand All @@ -50,8 +69,14 @@ impl Manager for OpenAIClientManager {
}

pub enum AnyOpenAIClient {
Azure(Client<AzureConfig>),
OpenAI(Client<OpenAIConfig>),
Azure {
embedding_client: Client<AzureConfig>,
completion_client: Client<AzureConfig>,
},
OpenAI {
embedding_client: Client<OpenAIConfig>,
completion_client: Client<OpenAIConfig>,
},
}

impl AnyOpenAIClient {
Expand All @@ -70,8 +95,12 @@ impl AnyOpenAIClient {
.build()?;

match self {
AnyOpenAIClient::Azure(client) => client.chat().create(request).await,
AnyOpenAIClient::OpenAI(client) => client.chat().create(request).await,
AnyOpenAIClient::Azure {
completion_client, ..
} => completion_client.chat().create(request).await,
AnyOpenAIClient::OpenAI {
completion_client, ..
} => completion_client.chat().create(request).await,
}
}

Expand All @@ -80,13 +109,15 @@ impl AnyOpenAIClient {
query_vec: Vec<String>,
) -> Result<Vec<Vec<f32>>, OpenAIError> {
match self {
AnyOpenAIClient::OpenAI(client) => {
AnyOpenAIClient::OpenAI {
embedding_client, ..
} => {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-ada-002")
.input(query_vec)
.build()?;

let response = client.embeddings().create(request).await?;
let response = embedding_client.embeddings().create(request).await?;
let embeddings: Vec<_> = response
.data
.iter()
Expand All @@ -95,7 +126,9 @@ impl AnyOpenAIClient {

Ok(embeddings)
}
AnyOpenAIClient::Azure(client) => {
AnyOpenAIClient::Azure {
embedding_client, ..
} => {
let tasks: Vec<_> = query_vec
.into_iter()
.map(|query| async {
Expand All @@ -104,7 +137,7 @@ impl AnyOpenAIClient {
.input(vec![query])
.build()?;

client.embeddings().create(request).await
embedding_client.embeddings().create(request).await
})
.collect();

Expand Down
8 changes: 6 additions & 2 deletions src/redis_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ pub fn ensure_redisearch_index(
let index_info: Result<redis::Value, _> = redis::cmd("FT.INFO").arg(index_name).query(&mut con);

if let Err(err) = index_info {
if err.to_string().to_lowercase().contains("unknown: index name") {
if err
.to_string()
.to_lowercase()
.contains("unknown: index name")
{
redis::cmd("FT.CREATE")
.arg(index_name)
.arg("ON")
Expand All @@ -21,7 +25,7 @@ pub fn ensure_redisearch_index(
.arg("motorhead:")
.arg("SCHEMA")
.arg("session")
.arg("TEXT")
.arg("TAG")
.arg("content")
.arg("TEXT")
.arg("role")
Expand Down

0 comments on commit 4ffc493

Please sign in to comment.