diff --git a/src/layoutparser/__init__.py b/src/layoutparser/__init__.py index 339f927..09ff240 100644 --- a/src/layoutparser/__init__.py +++ b/src/layoutparser/__init__.py @@ -23,6 +23,7 @@ is_effdet_available, is_pytesseract_available, is_gcv_available, + is_paddleocr_available, ) _import_structure = { @@ -51,6 +52,7 @@ "is_paddle_available", "is_pytesseract_available", "is_gcv_available", + "is_paddleocr_available", "requires_backends" ], "tools": [ @@ -80,6 +82,9 @@ if is_gcv_available(): _import_structure["ocr.gcv_agent"] = ["GCVAgent", "GCVFeatureType"] +if is_paddleocr_available(): + _import_structure["ocr.paddleocr_agent"] = ["PaddleOCRAgent", "PaddleOCRFeatureType"] + sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], diff --git a/src/layoutparser/file_utils.py b/src/layoutparser/file_utils.py index b10a747..f74e83b 100644 --- a/src/layoutparser/file_utils.py +++ b/src/layoutparser/file_utils.py @@ -88,6 +88,16 @@ except ModuleNotFoundError: _gcv_available = False +try: + _paddleocr_available = importlib.util.find_spec("paddleocr") is not None + try: + _paddleocr_version = importlib_metadata.version("paddleocr") + logger.debug(f"PaddleOCR version {_paddleocr_version} available.") + except importlib_metadata.PackageNotFoundError: + _paddleocr_available = False +except ModuleNotFoundError: + _paddleocr_available = False + def is_torch_available(): return _torch_available @@ -121,6 +131,9 @@ def is_pytesseract_available(): def is_gcv_available(): return _gcv_available +def is_paddleocr_available(): + return _paddleocr_available + PYTORCH_IMPORT_ERROR = """ {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the @@ -154,6 +167,11 @@ def is_gcv_available(): `pip install google-cloud-vision==1` """ +PADDLEOCR_IMPORT_ERROR = """ +{0} requires the PaddleOCR library but it was not found in your environment. You can install it with pip: +`pip install paddleocr` +""" + BACKENDS_MAPPING = dict( [ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), @@ -162,6 +180,7 @@ def is_gcv_available(): ("effdet", (is_effdet_available, EFFDET_IMPORT_ERROR)), ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)), ("google-cloud-vision", (is_gcv_available, GCV_IMPORT_ERROR)), + ("paddleocr", (is_paddleocr_available, PADDLEOCR_IMPORT_ERROR)), ] ) diff --git a/src/layoutparser/ocr/__init__.py b/src/layoutparser/ocr/__init__.py index 66efd76..c296216 100644 --- a/src/layoutparser/ocr/__init__.py +++ b/src/layoutparser/ocr/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .gcv_agent import GCVAgent, GCVFeatureType -from .tesseract_agent import TesseractAgent, TesseractFeatureType \ No newline at end of file +from .tesseract_agent import TesseractAgent, TesseractFeatureType +from .paddleocr_agent import PaddleOCRAgent \ No newline at end of file diff --git a/src/layoutparser/ocr/paddleocr_agent.py b/src/layoutparser/ocr/paddleocr_agent.py new file mode 100644 index 0000000..667023a --- /dev/null +++ b/src/layoutparser/ocr/paddleocr_agent.py @@ -0,0 +1,138 @@ +# Copyright 2021 The Layout Parser team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +import cv2 +import numpy as np + +from .base import BaseOCRAgent +from ..file_utils import is_paddleocr_available + +if is_paddleocr_available(): + import paddleocr + + +class PaddleOCRAgent(BaseOCRAgent): + """ + A wrapper for `PaddleOCR `_ Text + Detection APIs based on `PaddleOCR `_. + """ + + DEPENDENCIES = ["paddleocr"] + + def __init__(self, languages="en", use_gpu=False, use_angle_cls=False): + """Create a Tesseract OCR Agent. + + Args: + languages (:obj:`list` or :obj:`str`, optional): + You can specify the language code(s) of the documents to detect to improve + accuracy. The supported language and their code can be found on + `its github repo `_. + It supports llaguages:`ch`, `en`, `french`, `german`, `korean`, `japan`. + Defaults to 'eng'. + """ + self.lang = languages + self.use_gpu = use_gpu + self.use_angle_cls = use_angle_cls + self.ocr = paddleocr.PaddleOCR(use_gpu=self.use_gpu, use_angle_cls=self.use_angle_cls, lang=self.lang) + + def resized_long(self, image, target_size): + shape = image.shape + if max(image.shape[0], image.shape[1]) >= target_size: + return image + if shape[0] >= shape[1]: + ratio = 1.0 * target_size / shape[0] + out = (int(shape[1] * ratio), target_size) + else: + ratio = 1.0 * target_size / shape[1] + out = (target_size, int(shape[0] * ratio)) + return cv2.resize(image, out) + + def pad_img_to_longer_edge(self, image, padding_value=127): + max_shape = max(image.shape[0], image.shape[1]) + out_img = np.ones([max_shape, max_shape, 3]) * padding_value + out_img[:image.shape[0], :image.shape[1], :image.shape[2]] = image + return out_img + + def _detect(self, img_content, target_size, padding_value, + det, rec, cls, threshold): + res = {} + img_content = self.resized_long(img_content, target_size) + img_content = self.pad_img_to_longer_edge(img_content, padding_value) + result = self.ocr.ocr(img_content, det=det, rec=rec, cls=cls) + text = [] + for line in result: + if line[1][1]>threshold: + text.append(line[1][0]) + res["text"] = '\n'.join(text) + return res + + def detect( + self, image, target_size=480, padding_value=127, + det=True, rec=True, cls=True, threshold=0.5, + return_response=False, return_only_text=True + ): + """Send the input image for OCR. + + Args: + image (:obj:`np.ndarray` or :obj:`str`): + The input image array or the name of the image file + target_size (:obj:`int`, optional): + The size of the longest side after resize. + Default to `480`. + padding_value (:obj:`int`, optional): + The padding value will apply to get a square image. + Default to `127`. + det (:obj:`bool`, optional): + use text detection or not, if false, only rec will be exec. + Default to `True`. + rec (:obj:`bool`, optional): + Use text recognition or not, if false, only det will be exec. + Default to `True`. + cls (:obj:`bool`, optional): + Use 180 degree rotation text recognition or not. + Default to `True`. + threshold (:obj:`float`, optional): + Filter the recognition results with recognition scores less than threshold. + Default to '0.5'. + return_response (:obj:`bool`, optional): + Whether directly return all output (string and boxes + info) from Tesseract. + Defaults to `False`. + return_only_text (:obj:`bool`, optional): + Whether return only the texts in the OCR results. + Defaults to `False`. + """ + + res = self._detect(image, target_size, padding_value, det, rec, cls, threshold) + + if return_response: + return res + + if return_only_text: + return res["text"] + + return res["text"] + + @staticmethod + def load_response(filename): + with open(filename, "rb") as fp: + res = pickle.load(fp) + return res + + @staticmethod + def save_response(res, file_name): + + with open(file_name, "wb") as fp: + pickle.dump(res, fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/tests/test_ocr.py b/tests/test_ocr.py index 0c42cfc..9099e79 100644 --- a/tests/test_ocr.py +++ b/tests/test_ocr.py @@ -17,6 +17,7 @@ GCVFeatureType, TesseractAgent, TesseractFeatureType, + PaddleOCRAgent, ) import json, cv2, os @@ -76,4 +77,15 @@ def test_tesseract(test_detect=False): assert r2 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.BLOCK) assert r3 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.PARA) assert r4 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.LINE) - assert r5 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.WORD) \ No newline at end of file + assert r5 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.WORD) + + +def test_paddleocr(test_detect=False): + + ocr_agent = PaddleOCRAgent(languages="en") + + # The results could be different is using another version of PaddleOCR Engine. + # PaddleOCR 2.0.1 is used for generating the result. + if test_detect: + res = ocr_agent.detect(image) + print(res) \ No newline at end of file