From 25aa99df9ea988d7ff148bf5c293d213f893d6a9 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn <42898679+KasperFyhn@users.noreply.github.com> Date: Wed, 12 Jun 2024 09:00:56 +0200 Subject: [PATCH] Kf/timestamped triplets (#73) * Passing on timestamp of documents to triplet output * Also outputting alternative labels for graph nodes and edges * Showing alternative labels of nodes on selection * Filtering on dates now possible * A better view of node information * Map lookup of nodes from graph service in visualizer --- config/eschatology.toml | 1 + src/conspiracies/common/fileutils.py | 3 +- .../corpusprocessing/aggregation.py | 123 ++++++++++++++---- .../corpusprocessing/clustering.py | 22 +++- src/conspiracies/corpusprocessing/triplet.py | 13 +- src/conspiracies/docprocessing/doc_utils.py | 21 ++- .../docprocessing/docprocessor.py | 8 +- src/conspiracies/document.py | 13 +- src/conspiracies/pipeline/config.py | 2 +- src/conspiracies/pipeline/pipeline.py | 2 +- src/conspiracies/preprocessing/csv.py | 4 + src/conspiracies/preprocessing/infomedia.py | 9 +- .../preprocessing/preprocessor.py | 11 +- src/conspiracies/preprocessing/text.py | 1 + src/conspiracies/preprocessing/tweets.py | 9 +- tests/test_preprocessing.py | 5 +- .../src/graph/GraphFilterControlPanel.tsx | 57 ++++---- .../{GraphService.tsx => GraphService.ts} | 46 ++++++- visualizer/src/graph/GraphViewer.tsx | 24 +++- visualizer/src/graph/NodeInfo.tsx | 34 +++++ visualizer/src/graph/graph.css | 14 +- 21 files changed, 335 insertions(+), 87 deletions(-) rename visualizer/src/graph/{GraphService.tsx => GraphService.ts} (71%) create mode 100644 visualizer/src/graph/NodeInfo.tsx diff --git a/config/eschatology.toml b/config/eschatology.toml index db913e6..7aa4c7b 100644 --- a/config/eschatology.toml +++ b/config/eschatology.toml @@ -8,6 +8,7 @@ doc_type = "csv" [preprocessing.extra] id_column = "id" text_column = "body" +timestamp_column = "timestamp" [docprocessing] enabled = true diff --git a/src/conspiracies/common/fileutils.py b/src/conspiracies/common/fileutils.py index e235b54..1661505 100644 --- a/src/conspiracies/common/fileutils.py +++ b/src/conspiracies/common/fileutils.py @@ -16,4 +16,5 @@ def iter_lines_of_files(glob_pattern: Union[str, Path]): for file in files: with open(file) as f: for line in f: - yield line + if line.strip(): + yield line diff --git a/src/conspiracies/corpusprocessing/aggregation.py b/src/conspiracies/corpusprocessing/aggregation.py index b529cad..0fe055c 100644 --- a/src/conspiracies/corpusprocessing/aggregation.py +++ b/src/conspiracies/corpusprocessing/aggregation.py @@ -1,10 +1,22 @@ -from collections import Counter -from typing import List, TypedDict, Dict, Union, Callable, Iterable, Tuple +from collections import Counter, defaultdict +from datetime import datetime +from typing import ( + List, + TypedDict, + Dict, + Union, + Callable, + Iterable, + Tuple, + Optional, + Any, + Mapping, +) from pydantic import BaseModel from conspiracies.corpusprocessing.clustering import Mappings -from conspiracies.corpusprocessing.triplet import Triplet +from conspiracies.corpusprocessing.triplet import Triplet, TripletField def min_max_normalizer(values: Iterable[Union[int, float]]) -> Callable[[float], float]: @@ -21,20 +33,64 @@ class StatsEntry(TypedDict): key: Union[str, Tuple[str]] frequency: int norm_frequency: float + docs: Optional[list[str]] + first_occurrence: Optional[datetime] + last_occurrence: Optional[datetime] + alt_labels: Optional[list[str]] class StatsDict(Dict[str, StatsEntry]): pass @classmethod - def from_counter(cls, counter: Counter): + def from_iterable( + cls, + generator: Iterable[Tuple[Any, Optional[str], Optional[datetime]]], + alt_labels: Mapping[str, list[str]] = None, + ): + counter = Counter() + docs = defaultdict(set) + first_occurrence = {} + last_occurrence = {} + for key, doc_id, timestamp in generator: + counter[key] += 1 + + if doc_id: + docs[key].add(doc_id) + + if timestamp: + if key in first_occurrence: + first_occurrence[key] = min(timestamp, first_occurrence[key]) + else: + first_occurrence[key] = timestamp + + if key in last_occurrence: + last_occurrence[key] = max(timestamp, last_occurrence[key]) + else: + last_occurrence[key] = timestamp + normalizer = min_max_normalizer(counter.values()) + return cls( { key: StatsEntry( key=key, frequency=value, norm_frequency=normalizer(value), + docs=list(docs[key]) if key in docs else None, + first_occurrence=( + first_occurrence[key].isoformat() + if key in first_occurrence + else None + ), + last_occurrence=( + last_occurrence[key].isoformat() + if key in last_occurrence + else None + ), + alt_labels=( + alt_labels[key] if alt_labels and key in alt_labels else None + ), ) for key, value in counter.items() }, @@ -63,27 +119,46 @@ def aggregate( triplets: List[Triplet], remove_identical_subj_and_obj: bool = True, ): - mapped_triplets = [ - ( - self._mappings.map_entity(triplet.subject.text), - self._mappings.map_predicate(triplet.predicate.text), - self._mappings.map_entity(triplet.object.text), - ) - for triplet in triplets - ] + if self._mappings is not None: + triplets = [ + Triplet( + subject=TripletField( + text=self._mappings.map_entity(t.subject.text), + ), + predicate=TripletField( + text=self._mappings.map_predicate(t.predicate.text), + ), + object=TripletField(text=self._mappings.map_entity(t.object.text)), + doc=t.doc, + timestamp=t.timestamp, + ) + for t in triplets + ] + if remove_identical_subj_and_obj: - mapped_triplets = [t for t in mapped_triplets if t[0] != t[2]] - triplet_counts = Counter(triplet for triplet in mapped_triplets) - entity_counts = Counter( - entity for triplet in mapped_triplets for entity in (triplet[0], triplet[2]) - ) - # entity_distinct_triplets = Counter( - # entity for triplet in set(mapped_triplets) for entity in (triplet[0], triplet[2]) - # ) + triplets = [t for t in triplets if t.subject.text != t.object.text] - predicate_counts = Counter(triplet[1] for triplet in mapped_triplets) return TripletStats( - triplets=StatsDict.from_counter(triplet_counts), - entities=StatsDict.from_counter(entity_counts), - predicates=StatsDict.from_counter(predicate_counts), + triplets=StatsDict.from_iterable( + ( + ( + (t.subject.text, t.predicate.text, t.object.text), + t.doc, + t.timestamp, + ) + for t in triplets + ) + ), + entities=StatsDict.from_iterable( + ( + (entity, t.doc, t.timestamp) + for t in triplets + for entity in [t.subject.text, t.object.text] + ), + self._mappings.entity_alt_labels(), + ), + predicates=StatsDict.from_iterable( + ((t.predicate.text, t.doc, t.timestamp) for t in triplets), + self._mappings.predicate_alt_labels(), + ), ) diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index ea666da..cd7fbcf 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -20,9 +20,21 @@ class Mappings(BaseModel): def map_entity(self, entity: str): return self.entities[entity] if entity in self.entities else entity + def entity_alt_labels(self): + alt_labels = defaultdict(list) + for entity, label in self.entities.items(): + alt_labels[label].append(entity) + return alt_labels + def map_predicate(self, predicate: str): return self.predicates[predicate] if predicate in self.predicates else predicate + def predicate_alt_labels(self): + alt_labels = defaultdict(list) + for predicate, label in self.predicates.items(): + alt_labels[label].append(predicate) + return alt_labels + class Clustering: def __init__( @@ -120,10 +132,12 @@ def _cluster( list(clusters.values()), get_combine_key=lambda t: t[0].text, ) - merged = self._combine_clusters( - merged, - get_combine_key=lambda t: t[0].head, - ) + + # too risky with false positives from this + # merged = self._combine_clusters( + # merged, + # get_combine_key=lambda t: t[0].head, + # ) # sort by how "prototypical" a member is in the cluster for cluster in merged: diff --git a/src/conspiracies/corpusprocessing/triplet.py b/src/conspiracies/corpusprocessing/triplet.py index d31bf2a..5d87b17 100644 --- a/src/conspiracies/corpusprocessing/triplet.py +++ b/src/conspiracies/corpusprocessing/triplet.py @@ -1,8 +1,8 @@ import json +from datetime import datetime from pathlib import Path from typing import Optional, Set, Iterator, Iterable, List, Union -from jsonlines import jsonlines from pydantic import BaseModel from stop_words import get_stop_words @@ -24,6 +24,7 @@ class Triplet(BaseModel): predicate: TripletField object: TripletField doc: Optional[str] + timestamp: Optional[datetime] def fields(self): return self.subject, self.predicate, self.object @@ -54,12 +55,16 @@ def filter_on_stopwords( @classmethod def from_annotated_docs(cls, path: Path) -> Iterator["Triplet"]: return ( - cls(**triplet_data, doc=json_data["id"]) + cls( + **triplet_data, + doc=json_data.get("id", None), + timestamp=json_data.get("timestamp", None), + ) for json_data in (json.loads(line) for line in iter_lines_of_files(path)) for triplet_data in json_data["semantic_triplets"] ) @staticmethod def write_jsonl(path: Union[str, Path], triplets: Iterable["Triplet"]): - with jsonlines.open(path, "w") as out: - out.write_all(map(lambda triplet: triplet.dict(), triplets)) + with open(path, "w") as out: + print(*(t.json() for t in triplets), file=out, sep="\n") diff --git a/src/conspiracies/docprocessing/doc_utils.py b/src/conspiracies/docprocessing/doc_utils.py index 3447616..48e1733 100644 --- a/src/conspiracies/docprocessing/doc_utils.py +++ b/src/conspiracies/docprocessing/doc_utils.py @@ -12,13 +12,26 @@ DocTriplets, SpanTriplet, ) +from conspiracies.document import Document -def _doc_to_json(doc: Union[Doc, Tuple[Doc, str]], include_span_heads=True): +def _doc_to_json( + doc: Union[Doc, Tuple[Doc, Union[str, Document]]], + include_span_heads=True, +): if isinstance(doc, Tuple): - doc, id_ = doc + if isinstance(doc[1], str): + doc, id_ = doc + timestamp = None + elif isinstance(doc[1], Document): + doc, src_doc = doc + id_ = src_doc.id + timestamp = src_doc.timestamp.isoformat() + else: + raise TypeError(f"Unexpected input type {type(doc[1])}") else: id_ = None + timestamp = None if Doc.has_extension("relation_triplets"): triplets = doc._.relation_triplets else: @@ -26,6 +39,8 @@ def _doc_to_json(doc: Union[Doc, Tuple[Doc, str]], include_span_heads=True): json = doc.to_json() if id_ is not None: json["id"] = id_ + if timestamp is not None: + json["timestamp"] = timestamp json["semantic_triplets"] = [ triplet.to_dict(include_doc=False, include_span_heads=include_span_heads) for triplet in triplets @@ -46,7 +61,7 @@ def _doc_from_json(json: dict, nlp: Language) -> Doc: def docs_to_jsonl( - docs: Iterable[Union[Doc, Tuple[Doc, str]]], + docs: Iterable[Union[Doc, Tuple[Doc, Union[str, Document]]]], path: Union[Path, str], append=False, include_span_heads=True, diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index e3cf0c9..1fa8e07 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -106,20 +106,20 @@ def process_docs( annotated_doc["id"] for annotated_doc in annotated_docs } print(f"Skipping {len(already_processed)} processed docs.") - docs = (doc for doc in docs if doc["id"] not in already_processed) + docs = (doc for doc in docs if doc.id not in already_processed) # The coreference pipeline tends to choke on too large batches because of an # extreme memory pressure, hence the small batch size coref_resolved_docs = self.coref_pipeline.pipe( - ((text_with_context(doc), doc["id"]) for doc in docs), + ((text_with_context(src_doc), src_doc) for src_doc in docs), batch_size=self.batch_size, as_tuples=True, ) with_triplets = self.triplet_extraction_pipeline.pipe( ( - (remove_context(doc._.resolve_coref), id_) - for doc, id_ in coref_resolved_docs + (remove_context(doc._.resolve_coref), src_doc) + for doc, src_doc in coref_resolved_docs ), batch_size=self.batch_size, as_tuples=True, diff --git a/src/conspiracies/document.py b/src/conspiracies/document.py index ddc2825..1f461c9 100644 --- a/src/conspiracies/document.py +++ b/src/conspiracies/document.py @@ -1,20 +1,23 @@ -from typing import TypedDict, Optional +from datetime import datetime +from typing import Optional +from pydantic import BaseModel -class Document(TypedDict): +class Document(BaseModel): id: str metadata: dict text: str context: Optional[str] + timestamp: Optional[datetime] CONTEXT_END_MARKER = "[CONTEXT_END]" def text_with_context(doc: Document) -> str: - if doc["context"] is None: - return doc["text"] - return "\n".join([doc["context"], CONTEXT_END_MARKER, doc["text"]]) + if doc.context is None: + return doc.text + return "\n".join([doc.context, CONTEXT_END_MARKER, doc.text]) def remove_context(text: str) -> str: diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index 7c3bcfc..6458eb8 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -37,7 +37,7 @@ class ClusteringThresholds(BaseModel): def estimate_from_n_triplets(cls, n_triplets: int): factor = n_triplets / 1000 thresholds = cls( - min_cluster_size=int(factor + 1), + min_cluster_size=max(int(factor + 1), 2), min_samples=int(factor + 1), ) return thresholds diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index 968e686..0da29cf 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -87,7 +87,7 @@ def docprocessing(self, continue_from_last=False): docprocessor = self._get_docprocessor() docprocessor.process_docs( ( - json.loads(line, object_hook=lambda d: Document(**d)) + Document(**json.loads(line)) for line in iter_lines_of_files( self.output_path / "preprocessed.ndjson", ) diff --git a/src/conspiracies/preprocessing/csv.py b/src/conspiracies/preprocessing/csv.py index 0e53fc9..5bc59cd 100644 --- a/src/conspiracies/preprocessing/csv.py +++ b/src/conspiracies/preprocessing/csv.py @@ -14,12 +14,14 @@ def __init__( id_column: str = None, text_column: str = None, context_column: str = None, + timestamp_column: str = None, delimiter=",", metadata_fields: Iterable[str] = ("*",), ): self.id_column = id_column self.text_column = text_column self.context_column = context_column + self.timestamp_column = timestamp_column self.non_metadata_columns = { self.id_column, self.text_column, @@ -34,6 +36,7 @@ def _read_lines(self, lines: Iterable[str]) -> Iterator[str]: id_ = row[self.id_column] text = row[self.text_column] context = row[self.context_column] if self.context_column else None + timestamp = row[self.timestamp_column] if self.timestamp_column else None metadata = { k: v for k, v in row.items() if k not in self.non_metadata_columns } @@ -42,6 +45,7 @@ def _read_lines(self, lines: Iterable[str]) -> Iterator[str]: text=text, context=context, metadata=metadata, + timestamp=timestamp, ) def _do_preprocess_docs(self, input_path: Union[str, Path]) -> Iterator[str]: diff --git a/src/conspiracies/preprocessing/infomedia.py b/src/conspiracies/preprocessing/infomedia.py index c9b91aa..8a07414 100644 --- a/src/conspiracies/preprocessing/infomedia.py +++ b/src/conspiracies/preprocessing/infomedia.py @@ -32,7 +32,14 @@ def process_line(self, line: str): metadata = {k: obj[k] for k in InfoMediaPreprocessor.METADATA_KEYS} text = self.create_text(obj) - return Document(id=doc_id, metadata=metadata, text=text, context=None) + # TODO: get timestamp from metadata + return Document( + id=doc_id, + metadata=metadata, + text=text, + context=None, + timestamp=None, + ) @staticmethod def create_text(doc_obj: dict): diff --git a/src/conspiracies/preprocessing/preprocessor.py b/src/conspiracies/preprocessing/preprocessor.py index 861b8ef..e23f30c 100644 --- a/src/conspiracies/preprocessing/preprocessor.py +++ b/src/conspiracies/preprocessing/preprocessor.py @@ -2,8 +2,6 @@ from pathlib import Path from typing import Iterator, Iterable -import ndjson - from conspiracies.document import Document @@ -28,7 +26,7 @@ def _filter_metadata( yield doc for doc in preprocessed_docs: - metadata = doc["metadata"] + metadata = doc.metadata for key in list(metadata.keys()): if key not in self.metadata_fields: del metadata[key] @@ -39,10 +37,10 @@ def _validate_content( preprocessed_docs: Iterator[Document], ) -> Iterator[Document]: for doc in preprocessed_docs: - if not doc["text"]: + if not doc.text: logging.warning( "Skipping doc with id '%s' because of empty text field", - doc["id"], + doc.id, ) continue else: @@ -54,5 +52,6 @@ def preprocess_docs(self, input_path: Path, output_path: Path, n_docs: int = Non if n_docs and n_docs > 0: validated = (d for i, d in enumerate(validated) if i < n_docs) metadata_filtered = self._filter_metadata(validated) + with output_path.open("w+") as out_file: - ndjson.dump(metadata_filtered, out_file) + print(*(d.json() for d in metadata_filtered), file=out_file, sep="\n") diff --git a/src/conspiracies/preprocessing/text.py b/src/conspiracies/preprocessing/text.py index 426e8aa..771edd9 100644 --- a/src/conspiracies/preprocessing/text.py +++ b/src/conspiracies/preprocessing/text.py @@ -22,4 +22,5 @@ def _do_preprocess_docs(self, glob_pattern: str): text=text, metadata={}, context=None, + timestamp=None, ) diff --git a/src/conspiracies/preprocessing/tweets.py b/src/conspiracies/preprocessing/tweets.py index 7644e31..7afc710 100644 --- a/src/conspiracies/preprocessing/tweets.py +++ b/src/conspiracies/preprocessing/tweets.py @@ -44,4 +44,11 @@ def _do_preprocess_docs(self, glob_pattern: Union[str, Path]) -> Iterable[str]: metadata = {k: v for k, v in tweet.items() if k != "text"} text = tweet["text"] context = "\n".join(t["text"] for t in context_tweets) - yield Document(id=doc_id, metadata=metadata, text=text, context=context) + # TODO: get timestamp from metadata + yield Document( + id=doc_id, + metadata=metadata, + text=text, + context=context, + timestamp=None, + ) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 9a3d705..a44d2db 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -20,6 +20,7 @@ def docs() -> Iterator[Document]: "two": "something else", }, context=None, + timestamp=None, ), ) return (d for d in docs) @@ -30,7 +31,7 @@ def test_metadata_filtering_removes_field(docs): filtered = list(preprocessor._filter_metadata(docs)) assert len(filtered) == 2 for doc in filtered: - assert "one" in doc["metadata"] and "two" not in doc["metadata"] + assert "one" in doc.metadata and "two" not in doc.metadata def test_metadata_filtering_retains_all(docs): @@ -38,4 +39,4 @@ def test_metadata_filtering_retains_all(docs): filtered = list(preprocessor._filter_metadata(docs)) assert len(filtered) == 2 for doc in filtered: - assert all(key in doc["metadata"] for key in ("one", "two")) + assert all(key in doc.metadata for key in ("one", "two")) diff --git a/visualizer/src/graph/GraphFilterControlPanel.tsx b/visualizer/src/graph/GraphFilterControlPanel.tsx index c975154..12210a1 100644 --- a/visualizer/src/graph/GraphFilterControlPanel.tsx +++ b/visualizer/src/graph/GraphFilterControlPanel.tsx @@ -12,43 +12,56 @@ export const GraphFilterControlPanel = ({graphFilter, setGraphFilter}: GraphFilt return
Frequency: {stats.frequency}
+Norm. frequency: {stats.norm_frequency?.toPrecision(3)}
+ {stats.first_occurrence &&Earliest date: {stats.first_occurrence}
} + {stats.last_occurrence &&Latest date: {stats.last_occurrence}
} + {stats.alt_labels && +