Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Remove KeyVault Dependency and Initialize clients using Managed identity directly. Also, update BlobTriggerFunction to read blobstream and do document analysis on read chunks. #2

Merged
merged 2 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 59 additions & 23 deletions DocumentVectorPipelineFunctions/BlobTriggerFunction.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.ClientModel;
using System.Net;
using Azure;
using Azure.AI.FormRecognizer.DocumentAnalysis;
using Azure.Storage.Blobs;
Expand All @@ -7,7 +9,7 @@
using Microsoft.Extensions.Logging;
using OpenAI.Embeddings;

namespace BlobStorageTriggeredFunction;
namespace DocumentVectorPipelineFunctions;

public class BlobTriggerFunction(
IConfiguration configuration,
Expand All @@ -21,7 +23,11 @@ public class BlobTriggerFunction(
private const string AzureOpenAIModelDeploymentDimensionsName = "AzureOpenAIModelDimensions";
private static readonly int DefaultDimensions = 1536;

private const int MaxBatchSize = 2048;
private const int MaxRetryCount = 100;
private const int RetryDelay = 10 * 1000; // 100 seconds

private const int MaxBatchSize = 100;
private int embeddingDimensions = DefaultDimensions;

[Function("BlobTriggerFunction")]
public async Task Run([BlobTrigger("documents/{name}", Connection = "AzureBlobStorageAccConnectionString")] BlobClient blobClient)
Expand All @@ -41,20 +47,27 @@ public async Task Run([BlobTrigger("documents/{name}", Connection = "AzureBlobSt

private async Task HandleBlobCreateEventAsync(BlobClient blobClient)
{
var cosmosDBClientWrapper = await CosmosDBClientWrapper.CreateInstance(cosmosClient, this._logger);

this.embeddingDimensions = configuration.GetValue<int>(AzureOpenAIModelDeploymentDimensionsName, DefaultDimensions);
this._logger.LogInformation("Using OpenAI model dimensions: '{embeddingDimensions}'.", this.embeddingDimensions);

this._logger.LogInformation("Analyzing document using DocumentAnalyzerService from blobUri: '{blobUri}' using layout: {layout}", blobClient.Name, "prebuilt-read");
var operation = await documentAnalysisClient.AnalyzeDocumentFromUriAsync(

using MemoryStream memoryStream = new MemoryStream();
await blobClient.DownloadToAsync(memoryStream);
memoryStream.Seek(0, SeekOrigin.Begin);

var operation = await documentAnalysisClient.AnalyzeDocumentAsync(
WaitUntil.Completed,
"prebuilt-read",
blobClient.Uri);
memoryStream);

var result = operation.Value;

this._logger.LogInformation("Extracted content from '{name}', # pages {pageCount}", blobClient.Name, result.Pages.Count);

var cosmosDBClientWrapper = await CosmosDBClientWrapper.CreateInstance(cosmosClient, this._logger);

int totalChunksCount = 0;
var batchChunkTexts = new List<TextChunk>();
var batchChunkTexts = new List<TextChunk>(MaxBatchSize);
amisi01 marked this conversation as resolved.
Show resolved Hide resolved
foreach (var chunk in TextChunker.FixedSizeChunking(result))
{
batchChunkTexts.Add(chunk);
Expand All @@ -63,6 +76,7 @@ private async Task HandleBlobCreateEventAsync(BlobClient blobClient)
if (batchChunkTexts.Count >= MaxBatchSize)
{
await this.ProcessCurrentBatchAsync(blobClient, cosmosDBClientWrapper, batchChunkTexts);
batchChunkTexts.Clear();
}
}

Expand All @@ -72,30 +86,53 @@ private async Task HandleBlobCreateEventAsync(BlobClient blobClient)
await this.ProcessCurrentBatchAsync(blobClient, cosmosDBClientWrapper, batchChunkTexts);
}

this._logger.LogInformation("Created total chunks: '{documentCount}' of document.", totalChunksCount);
this._logger.LogInformation("Finished processing blob {name}, total chunks processed {count}.", blobClient.Name, totalChunksCount);
}

private async Task ProcessCurrentBatchAsync(BlobClient blobClient, CosmosDBClientWrapper cosmosDBClientWrapper, List<TextChunk> batchChunkTexts)
{
this._logger.LogInformation("Creating Cosmos DB documents for batch of size {count}", batchChunkTexts.Count);
this._logger.LogInformation("Generating embeddings for : '{count}'.", batchChunkTexts.Count());
var embeddings = await this.GenerateEmbeddingsWithRetryAsync(batchChunkTexts);

int embeddingDimensions = DefaultDimensions;
if (configuration != null &&
!string.IsNullOrWhiteSpace(configuration[AzureOpenAIModelDeploymentDimensionsName]) &&
int.TryParse(configuration[AzureOpenAIModelDeploymentDimensionsName], out int inputDimensions))
{
embeddingDimensions = inputDimensions;
this._logger.LogInformation("Using OpenAI model dimensions: '{embeddingDimensions}'.", embeddingDimensions);
}
this._logger.LogInformation("Creating Cosmos DB documents for batch of size {count}", batchChunkTexts.Count);
await cosmosDBClientWrapper.UpsertDocumentsAsync(blobClient.Uri.AbsoluteUri, batchChunkTexts, embeddings);
}

private async Task<EmbeddingCollection> GenerateEmbeddingsWithRetryAsync(IEnumerable<TextChunk> batchChunkTexts)
{
EmbeddingGenerationOptions embeddingGenerationOptions = new EmbeddingGenerationOptions()
{
Dimensions = embeddingDimensions
Dimensions = this.embeddingDimensions
};
var embeddings = await embeddingClient.GenerateEmbeddingsAsync(batchChunkTexts.Select(p => p.Text).ToList(), embeddingGenerationOptions);
await cosmosDBClientWrapper.UpsertDocumentsAsync(blobClient.Uri.AbsoluteUri, batchChunkTexts, embeddings);

batchChunkTexts.Clear();
int retryCount = 0;
while (retryCount < MaxRetryCount)
{
try
{
return await embeddingClient.GenerateEmbeddingsAsync(batchChunkTexts.Select(p => p.Text).ToList(), embeddingGenerationOptions);
}
catch (ClientResultException ex)
{
if (ex.Status is ((int)HttpStatusCode.TooManyRequests) or ((int)HttpStatusCode.Unauthorized))
{
if (retryCount >= MaxRetryCount)
{
throw new Exception($"Max retry attempts reached generating embeddings with exception: {ex}.");
}

retryCount++;

await Task.Delay(RetryDelay);
}
else
{
throw new Exception($"Failed to generate embeddings with error: {ex}.");
}
}
}

throw new Exception($"Failed to generate embeddings after retrying for ${MaxRetryCount} times.");
}

private async Task HandleBlobDeleteEventAsync(BlobClient blobClient)
Expand All @@ -106,4 +143,3 @@ private async Task HandleBlobDeleteEventAsync(BlobClient blobClient)
await Task.Delay(1);
}
}

36 changes: 34 additions & 2 deletions DocumentVectorPipelineFunctions/CosmosDBClientWrapper.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using System.Globalization;
using System.Net;
using System.Text.Json.Serialization;
using Microsoft.Azure.Cosmos;
using Microsoft.Extensions.Logging;
using OpenAI.Embeddings;
using Container = Microsoft.Azure.Cosmos.Container;

namespace BlobStorageTriggeredFunction;
namespace DocumentVectorPipelineFunctions;

internal class CosmosDBClientWrapper
{
Expand All @@ -15,6 +16,8 @@ internal class CosmosDBClientWrapper

private static CosmosDBClientWrapper? instance;

private const int MaxRetryCount = 100;

public static async ValueTask<CosmosDBClientWrapper> CreateInstance(CosmosClient client, ILogger logger)
{
if (instance != null)
Expand Down Expand Up @@ -48,7 +51,7 @@ public async Task UpsertDocumentsAsync(string fileUri, List<TextChunk> chunks, E
ChunkText = chunks[index].Text,
PageNumber = chunks[index].PageNumberIfKnown,
};
upsertTasks.Add(this.container.UpsertItemAsync(documentChunk));
upsertTasks.Add(this.UpsertDocumentWithRetryAsync(documentChunk, CosmosDBClientWrapper.MaxRetryCount));
}

try
Expand All @@ -69,6 +72,35 @@ public async Task UpsertDocumentsAsync(string fileUri, List<TextChunk> chunks, E
}
}

private async Task<ItemResponse<DocumentChunk>> UpsertDocumentWithRetryAsync(DocumentChunk document, int maxRetryAttempts)
{
if (this.container == null)
{
throw new InvalidOperationException("Container is not initialized.");
}

int retryCount = 0;
while (retryCount < maxRetryAttempts)
{
try
{
return await this.container.UpsertItemAsync(document);
}
catch (CosmosException ex) when (ex.StatusCode == HttpStatusCode.TooManyRequests)
{
retryCount++;
await Task.Delay(ex.RetryAfter.GetValueOrDefault());
}
catch (Exception ex)
{
this.logger.LogError($"An error occurred while upserting document with ID {document.ChunkId}: {ex.Message}");
throw;
}
}

throw new Exception($"Max retry attempts reached for document with ID {document.ChunkId}. Operation failed.");
}

private CosmosDBClientWrapper(CosmosClient client, ILogger logger)
{
this.client = client;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

// Taken from https://github.com/Azure/azure-cosmos-dotnet-v3/pull/4332

using System;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.Azure.Cosmos;

namespace Microsoft.Azure.Cosmos;
namespace DocumentVectorPipelineFunctions;

/// <summary>
/// This class provides a default implementation of System.Text.Json Cosmos Linq Serializer.
Expand Down Expand Up @@ -120,11 +119,8 @@ public override Stream ToStream<T>(T input)
}

JsonPropertyNameAttribute? jsonPropertyNameAttribute = memberInfo.GetCustomAttribute<JsonPropertyNameAttribute>(true);
if (jsonPropertyNameAttribute is { } && !string.IsNullOrEmpty(jsonPropertyNameAttribute.Name))
{
return jsonPropertyNameAttribute.Name;
}

return memberInfo.Name;
return jsonPropertyNameAttribute is { } && !string.IsNullOrEmpty(jsonPropertyNameAttribute.Name)
? jsonPropertyNameAttribute.Name
: memberInfo.Name;
}
}
52 changes: 22 additions & 30 deletions DocumentVectorPipelineFunctions/Program.cs
Original file line number Diff line number Diff line change
@@ -1,43 +1,32 @@
using System.ClientModel.Primitives;
using System.Text.Json;
using Azure;
using Azure.AI.FormRecognizer.DocumentAnalysis;
using Azure.AI.OpenAI;
using Azure.Core;
using Azure.Core.Pipeline;
using Azure.Identity;
using BlobStorageTriggeredFunction;
using DocumentVectorPipelineFunctions;
using Microsoft.Azure.Cosmos;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using OpenAI.Embeddings;

const string AzureDocumentIntelligenceEndpointConfigName = "AzureDocumentIntelligenceEndpoint";
const string AzureDocumentIntelligenceApiKeyConfigName = "AzureDocumentIntelligenceApiKey";
const string AzureCosmosDBConnectionStringConfigName = "AzureCosmosDBConnectionString";
const string AzureOpenAIEndpointConfigName = "AzureOpenAIEndpoint";
const string AzureOpenAIApiKeyConfigName = "AzureOpenAIApiKey";
const string AzureDocumentIntelligenceEndpointConfigName = "AzureDocumentIntelligenceConnectionString";
const string AzureCosmosDBConnectionString = "AzureCosmosDBConnectionString";
const string AzureOpenAIConnectionString = "AzureOpenAIConnectionString";
const string AzureOpenAIModelDeploymentConfigName = "AzureOpenAIModelDeployment";

string? keyVaultUri = Environment.GetEnvironmentVariable("AzureKeyVaultEndpoint");
if (string.IsNullOrWhiteSpace(keyVaultUri))
{
throw new InvalidOperationException("Set environment variable 'AzureKeyVaultEndpoint' to run.");
}

string? managedIdentityClientId = Environment.GetEnvironmentVariable("AzureManagedIdentityClientId");
bool local = Convert.ToBoolean(Environment.GetEnvironmentVariable("RunningLocally"));

TokenCredential credential = local
? new DefaultAzureCredential()
: new ManagedIdentityCredential(clientId: managedIdentityClientId);

var hostBuilder = new HostBuilder()
.ConfigureFunctionsWorkerDefaults()
.ConfigureAppConfiguration(config =>
{
TokenCredential credential = local
? new DefaultAzureCredential()
: new ManagedIdentityCredential(clientId: managedIdentityClientId);

config.AddAzureKeyVault(new Uri(keyVaultUri), credential);
config.AddUserSecrets<BlobTriggerFunction>(optional: true, reloadOnChange: false);
});

Expand All @@ -46,16 +35,19 @@
sc.AddSingleton<DocumentAnalysisClient>(sp =>
{
var config = sp.GetRequiredService<IConfiguration>();
var documentIntelligenceEndpoint = config[AzureDocumentIntelligenceEndpointConfigName] ?? throw new Exception($"Configure {AzureDocumentIntelligenceEndpointConfigName}");
var documentAnalysisClient = new DocumentAnalysisClient(
new Uri(config[AzureDocumentIntelligenceEndpointConfigName] ?? throw new Exception($"Configure {AzureDocumentIntelligenceEndpointConfigName}")),
new AzureKeyCredential(config[AzureDocumentIntelligenceApiKeyConfigName] ?? throw new Exception($"Configure {AzureDocumentIntelligenceApiKeyConfigName}")));
new Uri(documentIntelligenceEndpoint),
credential);
return documentAnalysisClient;
});
sc.AddSingleton<CosmosClient>(sp =>
{
var config = sp.GetRequiredService<IConfiguration>();
var cosmosdbEndpoint = config[AzureCosmosDBConnectionString] ?? throw new Exception($"Configure {AzureCosmosDBConnectionString}");
var cosmosClient = new CosmosClient(
config[AzureCosmosDBConnectionStringConfigName] ?? throw new Exception($"Configure {AzureCosmosDBConnectionStringConfigName}"),
cosmosdbEndpoint,
credential,
new CosmosClientOptions
{
ApplicationName = "document ingestion",
Expand All @@ -67,16 +59,16 @@
sc.AddSingleton<EmbeddingClient>(sp =>
{
var config = sp.GetRequiredService<IConfiguration>();
var openAIEndpoint = config[AzureOpenAIConnectionString] ?? throw new Exception($"Configure {AzureCosmosDBConnectionString}");
// TODO: Implement a custom retry policy that takes the retry-after header into account.
amisi01 marked this conversation as resolved.
Show resolved Hide resolved
var options = new AzureOpenAIClientOptions()
{
ApplicationId = "DocumentIngestion",
RetryPolicy = new ClientRetryPolicy(maxRetries: 10),
};
var azureOpenAIClient = new AzureOpenAIClient(
new Uri(config[AzureOpenAIEndpointConfigName] ?? throw new Exception($"Configure {AzureOpenAIEndpointConfigName}")),
new AzureKeyCredential(config[AzureOpenAIApiKeyConfigName] ?? throw new Exception($"Configure {AzureOpenAIApiKeyConfigName}")),
options);
new Uri(openAIEndpoint),
credential,
new AzureOpenAIClientOptions()
{
ApplicationId = "DocumentIngestion",
RetryPolicy = new ClientRetryPolicy(maxRetries: 10),
});
return azureOpenAIClient.GetEmbeddingClient(config[AzureOpenAIModelDeploymentConfigName] ?? throw new Exception($"Configure {AzureOpenAIModelDeploymentConfigName}"));
});
});
Expand Down
5 changes: 2 additions & 3 deletions DocumentVectorPipelineFunctions/TextChunker.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
using System.Text;
using System.Text.RegularExpressions;
using Azure.AI.FormRecognizer.DocumentAnalysis;

namespace BlobStorageTriggeredFunction;
namespace DocumentVectorPipelineFunctions;

record struct TextChunk(
internal record struct TextChunk(
string Text,
int PageNumberIfKnown,
int ChunkNumber);
Expand Down
Loading