Skip to content

Commit

Permalink
using azure-identity to support additional auth method (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Q authored Aug 30, 2024
1 parent 18e8f8b commit 49cf816
Show file tree
Hide file tree
Showing 10 changed files with 1,291 additions and 1,066 deletions.
4 changes: 2 additions & 2 deletions auto_eval/ds1000_scripts/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ pandas==1.5.3
pytorch::cpuonly
pytorch::pytorch==2.2.0
seaborn==0.13.2
scikit-learn==1.4.0
scikit-learn==1.5.0
scipy==1.12.0
statsmodels==0.14.1
xgboost==2.0.3
tensorflow==2.11.1
tensorflow==2.12.1
yaml
28 changes: 23 additions & 5 deletions taskweaver/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import sys
from typing import TYPE_CHECKING, Any, Generator, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional

from injector import inject

Expand Down Expand Up @@ -50,7 +50,7 @@ def _configure(self) -> None:
self.response_format = self.llm_module_config.response_format

# openai specific config
self.api_version = self._get_str("api_version", "2023-12-01-preview")
self.api_version = self._get_str("api_version", "2024-06-01")
self.api_auth_type = self._get_enum(
"api_auth_type",
["openai", "azure", "azure_ad"],
Expand All @@ -59,7 +59,7 @@ def _configure(self) -> None:
is_azure_ad_login = self.api_type == "azure_ad"
self.aad_auth_mode = self._get_enum(
"aad_auth_mode",
["device_login", "aad_app"],
["device_login", "aad_app", "default_azure_credential"],
None if is_azure_ad_login else "device_login",
)

Expand Down Expand Up @@ -147,7 +147,7 @@ def client(self):
client = AzureOpenAI(
api_version=self.config.api_version,
azure_endpoint=self.config.api_base,
azure_ad_token_provider=lambda: self._get_aad_token(),
azure_ad_token_provider=self._get_aad_token_provider(),
)
else:
raise Exception(f"Invalid API type: {self.api_type}")
Expand Down Expand Up @@ -297,7 +297,25 @@ def get_embeddings(self, strings: List[str]) -> List[List[float]]:
).data
return [r.embedding for r in embedding_results]

def _get_aad_token(self) -> str:
def _get_aad_token_provider(self) -> Callable[[], str]:
if self.config.aad_auth_mode == "default_azure_credential":
return self._get_aad_token_provider_azure_identity()
return lambda: self._get_aad_token_msal()

def _get_aad_token_provider_azure_identity(self) -> Callable[[], str]:
try:
from azure.identity import DefaultAzureCredential, get_bearer_token_provider # type: ignore
except ImportError:
raise Exception(
"AAD authentication requires azure-identity module to be installed, "
"please run `pip install azure-identity`",
)
credential = DefaultAzureCredential(exclude_interactive_browser_credential=False)
print("Using DefaultAzureCredential for AAD authentication")
scope = f"{self.config.aad_api_resource}/{self.config.aad_api_scope}"
return get_bearer_token_provider(credential, scope)

def _get_aad_token_msal(self) -> str:
try:
import msal # type: ignore
except ImportError:
Expand Down
4 changes: 4 additions & 0 deletions website/blog/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ However, both agents provide the correct answer to the question.
But if the evaluation method takes the agent as a function, it may not be able to handle the different behaviors of the agents
and consider Agent 2 as incorrect (as the first response does not match the ground truth, e.g., "sunny").


## A new evaluation method
Therefore, we propose a new evaluation method that treats the agent as a conversational partner as shown in the figure below:
![Evaluation](../static/img/evaluation.png)

<!-- truncate -->

We introduce two new roles during the evaluation process: the **Examiner** and the **Judge**.
For each test case, the task description is first given to the Examiner.
The Examiner then asks questions to the agent and supervises the conversation.
Expand Down
4 changes: 3 additions & 1 deletion website/blog/local_llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
The feature introduced in this blog post can cause incompatibility issue with the previous version of TaskWeaver
if you have customized the examples for the planner and code interpreter.
The issue is easy to fix by changing the examples to the new schema.
Please refer to the [How we implemented the constrained generation in TaskWeaver](#how-we-implemented-the-constrained-generation-in-taskweaver) section for more details.
Please refer to the [How we implemented the constrained generation in TaskWeaver](/blog/local_llm#how-we-implemented-the-constrained-generation-in-taskweaver) section for more details.
:::

## Motivation
Expand All @@ -21,6 +21,8 @@ was to ask the model to re-generate the response if it does not follow the forma
We include the format error in the prompt to help the model understand the error and
correct it. However, this approach also did not work well.

<!-- truncate -->

## Constrained Generation

Recently, we discovered a new approach called "Constrained Generation" that can enforce
Expand Down
2 changes: 2 additions & 0 deletions website/blog/plugin.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ We did not check the discrepancy between the function signature in the Python im
So, it is important to keep them consistent.
The `examples` field is used to provide examples of how to use the plugin for the LLM.

<!-- truncate -->

## Configurations and States

Although the plugin is used as a function in the code snippets, it is more than a normal Python function.
Expand Down
2 changes: 2 additions & 0 deletions website/blog/role.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ flowchart TD
B --response--> A
```

<!-- truncate -->

However, we do find challenges for other tasks that are not naturally represented in code snippets.
Let's consider another example: _the agent is asked to read a manual and follow the instructions to process the data_.
We first assume there is a plugin that can read the manual and extract the instructions, called `read_manual`.
Expand Down
61 changes: 55 additions & 6 deletions website/docs/llms/aoai.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ description: Using LLMs from OpenAI/AOAI
---
# Azure OpenAI

## Using API Key

1. Create an account on [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) and get your API key.
2. Add the following to your `taskweaver_config.json` file:
```json showLineNumbers
2. Create a new deployment of the model and get the deployment name.
3. Add the following to your `taskweaver_config.json` file:
```jsonc showLineNumbers
{
"llm.api_base":"YOUR_AOAI_ENDPOINT",
"llm.api_base":"YOUR_AOAI_ENDPOINT", // in the format of https://<my-resource>.openai.azure.com"
"llm.api_key":"YOUR_API_KEY",
"llm.api_type":"azure",
"llm.auth_mode":"api-key",
"llm.model":"gpt-4-1106-preview", # this is known as deployment_name in Azure OpenAI
"llm.model":"gpt-4-1106-preview", // this is known as deployment_name in Azure OpenAI
"llm.response_format": "json_object"
}
```
Expand All @@ -21,5 +24,51 @@ For model versions or after `1106`, `llm.response_format` can be set to `json_ob
However, for the earlier models, which do not support JSON response explicitly, `llm.response_format` should be set to `null`.
:::

3. Start TaskWeaver and chat with TaskWeaver.
You can refer to the [Quick Start](../quickstart.md) for more details.
4. Start TaskWeaver and chat with TaskWeaver.
You can refer to the [Quick Start](../quickstart.md) for more details.

## Using Entra Authentication

1. Create an account on [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) and
[assign the proper Azure RBAC Role](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/role-based-access-control) to your account (or service principal).
2. Create a new deployment of the model and get the deployment name.
3. Add the following to your `taskweaver_config.json` file:
```jsonc showLineNumbers
{
"llm.api_base":"YOUR_AOAI_ENDPOINT", // in the format of https://<my-resource>.openai.azure.com"
"llm.api_type":"azure_ad",
"llm.auth_mode":"default_azure_credential",
"llm.model":"gpt-4-1106-preview", // this is known as deployment_name in Azure OpenAI
"llm.response_format": "json_object"
}
```
4. Install extra dependencies:
```bash
pip install azure-identity
```
5. Optionally configure additional environment variables or dependencies for the specifying authentication method:

Internally, authentication is handled by the `DefaultAzureCredential` class from the `azure-identity` package. It would try to authenticate using a series of methods depending on the availability in current running environment (such as environment variables, managed identity, etc.). You can refer to the [official documentation](https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python) for more details.

For example, you can specify different environment variables to control the authentication method:
1. Authenticating with AzureCLI (recommended for local development):

Install AzureCLI and ensure `az` is available in your PATH. Then run the following command to login:
```bash
az login
```

2. Authenticating with Managed Identity (recommended for Azure environment):

If you are running TaskWeaver on Azure, you can use Managed Identity for authentication. You can check the document for specific Azure services on how to enable Managed Identity.

When using user assigned managed identity, you can set the following environment variable to specify the client ID of the managed identity:
```bash
export AZURE_CLIENT_ID="YOUR_CLIENT_ID"
```

3. Authenticating with Service Principal:

You can follow the docs in the [official documentation](https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.environmentcredential?view=azure-python) to specify the environment variables for Service Principal authentication.

6. Start TaskWeaver and chat with TaskWeaver.
52 changes: 26 additions & 26 deletions website/docusaurus.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// There are various equivalent ways to declare your Docusaurus config.
// See: https://docusaurus.io/docs/api/docusaurus-config

import {themes as prismThemes} from 'prism-react-renderer';
import { themes as prismThemes } from 'prism-react-renderer';

/** @type {import('@docusaurus/types').Config} */
const config = {
Expand Down Expand Up @@ -147,30 +147,30 @@ const config = {
additionalLanguages: ['bash', 'json', 'yaml'],
},
}),
themes: [
[
require.resolve("@easyops-cn/docusaurus-search-local"),
/** @type {import("@easyops-cn/docusaurus-search-local").PluginOptions} */
({
hashed: true,
docsRouteBasePath: "docs",
blogRouteBasePath: "blog",
docsDir: "docs",
blogDir: "blog",
searchContextByPaths: [
{
label: "Documents",
path: "docs",
},
{
label: "Blog",
path: "blog",
},
],
hideSearchBarWithNoSearchContext: true,
}),
],
'@docusaurus/theme-mermaid'
themes: [
[
require.resolve("@easyops-cn/docusaurus-search-local"),
/** @type {import("@easyops-cn/docusaurus-search-local").PluginOptions} */
{
hashed: true,
docsRouteBasePath: "docs",
blogRouteBasePath: "blog",
docsDir: "docs",
blogDir: "blog",
searchContextByPaths: [
{
label: "Documents",
path: "docs",
},
{
label: "Blog",
path: "blog",
},
],
hideSearchBarWithNoSearchContext: true,
},
],
};
'@docusaurus/theme-mermaid'
],
};
export default config;
Loading

0 comments on commit 49cf816

Please sign in to comment.