Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preprocess Text: Add Spacy POS tagger #1070

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions orangecontrib/text/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ def has_tokens(self):
""" Return whether corpus is preprocessed or not. """
return self._tokens is not None

def has_tags(self):
""" Return whether corpus is POS tagged or not. """
return self._pos_tags is not None

def _base_tokens(self):
from orangecontrib.text.preprocess import BASE_TRANSFORMER, \
BASE_TOKENIZER, PreprocessorList
Expand Down
2 changes: 2 additions & 0 deletions orangecontrib/text/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@
"uk": "Ukrainian",
"ur": "Urdu",
"vi": "Vietnamese",
# Spacy code for multi-language model
"xx": "Multi-language",
"zh": "Chinese",
"zh_char": "Chinese - Chinese Characters",
None: None,
Expand Down
87 changes: 85 additions & 2 deletions orangecontrib/text/tag/pos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import List, Callable
from typing import List, Callable, Tuple

import nltk
import spacy
from spacy.cli import info, download
from spacy.tokens import Doc
import numpy as np
from Orange.util import wrap_callback, dummy_callback

Expand All @@ -10,7 +13,38 @@
from orangecontrib.text.util import chunkable


__all__ = ["POSTagger", "AveragedPerceptronTagger", "MaxEntTagger"]
__all__ = ["POSTagger", "AveragedPerceptronTagger", "MaxEntTagger",
"SpacyPOSTagger"]


SPACY_MODELS = {
"ca": {"language": "Catalan", "package": "ca_core_news_sm", "dependency": "None"},
"zh": {"language": "Chinese", "package": "zh_core_web_sm", "dependency": "Jieba"},
"hr": {"language": "Croatian", "package": "hr_core_news_sm", "dependency": "None"},
"da": {"language": "Danish", "package": "da_core_news_sm", "dependency": "None"},
"nl": {"language": "Dutch", "package": "nl_core_news_sm", "dependency": "None"},
"en": {"language": "English", "package": "en_core_web_sm", "dependency": "None"},
"fi": {"language": "Finnish", "package": "fi_core_news_sm", "dependency": "None"},
"fr": {"language": "French", "package": "fr_core_news_sm", "dependency": "None"},
"de": {"language": "German", "package": "de_core_news_sm", "dependency": "None"},
"el": {"language": "Greek", "package": "el_core_news_sm", "dependency": "None"},
"it": {"language": "Italian", "package": "it_core_news_sm", "dependency": "None"},
"ja": {"language": "Japanese", "package": "ja_core_news_sm", "dependency": "SudachiPy"},
"ko": {"language": "Korean", "package": "ko_core_news_sm", "dependency": "None"},
"lt": {"language": "Lithuanian", "package": "lt_core_news_sm", "dependency": "None"},
"mk": {"language": "Macedonian", "package": "mk_core_news_sm", "dependency": "None"},
"xx": {"language": "Multi-language", "package": "xx_ent_wiki_sm", "dependency": "None"},
"nb": {"language": "Norwegian Bokmål", "package": "nb_core_news_sm", "dependency": "None"},
"pl": {"language": "Polish", "package": "pl_core_news_sm", "dependency": "None"},
"pt": {"language": "Portuguese", "package": "pt_core_news_sm", "dependency": "None"},
"ro": {"language": "Romanian", "package": "ro_core_news_sm", "dependency": "None"},
"ru": {"language": "Russian", "package": "ru_core_news_sm", "dependency": "pymorphy3"},
"sl": {"language": "Slovenian", "package": "sl_core_news_sm", "dependency": "None"},
"es": {"language": "Spanish", "package": "es_core_news_sm", "dependency": "None"},
"sv": {"language": "Swedish", "package": "sv_core_news_sm", "dependency": "None"},
"uk": {"language": "Ukrainian", "package": "uk_core_news_sm", "dependency":
"pymorphy3, pymorphy3-dicts-uk"}
}


class POSTagger(TokenizedPreprocessor):
Expand Down Expand Up @@ -52,3 +86,52 @@ class MaxEntTagger(POSTagger):
def __init__(self):
tagger = nltk.data.load('taggers/maxent_treebank_pos_tagger/english.pickle')
super().__init__(tagger)


def find_model(language: str) -> str:
return SPACY_MODELS[language]["package"]


class SpacyModels:
installed_models_info = info()

def __init__(self):
self.installed_models = self.installed_models_info['pipelines']

def __getitem__(self, language: str) -> str:
model = find_model(language)
if model not in self.installed_models:
download(model)
return model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.installed_models should be updated at this point. If not the package keeps getting downloaded.



class SpacyPOSTagger(TokenizedPreprocessor):
name = 'Spacy POS Tagger'
supported_languages = set(SPACY_MODELS.keys())

def __init__(self, language: str = "en"):
self.__language = language
self.models = SpacyModels()
self.__model = None

def __call__(self, corpus: Corpus, callback: Callable = None,
**kw) -> Corpus:
""" Marks tokens of a corpus with POS tags. """
if callback is None:
callback = dummy_callback
corpus = super().__call__(corpus, wrap_callback(callback, end=0.2))

assert corpus.has_tokens()
callback(0.2, "POS Tagging...")
self.__model = spacy.load(self.models[self.__language])
tags = np.array(self.tag(corpus.tokens), dtype=object)
corpus.pos_tags = tags
return corpus

def tag(self, tokens):
out_tokens = []
for token_list in tokens:
# required for Spacy to work with pre-tokenized texts
doc = Doc(self.__model.vocab, words=token_list)
out_tokens.append([token.pos_ for token in self.__model(doc)])
return out_tokens
29 changes: 29 additions & 0 deletions orangecontrib/text/tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,5 +721,34 @@ def test_can_pickle(self):
self.assertEqual(loaded._NGrams__range, self.pp._NGrams__range)


class TestPOSTagging(unittest.TestCase):
def setUp(self):
self.corpus = Corpus.from_file("deerwester")
self.pp = [preprocess.WordPunctTokenizer(),
tag.SpacyPOSTagger()]

def test_no_tokens(self):
self.assertFalse(self.corpus.has_tokens())
tagger = tag.SpacyPOSTagger()
corpus = tagger(self.corpus)
self.assertEqual(len(corpus.used_preprocessor.preprocessors), 2)
self.assertTrue(corpus.has_tags())

def test_pos_tagger(self):
corpus = self.corpus
for pp in self.pp:
corpus = pp(corpus)
self.assertTrue(corpus.has_tokens())
self.assertTrue(corpus.has_tags())
self.assertEqual(len(corpus.pos_tags), len(corpus.tokens))
spacy_tags = corpus.pos_tags
tagger = tag.AveragedPerceptronTagger()
corpus = tagger(self.corpus)
self.assertEqual(len(corpus.pos_tags), len(corpus.tokens))
self.assertEqual(len(corpus.used_preprocessor.preprocessors), 2)
apt_tags = corpus.pos_tags
self.assertFalse(bool(np.array_equal(spacy_tags, apt_tags)))


if __name__ == "__main__":
unittest.main()
58 changes: 53 additions & 5 deletions orangecontrib/text/widgets/owpreprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from orangecontrib.text.misc import nltk_data_dir
from orangecontrib.text.preprocess import *
from orangecontrib.text.preprocess.normalize import UDPipeStopIteration, UDPipeModels
from orangecontrib.text.tag import AveragedPerceptronTagger, MaxEntTagger, \
POSTagger
from orangecontrib.text.tag import (AveragedPerceptronTagger, MaxEntTagger,
SpacyPOSTagger, POSTagger)

_DEFAULT_NONE = "(none)"

Expand Down Expand Up @@ -1033,15 +1033,63 @@ def createinstance(params: Dict) -> NGrams:


class POSTaggingModule(SingleMethodModule):
Averaged, MaxEnt = range(2)
Averaged, MaxEnt, Spacy = range(3)
Methods = {Averaged: AveragedPerceptronTagger,
MaxEnt: MaxEntTagger}
MaxEnt: MaxEntTagger,
Spacy: SpacyPOSTagger}
DEFAULT_METHOD = Averaged
DEFAULT_LANGUAGE = "en"

def __init__(self, parent=None, **kwargs):
super().__init__(parent, **kwargs)
self.__method = self.DEFAULT_METHOD
self.spacy_lang = self.DEFAULT_LANGUAGE

self.__combo_scy = LanguageComboBox(
self,
SpacyPOSTagger.supported_languages,
self.DEFAULT_LANGUAGE,
False,
self.__set_spacy_lang
)

label = QLabel("Language:")
label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
self.layout().addWidget(label, self.Spacy, 1)
self.layout().addWidget(self.__combo_scy, self.Spacy, 2)

def __set_spacy_lang(self, language: str):
if self.spacy_lang != language:
self.spacy_lang = language
self.__combo_scy.set_current_language(language)
self.changed.emit()
if self.method == self.Spacy:
self.edited.emit()

def setParameters(self, params: Dict):
super().setParameters(params)
spacy_lang = params.get("spacy_language", self.DEFAULT_LANGUAGE)
self.__set_spacy_lang(spacy_lang)

def parameters(self) -> Dict:
params = super().parameters()
params.update({"spacy_language": self.spacy_lang})
return params

@staticmethod
def createinstance(params: Dict) -> POSTagger:
method = params.get("method", POSTaggingModule.DEFAULT_METHOD)
return POSTaggingModule.Methods[method]()
args = {}
if method == POSTaggingModule.Spacy:
args = {"language": params.get("spacy_language",
POSTaggingModule.DEFAULT_LANGUAGE)}
return POSTaggingModule.Methods[method](**args)

def __repr__(self):
text = super().__repr__()
if self.method == self.Spacy:
text = f"{text} ({self.spacy_lang})"
return text


PREPROCESS_ACTIONS = [
Expand Down
23 changes: 17 additions & 6 deletions orangecontrib/text/widgets/tests/test_owpreprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from orangecontrib.text.preprocess import RegexpTokenizer, WhitespaceTokenizer, \
LowercaseTransformer, HtmlTransformer, PorterStemmer, SnowballStemmer, \
UDPipeLemmatizer, StopwordsFilter, MostFrequentTokensFilter, NGrams
from orangecontrib.text.tag import AveragedPerceptronTagger, MaxEntTagger
from orangecontrib.text.tag import (AveragedPerceptronTagger, MaxEntTagger,
SpacyPOSTagger)
from orangecontrib.text.tests.test_preprocess import SF_LIST, SERVER_FILES
from orangecontrib.text.widgets.owpreprocess import (
OWPreprocess,
Expand Down Expand Up @@ -884,20 +885,21 @@ def buttons(self):

def test_init(self):
self.assertTrue(self.buttons[0].isChecked())
for i in range(1, 2):
for i in range(1, 3):
self.assertFalse(self.buttons[i].isChecked())

def test_parameters(self):
params = {"method": POSTaggingModule.Averaged}
params = {"method": POSTaggingModule.Averaged, "spacy_language":
POSTaggingModule.DEFAULT_LANGUAGE}
self.assertDictEqual(self.editor.parameters(), params)

def test_set_parameters(self):
params = {"method": POSTaggingModule.MaxEnt}
params = {"method": POSTaggingModule.Spacy, "spacy_language": "sl"}
self.editor.setParameters(params)
self.assertDictEqual(self.editor.parameters(), params)

self.assertTrue(self.buttons[1].isChecked())
for i in range(1):
self.assertTrue(self.buttons[2].isChecked())
for i in range(0, 2):
self.assertFalse(self.buttons[i].isChecked())

def test_createinstance(self):
Expand All @@ -907,9 +909,18 @@ def test_createinstance(self):
pp = self.editor.createinstance({"method": POSTaggingModule.MaxEnt})
self.assertIsInstance(pp, MaxEntTagger)

pp = self.editor.createinstance({"method": POSTaggingModule.Spacy})
self.assertIsInstance(pp, SpacyPOSTagger)

def test_repr(self):
self.assertEqual(str(self.editor), "Averaged Perceptron Tagger")

params = {"method": POSTaggingModule.Spacy, "spacy_language":
POSTaggingModule.DEFAULT_LANGUAGE}
self.editor.setParameters(params)
self.assertEqual(str(self.editor),
f"Spacy POS Tagger ({params['spacy_language']})")


class TestLanguageComboBox(WidgetTest):
def test_basic_setup(self):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ serverfiles
simhash >=1.11
shapely >=2.0
six
spacy
tweepy >=4.0.0
ufal.udpipe >=1.2.0.3
trimesh >=3.9.8 # required by alphashape
Expand Down
Loading