Skip to content

Commit

Permalink
refactor: update TableExtraction class, sort import add checking cond…
Browse files Browse the repository at this point in the history
…itions and add example notebook
  • Loading branch information
trungtd-2436 committed Sep 21, 2021
1 parent 6999711 commit b789ff8
Show file tree
Hide file tree
Showing 22 changed files with 649 additions and 195 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,6 @@ cython_debug/

table_reconstruction/__version__.py

docs/source/_build
docs/source/_build

tmp/
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@

html_theme = "sphinx_rtd_theme"
napoleon_include_init_with_doc = True
autoclass_content = "both"
autodoc_class_signature = "separated"
add_function_parentheses = False
add_module_names = False
Expand Down
494 changes: 432 additions & 62 deletions example/example.ipynb

Large diffs are not rendered by default.

Binary file modified example/table.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ torchvision==0.9.1
gdown
notebook
scikit-image==0.18.3
Shapely==1.7.1
Shapely==1.7.1
130 changes: 114 additions & 16 deletions table_reconstruction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
from typing import List, Union

import numpy as np
import torch
from numpy import array, ndarray
from PIL.Image import Image
from pkg_resources import DistributionNotFound, get_distribution

from .line_segmentation.line_segment import LineSegmentation # noqa: F401
from .output.cell import Cell
from .output.table import Table
from .line_segmentation.line_segment import LineSegmentation # noqa: F401
from .table_detection.detector import TableDetector
from .utils.cell_utils import (
calculate_cell_coordinate,
get_intersection_points,
predict_relation,
sort_cell,
)
from .utils.lines_utils import get_coordinates
from .utils.mask_utils import normalize
from .utils.table_utils import DirectedGraph, convertSpanCell2DocxCoord

__version__ = None
try:
Expand All @@ -20,38 +33,123 @@ class TableExtraction:

def __init__(
self,
device: torch.device,
line_segment_weight_path: str = None,
table_detection_weight_path: str = None,
normalize_thresh: int = 15,
) -> None:
"""
"""[summary]
Args:
line_segment_weight_path (str, optional): Path to exported weight file of
Line segmentation model. Defaults to None.
table_detection_weight_path (str, optional): Path to exported weight file
of Table detection model. Defaults to None.
Raises:
NotImplementedError: [description]
device (torch.device)
line_segment_weight_path (str, optional): Defaults to None.
table_detection_weight_path (str, optional): Defaults to None.
normalize_thresh (int, optional): Normalize threshold used after receive
result from line segmentation model. Defaults to 15.
"""
raise NotImplementedError("Required models was not defined")
self.table_detection_model = TableDetector(
table_detection_weight_path, device=device.type
)
self.line_segmentation_model = LineSegmentation(
line_segment_weight_path, device=device
)
self.normalize_thresh = normalize_thresh

def extract(self, image: Union[ndarray, Image]) -> List[Table]:
"""Extract tables from image
"""Extract table from image
Args:
image (Union[ndarray, Image]): [description]
image (Union[ndarray, Image]): Input image
Raises:
ValueError: Will be raised when input image is not Numpy array or PIL Image
NotImplementedError: [description]
ValueError: will be raised if the input image is not np.ndarray or PIL.Image
Returns:
List[Table]: [description]
List[Table]: list of extracted tables
"""
if isinstance(image, Image):
image = array(image)
elif not isinstance(image, ndarray):
raise ValueError(("Input image must be Numpy array or PIL Image"))

raise NotImplementedError("Extracting methods were not defined")
table_regions = self.table_detection_model.predict([image])

tables = []
for region in table_regions[0]:
x_min, y_min, x_max, y_max = region

img = image[y_min:y_max, x_min:x_max]

h, w, _ = img.shape
padding_img = np.ones((h + 10, w + 10, 3), dtype=np.uint8) * 255
padding_img[5 : h + 5, 5 : w + 5, :] = img

mask = self.line_segmentation_model.predict(padding_img)
mask = normalize(img, mask_img=mask)
mask = np.array(mask[5 : h + 5, 5 : w + 5])
try:
(
tab_coord,
vertical_lines_coord,
horizontal_lines_coord,
) = get_coordinates(mask, ths=self.normalize_thresh)
except Exception as e:
print(str(e))
continue

intersect_points, fake_intersect_points = get_intersection_points(
horizontal_lines_coord, vertical_lines_coord, tab_coord
)

cells = calculate_cell_coordinate(
intersect_points.copy(),
False,
self.normalize_thresh,
[horizontal_lines_coord, vertical_lines_coord],
)

fake_cells = calculate_cell_coordinate(
fake_intersect_points.copy(), True, self.normalize_thresh
)

if len(cells) <= 1:
continue
cells = sort_cell(cells=np.array(cells))
fake_cells = sort_cell(cells=np.array(fake_cells))

hor_couple_ids, ver_couple_ids = predict_relation(cells)

H_Graph = DirectedGraph(len(cells))
H_Graph.add_edges(hor_couple_ids)
nb_col = H_Graph.findLongestPath()

V_Graph = DirectedGraph(len(cells))
V_Graph.add_edges(ver_couple_ids)
nb_row = V_Graph.findLongestPath()

span_list = convertSpanCell2DocxCoord(
cells, fake_cells, list(range(len(cells))), nb_col
)

cells_list = [
Cell(
(c_x_min, c_x_max, c_y_min, c_y_max),
col_index=span_info["y"][0],
row_index=span_info["x"][0],
col_span=span_info["y"][1],
row_span=span_info["x"][1],
)
for span_info, (c_x_min, c_x_max, c_y_min, c_y_max) in zip(
span_list, cells
)
]

tables.append(
Table(
coordinate=(x_min, x_max, y_min, y_max),
col_numb=nb_col,
row_numb=nb_row,
cells=cells_list,
)
)
return tables
13 changes: 6 additions & 7 deletions table_reconstruction/line_segmentation/line_segment.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import logging
import os
from typing import Tuple

import cv2
import gdown
import torch
import numpy as np
import cv2

from torchvision import transforms
import torch
from PIL import Image
from torchvision import transforms

from .utils import load_model_unet

Expand Down Expand Up @@ -41,7 +40,7 @@ def __init__(
except Exception as e:
logging.info("Could not download weight, please try again!")
logging.info(f"Error code: {e}")
raise Exception('An error occured while downloading weight file')
raise Exception("An error occured while downloading weight file")
self.model = load_model_unet(MODEL_PATH, device)
else:
if os.path.exists(model_path):
Expand Down Expand Up @@ -116,7 +115,7 @@ def _preprocess(
h, w, _ = img.shape
assert pad >= 0, "Pad must great than 0"
padding_img = np.ones((h + pad * 2, w + pad * 2, 3), dtype=np.uint8) * 255
padding_img[pad: h + pad, pad: w + pad, :] = img
padding_img[pad : h + pad, pad : w + pad, :] = img
pil_img = Image.fromarray(padding_img)

# Resize
Expand Down
2 changes: 1 addition & 1 deletion table_reconstruction/line_segmentation/unet/resunet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn
from torch.nn.modules.batchnorm import BatchNorm2d

from .unet_parts import Up, OutConv
from .unet_parts import OutConv, Up


def conv(ni, nf, ks=3, stride=1, act=True, bn=True):
Expand Down
6 changes: 3 additions & 3 deletions table_reconstruction/output/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, coordinate: Tuple[int, int, int, int]):
Args:
coordinate (List[int, int, int, int]): A list that contains 4 integer
values ​​defined as x_min, x_max, y_min, y_max respectively
values ​​defined as x_min, x_max, y_min, y_max respectively
Returns:
bool: The return value. True for success, False otherwise.
Expand Down Expand Up @@ -63,7 +63,7 @@ def y_max(self, value):
self.coordinate[3] = value

def __repr__(self):
return f"{self.__class__.__name__}({self.coord})"
return f"{self.__class__.__name__}({self.coordinate})"

def __str__(self):
return f"{self.__class__.__name__}({self.coord})"
return f"{self.__class__.__name__}({self.coordinate})"
2 changes: 1 addition & 1 deletion table_reconstruction/table_detection/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import time
from typing import List, Union

import numpy as np
import torch
import torchvision
from typing import Union, List


def box_iou(box1: torch.Tensor, box2: torch.Tensor) -> torch.Tensor:
Expand Down
3 changes: 2 additions & 1 deletion table_reconstruction/table_detection/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, List, Tuple, Union

import cv2
import numpy as np
import torch
from typing import Any, List, Tuple, Union


def create_batch(
Expand Down
20 changes: 14 additions & 6 deletions table_reconstruction/table_detection/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import torch
from typing import Union
import gdown

import torch

from .yolov5 import YOLO_DIR


DETECTOR_WEIGHT_URL = "https://drive.google.com/uc?id=18eh4wpbeEGn0bNXyDFDxgQNWTh-wtnPJ"
DETECTOR_WEIGHT_URL = "https://drive.google.com/uc?id=12ttln8zPOWrFCPLr4hChmxr4rxuHRRoz"


def select_device(device: str = "") -> torch.device:
Expand All @@ -17,12 +16,14 @@ def select_device(device: str = "") -> torch.device:
Returns:
device (torch.device): selected device
"""
if not isinstance(device, str):
device = device.type
cpu = device.lower() == "cpu"
cuda = not cpu and torch.cuda.is_available()
return torch.device("cuda:0" if cuda else "cpu")


def load_yolo_model(weight_path: str, device: str):
def load_yolo_model(weight_path: str, device: Union[torch.device, str]):
"""load yolo model detect using torch hub
Args:
Expand All @@ -34,8 +35,15 @@ def load_yolo_model(weight_path: str, device: str):
model stride (torch.Tensor): stride of model
"""
model = torch.hub.load(
str(YOLO_DIR), "custom", path=weight_path, source="local", device=device
str(YOLO_DIR),
"custom",
path=weight_path,
source="local",
device=device,
force_reload=True,
)
if isinstance(device, str):
device = torch.device(device)
model.to(device)
return model, model.stride

Expand Down
9 changes: 5 additions & 4 deletions table_reconstruction/table_detection/yolov5/hubconf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from typing import Any, Union

import torch


def _create(
name,
Expand All @@ -25,10 +26,8 @@ def _create(
"""
from pathlib import Path

from models.experimental import attempt_load
from models.yolo import Model
from models.experimental import (
attempt_load,
)

file = Path(__file__).absolute()

Expand Down Expand Up @@ -73,6 +72,8 @@ def select_device(device: Union[str, Any] = "") -> torch.device:
Returns:
device (torch.device): selected device
"""
if not isinstance(device, str):
device = device.type
cpu = device.lower() == "cpu"
cuda = not cpu and torch.cuda.is_available()
return torch.device("cuda:0" if cuda else "cpu")
8 changes: 4 additions & 4 deletions table_reconstruction/table_detection/yolov5/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from torch.cuda import amp

from .utils import (
letterbox,
color_list,
increment_path,
letterbox,
make_divisible,
non_max_suppression,
scale_coords,
xyxy2xywh,
color_list,
plot_one_box,
scale_coords,
time_synchronized,
xyxy2xywh,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
import torch
import torch.nn as nn

from .common import (
Conv,
DWConv,
)
from .common import Conv, DWConv


class CrossConv(nn.Module):
Expand Down
Loading

0 comments on commit b789ff8

Please sign in to comment.