Skip to content

Commit

Permalink
Addressed some of my comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pamelafox committed Nov 20, 2024
1 parent 9d2dbf1 commit 4ef5622
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 139 deletions.
81 changes: 42 additions & 39 deletions app/backend/prepdocslib/blobmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import logging
import os
import re
from typing import List, Optional, Union, NamedTuple, Tuple
from enum import Enum
from typing import List, Optional, Union

import fitz # type: ignore
from azure.core.credentials_async import AsyncTokenCredential
from azure.storage.blob import (
BlobClient,
BlobSasPermissions,
UserDelegationKey,
generate_blob_sas,
BlobClient
generate_blob_sas,
)
from azure.storage.blob.aio import BlobServiceClient, ContainerClient
from PIL import Image, ImageDraw, ImageFont
Expand All @@ -21,6 +22,7 @@

logger = logging.getLogger("scripts")


class BlobManager:
"""
Class to manage uploading and deleting blobs containing citation information from a blob storage account
Expand All @@ -45,58 +47,60 @@ def __init__(
self.subscriptionId = subscriptionId
self.user_delegation_key: Optional[UserDelegationKey] = None

#async def upload_blob(self, file: File, container_client:ContainerClient) -> Optional[List[str]]:

async def _create_new_blob(self, file: File, container_client:ContainerClient) -> BlobClient:
async def _create_new_blob(self, file: File, container_client: ContainerClient) -> BlobClient:
with open(file.content.name, "rb") as reopened_file:
blob_name = BlobManager.blob_name_from_file_name(file.content.name)
logger.info("Uploading blob for whole file -> %s", blob_name)
return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata)
blob_name = BlobManager.blob_name_from_file_name(file.content.name)
logger.info("Uploading blob for whole file -> %s", blob_name)
return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata)

async def _file_blob_update_needed(self, blob_client: BlobClient, file : File) -> bool:
md5_check : int = 0 # 0= not done, 1, positive,. 2 negative
async def _file_blob_update_needed(self, blob_client: BlobClient, file: File) -> bool:
# Get existing blob properties
blob_properties = await blob_client.get_blob_properties()
blob_metadata = blob_properties.metadata

# Check if the md5 values are the same
file_md5 = file.metadata.get('md5')
blob_md5 = blob_metadata.get('md5')

# Remove md5 from file metadata if it matches the blob metadata
if file_md5 and file_md5 != blob_md5:
return True
else:
return False

file_md5 = file.metadata.get("md5")
blob_md5 = blob_metadata.get("md5")

# If the file has an md5 value, check if it is different from the blob
return file_md5 and file_md5 != blob_md5

async def upload_blob(self, file: File) -> Optional[List[str]]:
async with BlobServiceClient(
account_url=self.endpoint, credential=self.credential, max_single_put_size=4 * 1024 * 1024
) as service_client, service_client.get_container_client(self.container) as container_client:
if not await container_client.exists():
await container_client.create_container()

# Re-open and upload the original file
md5_check : int = 0 # 0= not done, 1, positive,. 2 negative

# upload the file local storage zu azure storage

# Re-open and upload the original file if the blob does not exist or the md5 values do not match
class MD5Check(Enum):
NOT_DONE = 0
MATCH = 1
NO_MATCH = 2

md5_check = MD5Check.NOT_DONE

# Upload the file to Azure Storage
# file.url is only None if files are not uploaded yet, for datalake it is set
if file.url is None:
blob_client = container_client.get_blob_client(file.url)

if not await blob_client.exists():
logger.info("Blob %s does not exist, uploading", file.url)
blob_client = await self._create_new_blob(file, container_client)
else:
if self._blob_update_needed(blob_client, file):
md5_check = 2
logger.info("Blob %s exists but md5 values do not match, updating", file.url)
md5_check = MD5Check.NO_MATCH
# Upload the file with the updated metadata
with open(file.content.name, "rb") as data:
await blob_client.upload_blob(data, overwrite=True, metadata=file.metadata)
else:
md5_check = 1
logger.info("Blob %s exists and md5 values match, skipping upload", file.url)
md5_check = MD5Check.MATCH
file.url = blob_client.url
if md5_check!=1 and self.store_page_images:

if md5_check != MD5Check.MATCH and self.store_page_images:
if os.path.splitext(file.content.name)[1].lower() == ".pdf":
return await self.upload_pdf_blob_images(service_client, container_client, file)
else:
Expand Down Expand Up @@ -127,20 +131,19 @@ async def upload_pdf_blob_images(

for i in range(page_count):
blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i)

blob_client = container_client.get_blob_client(blob_name)
do_upload : bool = True
if await blob_client.exists():
# Get existing blob properties
blob_properties = await blob_client.get_blob_properties()
blob_metadata = blob_properties.metadata

# Check if the md5 values are the same
file_md5 = file.metadata.get('md5')
blob_md5 = blob_metadata.get('md5')
file_md5 = file.metadata.get("md5")
blob_md5 = blob_metadata.get("md5")
if file_md5 == blob_md5:
continue # documemt already uploaded
continue # documemt already uploaded

logger.debug("Converting page %s to image and uploading -> %s", i, blob_name)

doc = fitz.open(file.content.name)
Expand All @@ -167,7 +170,7 @@ async def upload_pdf_blob_images(
output = io.BytesIO()
new_img.save(output, format="PNG")
output.seek(0)

await blob_client.upload_blob(data=output, overwrite=True, metadata=file.metadata)
if not self.user_delegation_key:
self.user_delegation_key = await service_client.get_user_delegation_key(start_time, expiry_time)
Expand All @@ -181,7 +184,7 @@ async def upload_pdf_blob_images(
permission=BlobSasPermissions(read=True),
expiry=expiry_time,
start=start_time,
)
)
sas_uris.append(f"{blob_client.url}?{sas_token}")

return sas_uris
Expand Down
15 changes: 8 additions & 7 deletions app/backend/prepdocslib/filestrategy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from tqdm.asyncio import tqdm

from .blobmanager import BlobManager
from .embeddings import ImageEmbeddings, OpenAIEmbeddings
from .fileprocessor import FileProcessor
Expand Down Expand Up @@ -36,6 +32,7 @@ async def parse_file(
]
return sections


class FileStrategy(Strategy):
"""
Strategy for ingesting documents into a search service from files stored either locally or in a data lake storage account
Expand Down Expand Up @@ -96,7 +93,9 @@ async def run(self):
blob_image_embeddings: Optional[List[List[float]]] = None
if self.image_embeddings and blob_sas_uris:
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris)
await search_manager.update_content(sections=sections, file=file, image_embeddings=blob_image_embeddings)
await search_manager.update_content(
sections=sections, file=file, image_embeddings=blob_image_embeddings
)
finally:
if file:
file.close()
Expand Down Expand Up @@ -128,7 +127,9 @@ async def process_file(self, file, search_manager):
blob_image_embeddings: Optional[List[List[float]]] = None
if self.image_embeddings and blob_sas_uris:
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris)
await search_manager.update_content(sections=sections, file=file, image_embeddings=blob_image_embeddings)
await search_manager.update_content(
sections=sections, file=file, image_embeddings=blob_image_embeddings
)
finally:
if file:
file.close()
Expand Down
39 changes: 22 additions & 17 deletions app/backend/prepdocslib/listfilestrategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from azure.storage.filedatalake import DataLakeServiceClient
from azure.storage.blob import BlobServiceClient
import base64
import hashlib
import logging
Expand All @@ -10,12 +8,10 @@
from glob import glob
from typing import IO, AsyncGenerator, Dict, List, Optional, Union

from azure.identity import DefaultAzureCredential

from azure.core.credentials_async import AsyncTokenCredential
from azure.storage.filedatalake.aio import (
DataLakeServiceClient,
)
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azure.storage.filedatalake.aio import DataLakeServiceClient

logger = logging.getLogger("scripts")

Expand All @@ -26,11 +22,17 @@ class File:
This file might contain access control information about which users or groups can access it
"""

def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None, metadata : Dict[str, str]= None):
def __init__(
self,
content: IO,
acls: Optional[dict[str, list]] = None,
url: Optional[str] = None,
metadata: Dict[str, str] = None,
):
self.content = content
self.acls = acls or {}
self.url = url
self.metadata = metadata
self.metadata = metadata

def filename(self):
return os.path.basename(self.content.name)
Expand Down Expand Up @@ -63,11 +65,12 @@ async def list(self) -> AsyncGenerator[File, None]:
async def list_paths(self) -> AsyncGenerator[str, None]:
if False: # pragma: no cover - this is necessary for mypy to type check
yield

def count_docs(self) -> int:
if False: # pragma: no cover - this is necessary for mypy to type check
yield


class LocalListFileStrategy(ListFileStrategy):
"""
Concrete strategy for listing files that are located in a local filesystem
Expand Down Expand Up @@ -117,7 +120,6 @@ def check_md5(self, path: str) -> bool:
md5_f.write(existing_hash)

return False


def count_docs(self) -> int:
"""
Expand All @@ -135,6 +137,7 @@ def _list_paths_sync(self, path_pattern: str):
else:
yield path


class ADLSGen2ListFileStrategy(ListFileStrategy):
"""
Concrete strategy for listing files that are located in a data lake storage account
Expand Down Expand Up @@ -191,9 +194,11 @@ async def list(self) -> AsyncGenerator[File, None]:
if acl_parts[0] == "user" and "r" in acl_parts[2]:
acls["oids"].append(acl_parts[1])
if acl_parts[0] == "group" and "r" in acl_parts[2]:
acls["groups"].append(acl_parts[1])
acls["groups"].append(acl_parts[1])
properties = await file_client.get_file_properties()
yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata)
yield File(
content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata
)
except Exception as data_lake_exception:
logger.error(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file")
try:
Expand All @@ -205,18 +210,18 @@ def count_docs(self) -> int:
"""
Return the number of blobs in the specified folder within the Azure Blob Storage container.
"""

# Create a BlobServiceClient using account URL and credentials
service_client = BlobServiceClient(
account_url=f"https://{self.data_lake_storage_account}.blob.core.windows.net",
credential=DefaultAzureCredential())
credential=DefaultAzureCredential(),
)

# Get the container client
container_client = service_client.get_container_client(self.data_lake_filesystem)

# Count blobs within the specified folder
if self.data_lake_path != "/":
return sum(1 for _ in container_client.list_blobs(name_starts_with= self.data_lake_path))
return sum(1 for _ in container_client.list_blobs(name_starts_with=self.data_lake_path))
else:
return sum(1 for _ in container_client.list_blobs())

Loading

0 comments on commit 4ef5622

Please sign in to comment.