Skip to content

Commit

Permalink
Kf/timestamped triplets (#73)
Browse files Browse the repository at this point in the history
* Passing on timestamp of documents to triplet output

* Also outputting alternative labels for graph nodes and edges

* Showing alternative labels of nodes on selection

* Filtering on dates now possible

* A better view of node information

* Map lookup of nodes from graph service in visualizer
  • Loading branch information
KasperFyhn authored Jun 12, 2024
1 parent d9d5188 commit 25aa99d
Show file tree
Hide file tree
Showing 21 changed files with 335 additions and 87 deletions.
1 change: 1 addition & 0 deletions config/eschatology.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ doc_type = "csv"
[preprocessing.extra]
id_column = "id"
text_column = "body"
timestamp_column = "timestamp"

[docprocessing]
enabled = true
Expand Down
3 changes: 2 additions & 1 deletion src/conspiracies/common/fileutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ def iter_lines_of_files(glob_pattern: Union[str, Path]):
for file in files:
with open(file) as f:
for line in f:
yield line
if line.strip():
yield line
123 changes: 99 additions & 24 deletions src/conspiracies/corpusprocessing/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from collections import Counter
from typing import List, TypedDict, Dict, Union, Callable, Iterable, Tuple
from collections import Counter, defaultdict
from datetime import datetime
from typing import (
List,
TypedDict,
Dict,
Union,
Callable,
Iterable,
Tuple,
Optional,
Any,
Mapping,
)

from pydantic import BaseModel

from conspiracies.corpusprocessing.clustering import Mappings
from conspiracies.corpusprocessing.triplet import Triplet
from conspiracies.corpusprocessing.triplet import Triplet, TripletField


def min_max_normalizer(values: Iterable[Union[int, float]]) -> Callable[[float], float]:
Expand All @@ -21,20 +33,64 @@ class StatsEntry(TypedDict):
key: Union[str, Tuple[str]]
frequency: int
norm_frequency: float
docs: Optional[list[str]]
first_occurrence: Optional[datetime]
last_occurrence: Optional[datetime]
alt_labels: Optional[list[str]]


class StatsDict(Dict[str, StatsEntry]):
pass

@classmethod
def from_counter(cls, counter: Counter):
def from_iterable(
cls,
generator: Iterable[Tuple[Any, Optional[str], Optional[datetime]]],
alt_labels: Mapping[str, list[str]] = None,
):
counter = Counter()
docs = defaultdict(set)
first_occurrence = {}
last_occurrence = {}
for key, doc_id, timestamp in generator:
counter[key] += 1

if doc_id:
docs[key].add(doc_id)

if timestamp:
if key in first_occurrence:
first_occurrence[key] = min(timestamp, first_occurrence[key])
else:
first_occurrence[key] = timestamp

if key in last_occurrence:
last_occurrence[key] = max(timestamp, last_occurrence[key])
else:
last_occurrence[key] = timestamp

normalizer = min_max_normalizer(counter.values())

return cls(
{
key: StatsEntry(
key=key,
frequency=value,
norm_frequency=normalizer(value),
docs=list(docs[key]) if key in docs else None,
first_occurrence=(
first_occurrence[key].isoformat()
if key in first_occurrence
else None
),
last_occurrence=(
last_occurrence[key].isoformat()
if key in last_occurrence
else None
),
alt_labels=(
alt_labels[key] if alt_labels and key in alt_labels else None
),
)
for key, value in counter.items()
},
Expand Down Expand Up @@ -63,27 +119,46 @@ def aggregate(
triplets: List[Triplet],
remove_identical_subj_and_obj: bool = True,
):
mapped_triplets = [
(
self._mappings.map_entity(triplet.subject.text),
self._mappings.map_predicate(triplet.predicate.text),
self._mappings.map_entity(triplet.object.text),
)
for triplet in triplets
]
if self._mappings is not None:
triplets = [
Triplet(
subject=TripletField(
text=self._mappings.map_entity(t.subject.text),
),
predicate=TripletField(
text=self._mappings.map_predicate(t.predicate.text),
),
object=TripletField(text=self._mappings.map_entity(t.object.text)),
doc=t.doc,
timestamp=t.timestamp,
)
for t in triplets
]

if remove_identical_subj_and_obj:
mapped_triplets = [t for t in mapped_triplets if t[0] != t[2]]
triplet_counts = Counter(triplet for triplet in mapped_triplets)
entity_counts = Counter(
entity for triplet in mapped_triplets for entity in (triplet[0], triplet[2])
)
# entity_distinct_triplets = Counter(
# entity for triplet in set(mapped_triplets) for entity in (triplet[0], triplet[2])
# )
triplets = [t for t in triplets if t.subject.text != t.object.text]

predicate_counts = Counter(triplet[1] for triplet in mapped_triplets)
return TripletStats(
triplets=StatsDict.from_counter(triplet_counts),
entities=StatsDict.from_counter(entity_counts),
predicates=StatsDict.from_counter(predicate_counts),
triplets=StatsDict.from_iterable(
(
(
(t.subject.text, t.predicate.text, t.object.text),
t.doc,
t.timestamp,
)
for t in triplets
)
),
entities=StatsDict.from_iterable(
(
(entity, t.doc, t.timestamp)
for t in triplets
for entity in [t.subject.text, t.object.text]
),
self._mappings.entity_alt_labels(),
),
predicates=StatsDict.from_iterable(
((t.predicate.text, t.doc, t.timestamp) for t in triplets),
self._mappings.predicate_alt_labels(),
),
)
22 changes: 18 additions & 4 deletions src/conspiracies/corpusprocessing/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,21 @@ class Mappings(BaseModel):
def map_entity(self, entity: str):
return self.entities[entity] if entity in self.entities else entity

def entity_alt_labels(self):
alt_labels = defaultdict(list)
for entity, label in self.entities.items():
alt_labels[label].append(entity)
return alt_labels

def map_predicate(self, predicate: str):
return self.predicates[predicate] if predicate in self.predicates else predicate

def predicate_alt_labels(self):
alt_labels = defaultdict(list)
for predicate, label in self.predicates.items():
alt_labels[label].append(predicate)
return alt_labels


class Clustering:
def __init__(
Expand Down Expand Up @@ -120,10 +132,12 @@ def _cluster(
list(clusters.values()),
get_combine_key=lambda t: t[0].text,
)
merged = self._combine_clusters(
merged,
get_combine_key=lambda t: t[0].head,
)

# too risky with false positives from this
# merged = self._combine_clusters(
# merged,
# get_combine_key=lambda t: t[0].head,
# )

# sort by how "prototypical" a member is in the cluster
for cluster in merged:
Expand Down
13 changes: 9 additions & 4 deletions src/conspiracies/corpusprocessing/triplet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Optional, Set, Iterator, Iterable, List, Union

from jsonlines import jsonlines
from pydantic import BaseModel
from stop_words import get_stop_words

Expand All @@ -24,6 +24,7 @@ class Triplet(BaseModel):
predicate: TripletField
object: TripletField
doc: Optional[str]
timestamp: Optional[datetime]

def fields(self):
return self.subject, self.predicate, self.object
Expand Down Expand Up @@ -54,12 +55,16 @@ def filter_on_stopwords(
@classmethod
def from_annotated_docs(cls, path: Path) -> Iterator["Triplet"]:
return (
cls(**triplet_data, doc=json_data["id"])
cls(
**triplet_data,
doc=json_data.get("id", None),
timestamp=json_data.get("timestamp", None),
)
for json_data in (json.loads(line) for line in iter_lines_of_files(path))
for triplet_data in json_data["semantic_triplets"]
)

@staticmethod
def write_jsonl(path: Union[str, Path], triplets: Iterable["Triplet"]):
with jsonlines.open(path, "w") as out:
out.write_all(map(lambda triplet: triplet.dict(), triplets))
with open(path, "w") as out:
print(*(t.json() for t in triplets), file=out, sep="\n")
21 changes: 18 additions & 3 deletions src/conspiracies/docprocessing/doc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,35 @@
DocTriplets,
SpanTriplet,
)
from conspiracies.document import Document


def _doc_to_json(doc: Union[Doc, Tuple[Doc, str]], include_span_heads=True):
def _doc_to_json(
doc: Union[Doc, Tuple[Doc, Union[str, Document]]],
include_span_heads=True,
):
if isinstance(doc, Tuple):
doc, id_ = doc
if isinstance(doc[1], str):
doc, id_ = doc
timestamp = None
elif isinstance(doc[1], Document):
doc, src_doc = doc
id_ = src_doc.id
timestamp = src_doc.timestamp.isoformat()
else:
raise TypeError(f"Unexpected input type {type(doc[1])}")
else:
id_ = None
timestamp = None
if Doc.has_extension("relation_triplets"):
triplets = doc._.relation_triplets
else:
triplets = []
json = doc.to_json()
if id_ is not None:
json["id"] = id_
if timestamp is not None:
json["timestamp"] = timestamp
json["semantic_triplets"] = [
triplet.to_dict(include_doc=False, include_span_heads=include_span_heads)
for triplet in triplets
Expand All @@ -46,7 +61,7 @@ def _doc_from_json(json: dict, nlp: Language) -> Doc:


def docs_to_jsonl(
docs: Iterable[Union[Doc, Tuple[Doc, str]]],
docs: Iterable[Union[Doc, Tuple[Doc, Union[str, Document]]]],
path: Union[Path, str],
append=False,
include_span_heads=True,
Expand Down
8 changes: 4 additions & 4 deletions src/conspiracies/docprocessing/docprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,20 @@ def process_docs(
annotated_doc["id"] for annotated_doc in annotated_docs
}
print(f"Skipping {len(already_processed)} processed docs.")
docs = (doc for doc in docs if doc["id"] not in already_processed)
docs = (doc for doc in docs if doc.id not in already_processed)

# The coreference pipeline tends to choke on too large batches because of an
# extreme memory pressure, hence the small batch size
coref_resolved_docs = self.coref_pipeline.pipe(
((text_with_context(doc), doc["id"]) for doc in docs),
((text_with_context(src_doc), src_doc) for src_doc in docs),
batch_size=self.batch_size,
as_tuples=True,
)

with_triplets = self.triplet_extraction_pipeline.pipe(
(
(remove_context(doc._.resolve_coref), id_)
for doc, id_ in coref_resolved_docs
(remove_context(doc._.resolve_coref), src_doc)
for doc, src_doc in coref_resolved_docs
),
batch_size=self.batch_size,
as_tuples=True,
Expand Down
13 changes: 8 additions & 5 deletions src/conspiracies/document.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from typing import TypedDict, Optional
from datetime import datetime
from typing import Optional
from pydantic import BaseModel


class Document(TypedDict):
class Document(BaseModel):
id: str
metadata: dict
text: str
context: Optional[str]
timestamp: Optional[datetime]


CONTEXT_END_MARKER = "[CONTEXT_END]"


def text_with_context(doc: Document) -> str:
if doc["context"] is None:
return doc["text"]
return "\n".join([doc["context"], CONTEXT_END_MARKER, doc["text"]])
if doc.context is None:
return doc.text
return "\n".join([doc.context, CONTEXT_END_MARKER, doc.text])


def remove_context(text: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/conspiracies/pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ClusteringThresholds(BaseModel):
def estimate_from_n_triplets(cls, n_triplets: int):
factor = n_triplets / 1000
thresholds = cls(
min_cluster_size=int(factor + 1),
min_cluster_size=max(int(factor + 1), 2),
min_samples=int(factor + 1),
)
return thresholds
Expand Down
2 changes: 1 addition & 1 deletion src/conspiracies/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def docprocessing(self, continue_from_last=False):
docprocessor = self._get_docprocessor()
docprocessor.process_docs(
(
json.loads(line, object_hook=lambda d: Document(**d))
Document(**json.loads(line))
for line in iter_lines_of_files(
self.output_path / "preprocessed.ndjson",
)
Expand Down
Loading

0 comments on commit 25aa99d

Please sign in to comment.