-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support for English documents Kf/english pipeline (#65)
Adding English models where necessary as well CLI/config options to set language as "da" or "en"
- Loading branch information
1 parent
b816359
commit fa4f28b
Showing
15 changed files
with
293 additions
and
104 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,32 @@ | ||
[base] | ||
project_name = "PROJECT_NAME" # also CLI arg | ||
output_root = "output" | ||
language = "da/en" | ||
language = "da/en" # also CLI arg | ||
|
||
[preprocessing] | ||
enabled = true | ||
input_path = "PATH/TO/INPUT/*" # also CLI arg | ||
n_docs = -1 # also CLI arg | ||
doc_type = "text/csv/tweets/infomedia" | ||
metadata_fields = ["*"] | ||
|
||
[preprocessing.extra] | ||
# specific extra arguments for your preprocessor, e.g. context length for tweets. | ||
# leave empty unless you have very specific needs | ||
# specific extra arguments for your preprocessor, e.g. context length for tweets or | ||
# or field specification for CSVs | ||
|
||
[docprocessing] | ||
enabled = true | ||
batch_size = 25 | ||
continue_from_last = true | ||
triplet_extraction_method = "multi2oie/prompting" | ||
|
||
[corpusprocessing] | ||
enabled = true | ||
enabled = true | ||
embedding_model = "PATH_OR_NAME" # leave out for default model choice by language | ||
dimensions = 100 # leave out to skip dimensionality reduction | ||
n_neighbors = 15 # used for dimensionality reduction | ||
|
||
[corpusprocessing.thresholds] # leave out for automatic estimation | ||
min_cluster_size = 3 # unused if auto_thresholds is true | ||
min_samples = 3 # unused if auto_thresholds is true | ||
min_topic_size = 5 # unused if auto_thresholds is true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import logging | ||
from typing import Any | ||
|
||
|
||
class ModelChoice: | ||
"""Helper class for selecting an appropriate model based on language codes. | ||
An error will be thrown on unsupported languages. To avoid that, set 'fallback' | ||
model if appropriate. | ||
If choices are given as supplier functions, they will be called and then returned. | ||
Usage: | ||
>>> mc1 = ModelChoice(da="danish_model", fallback="fallback_model") | ||
>>> mc1.get_model("da") # "danish_model" | ||
>>> mc1.get_model("de") # "fallback_model" | ||
>>> mc2 = ModelChoice(da="danish_model") | ||
>>> mc2.get_model("de") # throws error | ||
>>> mc3 = ModelChoice(da=lambda: "danish_model") | ||
>>> mc3.get_model("da") # "danish_model" | ||
""" | ||
|
||
def __init__(self, **choices: Any): | ||
self.models = choices | ||
|
||
def get_model(self, language: str): | ||
if language not in self.models: | ||
error = f"Language '{language}' not supported!" | ||
if "fallback" in self.models: | ||
logging.warning(error + " Using fallback model.") | ||
language = "fallback" | ||
else: | ||
raise ValueError(error) | ||
model = self.models[language] | ||
if callable(model): # if supplier function | ||
model = model() | ||
logging.debug("Using '%s' model: %s", language, model) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.