diff --git a/app/backend/prepdocslib/blobmanager.py b/app/backend/prepdocslib/blobmanager.py index 78b9d5b0b9..eefbad5efe 100644 --- a/app/backend/prepdocslib/blobmanager.py +++ b/app/backend/prepdocslib/blobmanager.py @@ -3,15 +3,16 @@ import logging import os import re -from typing import List, Optional, Union, NamedTuple, Tuple +from enum import Enum +from typing import List, Optional, Union import fitz # type: ignore from azure.core.credentials_async import AsyncTokenCredential from azure.storage.blob import ( + BlobClient, BlobSasPermissions, UserDelegationKey, - generate_blob_sas, - BlobClient + generate_blob_sas, ) from azure.storage.blob.aio import BlobServiceClient, ContainerClient from PIL import Image, ImageDraw, ImageFont @@ -21,6 +22,7 @@ logger = logging.getLogger("scripts") + class BlobManager: """ Class to manage uploading and deleting blobs containing citation information from a blob storage account @@ -45,58 +47,60 @@ def __init__( self.subscriptionId = subscriptionId self.user_delegation_key: Optional[UserDelegationKey] = None - #async def upload_blob(self, file: File, container_client:ContainerClient) -> Optional[List[str]]: - - async def _create_new_blob(self, file: File, container_client:ContainerClient) -> BlobClient: + async def _create_new_blob(self, file: File, container_client: ContainerClient) -> BlobClient: with open(file.content.name, "rb") as reopened_file: - blob_name = BlobManager.blob_name_from_file_name(file.content.name) - logger.info("Uploading blob for whole file -> %s", blob_name) - return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata) + blob_name = BlobManager.blob_name_from_file_name(file.content.name) + logger.info("Uploading blob for whole file -> %s", blob_name) + return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata) - async def _file_blob_update_needed(self, blob_client: BlobClient, file : File) -> bool: - md5_check : int = 0 # 0= not done, 1, positive,. 2 negative + async def _file_blob_update_needed(self, blob_client: BlobClient, file: File) -> bool: # Get existing blob properties blob_properties = await blob_client.get_blob_properties() blob_metadata = blob_properties.metadata - + # Check if the md5 values are the same - file_md5 = file.metadata.get('md5') - blob_md5 = blob_metadata.get('md5') - - # Remove md5 from file metadata if it matches the blob metadata - if file_md5 and file_md5 != blob_md5: - return True - else: - return False - + file_md5 = file.metadata.get("md5") + blob_md5 = blob_metadata.get("md5") + + # If the file has an md5 value, check if it is different from the blob + return file_md5 and file_md5 != blob_md5 + async def upload_blob(self, file: File) -> Optional[List[str]]: async with BlobServiceClient( account_url=self.endpoint, credential=self.credential, max_single_put_size=4 * 1024 * 1024 ) as service_client, service_client.get_container_client(self.container) as container_client: if not await container_client.exists(): await container_client.create_container() - - # Re-open and upload the original file - md5_check : int = 0 # 0= not done, 1, positive,. 2 negative - - # upload the file local storage zu azure storage + + # Re-open and upload the original file if the blob does not exist or the md5 values do not match + class MD5Check(Enum): + NOT_DONE = 0 + MATCH = 1 + NO_MATCH = 2 + + md5_check = MD5Check.NOT_DONE + + # Upload the file to Azure Storage # file.url is only None if files are not uploaded yet, for datalake it is set if file.url is None: blob_client = container_client.get_blob_client(file.url) if not await blob_client.exists(): + logger.info("Blob %s does not exist, uploading", file.url) blob_client = await self._create_new_blob(file, container_client) else: if self._blob_update_needed(blob_client, file): - md5_check = 2 + logger.info("Blob %s exists but md5 values do not match, updating", file.url) + md5_check = MD5Check.NO_MATCH # Upload the file with the updated metadata with open(file.content.name, "rb") as data: await blob_client.upload_blob(data, overwrite=True, metadata=file.metadata) else: - md5_check = 1 + logger.info("Blob %s exists and md5 values match, skipping upload", file.url) + md5_check = MD5Check.MATCH file.url = blob_client.url - - if md5_check!=1 and self.store_page_images: + + if md5_check != MD5Check.MATCH and self.store_page_images: if os.path.splitext(file.content.name)[1].lower() == ".pdf": return await self.upload_pdf_blob_images(service_client, container_client, file) else: @@ -127,20 +131,19 @@ async def upload_pdf_blob_images( for i in range(page_count): blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i) - + blob_client = container_client.get_blob_client(blob_name) - do_upload : bool = True if await blob_client.exists(): # Get existing blob properties blob_properties = await blob_client.get_blob_properties() blob_metadata = blob_properties.metadata - + # Check if the md5 values are the same - file_md5 = file.metadata.get('md5') - blob_md5 = blob_metadata.get('md5') + file_md5 = file.metadata.get("md5") + blob_md5 = blob_metadata.get("md5") if file_md5 == blob_md5: - continue # documemt already uploaded - + continue # documemt already uploaded + logger.debug("Converting page %s to image and uploading -> %s", i, blob_name) doc = fitz.open(file.content.name) @@ -167,7 +170,7 @@ async def upload_pdf_blob_images( output = io.BytesIO() new_img.save(output, format="PNG") output.seek(0) - + await blob_client.upload_blob(data=output, overwrite=True, metadata=file.metadata) if not self.user_delegation_key: self.user_delegation_key = await service_client.get_user_delegation_key(start_time, expiry_time) @@ -181,7 +184,7 @@ async def upload_pdf_blob_images( permission=BlobSasPermissions(read=True), expiry=expiry_time, start=start_time, - ) + ) sas_uris.append(f"{blob_client.url}?{sas_token}") return sas_uris diff --git a/app/backend/prepdocslib/filestrategy.py b/app/backend/prepdocslib/filestrategy.py index 69183f3220..fbef639ff2 100644 --- a/app/backend/prepdocslib/filestrategy.py +++ b/app/backend/prepdocslib/filestrategy.py @@ -1,10 +1,6 @@ import logging -import asyncio -from concurrent.futures import ThreadPoolExecutor from typing import List, Optional -from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional -from tqdm.asyncio import tqdm + from .blobmanager import BlobManager from .embeddings import ImageEmbeddings, OpenAIEmbeddings from .fileprocessor import FileProcessor @@ -36,6 +32,7 @@ async def parse_file( ] return sections + class FileStrategy(Strategy): """ Strategy for ingesting documents into a search service from files stored either locally or in a data lake storage account @@ -96,7 +93,9 @@ async def run(self): blob_image_embeddings: Optional[List[List[float]]] = None if self.image_embeddings and blob_sas_uris: blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) - await search_manager.update_content(sections=sections, file=file, image_embeddings=blob_image_embeddings) + await search_manager.update_content( + sections=sections, file=file, image_embeddings=blob_image_embeddings + ) finally: if file: file.close() @@ -128,7 +127,9 @@ async def process_file(self, file, search_manager): blob_image_embeddings: Optional[List[List[float]]] = None if self.image_embeddings and blob_sas_uris: blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) - await search_manager.update_content(sections=sections, file=file, image_embeddings=blob_image_embeddings) + await search_manager.update_content( + sections=sections, file=file, image_embeddings=blob_image_embeddings + ) finally: if file: file.close() diff --git a/app/backend/prepdocslib/listfilestrategy.py b/app/backend/prepdocslib/listfilestrategy.py index 8d232e0b54..eed61d452e 100644 --- a/app/backend/prepdocslib/listfilestrategy.py +++ b/app/backend/prepdocslib/listfilestrategy.py @@ -1,5 +1,3 @@ -from azure.storage.filedatalake import DataLakeServiceClient -from azure.storage.blob import BlobServiceClient import base64 import hashlib import logging @@ -10,12 +8,10 @@ from glob import glob from typing import IO, AsyncGenerator, Dict, List, Optional, Union -from azure.identity import DefaultAzureCredential - from azure.core.credentials_async import AsyncTokenCredential -from azure.storage.filedatalake.aio import ( - DataLakeServiceClient, -) +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from azure.storage.filedatalake.aio import DataLakeServiceClient logger = logging.getLogger("scripts") @@ -26,11 +22,17 @@ class File: This file might contain access control information about which users or groups can access it """ - def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None, metadata : Dict[str, str]= None): + def __init__( + self, + content: IO, + acls: Optional[dict[str, list]] = None, + url: Optional[str] = None, + metadata: Dict[str, str] = None, + ): self.content = content self.acls = acls or {} self.url = url - self.metadata = metadata + self.metadata = metadata def filename(self): return os.path.basename(self.content.name) @@ -63,11 +65,12 @@ async def list(self) -> AsyncGenerator[File, None]: async def list_paths(self) -> AsyncGenerator[str, None]: if False: # pragma: no cover - this is necessary for mypy to type check yield - + def count_docs(self) -> int: if False: # pragma: no cover - this is necessary for mypy to type check yield + class LocalListFileStrategy(ListFileStrategy): """ Concrete strategy for listing files that are located in a local filesystem @@ -117,7 +120,6 @@ def check_md5(self, path: str) -> bool: md5_f.write(existing_hash) return False - def count_docs(self) -> int: """ @@ -135,6 +137,7 @@ def _list_paths_sync(self, path_pattern: str): else: yield path + class ADLSGen2ListFileStrategy(ListFileStrategy): """ Concrete strategy for listing files that are located in a data lake storage account @@ -191,9 +194,11 @@ async def list(self) -> AsyncGenerator[File, None]: if acl_parts[0] == "user" and "r" in acl_parts[2]: acls["oids"].append(acl_parts[1]) if acl_parts[0] == "group" and "r" in acl_parts[2]: - acls["groups"].append(acl_parts[1]) + acls["groups"].append(acl_parts[1]) properties = await file_client.get_file_properties() - yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata) + yield File( + content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata + ) except Exception as data_lake_exception: logger.error(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file") try: @@ -205,18 +210,18 @@ def count_docs(self) -> int: """ Return the number of blobs in the specified folder within the Azure Blob Storage container. """ - + # Create a BlobServiceClient using account URL and credentials service_client = BlobServiceClient( account_url=f"https://{self.data_lake_storage_account}.blob.core.windows.net", - credential=DefaultAzureCredential()) + credential=DefaultAzureCredential(), + ) # Get the container client container_client = service_client.get_container_client(self.data_lake_filesystem) # Count blobs within the specified folder if self.data_lake_path != "/": - return sum(1 for _ in container_client.list_blobs(name_starts_with= self.data_lake_path)) + return sum(1 for _ in container_client.list_blobs(name_starts_with=self.data_lake_path)) else: return sum(1 for _ in container_client.list_blobs()) - diff --git a/app/backend/prepdocslib/searchmanager.py b/app/backend/prepdocslib/searchmanager.py index 15f7d836d8..47945786a5 100644 --- a/app/backend/prepdocslib/searchmanager.py +++ b/app/backend/prepdocslib/searchmanager.py @@ -1,10 +1,10 @@ import asyncio import datetime -import dateutil.parser as parser import logging import os -from typing import Dict, List, Optional +from typing import List, Optional +import dateutil.parser as parser from azure.search.documents.indexes.models import ( AzureOpenAIVectorizer, AzureOpenAIVectorizerParameters, @@ -102,22 +102,10 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] vector_search_dimensions=self.embedding_dimensions, vector_search_profile_name="embedding_config", ), - SimpleField(name="category", - type="Edm.String", - filterable=True, - facetable=True), - SimpleField(name="md5", - type="Edm.String", - filterable=True, - facetable=True), - SimpleField(name="deeplink", - type="Edm.String", - filterable=True, - facetable=False), - SimpleField(name="updated", - type="Edm.DateTimeOffset", - filterable=True, - facetable=True), + SimpleField(name="category", type="Edm.String", filterable=True, facetable=True), + SimpleField(name="md5", type="Edm.String", filterable=True, facetable=True), + SimpleField(name="deeplink", type="Edm.String", filterable=True, facetable=False), + SimpleField(name="updated", type="Edm.DateTimeOffset", filterable=True, facetable=True), SimpleField( name="sourcepage", type="Edm.String", @@ -239,7 +227,11 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] # Check and add missing fields missing_fields = [field for field in fields if field.name not in existing_field_names] if missing_fields: - logger.info("Adding missing fields to index %s: %s", self.search_info.index_name, [field.name for field in missing_fields]) + logger.info( + "Adding missing fields to index %s: %s", + self.search_info.index_name, + [field.name for field in missing_fields], + ) existing_index.fields.extend(missing_fields) await search_index_client.create_or_update_index(existing_index) @@ -266,17 +258,17 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] self.search_info, ) - async def file_exists(self, file : File ) -> bool: + async def file_exists(self, file: File) -> bool: async with self.search_info.create_search_client() as search_client: ## make sure that we don't update unchanged sections, by if sourcefile and md5 are the same - if file.metadata.get('md5')!= None: + if file.metadata.get("md5") is not None: filter = None - assert file.filename() is not None + assert file.filename() is not None filter = f"sourcefile eq '{str(file.filename())}' and md5 eq '{file.metadata.get('md5')}'" - + # make sure (when applicable) that we don't skip if different categories have same file.filename() - #TODO: refactoring: check if using file.filename() as primary for blob is a good idea, or better use sha256(instead as md5) as reliable for blob and index primary key - if file.metadata.get('category') is not None: + # TODO: refactoring: check if using file.filename() as primary for blob is a good idea, or better use sha256(instead as md5) as reliable for blob and index primary key + if file.metadata.get("category") is not None: filter = filter + f" and category eq '{file.metadata.get('category')}'" max_results = 1 result = await search_client.search( @@ -285,32 +277,35 @@ async def file_exists(self, file : File ) -> bool: result_count = await result.get_count() if result_count > 0: logger.debug("Skipping %s, no changes detected.", file.filename()) - return True + return True else: return False ## -- end of check async def update_content( - self, sections: List[Section], file : File ,image_embeddings: Optional[List[List[float]]] = None): + self, sections: List[Section], file: File, image_embeddings: Optional[List[List[float]]] = None + ): MAX_BATCH_SIZE = 1000 section_batches = [sections[i : i + MAX_BATCH_SIZE] for i in range(0, len(sections), MAX_BATCH_SIZE)] async with self.search_info.create_search_client() as search_client: - + ## caluclate a (default) updated timestamp in format of index - if file.metadata.get('updated') is None: - docdate = datetime.now(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' + if file.metadata.get("updated") is None: + docdate = datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" else: - docdate = parser.isoparse(file.metadata.get('updated')).strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' - + docdate = parser.isoparse(file.metadata.get("updated")).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + for batch_index, batch in enumerate(section_batches): documents = [ { "id": f"{section.content.filename_to_id()}-page-{section_index + batch_index * MAX_BATCH_SIZE}", "content": section.split_page.text, - "category": file.metadata.get('category'), - "md5": file.metadata.get('md5'), - "deeplink": file.metadata.get('deeplink'), # optional deel link original doc source for citiation,inline view + "category": file.metadata.get("category"), + "md5": file.metadata.get("md5"), + "deeplink": file.metadata.get( + "deeplink" + ), # optional deel link original doc source for citiation,inline view "updated": docdate, "sourcepage": ( BlobManager.blob_image_name_from_file_page( diff --git a/scripts/adlsgen2setup.py b/scripts/adlsgen2setup.py index 90a12eca86..10a5cd59a0 100644 --- a/scripts/adlsgen2setup.py +++ b/scripts/adlsgen2setup.py @@ -1,10 +1,10 @@ import argparse import asyncio -from datetime import datetime +import hashlib import json import logging import os -import hashlib +from datetime import datetime from typing import Any, Optional import aiohttp @@ -20,7 +20,8 @@ logger = logging.getLogger("scripts") # Set the logging level for the azure package to DEBUG logging.getLogger("azure").setLevel(logging.DEBUG) -logging.getLogger('azure.core.pipeline.policies.http_logging_policy').setLevel(logging.DEBUG) +logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.DEBUG) + class AdlsGen2Setup: """ @@ -94,7 +95,9 @@ async def run(self, scandirs: bool = False): if directory not in directories: logger.error(f"File {file} has unknown directory {directory}, exiting...") return - await self.upload_file(directory_client=directories[directory], file_path=os.path.join(self.data_directory, file)) + await self.upload_file( + directory_client=directories[directory], file_path=os.path.join(self.data_directory, file) + ) logger.info("Setting access control...") for directory, access_control in self.data_access_control_format["directories"].items(): @@ -106,7 +109,8 @@ async def run(self, scandirs: bool = False): f"Directory {directory} has unknown group {group_name} in access control list, exiting" ) return - await directory_client.update_access_control_recursive(acl=f"group:{groups[group_name]}:r-x" + await directory_client.update_access_control_recursive( + acl=f"group:{groups[group_name]}:r-x" ) if "oids" in access_control: for oid in access_control["oids"]: @@ -115,60 +119,60 @@ async def run(self, scandirs: bool = False): for directory_client in directories.values(): await directory_client.close() - async def walk_files(self, src_filepath = "."): + async def walk_files(self, src_filepath="."): filepath_list = [] - - #This for loop uses the os.walk() function to walk through the files and directories - #and records the filepaths of the files to a list + + # This for loop uses the os.walk() function to walk through the files and directories + # and records the filepaths of the files to a list for root, dirs, files in os.walk(src_filepath): - - #iterate through the files currently obtained by os.walk() and - #create the filepath string for that file and add it to the filepath_list list + + # iterate through the files currently obtained by os.walk() and + # create the filepath string for that file and add it to the filepath_list list root_found: bool = False for file in files: - #Checks to see if the root is '.' and changes it to the correct current - #working directory by calling os.getcwd(). Otherwise root_path will just be the root variable value. - - if not root_found and root == '.': - filepath =os.path.join(os.getcwd() + "/", file) + # Checks to see if the root is '.' and changes it to the correct current + # working directory by calling os.getcwd(). Otherwise root_path will just be the root variable value. + + if not root_found and root == ".": + filepath = os.path.join(os.getcwd() + "/", file) root_found = True else: filepath = os.path.join(root, file) - - #Appends filepath to filepath_list if filepath does not currently exist in filepath_list + + # Appends filepath to filepath_list if filepath does not currently exist in filepath_list if filepath not in filepath_list: - filepath_list.append(filepath) - - #Return filepath_list + filepath_list.append(filepath) + + # Return filepath_list return filepath_list async def scan_and_upload_directories(self, directories: dict[str, DataLakeDirectoryClient], filesystem_client): logger.info("Scanning and uploading files from directories recursively...") - + for directory, directory_client in directories.items(): directory_path = os.path.join(self.data_directory, directory) if directory == "/": continue - - # Check if 'scandir' exists and is set to False + + # Check if 'scandir' exists and is set to False if not self.data_access_control_format["directories"][directory].get("scandir", True): logger.info(f"Skipping directory {directory} as 'scandir' is set to False") continue - + # Check if the directory exists before walking it if not os.path.exists(directory_path): logger.warning(f"Directory does not exist: {directory_path}") continue - + # Get all file paths using the walk_files function file_paths = await self.walk_files(directory_path) # Upload each file collected - count =0 + count = 0 num = len(file_paths) for file_path in file_paths: await self.upload_file(directory_client, file_path, directory) - count=+1 + count = +1 logger.info(f"Uploaded [{count}/{num}] {directory}/{file_path}") def create_service_client(self): @@ -190,11 +194,11 @@ async def get_blob_md5(self, directory_client: DataLakeDirectoryClient, filename file_client = directory_client.get_file_client(filename) try: properties = await file_client.get_file_properties() - return properties.metadata.get('md5') + return properties.metadata.get("md5") except Exception as e: logger.error(f"Error getting blob properties for {filename}: {e}") return None - + async def upload_file(self, directory_client: DataLakeDirectoryClient, file_path: str, category: str = ""): # Calculate MD5 hash once md5_hash = await self.calc_md5(file_path) @@ -212,12 +216,7 @@ async def upload_file(self, directory_client: DataLakeDirectoryClient, file_path tmtime = os.path.getmtime(file_path) last_modified = datetime.fromtimestamp(tmtime).isoformat() title = os.path.splitext(filename)[0] - metadata = { - "md5": md5_hash, - "category": category, - "updated": last_modified, - "title": title - } + metadata = {"md5": md5_hash, "category": category, "updated": last_modified, "title": title} await file_client.upload_data(f, overwrite=True) await file_client.set_metadata(metadata) logger.info(f"Uploaded and updated metadata for {filename}") @@ -248,7 +247,6 @@ async def create_or_get_group(self, group_name: str): # If Unified does not work for you, then you may need the following settings instead: # "mailEnabled": False, # "mailNickname": group_name, - } async with session.post("https://graph.microsoft.com/v1.0/groups", json=group) as response: content = await response.json() @@ -270,7 +268,7 @@ async def main(args: Any): data_access_control_format = json.load(f) command = AdlsGen2Setup( data_directory=args.data_directory, - storage_account_name=os.environ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT"], + storage_account_name=os.environ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT"], filesystem_name=os.environ["AZURE_ADLS_GEN2_FILESYSTEM"], security_enabled_groups=args.create_security_enabled_groups, credentials=credentials, @@ -295,10 +293,12 @@ async def main(args: Any): "--data-access-control", required=True, help="JSON file describing access control for the sample data" ) parser.add_argument("--verbose", "-v", required=False, action="store_true", help="Verbose output") - parser.add_argument("--scandirs", required=False, action="store_true", help="Scan and upload all files from directories recursively") + parser.add_argument( + "--scandirs", required=False, action="store_true", help="Scan and upload all files from directories recursively" + ) args = parser.parse_args() if args.verbose: logging.basicConfig() - logging.getLogger().setLevel(logging.INFO) + logging.getLogger().setLevel(logging.INFO) asyncio.run(main(args))