diff --git a/config/default.toml b/config/default.toml deleted file mode 100644 index 429e4f2..0000000 --- a/config/default.toml +++ /dev/null @@ -1,19 +0,0 @@ -[base] -output_root = "output" -language = "da" - -[preprocessing] -enabled = true -doc_type = "text" -metadata_fields = ["*"] - -[preprocessing.extra] -basename_as_id = true - -[docprocessing] -enabled = true -continue_from_last = true -triplet_extraction_method = "multi2oie" - -[corpusprocessing] -enabled = true \ No newline at end of file diff --git a/config/template_csv.toml b/config/eschatology.toml similarity index 65% rename from config/template_csv.toml rename to config/eschatology.toml index b7b9e8e..2832c7c 100644 --- a/config/template_csv.toml +++ b/config/eschatology.toml @@ -1,11 +1,9 @@ [base] -output_root = "output" language = "en" [preprocessing] enabled = true doc_type = "csv" -metadata_fields = ["*"] [preprocessing.extra] id_column = "id" @@ -14,8 +12,6 @@ text_column = "body" [docprocessing] enabled = true batch_size = 5 -continue_from_last = true -triplet_extraction_method = "multi2oie" [corpusprocessing] enabled = true \ No newline at end of file diff --git a/config/template.toml b/config/template.toml index 2572fc8..cd76efd 100644 --- a/config/template.toml +++ b/config/template.toml @@ -1,20 +1,32 @@ [base] +project_name = "PROJECT_NAME" # also CLI arg output_root = "output" -language = "da/en" +language = "da/en" # also CLI arg [preprocessing] enabled = true +input_path = "PATH/TO/INPUT/*" # also CLI arg +n_docs = -1 # also CLI arg doc_type = "text/csv/tweets/infomedia" metadata_fields = ["*"] [preprocessing.extra] -# specific extra arguments for your preprocessor, e.g. context length for tweets. -# leave empty unless you have very specific needs +# specific extra arguments for your preprocessor, e.g. context length for tweets or +# or field specification for CSVs [docprocessing] enabled = true +batch_size = 25 continue_from_last = true triplet_extraction_method = "multi2oie/prompting" [corpusprocessing] -enabled = true \ No newline at end of file +enabled = true +embedding_model = "PATH_OR_NAME" # leave out for default model choice by language +dimensions = 100 # leave out to skip dimensionality reduction +n_neighbors = 15 # used for dimensionality reduction + +[corpusprocessing.thresholds] # leave out for automatic estimation +min_cluster_size = 3 # unused if auto_thresholds is true +min_samples = 3 # unused if auto_thresholds is true +min_topic_size = 5 # unused if auto_thresholds is true \ No newline at end of file diff --git a/docs/faq.rst b/docs/faq.rst index c1a8249..41d8b1b 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -5,21 +5,20 @@ Frequently asked questions How do I run the pipeline? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you have cloned this project from git, you can run the pipeline via :code:`run.py` -and, optionally, configurations from :code:`config`. +If you have installed the package, you can run the pipeline via :code:`conspiracies.run`. .. code-block:: bash - python3 run.py my_project_name my_input_path + python3 -m conspiracies.run my_project_name my_input_path -For a specific pipeline configuration, create one using :code:`config/template.toml` and +For fine-grained control of pipeline behavior, create a configuration file based on the `config template `__. and pass it with the :code:`-c` flag. .. code-block:: bash - python3 run.py my_project_name my_input_path -c config/my-config.toml + python3 -m conspiracies.run my_project_name my_input_path -c my-config.toml Project name and input path can also be specified in the configuration instead in which @@ -29,7 +28,7 @@ case you can do python3 run.py -c config/my-config.toml -If you have installed the package via :code:`pip` and want to integrate (parts of) the +If you have installed the package and want to integrate (parts of) the pipeline into your own workflow, you can use individual components or integrate a :code:`Pipeline` object in your script. diff --git a/src/conspiracies/common/modelchoice.py b/src/conspiracies/common/modelchoice.py new file mode 100644 index 0000000..e02d154 --- /dev/null +++ b/src/conspiracies/common/modelchoice.py @@ -0,0 +1,39 @@ +import logging +from typing import Any + + +class ModelChoice: + """Helper class for selecting an appropriate model based on language codes. + + An error will be thrown on unsupported languages. To avoid that, set 'fallback' + model if appropriate. + + If choices are given as supplier functions, they will be called and then returned. + + Usage: + + >>> mc1 = ModelChoice(da="danish_model", fallback="fallback_model") + >>> mc1.get_model("da") # "danish_model" + >>> mc1.get_model("de") # "fallback_model" + >>> mc2 = ModelChoice(da="danish_model") + >>> mc2.get_model("de") # throws error + >>> mc3 = ModelChoice(da=lambda: "danish_model") + >>> mc3.get_model("da") # "danish_model" + """ + + def __init__(self, **choices: Any): + self.models = choices + + def get_model(self, language: str): + if language not in self.models: + error = f"Language '{language}' not supported!" + if "fallback" in self.models: + logging.warning(error + " Using fallback model.") + language = "fallback" + else: + raise ValueError(error) + model = self.models[language] + if callable(model): # if supplier function + model = model() + logging.debug("Using '%s' model: %s", language, model) + return model diff --git a/src/conspiracies/corpusprocessing/umap_hdb.py b/src/conspiracies/corpusprocessing/umap_hdb.py index 4a7c96e..7724df5 100644 --- a/src/conspiracies/corpusprocessing/umap_hdb.py +++ b/src/conspiracies/corpusprocessing/umap_hdb.py @@ -1,5 +1,5 @@ import json -from typing import Tuple, List, Dict, Optional, Union +from typing import Tuple, List, Dict, Optional, Union, Set import os import spacy from umap import UMAP @@ -12,6 +12,8 @@ import random import argparse +from conspiracies.common.modelchoice import ModelChoice + def read_txt(path: str): with open(path, mode="r", encoding="utf8") as f: @@ -39,7 +41,7 @@ def triplet_from_line(line: str) -> Union[Tuple[str, str, str], None]: def filter_triplets_with_stopwords( triplets: List[Tuple[str, str, str]], - stopwords: List[str], + stopwords: Set[str], soft: bool = True, ) -> List[Tuple[str, str, str]]: """Filters triplets that contain a stopword. @@ -69,6 +71,7 @@ def filter_triplets_with_stopwords( def load_triplets( file_path: str, + language: str = "danish", soft_filtering: bool = True, shuffle: bool = True, ) -> Tuple[list, list, list, list]: @@ -84,7 +87,6 @@ def load_triplets( objects: List of objects filtered_triplets: List of filtered triplets """ - triplets_list: List[Tuple[str, str, str]] = [] data = read_txt(file_path) triplets_list = [ triplet_from_line(line) @@ -93,7 +95,7 @@ def load_triplets( ] filtered_triplets = filter_triplets_with_stopwords( triplets_list, - get_stop_words("danish"), + set(get_stop_words(language)), soft=soft_filtering, ) @@ -290,8 +292,9 @@ def label_clusters( def embed_and_cluster( list_to_embed: List[str], - embedding_model: str = "vesteinn/DanskBERT", - n_dimensions: int = 40, + language: str, + embedding_model: str, + n_dimensions: int = None, n_neighbors: int = 15, min_cluster_size: int = 5, min_samples: int = 3, @@ -302,7 +305,10 @@ def embed_and_cluster( Args: list_to_embed: List of strings to embed and cluster - n_dimensions: Number of dimensions to reduce the embedding space to + language: language for SpaCy pipeline for cluster labeling + embedding_model: model name or path, refer to + https://www.sbert.net/docs/pretrained_models.html + n_dimensions: Number of dimensions to reduce the embedding space to, None to skip n_neighbors: Number of neighbors to use for UMAP min_cluster_size: Minimum cluster size for HDBscan min_samples: Minimum number of samples for HDBscan @@ -316,12 +322,16 @@ def embed_and_cluster( embedding_model = SentenceTransformer(embedding_model) - # Embed and reduce embdding space - print("Embedding and reducing embedding space") + print("Embedding") embeddings = embedding_model.encode(list_to_embed) # type: ignore scaled_embeddings = StandardScaler().fit_transform(embeddings) - reducer = UMAP(n_components=n_dimensions, n_neighbors=n_neighbors) - reduced_embeddings = reducer.fit_transform(scaled_embeddings) + + if n_dimensions is not None: + print("Reducing embedding space") + reducer = UMAP(n_components=n_dimensions, n_neighbors=n_neighbors) + reduced_embeddings = reducer.fit_transform(scaled_embeddings) + else: + reduced_embeddings = scaled_embeddings # Cluster with HDBscan print("Clustering") @@ -335,7 +345,8 @@ def embed_and_cluster( # Label and prune clusters print("Labeling clusters") - nlp = spacy.load("da_core_news_sm") + model = ModelChoice(da="da_core_news_sm", en="en_core_web_sm").get_model(language) + nlp = spacy.load(model) labeled_clusters = label_clusters( clusters, nlp, @@ -420,7 +431,8 @@ def create_nodes_and_edges( def main( path: str, - embedding_model: str, + language: str, + embedding_model: str = None, dim=40, n_neighbors=15, min_cluster_size=5, @@ -428,10 +440,19 @@ def main( min_topic_size=20, save: bool = False, ): + # figure out embedding model if not given explicitly + if embedding_model is None: + embedding_model = ModelChoice( + da="vesteinn/DanskBERT", + en="all-MiniLM-L6-v2", + fallback="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + ).get_model(language) + # Load triplets print("Loading triplets") subjects, predicates, objects, filtered_triplets = load_triplets( path, + language, soft_filtering=True, shuffle=True, ) @@ -443,12 +464,6 @@ def main( f"_clust={min_cluster_size}_samp={min_samples}_nodes_edges.json", ) # type: ignore - model = ( - "vesteinn/DanskBERT" - if embedding_model == "danskBERT" - else "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - ) - print( f"Dimensions: {dim}, neighbors: {n_neighbors}, min cluster size: " f"{min_cluster_size}, samples: {min_samples}, min topic size: {min_topic_size}", @@ -458,7 +473,8 @@ def main( # For predicate, we wanna keep all clusters -> min_topic_size=1 predicate_clusters = embed_and_cluster( list_to_embed=predicates, - embedding_model=model, + language=language, + embedding_model=embedding_model, n_dimensions=dim, n_neighbors=n_neighbors, min_cluster_size=min_cluster_size, @@ -472,7 +488,8 @@ def main( subj_obj = subjects + objects subj_obj_clusters = embed_and_cluster( list_to_embed=subj_obj, - embedding_model=model, + language=language, + embedding_model=embedding_model, n_dimensions=dim, n_neighbors=n_neighbors, min_cluster_size=min_cluster_size, @@ -503,13 +520,21 @@ def main( help="Event to cluster. Must include name of source folder (newspapers or " "twitter) and event", ) + parser.add_argument( + "-lang", + "--language", + type=str, + default="paraphrase", + help="Choice of language for embedding model (if not specified) and stop " + "words filtering", + ) parser.add_argument( "-emb", "--embedding_model", type=str, - default="paraphrase", - help="""Which embedding model to use, default is paraphrase. - The other option is danskBERT""", + default=None, + help="Which embedding model to use. Automatically determined via language if " + "not given.", ) parser.add_argument( "-dim", @@ -552,6 +577,7 @@ def main( main( path, embedding_model=args.embedding_model, + language=args.language, dim=args.n_dimensions, n_neighbors=args.n_neighbors, min_cluster_size=args.min_cluster_size, diff --git a/src/conspiracies/docprocessing/coref/coref_component.py b/src/conspiracies/docprocessing/coref/coref_component.py index 0ef6106..c486913 100644 --- a/src/conspiracies/docprocessing/coref/coref_component.py +++ b/src/conspiracies/docprocessing/coref/coref_component.py @@ -9,6 +9,7 @@ from spacy.tokens import Doc, Span from spacy.util import minibatch +from conspiracies.common.modelchoice import ModelChoice from conspiracies.docprocessing.coref import CoreferenceModel @@ -18,16 +19,25 @@ def __init__( vocab: Vocab, name: str, model_path: Union[Path, str, None], + language: str, device: int, open_unverified_connection: bool, ): self.name = name self.vocab = vocab - self.model = CoreferenceModel( # type: ignore - model_path=model_path, - device=device, - open_unverified_connection=open_unverified_connection, - ) + if model_path is None: + self.model = ModelChoice( + da=lambda: CoreferenceModel.danish( + device=device, + open_unverified_connection=open_unverified_connection, + ), + en=lambda: CoreferenceModel.english(device=device), + ).get_model(language) + else: + self.model = CoreferenceModel( # type: ignore + model_path=model_path, + device=device, + ) # Register custom extension on the Doc and Span if not Doc.has_extension("resolve_coref"): @@ -190,8 +200,8 @@ def create_coref_component( Args: nlp (Language): A spacy language pipeline name (str): The name of the component - model_path (Union[Path, str, None]): Path to the model, if None, the - model will be downloaded to the default cache directory. + model_path(Union[Path, str, None], optional): Path to the model, if None, a + model will be downloaded according to the language of the pipeline. device (int, optional): Cuda device. If >= 0 will use the corresponding GPU, below 0 is CPU. Defaults to -1. open_unverified_connection (bool, optional): Should you download the model from @@ -200,11 +210,11 @@ def create_coref_component( Returns: CorefenceComponent: The coreference model component """ - return CoreferenceComponent( nlp.vocab, name=name, model_path=model_path, + language=nlp.lang, device=device, open_unverified_connection=open_unverified_connection, ) diff --git a/src/conspiracies/docprocessing/coref/coref_model.py b/src/conspiracies/docprocessing/coref/coref_model.py index 6240982..a8949b2 100644 --- a/src/conspiracies/docprocessing/coref/coref_model.py +++ b/src/conspiracies/docprocessing/coref/coref_model.py @@ -24,7 +24,10 @@ class CoreferenceModel(Predictor): Args: model_path(Union[Path, str, None], optional): Path to the model, if None, the - model will be downloaded to the default cache directory. + model will be downloaded according to default specifications of the language + parameter. + language: language of the (default) model to load. Has no effect if + model_path is given. Options: da, en device(int, optional): Cuda device. If >= 0 will use the corresponding GPU, below 0 is CPU. Defaults to -1. open_unverified_connection (bool, optional): Should you download from an @@ -38,16 +41,9 @@ def __init__( self, model_path: Union[Path, str, None] = None, device: int = -1, - open_unverified_connection: bool = False, - **kwargs, ) -> None: - if model_path is None: - model_path = download_model( - "da_coref_twitter_v1", - open_unverified_connection=open_unverified_connection, - ) - archive = load_archive(model_path, cuda_device=device, **kwargs) + archive = load_archive(model_path, cuda_device=device) config = archive.config prepare_environment(config) dataset_reader = archive.validation_dataset_reader @@ -62,3 +58,19 @@ def predict_batch_docs(self, docs: List[Doc]) -> List[Instance]: """Convert a list of docs to Instance and predict the batch.""" instances = [self._doc_to_instance(doc) for doc in docs] return self.predict_batch_instance(instances) + + @classmethod + def danish(cls, device: int = -1, open_unverified_connection: bool = False): + model_path = download_model( + "da_coref_twitter_v1", + open_unverified_connection=open_unverified_connection, + ) + return cls(model_path=model_path, device=device) + + @classmethod + def english(cls, device: int = -1): + model_path = ( + "https://storage.googleapis.com/allennlp-public-models/" + "coref-spanbert-large-2021.03.10.tar.gz" + ) + return cls(model_path=model_path, device=device) diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index 1862507..9bb1fa6 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -6,12 +6,13 @@ from tqdm import tqdm from conspiracies import docs_to_jsonl +from conspiracies.common.modelchoice import ModelChoice from conspiracies.document import Document, text_with_context, remove_context class DocProcessor: def _build_coref_pipeline(self): - nlp_coref = spacy.blank("da") + nlp_coref = spacy.blank(self.language) nlp_coref.add_pipe("sentencizer") nlp_coref.add_pipe("allennlp_coref") @@ -25,7 +26,10 @@ def warn_error(proc_name, proc, docs, e): return nlp_coref def _build_triplet_extraction_pipeline(self): - nlp = spacy.load("da_core_news_sm") + model = ModelChoice(da="da_core_news_sm", en="en_core_web_sm").get_model( + self.language, + ) + nlp = spacy.load(model) nlp.add_pipe( "heads_extraction", config={"normalize_to_entity": True, "normalize_to_noun_chunk": True}, @@ -67,9 +71,16 @@ def _build_triplet_extraction_pipeline(self): ) return nlp - def __init__(self, triplet_extraction="multi2oie"): + def __init__( + self, + language="da", + batch_size=25, + triplet_extraction_method="multi2oie", + ): + self.language = language + self.batch_size = batch_size self.coref_pipeline = self._build_coref_pipeline() - self.triplet_extraction_component = triplet_extraction + self.triplet_extraction_component = triplet_extraction_method self.triplet_extraction_pipeline = self._build_triplet_extraction_pipeline() def process_docs( @@ -77,7 +88,6 @@ def process_docs( docs: Iterable[Document], output_path: str, continue_from_last=False, - batch_size=25, ): if continue_from_last and os.path.exists(output_path): with jsonlines.open(output_path) as annotated_docs: @@ -91,7 +101,7 @@ def process_docs( # extreme memory pressure, hence the small batch size coref_resolved_docs = self.coref_pipeline.pipe( ((text_with_context(doc), doc["id"]) for doc in docs), - batch_size=batch_size, + batch_size=self.batch_size, as_tuples=True, ) @@ -100,7 +110,7 @@ def process_docs( (remove_context(doc._.resolve_coref), id_) for doc, id_ in coref_resolved_docs ), - batch_size=batch_size * 4, + batch_size=self.batch_size, as_tuples=True, ) diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index def888f..41fdd09 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -6,7 +6,7 @@ class BaseConfig(BaseModel): project_name: str - output_root: str + output_root: str = "output" language: str @@ -16,17 +16,39 @@ class StepConfig(BaseModel): class PreProcessingConfig(StepConfig): input_path: str - doc_type: str - metadata_fields: set[str] + n_docs: int = None + doc_type: str = "text" + metadata_fields: set[str] = ["*"] extra: dict = {} class DocProcessingConfig(StepConfig): - triplet_extraction_method: str + batch_size: int = 25 + continue_from_last: bool = True + triplet_extraction_method: str = "multi2oie" + + +class ClusteringThresholds(BaseModel): + min_cluster_size: int + min_samples: int + min_topic_size: int + + @classmethod + def estimate_from_n_triplets(cls, n_triplets: int): + factor = n_triplets / 1000 + thresholds = cls( + min_cluster_size=int(factor + 1), + min_samples=int(factor + 1), + min_topic_size=int(factor * 2 + 1), + ) + return thresholds class CorpusProcessingConfig(StepConfig): - pass + dimensions: int = None + n_neighbors: int = 15 + embedding_model: str = None + thresholds: ClusteringThresholds = None class PipelineConfig(BaseModel): @@ -53,13 +75,28 @@ def update_nested_dict(d: dict[str, Any], path: str, value: Any) -> None: d = d.setdefault(key, {}) d[keys[-1]] = value + @staticmethod + def full_update_nested_dict(d: dict[str, Any], values: dict[str, Any]): + for path, value in values.items(): + if value is None: + continue + PipelineConfig.update_nested_dict(d, path, value) + @classmethod def from_toml_file(cls, path: str, extra_config: dict = None): with open(path, "r") as file: config_data = toml.load(file) if extra_config: - for path, value in extra_config.items(): - if value is None: - continue - PipelineConfig.update_nested_dict(config_data, path, value) + PipelineConfig.full_update_nested_dict(config_data, extra_config) + return cls(**config_data) + + @classmethod + def default_with_extra_config(cls, extra_config: dict): + config_data = { + "base": {}, + "preprocessing": {}, + "docprocessing": {}, + "corpusprocessing": {}, + } + PipelineConfig.full_update_nested_dict(config_data, extra_config) return cls(**config_data) diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index 8bc6f6c..86536a2 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -5,7 +5,7 @@ from conspiracies.corpusprocessing import umap_hdb from conspiracies.docprocessing.docprocessor import DocProcessor from conspiracies.document import Document -from conspiracies.pipeline.config import PipelineConfig +from conspiracies.pipeline.config import PipelineConfig, ClusteringThresholds from conspiracies.preprocessing.csv import CsvPreprocessor from conspiracies.preprocessing.infomedia import InfoMediaPreprocessor from conspiracies.preprocessing.preprocessor import Preprocessor @@ -30,7 +30,9 @@ def run(self): self.preprocessing() if self.config.docprocessing.enabled: - self.docprocessing(continue_from_last=True) + self.docprocessing( + continue_from_last=self.config.docprocessing.continue_from_last, + ) if self.config.corpusprocessing.enabled: self.corpusprocessing() @@ -63,11 +65,14 @@ def preprocessing(self) -> None: preprocessor.preprocess_docs( self.input_path, f"{self.output_path}/preprocessed.ndjson", + n_docs=self.config.preprocessing.n_docs, ) def _get_docprocessor(self) -> DocProcessor: return DocProcessor( - triplet_extraction=self.config.docprocessing.triplet_extraction_method, + language=self.config.base.language, + batch_size=self.config.docprocessing.batch_size, + triplet_extraction_method=self.config.docprocessing.triplet_extraction_method, ) def docprocessing(self, continue_from_last=False): @@ -87,7 +92,7 @@ def corpusprocessing(self): # TODO: this process is kind of dumb with all the writing and reading of # files etc., but for now just make it work. It comes from individual scripts # and a lot of the logic of data structures happen in those read/writes. Also, - # a lot of data that we might want to save is thrown away, e.g. clsutered + # a lot of data that we might want to save is thrown away, e.g. clustered # entities, or calculated/fetched on the go, e.g. node weights for graphs. docs = ( json.loads(line) @@ -95,6 +100,7 @@ def corpusprocessing(self): f"{self.output_path}/annotations.ndjson", ) ) + n_triplets = 0 with open( f"{self.output_path}/triplets.csv", "w+", @@ -106,11 +112,22 @@ def corpusprocessing(self): for field in ("subject", "predicate", "object") ] print(", ".join(triplet_fields), file=out) + n_triplets += 1 + + if self.config.corpusprocessing.thresholds is None: + thresholds = ClusteringThresholds.estimate_from_n_triplets(n_triplets) + else: + thresholds = self.config.corpusprocessing.thresholds + umap_hdb.main( f"{self.output_path}/triplets.csv", - "danskBERT", - dim=40, + self.config.base.language, + dim=self.config.corpusprocessing.dimensions, + n_neighbors=self.config.corpusprocessing.n_neighbors, save=f"{self.output_path}/nodes_edges.json", + min_cluster_size=thresholds.min_cluster_size, + min_topic_size=thresholds.min_topic_size, + min_samples=thresholds.min_samples, ) nodes, edges = get_nodes_edges( f"{self.output_path}/", diff --git a/src/conspiracies/preprocessing/preprocessor.py b/src/conspiracies/preprocessing/preprocessor.py index 86221c2..1a848c4 100644 --- a/src/conspiracies/preprocessing/preprocessor.py +++ b/src/conspiracies/preprocessing/preprocessor.py @@ -33,8 +33,8 @@ def _filter_metadata( del metadata[key] yield doc + @staticmethod def _validate_content( - self, preprocessed_docs: Iterator[Document], ) -> Iterator[Document]: for doc in preprocessed_docs: @@ -47,9 +47,11 @@ def _validate_content( else: yield doc - def preprocess_docs(self, input_path: str, output_path: str): + def preprocess_docs(self, input_path: str, output_path: str, n_docs: int = None): preprocessed_docs = self._do_preprocess_docs(input_path) validated = self._validate_content(preprocessed_docs) + if n_docs and n_docs > 0: + validated = (d for i, d in enumerate(validated) if i < n_docs) metadata_filtered = self._filter_metadata(validated) with open(output_path, "w+") as out_file: ndjson.dump(metadata_filtered, out_file) diff --git a/run.py b/src/conspiracies/run.py similarity index 58% rename from run.py rename to src/conspiracies/run.py index 4baf017..d859b7e 100644 --- a/run.py +++ b/src/conspiracies/run.py @@ -1,5 +1,5 @@ import argparse - +import logging from conspiracies.pipeline.config import PipelineConfig from conspiracies.pipeline.pipeline import Pipeline @@ -21,19 +21,44 @@ help="Input path for preprocessing of documents. Can be a glob path, e.g." "path/to/files/*.txt. Be mindful of quotes for glob paths.", ) + arg_parser.add_argument( + "--language", + "-l", + default=None, + help="Language of models and word lists.", + ) + arg_parser.add_argument( + "--n_docs", + "-n", + default=None, + help="Max number of documents to output from preprocessing.", + ) arg_parser.add_argument( "-c", "--config", - default="config/default.toml", + default=None, help="Path to configuration file. Refer to config/template.toml for contents.", ) + arg_parser.add_argument( + "--root-log-level", + default="WARN", + help="Level of root logger.", + ) args = arg_parser.parse_args() + logging.getLogger().setLevel(args.root_log_level) + cli_args = { "base.project_name": args.project_name, + "base.language": args.language, "preprocessing.input_path": args.input_path, + "preprocessing.n_docs": args.n_docs, } - config = PipelineConfig.from_toml_file(args.config, cli_args) + + if args.config: + config = PipelineConfig.from_toml_file(args.config, cli_args) + else: + config = PipelineConfig.default_with_extra_config(cli_args) pipeline = Pipeline(config) pipeline.run() diff --git a/tests/test_coref_model.py b/tests/test_coref_model.py index 6424b3b..8fcbf9c 100644 --- a/tests/test_coref_model.py +++ b/tests/test_coref_model.py @@ -1,10 +1,10 @@ -from .utils import nlp_da # noqa +from .utils import nlp_da, nlp_en # noqa from conspiracies.docprocessing.coref import CoreferenceModel -def test_CoreferenceModel(nlp_da): # noqa - model = CoreferenceModel() # check that the model loads as intended +def test_da_coref_model(nlp_da): # noqa + model = CoreferenceModel.danish() # check that the model loads as intended text = [ "Hej Kenneth, har du en fed teksts vi kan skrive om dig?", @@ -18,3 +18,20 @@ def test_CoreferenceModel(nlp_da): # noqa # test output format for output in outputs: assert isinstance(output, dict) + + +def test_en_coref_model(nlp_en): # noqa + model = CoreferenceModel.english() + + text = [ + "Luke Skywalker is from a moisture farm on Tatooine. " + "He flees when the Empire comes to find C-3PO and R2-D2.", + ] + docs = nlp_en.pipe(text) + + # test batches forward + outputs = model.predict_batch_docs(docs) + + # test output format + for output in outputs: + assert isinstance(output, dict) diff --git a/tests/test_pipelineconfig.py b/tests/test_pipelineconfig.py index d96bbe6..21b1d61 100644 --- a/tests/test_pipelineconfig.py +++ b/tests/test_pipelineconfig.py @@ -39,3 +39,9 @@ def test_config_loading(path: str): ), corpusprocessing=CorpusProcessingConfig(), ) + + +def test_update_nested_dict(): + some_dict = {"a": {"b": 1}} + PipelineConfig.update_nested_dict(some_dict, "a.b", 2) + assert some_dict["a"]["b"] == 2