Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alexandra ai #15

Merged
merged 11 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,5 @@ models/*

# Hydra
.hydra/

*notes.md
4 changes: 2 additions & 2 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ domsdatabasen:
paths:
hf_hub: alexandrainst/domsdatabasen
data_raw_dir: data/raw/
data_processed_dir: data/processed
data_final_dir: data/final
data_processed_dir: data/processed/
data_final_dir: data/final/

file_names:
tabular_data: tabular_data.json
Expand Down
3 changes: 1 addition & 2 deletions config/process/process.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ force: False
case_id: "1"
all: False
start_case_id: "2732"
blacklist_flag: False

# Constants
test_case_id: "1"

page_number: False # Debug a specific page

gpu: False

max_y_difference: 25

neighbor_distance_max: 1
Expand Down
6 changes: 4 additions & 2 deletions src/doms_databasen/_text_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import easyocr
import numpy as np
import skimage
import torch
from img2table.document import Image as TableImage
from img2table.tables.objects.extraction import ExtractedTable, TableCell
from omegaconf import DictConfig
Expand All @@ -19,6 +20,7 @@
from skimage.filters import rank
from skimage.measure._regionprops import RegionProperties
from tika import parser
from tqdm import tqdm

from ._constants import (
BOX_HEIGHT_LOWER_BOUND,
Expand Down Expand Up @@ -48,7 +50,7 @@ class PDFTextReader:
def __init__(self, config: DictConfig):
"""Initialize PDFTextReader."""
self.config = config
self.reader = easyocr.Reader(["da"], gpu=config.process.gpu)
self.reader = easyocr.Reader(["da"], gpu=torch.cuda.is_available())

def extract_text(self, pdf_path: Path) -> dict[Any, Any]:
"""Extracts text from a PDF using easyocr or pypdf.
Expand Down Expand Up @@ -82,7 +84,7 @@ def extract_text(self, pdf_path: Path) -> dict[Any, Any]:
box_anonymization = True
underline_anonymization = True

for i, image in enumerate(images):
for i, image in tqdm(enumerate(images), desc="Reading PDF", total=len(images)):
page_num = str(i + 1)
logger.info(f"Reading page {page_num}")

Expand Down
4 changes: 2 additions & 2 deletions src/doms_databasen/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utility function for the doms_databasen package."""

import json
from typing import Dict, List
from typing import List

import jsonlines

Expand All @@ -19,7 +19,7 @@ def save_dict_to_json(dict_, file_path) -> None:
json.dump(dict_, f, indent=4)


def read_json(file_path) -> Dict[str, str]:
def read_json(file_path) -> dict:
"""Reads a json file.

Args:
Expand Down
162 changes: 162 additions & 0 deletions src/doms_databasen/dataset_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""DatasetBuilder to build the final dataset."""


import re
from logging import getLogger
from pathlib import Path
from typing import Tuple

from omegaconf import DictConfig

from doms_databasen._utils import append_jsonl, init_jsonl, read_json

logger = getLogger(__name__)


class DatasetBuilder:
"""DatasetBuilder to build the final dataset.

Args:
config (DictConfig):
Configuration object.

Attributes:
config (DictConfig):
Configuration object.
data_processed_dir (Path):
Path to processed data directory.
data_final_dir (Path):
Path to final data directory.
dataset_path (Path):
Path to the dataset file.
"""

def __init__(self, config: DictConfig) -> None:
"""Initializes the DatasetBuilder."""
self.config = config
self.data_processed_dir = Path(config.paths.data_processed_dir)
self.data_final_dir = Path(config.paths.data_final_dir)
self.dataset_path = self.data_final_dir / config.file_names.dataset

def build_dataset(self) -> None:
"""Build the final dataset."""
if self.dataset_path.exists() and not self.config.finalize.force:
logger.info(
f"Dataset already exists at {self.dataset_path}."
"Use 'finalize.force=True' to overwrite."
)
return

logger.info("Initializing dataset with path: {dataset_path}")
init_jsonl(file_name=self.dataset_path)

processed_case_paths = [
case_path
for case_path in self.data_processed_dir.iterdir()
if case_path.is_dir()
]
logger.info(
f"Found {len(processed_case_paths)} cases in {self.data_processed_dir}"
)

# Process cases in ascending order
processed_case_paths = sorted(processed_case_paths, key=lambda p: int(p.stem))
for path in processed_case_paths:
logger.info(f"Processing case {path.stem}...")
processed_data = read_json(path / self.config.file_names.processed_data)
dataset_sample = self.make_dataset_sample(processed_data=processed_data)
append_jsonl(data=dataset_sample, file_name=self.dataset_path)

logger.info(f"Dataset saved at {self.dataset_path}")

def make_dataset_sample(self, processed_data: dict) -> dict:
"""Make a dataset sample from processed data.

Args:
processed_data (dict):
Processed data for a case.

Returns:
dataset_sample (dict):
Dataset sample.
"""
dataset_sample = {}
dataset_sample["case_id"] = processed_data["case_id"]
dataset_sample.update(processed_data["tabular_data"])

text, text_anon = self._get_text(
processed_data=processed_data, config=self.config
)
dataset_sample["text"] = text
dataset_sample["text_anonymized"] = text_anon

dataset_sample["text_len"] = len(text)
dataset_sample["text_anon_len"] = len(text_anon)
return dataset_sample

def _get_text(self, processed_data: dict, config: DictConfig) -> Tuple[str, str]:
"""Get `text` and `text_anon` from processed data.

Args:
processed_data (dict):
Processed data for a case.
config (DictConfig):
Configuration object.

Returns:
text (str):
Text extracted from the PDF.
text_anon (str):
Anonymized text.
"""
pdf_data = processed_data["pdf_data"]
if pdf_data["anonymization_method"] == config.anon_method.none:
# PDF has no anonymization.
# Make `text_anon` empty.
# For main `text` use text extracted with Tika.
# If Tika hasn't been able to read any text,
# then use text extracted from each page with easyocr.
if pdf_data["text_tika"]:
text = pdf_data["text_tika"]
else:
text = self._get_text_from_pages(pages=pdf_data["pages"])

text_anon = ""

elif pdf_data["anonymization_method"] == config.anon_method.underline:
# PDF uses underline anonymization.
# Make `text_anon` text extracted from each page.
# If text is extracted with Tika, then
# use that for the `text`,
# else remove anon tags from the anonymized text,
# and use that for `text`.
text_anon = self._get_text_from_pages(pdf_data["pages"])
if pdf_data["text_tika"]:
text = pdf_data["text_tika"]
else:
text = re.sub(r"<anonym.*</anonym>", "", text_anon)

elif pdf_data["anonymization_method"] == config.anon_method.box:
# PDF uses box anonymization
# Make `text_anon` text extracted from each page.
# Remove anon tags from the anonymized text,
# and use that for `text`.
text_anon = self._get_text_from_pages(pdf_data["pages"])
text = text = re.sub(r"<anonym.*</anonym>", "", text_anon)

return text, text_anon

@staticmethod
def _get_text_from_pages(pages: dict) -> str:
"""Get text from pages.

Args:
pages (dict):
Pages with text and extraction method.

Returns:
pdf_text (str):
Text from pages.
"""
pdf_text = "\n\n".join(page["text"] for page in pages.values())
return pdf_text
17 changes: 11 additions & 6 deletions src/doms_databasen/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Dict, List, Union

import torch
from omegaconf import DictConfig

from ._constants import N_FILES_PROCESSED_CASE_DIR, N_FILES_RAW_CASE_DIR
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(self, config: DictConfig) -> None:
)

self.force = self.config.process.force
self.blacklist = self._read_blacklist()
self.blacklist = self._read_blacklist() if config.process.blacklist_flag else []

def process(self, case_id: str) -> Dict[str, Union[str, Dict[str, str]]]:
"""Processes a single case.
Expand All @@ -67,6 +68,8 @@ def process(self, case_id: str) -> Dict[str, Union[str, Dict[str, str]]]:
processed_data (dict):
Processed data (only returned for testing purposes)
"""
processed_data: Dict[str, Union[str, Dict[str, str]]] = {}

case_id = str(case_id)
if case_id in self.blacklist:
logger.info(f"{case_id} is blacklisted.")
Expand All @@ -78,16 +81,19 @@ def process(self, case_id: str) -> Dict[str, Union[str, Dict[str, str]]]:
case_dir_processed = self.data_processed_dir / case_id

# Check if raw data for case ID exists.
if not self._raw_data_exists(case_dir_raw):
if not self._raw_data_exists(case_dir=case_dir_raw):
logger.info(f"Case {case_id} does not exist in raw data directory.")
return {}

# If case has already been processed, skip, unless force=True.
if self._already_processed(case_dir_processed) and not self.force:
if self._already_processed(case_dir=case_dir_processed) and not self.force:
logger.info(
f"Case {case_id} has already been processed. Use --force to overwrite."
)
return {}
processed_data = read_json(
file_path=case_dir_processed / self.config.file_names.processed_data
)
return processed_data

# Process data for the case.
logger.info(f"Processing case {case_id}...")
Expand All @@ -98,7 +104,6 @@ def process(self, case_id: str) -> Dict[str, Union[str, Dict[str, str]]]:
case_dir_raw / self.config.file_names.tabular_data
)

processed_data: Dict[str, Union[str, Dict[str, str]]] = {}
processed_data["case_id"] = case_id
processed_data["tabular_data"] = tabular_data

Expand All @@ -109,7 +114,7 @@ def process(self, case_id: str) -> Dict[str, Union[str, Dict[str, str]]]:
processed_data["pdf_data"] = pdf_data
processed_data["process_info"] = {
"process_time": str(time.time() - start),
"hardware_used": "gpu" if self.config.process.gpu else "cpu",
"hardware_used": "gpu" if torch.cuda.is_available() else "cpu",
}

if not self.config.testing:
Expand Down
6 changes: 1 addition & 5 deletions src/doms_databasen/scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
from selenium import webdriver
from selenium.common.exceptions import NoSuchElementException
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service as ChromeService
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from webdriver_manager.chrome import ChromeDriverManager

from ._constants import N_FILES_RAW_CASE_DIR
from ._exceptions import PDFDownloadException
Expand Down Expand Up @@ -165,9 +163,7 @@ def _start_driver(self) -> webdriver.Chrome:
options.add_argument("--disable-dev-shm-usage")
options.add_argument("--headless")

driver = webdriver.Chrome(
service=ChromeService(ChromeDriverManager().install()), options=options
)
driver = webdriver.Chrome(options=options)
return driver

def _intialize_downloader_folder(self) -> None:
Expand Down
Loading
Loading