From 54d509ee3b2d0b4853b2f79c730c04573fd2146d Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Wed, 9 Oct 2024 17:10:42 +0300 Subject: [PATCH 1/3] code formatting --- deepmultilingualpunctuation/__init__.py | 2 +- .../punctuationmodel.py | 73 ++++++++++--------- setup.py | 8 +- 3 files changed, 45 insertions(+), 38 deletions(-) diff --git a/deepmultilingualpunctuation/__init__.py b/deepmultilingualpunctuation/__init__.py index 5919050..a285e41 100644 --- a/deepmultilingualpunctuation/__init__.py +++ b/deepmultilingualpunctuation/__init__.py @@ -1 +1 @@ -from .punctuationmodel import PunctuationModel \ No newline at end of file +from .punctuationmodel import PunctuationModel diff --git a/deepmultilingualpunctuation/punctuationmodel.py b/deepmultilingualpunctuation/punctuationmodel.py index 14d11fa..95e7e24 100644 --- a/deepmultilingualpunctuation/punctuationmodel.py +++ b/deepmultilingualpunctuation/punctuationmodel.py @@ -1,79 +1,86 @@ -from concurrent.futures import process -from transformers import pipeline import re + import torch +from transformers import pipeline + -class PunctuationModel(): - def __init__(self, model = "oliverguhr/fullstop-punctuation-multilang-large") -> None: +class PunctuationModel: + def __init__(self, model="oliverguhr/fullstop-punctuation-multilang-large") -> None: if torch.cuda.is_available(): - self.pipe = pipeline("ner",model, aggregation_strategy="none", device=0) + self.pipe = pipeline("ner", model, aggregation_strategy="none", device=0) else: - self.pipe = pipeline("ner",model, aggregation_strategy="none") + self.pipe = pipeline("ner", model, aggregation_strategy="none") - def preprocess(self,text): - #remove markers except for markers in numbers - text = re.sub(r"(? result[result_index]["end"] : - label = result[result_index]['entity'] - score = result[result_index]['score'] - result_index += 1 - tagged_words.append([word,label, score]) - + while ( + result_index < len(result) + and char_index > result[result_index]["end"] + ): + label = result[result_index]["entity"] + score = result[result_index]["score"] + result_index += 1 + tagged_words.append([word, label, score]) + assert len(tagged_words) == len(words) return tagged_words - def prediction_to_text(self,prediction): + def prediction_to_text(self, prediction): result = "" for word, label, _ in prediction: result += word if label == "0": result += " " if label in ".,?-:": - result += label+" " + result += label + " " return result.strip() -if __name__ == "__main__": + +if __name__ == "__main__": model = PunctuationModel() text = "das , ist fies " diff --git a/setup.py b/setup.py index f5e01a0..fc0be03 100644 --- a/setup.py +++ b/setup.py @@ -12,15 +12,15 @@ long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/oliverguhr/deepmultilingualpunctuation", - packages=setuptools.find_packages(), + packages=setuptools.find_packages(), classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], install_requires=[ - "transformers", - "torch>=1.8.1", + "transformers", + "torch>=1.8.1", ], - python_requires='>=3.6', + python_requires=">=3.6", ) From 5097bd7be608b05cfca7956e751ae2805b5788d4 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Wed, 9 Oct 2024 17:11:16 +0300 Subject: [PATCH 2/3] add acronyms --- deepmultilingualpunctuation/punctuationmodel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/deepmultilingualpunctuation/punctuationmodel.py b/deepmultilingualpunctuation/punctuationmodel.py index 95e7e24..63220c7 100644 --- a/deepmultilingualpunctuation/punctuationmodel.py +++ b/deepmultilingualpunctuation/punctuationmodel.py @@ -12,9 +12,10 @@ def __init__(self, model="oliverguhr/fullstop-punctuation-multilang-large") -> N self.pipe = pipeline("ner", model, aggregation_strategy="none") def preprocess(self, text): - # remove markers except for markers in numbers - text = re.sub(r"(? Date: Wed, 9 Oct 2024 17:14:14 +0300 Subject: [PATCH 3/3] expose kwargs in init --- deepmultilingualpunctuation/punctuationmodel.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/deepmultilingualpunctuation/punctuationmodel.py b/deepmultilingualpunctuation/punctuationmodel.py index 63220c7..e2b811b 100644 --- a/deepmultilingualpunctuation/punctuationmodel.py +++ b/deepmultilingualpunctuation/punctuationmodel.py @@ -5,11 +5,14 @@ class PunctuationModel: - def __init__(self, model="oliverguhr/fullstop-punctuation-multilang-large") -> None: - if torch.cuda.is_available(): - self.pipe = pipeline("ner", model, aggregation_strategy="none", device=0) - else: - self.pipe = pipeline("ner", model, aggregation_strategy="none") + def __init__( + self, + model="oliverguhr/fullstop-punctuation-multilang-large", + **kwargs, + ) -> None: + if "aggregation_strategy" not in kwargs: + kwargs["aggregation_strategy"] = "none" + self.pipe = pipeline("ner", model, **kwargs) def preprocess(self, text): # remove punctuation except dots