diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 40d3d76..c2efc26 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,9 @@ repos: rev: v0.5.7 hooks: - id: ruff + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v3.1.0 + hooks: + - id: prettier + files: "visualizer/.*" diff --git a/config/eschatology.toml b/config/eschatology.toml index 7aa4c7b..d66f4f2 100644 --- a/config/eschatology.toml +++ b/config/eschatology.toml @@ -2,7 +2,7 @@ language = "en" [preprocessing] -enabled = true +enabled = false doc_type = "csv" [preprocessing.extra] @@ -11,9 +11,14 @@ text_column = "body" timestamp_column = "timestamp" [docprocessing] -enabled = true -batch_size = 5 -prefer_gpu_for_coref = false +enabled = false +batch_size = 50 +prefer_gpu_for_coref = true +n_process = 1 [corpusprocessing] -enabled = true \ No newline at end of file +enabled = false + +[databasepopulation] +enabled = true +clear_and_write = true \ No newline at end of file diff --git a/config/template.toml b/config/template.toml index cd76efd..9515e1d 100644 --- a/config/template.toml +++ b/config/template.toml @@ -12,13 +12,14 @@ metadata_fields = ["*"] [preprocessing.extra] # specific extra arguments for your preprocessor, e.g. context length for tweets or -# or field specification for CSVs +# field specification for CSVs [docprocessing] enabled = true batch_size = 25 continue_from_last = true triplet_extraction_method = "multi2oie/prompting" +n_process = 1 # can be set to 2 or more for multiprocess ofloading to GPU; otherwise might not make sense [corpusprocessing] enabled = true @@ -27,6 +28,6 @@ dimensions = 100 # leave out to skip dimensionality reduction n_neighbors = 15 # used for dimensionality reduction [corpusprocessing.thresholds] # leave out for automatic estimation -min_cluster_size = 3 # unused if auto_thresholds is true -min_samples = 3 # unused if auto_thresholds is true -min_topic_size = 5 # unused if auto_thresholds is true \ No newline at end of file +min_label_occurrence = 3 +min_cluster_size = 3 +min_samples = 3 \ No newline at end of file diff --git a/docs/tutorials/overview.ipynb b/docs/tutorials/overview.ipynb index affc822..29ff5d0 100644 --- a/docs/tutorials/overview.ipynb +++ b/docs/tutorials/overview.ipynb @@ -71,7 +71,7 @@ " assert isinstance(sent._.coref_clusters[0], tuple)\n", " assert isinstance(sent._.coref_clusters[0][0], int)\n", " assert isinstance(sent._.coref_clusters[0][1], Span)\n", - " sent._.resolve_coref # get resolved coref" + " sent._.resolved_text # get resolved coref" ] }, { diff --git a/paper/extract_triplets_newspapers.py b/paper/extract_triplets_newspapers.py index 95ded68..745d85f 100644 --- a/paper/extract_triplets_newspapers.py +++ b/paper/extract_triplets_newspapers.py @@ -88,7 +88,7 @@ def process_file( # Resolve coreference coref_docs = nlp_coref.pipe(normalized_article) - resolved_docs = (d._.resolve_coref for d in coref_docs) + resolved_docs = (d._.resolved_text for d in coref_docs) # Extract relations docs = nlp.pipe(resolved_docs) diff --git a/paper/extract_triplets_tweets.py b/paper/extract_triplets_tweets.py index 77bcc8a..503835d 100644 --- a/paper/extract_triplets_tweets.py +++ b/paper/extract_triplets_tweets.py @@ -99,7 +99,7 @@ def concat_resolve_unconcat_contexts(file_path: str): coref_nlp = build_coref_pipeline() coref_docs = coref_nlp.pipe(context_tweets) - resolved_docs = (d._.resolve_coref for d in coref_docs) + resolved_docs = (d._.resolved_text for d in coref_docs) resolved_tweets = (tweet_from_context_text(tweet) for tweet in resolved_docs) return resolved_tweets @@ -240,7 +240,7 @@ def prompt_gpt3( for i, batch in enumerate(batch_generator(concatenated_tweets, batch_size)): start = time.time() coref_docs = coref_nlp.pipe(batch) - resolved_docs = (d._.resolve_coref for d in coref_docs) + resolved_docs = (d._.resolved_text for d in coref_docs) resolved_target_tweets = ( tweet_from_context_text(tweet) for tweet in resolved_docs ) diff --git a/pyproject.toml b/pyproject.toml index 9cb42b8..ad9e2b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,9 @@ dependencies = [ "sentence-transformers", "stop-words", "bs4", - "toml" + "toml", + "fastcoref", + "sqlalchemy" ] [project.license] @@ -94,6 +96,7 @@ content-type = "text/markdown" "prompt_relation_extraction" = "conspiracies.docprocessing.relationextraction.gptprompting:create_prompt_relation_extraction_component" "relation_extractor" = "conspiracies.docprocessing.relationextraction.multi2oie:make_relation_extractor" "allennlp_coref" = "conspiracies.docprocessing.coref:create_coref_component" +"safe_fastcoref" = "conspiracies.docprocessing.coref.safefastcoref:create_safe_fastcoref" "heads_extraction" = "conspiracies.docprocessing.headwordextraction:create_headwords_component" diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index 05deb3e..caa5b94 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -1,12 +1,15 @@ -from collections import defaultdict -from typing import List, Callable, Any, Hashable, Dict +import math +import os +from collections import defaultdict, Counter +from pathlib import Path +from typing import List, Callable, Any, Hashable, Dict, Union import networkx import numpy as np from hdbscan import HDBSCAN from pydantic import BaseModel from sentence_transformers import SentenceTransformer -from sklearn.preprocessing import StandardScaler +from tqdm import tqdm from umap import UMAP from conspiracies.common.modelchoice import ModelChoice @@ -45,6 +48,7 @@ def __init__( min_cluster_size: int = 5, min_samples: int = 3, embedding_model: str = None, + cache_location: Path = None, ): self.language = language self.n_dimensions = n_dimensions @@ -52,6 +56,9 @@ def __init__( self.min_cluster_size = min_cluster_size self.min_samples = min_samples self._embedding_model = embedding_model + self.cache_location = cache_location + if self.cache_location is not None: + os.makedirs(self.cache_location, exist_ok=True) def _get_embedding_model(self): # figure out embedding model if not given explicitly @@ -97,53 +104,88 @@ def _combine_clusters( return merged_clusters - def _cluster( + def _cluster_via_embeddings( self, - fields: List[TripletField], + labels: List[str], + cache_name: str = None, + show_progress: bool = True, ): - model = self._get_embedding_model() - print("Creating embeddings:") - embeddings = model.encode( - [field.text for field in fields], - show_progress_bar=True, + emb_cache = ( + Path(self.cache_location, f"embeddings-{cache_name}.npy") + if self.cache_location and cache_name + else None ) - embeddings = StandardScaler().fit_transform(embeddings) + if emb_cache and emb_cache.exists(): + print( + "Reusing cached embeddings! Delete cache if this is not supposed to happen.", + ) + embeddings = np.load(emb_cache) + else: + model = self._get_embedding_model() + + counter = Counter((field for field in labels)) + condensed = [ + field + for field, count in counter.items() + for _ in range(math.ceil(count / 1000)) + ] + embeddings = model.encode( + condensed, + normalize_embeddings=True, + show_progress_bar=show_progress, + ) + if emb_cache: + np.save(emb_cache, embeddings) if self.n_dimensions is not None: - print("Reducing embedding space") - reducer = UMAP(n_components=self.n_dimensions, n_neighbors=self.n_neighbors) - embeddings = reducer.fit_transform(embeddings) + reduced_emb_cache = ( + Path( + self.cache_location, + f"embeddings-{cache_name}-red{self.n_dimensions}.npy", + ) + if self.cache_location and cache_name + else None + ) + if reduced_emb_cache and reduced_emb_cache.exists(): + print( + "Reusing cached reduced embeddings! Delete cache if this is not supposed to happen.", + ) + embeddings = np.load(reduced_emb_cache) + else: + print("Reducing embedding space ...") + reducer = UMAP( + n_components=self.n_dimensions, + n_neighbors=self.n_neighbors, + ) + embeddings = reducer.fit_transform(embeddings) + if self.cache_location: + np.save(reduced_emb_cache, embeddings) - print("Clustering ...") hdbscan_model = HDBSCAN( min_cluster_size=self.min_cluster_size, + max_cluster_size=self.min_cluster_size + * 10, # somewhat arbitrary, mostly to avoid mega clusters that suck up everything min_samples=self.min_samples, ) hdbscan_model.fit(embeddings) clusters = defaultdict(list) for field, embedding, label, probability in zip( - fields, + labels, embeddings, hdbscan_model.labels_, hdbscan_model.probabilities_, ): # skip noise and low confidence - if label == -1 or probability < 0.1: + if label == -1 or probability < 0.5: continue clusters[label].append((field, embedding)) merged = self._combine_clusters( list(clusters.values()), - get_combine_key=lambda t: t[0].text, + get_combine_key=lambda t: t[0], ) - # too risky with false positives from this - # merged = self._combine_clusters( - # merged, - # get_combine_key=lambda t: t[0].head, - # ) - # sort by how "prototypical" a member is in the cluster for cluster in merged: mean = np.mean(np.stack([t[1] for t in cluster]), axis=0) @@ -153,11 +195,69 @@ def _cluster( return [[t[0] for t in cluster] for cluster in merged] @staticmethod - def _mapping_to_first_member(clusters: List[List[TripletField]]) -> Dict[str, str]: + def _cluster_via_normalization( + labels: List[str], + top: Union[int, float] = 1.0, + restrictive_labels=True, + ) -> List[List[str]]: + counter = Counter((label for label in labels)) + if isinstance(top, float): + top = int(top * len(counter)) + + norm_map = { + label: " " + + label.lower() + + " " # surrounding spaces avoids matches like evil <-> devil + for label in counter.keys() + } + cluster_map = { + label: [] + for label, count in counter.most_common(top) + # FIXME: hack due to lack of NER and lemmas at the time of writing + if not restrictive_labels + or len(label) >= 4 + and label[0].isupper() + or len(label.split()) > 1 + } + + for label in counter.keys(): + norm_label = norm_map[label] + matches = [ + substring + for substring in cluster_map.keys() + if norm_map[substring] in norm_label + ] + if not matches: + continue + + best_match = min( + matches, + key=lambda substring: len(norm_map[substring]), + ) + if best_match != label: + cluster_map[best_match].append(label) + + clusters = [ + [main_label] + alt_labels + for main_label, alt_labels in cluster_map.items() + if alt_labels + ] + return clusters + + @staticmethod + def _mapping_to_first_member( + clusters: List[List[Union[TripletField, str]]], + ) -> Dict[str, str]: + def get_text(member: Union[TripletField, str]): + if isinstance(member, TripletField): + return member.text + else: + return member + return { - member: cluster[0].text + member: get_text(cluster[0]) for cluster in clusters - for member in set(member.text for member in cluster) + for member in set(get_text(member) for member in cluster) } def create_mappings(self, triplets: List[Triplet]) -> Mappings: @@ -166,10 +266,33 @@ def create_mappings(self, triplets: List[Triplet]) -> Mappings: entities = subjects + objects predicates = [triplet.predicate for triplet in triplets] + # FIXME: clustering gets way to aggressive for many triplets + # print("Creating mappings for entities") + # entity_clusters = self._cluster(entities, "entities") + # print("Creating mappings for predicates") + # predicate_clusters = self._cluster(predicates, "predicates") + print("Creating mappings for entities") - entity_clusters = self._cluster(entities) + entity_clusters = self._cluster_via_normalization( + [e.text for e in entities], + 0.2, + ) + entity_clusters = [ + sub_cluster + for cluster in tqdm(entity_clusters, desc="Creating sub-clusters") + for sub_cluster in ( + self._cluster_via_embeddings(cluster, show_progress=False) + if len(cluster) > 10 + else [cluster] + ) + ] + print("Creating mappings for predicates") - predicate_clusters = self._cluster(predicates) + predicate_clusters = self._cluster_via_normalization( + [p.text for p in predicates], + top=0.2, + restrictive_labels=False, + ) mappings = Mappings( entities=self._mapping_to_first_member(entity_clusters), diff --git a/src/conspiracies/corpusprocessing/triplet.py b/src/conspiracies/corpusprocessing/triplet.py index 5d87b17..8d2d9b3 100644 --- a/src/conspiracies/corpusprocessing/triplet.py +++ b/src/conspiracies/corpusprocessing/triplet.py @@ -1,4 +1,5 @@ import json +from collections import Counter, defaultdict from datetime import datetime from pathlib import Path from typing import Optional, Set, Iterator, Iterable, List, Union @@ -11,6 +12,8 @@ class TripletField(BaseModel): text: str + start_char: Optional[int] + end_char: Optional[int] head: Optional[str] def clear_head_if_blacklist_match(self, blacklist: Set[str]): @@ -52,6 +55,30 @@ def filter_on_stopwords( if not triplet.has_blacklist_match(stopwords) ] + @staticmethod + def filter_on_entity_label_frequency( + triplets: Iterable["Triplet"], + min_frequency: int, + min_doc_frequency: int = 1, + ): + entity_label_counter = Counter( + f.text for triplet in triplets for f in (triplet.subject, triplet.object) + ) + docs = defaultdict(set) + for triplet in triplets: + for f in (triplet.subject, triplet.object): + docs[f.text].add(triplet.doc) + doc_frequency = {label: len(docs) for label, docs in docs.items()} + + filtered = [ + triplet + for triplet in triplets + if entity_label_counter[triplet.subject.text] >= min_frequency + and entity_label_counter[triplet.object.text] >= min_frequency + and doc_frequency[triplet.subject.text] >= min_doc_frequency + ] + return filtered + @classmethod def from_annotated_docs(cls, path: Path) -> Iterator["Triplet"]: return ( diff --git a/src/conspiracies/database/__init__.py b/src/conspiracies/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/conspiracies/database/engine.py b/src/conspiracies/database/engine.py new file mode 100644 index 0000000..288bb8c --- /dev/null +++ b/src/conspiracies/database/engine.py @@ -0,0 +1,20 @@ +from pathlib import Path + +from sqlalchemy import create_engine, Engine +from sqlalchemy.orm import Session + +from conspiracies.database.models import Base + + +def get_engine(filepath: Path): + engine = create_engine("sqlite:///" + filepath.as_posix()) + return engine + + +def setup_database(engine: Engine): + Base.metadata.create_all(engine) + + +def get_session(engine: Engine = None) -> Session: + session = Session(bind=engine) + return session diff --git a/src/conspiracies/database/models.py b/src/conspiracies/database/models.py new file mode 100644 index 0000000..cf6b5d1 --- /dev/null +++ b/src/conspiracies/database/models.py @@ -0,0 +1,138 @@ +from sqlalchemy import ( + Column, + Integer, + String, + ForeignKey, + Text, + DateTime, +) +from sqlalchemy.orm import declarative_base, relationship, Session + +Base = declarative_base() + + +class EntityOrm(Base): + __tablename__ = "entities" + id = Column(Integer, primary_key=True, autoincrement=True) + label = Column(String, nullable=False, index=True) + supernode_id = Column(Integer, ForeignKey("entities.id"), nullable=True) + + # Relationships + supernode = relationship( + "EntityOrm", + back_populates="subnodes", + remote_side="EntityOrm.id", + foreign_keys="EntityOrm.supernode_id", + ) + subnodes = relationship("EntityOrm", back_populates="supernode") + + +class RelationOrm(Base): + __tablename__ = "relations" + id = Column(Integer, primary_key=True, autoincrement=True) + label = Column(String, nullable=False, index=True) + subject_id = Column(Integer, ForeignKey("entities.id"), nullable=True) + object_id = Column(Integer, ForeignKey("entities.id"), nullable=True) + + subject = relationship( + "EntityOrm", + foreign_keys="RelationOrm.subject_id", + ) + object = relationship( + "EntityOrm", + foreign_keys="RelationOrm.object_id", + ) + + +class TripletOrm(Base): + __tablename__ = "triplets" + id = Column(Integer, primary_key=True, autoincrement=True) + doc_id = Column(Integer, ForeignKey("docs.id"), nullable=False) + subject_id = Column(Integer, ForeignKey("entities.id"), nullable=False) + relation_id = Column(Integer, ForeignKey("relations.id"), nullable=False) + object_id = Column(Integer, ForeignKey("entities.id"), nullable=False) + subj_span_start = Column(Integer, nullable=True) + subj_span_end = Column(Integer, nullable=True) + pred_span_start = Column(Integer, nullable=True) + pred_span_end = Column(Integer, nullable=True) + obj_span_start = Column(Integer, nullable=True) + obj_span_end = Column(Integer, nullable=True) + + # Relationships + subject_entity = relationship( + "EntityOrm", + foreign_keys="TripletOrm.subject_id", + ) + predicate_relation = relationship( + "RelationOrm", + foreign_keys="TripletOrm.relation_id", + ) + object_entity = relationship( + "EntityOrm", + foreign_keys="TripletOrm.object_id", + ) + document = relationship( + "DocumentOrm", + foreign_keys="TripletOrm.doc_id", + back_populates="triplets", + ) + + # TODO: this should be here, but sometimes we see duplicates. Why? + # __table_args__ = (UniqueConstraint( + # 'doc_id', + # 'subject_entity_id', + # 'predicate_relation_id', + # 'object_entity_id', + # name='unique_triplet_constraint' + # ),) + + +class DocumentOrm(Base): + __tablename__ = "docs" + id = Column(Integer, primary_key=True, autoincrement=True) + text = Column(Text, nullable=False) + orig_text = Column(Text, nullable=True) + timestamp = Column(DateTime) + + # Relationships + triplets = relationship("TripletOrm", back_populates="document") + + +class ModelLookupCache: + + def __init__(self, session: Session): + self._entities = {e.label: e for e in session.query(EntityOrm).all()} + self._relations = { + (int(r.subject_id), str(r.label), int(r.object_id)): r # noqa + for r in session.query(RelationOrm).all() + } + + def get_or_create_entity(self, label, session): + """Fetch an entity by label, or create it if it doesn't exist.""" + entity = self._entities.get(label, None) + if entity is None: + entity = EntityOrm(label=label) + session.add(entity) + session.flush() # Get the ID immediately + self._entities[label] = entity # noqa + return entity.id + + def get_or_create_relation( + self, + subject_id: int, + object_id: int, + label: str, + session: Session, + ): + """Fetch a relation by label, or create it if it doesn't exist.""" + relation = self._relations.get((subject_id, label, object_id), None) + if relation is None: + relation = RelationOrm( + label=label, + subject_id=subject_id, + object_id=object_id, + ) + session.add(relation) + session.flush() # Get the ID immediately + self._relations[(subject_id, label, object_id)] = relation # noqa + return relation.id diff --git a/src/conspiracies/docprocessing/coref/coref_component.py b/src/conspiracies/docprocessing/coref/coref_component.py index c486913..dd4ca1c 100644 --- a/src/conspiracies/docprocessing/coref/coref_component.py +++ b/src/conspiracies/docprocessing/coref/coref_component.py @@ -40,10 +40,10 @@ def __init__( ) # Register custom extension on the Doc and Span - if not Doc.has_extension("resolve_coref"): - Doc.set_extension("resolve_coref", getter=self.resolve_coref_doc) - if not Span.has_extension("resolve_coref"): - Span.set_extension("resolve_coref", getter=self.resolve_coref_span) + if not Doc.has_extension("resolved_text"): + Doc.set_extension("resolved_text", getter=self.resolved_text_doc) + if not Span.has_extension("resolved_text"): + Span.set_extension("resolved_text", getter=self.resolved_text_span) if not Doc.has_extension("coref_clusters"): Doc.set_extension("coref_clusters", default=list()) if not Span.has_extension("coref_clusters"): @@ -51,7 +51,7 @@ def __init__( if not Span.has_extension("antecedent"): Span.set_extension("antecedent", default=None) - def resolve_coref_doc(self, doc: Doc) -> str: + def resolved_text_doc(self, doc: Doc) -> str: """Resolve the coreference clusters by replacing each entity with the antecedent. The antecedent is the first entity that appears in the cluster. This is for the whole doc. @@ -73,7 +73,7 @@ def resolve_coref_doc(self, doc: Doc) -> str: resolved[i] = "" return "".join(resolved) - def resolve_coref_span(self, sent: Span) -> str: + def resolved_text_span(self, sent: Span) -> str: """Resolve the coreference clusters by replacing each entity with the antecedent. The antecedent is the first entity that appears in the cluster. This is for the the sent. diff --git a/src/conspiracies/docprocessing/coref/safefastcoref.py b/src/conspiracies/docprocessing/coref/safefastcoref.py new file mode 100644 index 0000000..318ca03 --- /dev/null +++ b/src/conspiracies/docprocessing/coref/safefastcoref.py @@ -0,0 +1,88 @@ +from fastcoref.spacy_component import FastCorefResolver +from spacy.language import Language +from spacy.pipeline import Pipe +from typing import Iterable + +import logging + +from spacy.util import minibatch +from datasets.utils.logging import disable_progress_bar + + +disable_progress_bar() # annoying progress bar per batch +logging.getLogger("fastcoref").setLevel(logging.WARNING) + + +class SafeFastCoref(Pipe): + def __init__(self, component: FastCorefResolver): + """Initialize the wrapper with the original component.""" + self.component = component + + def pipe(self, stream: Iterable, batch_size: int = 128): + """Wrap the pipe method of the component.""" + for mb in minibatch(stream, size=batch_size): + try: + # The pipe method can fail on one document in a loop and thereby fail on all docs in that + # minibatch. However, it is made as a generator and may not show before long after the first + # documents have passed through the whole pipeline. Therefore, the minibatch is processed fully + # and then yielded. If it fails, they will be processed individually. + annotated = list( + self.component.pipe( + mb, + batch_size=batch_size, + resolve_text=True, + ), + ) + except Exception as e: + # Log the error and return the unprocessed documents + logging.error( + f"Error in SafeFastCoref pipe: {e}. Trying documents individually", + ) + annotated = [self(d) for d in mb] + yield from annotated + + def __call__(self, doc): + """Wrap the __call__ method of the component.""" + try: + return self.component(doc, resolve_text=True) + except Exception as e: + # Log the error and return the original document + logging.error(f"Error in SafeFastCoref __call__: {e}") + doc._.coref_clusters = [] + doc._.resolved_text = doc.text + return doc + + +@Language.factory( + "safe_fastcoref", + assigns=["doc._.resolved_text", "doc._.coref_clusters"], + default_config={ + "model_architecture": "FCoref", # FCoref or LingMessCoref + "model_path": "biu-nlp/f-coref", # You can specify your own trained model path + "device": None, # "cuda" or "cpu" None defaults to cuda + "max_tokens_in_batch": 10000, + "enable_progress_bar": False, + }, +) +def create_safe_fastcoref( + nlp, + name, + model_architecture: str, + model_path: str, + device, + max_tokens_in_batch: int, + enable_progress_bar: bool, +): + """Factory method to create the SafeFastCoref component.""" + # Create the original FastCorefResolver with the given configuration + fastcoref_component = FastCorefResolver( + nlp=nlp, + name=name, + model_architecture=model_architecture, + model_path=model_path, + device=device, + max_tokens_in_batch=max_tokens_in_batch, + enable_progress_bar=enable_progress_bar, + ) + # Wrap it with SafeFastCoref + return SafeFastCoref(fastcoref_component) diff --git a/src/conspiracies/docprocessing/doc_utils.py b/src/conspiracies/docprocessing/doc_utils.py index 48e1733..f3b916c 100644 --- a/src/conspiracies/docprocessing/doc_utils.py +++ b/src/conspiracies/docprocessing/doc_utils.py @@ -29,6 +29,9 @@ def _doc_to_json( timestamp = src_doc.timestamp.isoformat() else: raise TypeError(f"Unexpected input type {type(doc[1])}") + elif "doc_metadata" in doc.user_data: + id_ = doc.user_data["doc_metadata"]["id"] + timestamp = doc.user_data["doc_metadata"]["timestamp"] else: id_ = None timestamp = None diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index 1fa8e07..b8b5d39 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -1,10 +1,13 @@ +import json +import logging import os +from glob import glob from pathlib import Path -from typing import Iterable +from typing import Iterable, Tuple, Iterator import spacy import torch -from jsonlines import jsonlines +from spacy.tokens import DocBin, Doc from tqdm import tqdm from conspiracies import docs_to_jsonl @@ -16,14 +19,28 @@ class DocProcessor: def _build_coref_pipeline(self): nlp_coref = spacy.blank(self.language) nlp_coref.add_pipe("sentencizer") - nlp_coref.add_pipe( - "allennlp_coref", - config={ - "device": ( - 0 if self.prefer_gpu_for_coref and torch.cuda.is_available() else -1 - ), - }, - ) + if self.language == "en": + nlp_coref.add_pipe( + "safe_fastcoref", + config={ + "device": ( + "cuda" + if self.prefer_gpu_for_coref and torch.cuda.is_available() + else "cpu" + ), + }, + ) + elif self.language == "da": + nlp_coref.add_pipe( + "allennlp_coref", + config={ + "device": ( + 0 + if self.prefer_gpu_for_coref and torch.cuda.is_available() + else -1 + ), + }, + ) def warn_error(proc_name, proc, docs, e): print( @@ -86,13 +103,100 @@ def __init__( batch_size=25, triplet_extraction_method="multi2oie", prefer_gpu_for_coref: bool = False, + n_process: int = 1, + doc_bin_size: int = 100, ): self.language = language self.batch_size = batch_size self.prefer_gpu_for_coref = prefer_gpu_for_coref + self.n_process = n_process + if n_process > 1: + # multiprocessing and torch with multiple threads result in a deadlock, therefore: + torch.set_num_threads(1) + self.doc_bin_size = doc_bin_size self.coref_pipeline = self._build_coref_pipeline() self.triplet_extraction_component = triplet_extraction_method self.triplet_extraction_pipeline = self._build_triplet_extraction_pipeline() + self.deduplicate_processed_docs = False + + def _store_doc_bins(self, docs: Iterator[Tuple[Doc, Document]], output_path: Path): + # FIXME: paths should be given elsewhere and not be inferred like this + output_dir = Path(os.path.dirname(output_path)) / "spacy_docs" + output_dir.mkdir(parents=True, exist_ok=True) + + prev_doc_bins = glob( + (Path(os.path.dirname(output_path)) / "spacy_docs").as_posix() + "/*.bin", + ) + start_from = ( + max(int(os.path.basename(doc).replace(".bin", "")) for doc in prev_doc_bins) + if prev_doc_bins + else 0 + ) + + size = self.doc_bin_size + doc_bin = DocBin(store_user_data=True) + at_doc = start_from + for doc, src_doc in docs: + at_doc += 1 + + # FIXME: this conversion is kind of stupid, but with old pydantic this will have to work for now. + doc.user_data["doc_metadata"] = json.loads(src_doc.json()) + + doc_bin.add(doc) + if at_doc % size == 0: + with open(output_dir / f"{at_doc}.bin", "wb") as f: + f.write(doc_bin.to_bytes()) + doc_bin = DocBin(store_user_data=True) + yield doc + + if len(doc_bin) > 0: + # write final doc bin if any docs are left + with open(output_dir / f"{at_doc}.bin", "wb") as f: + f.write(doc_bin.to_bytes()) + + def _read_doc_bins(self, output_path: Path): + # FIXME: paths should be given elsewhere and not be inferred like this + count = 0 + for bin_file in glob( + (Path(os.path.dirname(output_path)) / "spacy_docs").as_posix() + "/*.bin", + ): + with open(bin_file, "rb") as bytes_data: + doc_bin = DocBin().from_bytes(bytes_data.read()) + for doc in doc_bin.get_docs(self.triplet_extraction_pipeline.vocab): + count += 1 + src_doc = Document(**doc.user_data["doc_metadata"]) + yield doc, src_doc + + def _read_deduplicated_doc_bins( + self, + output_path: Path, + processed_ids: set[str] = None, + ): + if processed_ids is None: + processed_ids = set() + + for doc, src_doc in tqdm( + self._read_doc_bins(output_path), + desc="Reading previously processed docs", + ): + if src_doc.id in processed_ids: + logging.warning(f"Duplicate processed document: {src_doc.id}") + continue + processed_ids.add(src_doc.id) + yield doc, src_doc + + def deduplicate_doc_bins(self, output_path: Path): + spacy_docs = Path(os.path.dirname(output_path)) / "spacy_docs" + old_docs = Path(os.path.dirname(output_path)) / ".old" / "spacy_docs" + old_docs.mkdir(parents=True, exist_ok=True) + orig_dir = spacy_docs.rename(old_docs) + deduplicated = spacy_docs + deduplicated.mkdir() + for _ in self._store_doc_bins( + self._read_deduplicated_doc_bins(orig_dir), + deduplicated, + ): + pass def process_docs( self, @@ -100,13 +204,22 @@ def process_docs( output_path: Path, continue_from_last=False, ): - if continue_from_last and os.path.exists(output_path): - with jsonlines.open(output_path) as annotated_docs: - already_processed = { - annotated_doc["id"] for annotated_doc in annotated_docs - } - print(f"Skipping {len(already_processed)} processed docs.") - docs = (doc for doc in docs if doc.id not in already_processed) + if self.deduplicate_processed_docs: + self.deduplicate_doc_bins(output_path) + if continue_from_last: + print( + "Reading previously processed documents! Disable 'continue_from_last' to avoid this.'", + ) + processed_ids = set() + docs_to_jsonl( + self._read_deduplicated_doc_bins( + output_path, + processed_ids=processed_ids, + ), + output_path, + ) + print(f"Read {len(processed_ids)} previously processed docs.") + docs = (d for d in docs if d.id not in processed_ids) # The coreference pipeline tends to choke on too large batches because of an # extreme memory pressure, hence the small batch size @@ -114,19 +227,26 @@ def process_docs( ((text_with_context(src_doc), src_doc) for src_doc in docs), batch_size=self.batch_size, as_tuples=True, + n_process=self.n_process, ) with_triplets = self.triplet_extraction_pipeline.pipe( ( - (remove_context(doc._.resolve_coref), src_doc) + (remove_context(doc._.resolved_text), src_doc) for doc, src_doc in coref_resolved_docs ), batch_size=self.batch_size, as_tuples=True, + n_process=self.n_process, + ) + + docs_to_output = tqdm( + self._store_doc_bins(with_triplets, output_path), + desc="Processing documents", ) docs_to_jsonl( - tqdm(d for d in with_triplets), + docs_to_output, output_path, append=continue_from_last, ) diff --git a/src/conspiracies/docprocessing/relationextraction/data_classes.py b/src/conspiracies/docprocessing/relationextraction/data_classes.py index 5370451..4ac5276 100644 --- a/src/conspiracies/docprocessing/relationextraction/data_classes.py +++ b/src/conspiracies/docprocessing/relationextraction/data_classes.py @@ -264,7 +264,9 @@ def span_to_json(span: Union[Span, Doc]) -> Dict[str, Any]: span = span[:] return { "text": span.text, + "start_char": span.start_char, "start": span.start, + "end_char": span.end_char, "end": span.end, } diff --git a/src/conspiracies/docprocessing/relationextraction/multi2oie/multi2oie_component.py b/src/conspiracies/docprocessing/relationextraction/multi2oie/multi2oie_component.py index 0574572..e3b4c57 100644 --- a/src/conspiracies/docprocessing/relationextraction/multi2oie/multi2oie_component.py +++ b/src/conspiracies/docprocessing/relationextraction/multi2oie/multi2oie_component.py @@ -89,7 +89,7 @@ def set_annotations(self, doc: Iterable[Doc], predictions: Dict) -> None: try: self.do_set_annotations(doc, predictions) except Exception as e: - self.logger.exception(e) + self.logger.error(e) def do_set_annotations(self, doc: Iterable[Doc], predictions: Dict) -> None: # get nested list of indices above confidence threshold diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index 6458eb8..9328e59 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -1,3 +1,4 @@ +import math from typing import Any import toml @@ -5,8 +6,7 @@ class BaseConfig(BaseModel): - project_name: str - output_root: str = "output" + output_path: str language: str @@ -27,27 +27,36 @@ class DocProcessingConfig(StepConfig): continue_from_last: bool = True triplet_extraction_method: str = "multi2oie" prefer_gpu_for_coref: bool = False + n_process: int = 1 + doc_bin_size: int = 100 -class ClusteringThresholds(BaseModel): +class Thresholds(BaseModel): + min_label_occurrence: int min_cluster_size: int min_samples: int @classmethod def estimate_from_n_triplets(cls, n_triplets: int): - factor = n_triplets / 1000 + # factor = n_triplets / 10_000 thresholds = cls( - min_cluster_size=max(int(factor + 1), 2), - min_samples=int(factor + 1), + min_label_occurrence=math.floor(math.log10(n_triplets)) - 1, + min_label_doc_freq=2, + min_cluster_size=2, + min_samples=2, ) return thresholds +class DatabasePopulationConfig(StepConfig): + clear_and_write: bool = False + + class CorpusProcessingConfig(StepConfig): dimensions: int = None n_neighbors: int = 15 embedding_model: str = None - thresholds: ClusteringThresholds = None + thresholds: Thresholds = None class PipelineConfig(BaseModel): @@ -55,6 +64,7 @@ class PipelineConfig(BaseModel): preprocessing: PreProcessingConfig docprocessing: DocProcessingConfig corpusprocessing: CorpusProcessingConfig + databasepopulation: DatabasePopulationConfig @staticmethod def update_nested_dict(d: dict[str, Any], path: str, value: Any) -> None: diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index 0da29cf..1774605 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -1,15 +1,23 @@ import json import os +from datetime import datetime from pathlib import Path +from tqdm import tqdm from conspiracies.common.fileutils import iter_lines_of_files from conspiracies.corpusprocessing.aggregation import TripletAggregator -from conspiracies.corpusprocessing.clustering import Clustering +from conspiracies.corpusprocessing.clustering import Clustering, Mappings from conspiracies.corpusprocessing.triplet import Triplet +from conspiracies.database.engine import get_engine, setup_database, get_session +from conspiracies.database.models import ( + TripletOrm, + DocumentOrm, + ModelLookupCache, +) from conspiracies.docprocessing.docprocessor import DocProcessor from conspiracies.document import Document -from conspiracies.pipeline.config import PipelineConfig, ClusteringThresholds +from conspiracies.pipeline.config import PipelineConfig, Thresholds from conspiracies.preprocessing.csv import CsvPreprocessor from conspiracies.preprocessing.infomedia import InfoMediaPreprocessor from conspiracies.preprocessing.preprocessor import Preprocessor @@ -23,12 +31,11 @@ class Pipeline: def __init__(self, config: PipelineConfig): - self.project_name = config.base.project_name self.input_path = Path(config.preprocessing.input_path) + self.output_path = Path(config.base.output_path) + os.makedirs(self.output_path, exist_ok=True) self.config = config print("Initialized Pipeline with config:", config) - self.output_path = Path(self.config.base.output_root, self.project_name) - os.makedirs(self.output_path, exist_ok=True) def run(self): if self.config.preprocessing.enabled: @@ -44,6 +51,9 @@ def run(self): if self.config.corpusprocessing.enabled: self.corpusprocessing() + if self.config.databasepopulation.enabled: + self.databasepopulation() + def _get_preprocessor(self) -> Preprocessor: config = self.config.preprocessing doc_type = config.doc_type.lower() @@ -81,6 +91,8 @@ def _get_docprocessor(self) -> DocProcessor: batch_size=self.config.docprocessing.batch_size, triplet_extraction_method=self.config.docprocessing.triplet_extraction_method, prefer_gpu_for_coref=self.config.docprocessing.prefer_gpu_for_coref, + n_process=self.config.docprocessing.n_process, + doc_bin_size=self.config.docprocessing.doc_bin_size, ) def docprocessing(self, continue_from_last=False): @@ -101,12 +113,16 @@ def corpusprocessing(self): print("Collecting triplets.") triplets = Triplet.from_annotated_docs(self.output_path / "annotations.ndjson") triplets = Triplet.filter_on_stopwords(triplets, self.config.base.language) - Triplet.write_jsonl(self.output_path / "triplets.ndjson", triplets) - if self.config.corpusprocessing.thresholds is None: - thresholds = ClusteringThresholds.estimate_from_n_triplets(len(triplets)) + thresholds = Thresholds.estimate_from_n_triplets(len(triplets)) else: thresholds = self.config.corpusprocessing.thresholds + triplets = Triplet.filter_on_entity_label_frequency( + triplets, + thresholds.min_label_occurrence, + ) + Triplet.write_jsonl(self.output_path / "triplets.ndjson", triplets) + print("Clustering entities and predicates to create mappings.") clustering = Clustering( language=self.config.base.language, @@ -114,6 +130,7 @@ def corpusprocessing(self): n_neighbors=self.config.corpusprocessing.n_neighbors, min_cluster_size=thresholds.min_cluster_size, min_samples=thresholds.min_samples, + cache_location=self.output_path / "cache", ) mappings = clustering.create_mappings(triplets) with open(self.output_path / "mappings.json", "w") as out: @@ -145,3 +162,76 @@ def corpusprocessing(self): edges, save=self.output_path / "graph.png", ) + + def databasepopulation(self): + if self.config.databasepopulation.clear_and_write: + if os.path.exists(self.output_path / "database.db"): + print("Removing old database.") + os.remove(self.output_path / "database.db") + + print("Populating database.") + engine = get_engine(self.output_path / "database.db") + setup_database(engine) + session = get_session(engine) + + with open(self.output_path / "mappings.json") as mappings_file: + mappings = Mappings(**json.load(mappings_file)) + + with open(self.output_path / "triplets.ndjson") as triplets_file: + cache = ModelLookupCache(session) + bulk = [] + for line in tqdm(triplets_file, desc="Writing triplets to database"): + triplet = Triplet(**json.loads(line)) + subject_id = cache.get_or_create_entity( + mappings.map_entity(triplet.subject.text), + session, + ) + object_id = cache.get_or_create_entity( + mappings.map_entity(triplet.object.text), + session, + ) + relation_id = cache.get_or_create_relation( + subject_id, + object_id, + mappings.map_predicate(triplet.predicate.text), + session, + ) + + triplet_orm = TripletOrm( + doc_id=int(triplet.doc), + subject_id=subject_id, + relation_id=relation_id, + object_id=object_id, + subj_span_start=triplet.subject.start_char, + subj_span_end=triplet.subject.end_char, + pred_span_start=triplet.predicate.start_char, + pred_span_end=triplet.predicate.end_char, + obj_span_start=triplet.object.start_char, + obj_span_end=triplet.object.end_char, + ) + bulk.append(triplet_orm) + if len(bulk) >= 500: + session.bulk_save_objects(bulk) + bulk.clear() + session.bulk_save_objects(bulk) + bulk.clear() + session.commit() + + for doc in ( + json.loads(line) + for line in tqdm( + iter_lines_of_files(self.output_path / "annotations.ndjson"), + desc="Writing documents to database", + ) + ): + doc_orm = DocumentOrm( + id=doc["id"], + text=doc["text"], + timestamp=datetime.fromisoformat(doc["timestamp"]), + ) + bulk.append(doc_orm) + if len(bulk) >= 500: + session.bulk_save_objects(bulk) + bulk.clear() + session.bulk_save_objects(bulk) + session.commit() diff --git a/src/conspiracies/run.py b/src/conspiracies/run.py index d859b7e..83f02cc 100644 --- a/src/conspiracies/run.py +++ b/src/conspiracies/run.py @@ -8,7 +8,7 @@ if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( - "project_name", + "output_path", nargs="?", default=None, help="Name of your project under which various output files will be output" @@ -49,7 +49,7 @@ logging.getLogger().setLevel(args.root_log_level) cli_args = { - "base.project_name": args.project_name, + "base.output_path": args.output_path, "base.language": args.language, "preprocessing.input_path": args.input_path, "preprocessing.n_docs": args.n_docs, @@ -61,4 +61,13 @@ config = PipelineConfig.default_with_extra_config(cli_args) pipeline = Pipeline(config) + + logging.basicConfig( + level=logging.DEBUG, + filename=config.base.output_path + "/logfile", + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + filemode="a+", + ) + logging.info("Running pipeline...") + pipeline.run() diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 92cf144..2ba48d1 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -29,3 +29,20 @@ def test_tuples_with_second_element_as_combine_key(self): Clustering._combine_clusters(clusters, get_combine_key=lambda x: x[1]) == expected ) + + +def test_cluster_by_normalization(): + labels = [ + "popular label", + "popular label", + "popular label 2", + "another label", + "another label", + "yet another label", + "a third label", + ] + clusters = Clustering._cluster_via_normalization(labels, top=2) + assert clusters == [ + ["popular label", "popular label 2"], + ["another label", "yet another label"], + ] diff --git a/tests/test_coref_comp.py b/tests/test_coref_comp.py index 80c10d0..38b90b6 100644 --- a/tests/test_coref_comp.py +++ b/tests/test_coref_comp.py @@ -22,8 +22,8 @@ def test_coref_clusters(nlp_da_w_coref): # noqa F811 assert isinstance(sent._.coref_clusters[0][1], Span) -def test_resolve_coref(nlp_da_w_coref): # noqa F811 - resolve_coref_text = ( +def test_resolved_text(nlp_da_w_coref): # noqa F811 + resolved_text_text = ( "Aftalepartierne bag Rammeaftalen om plan for genåbning af Danmark blev i" + " foråret 2021 enige om at nedsætte en ekspertgruppe, en ekspertgruppe fik " + "til opgave at komme med input til den langsigtede strategi for håndtering " @@ -31,7 +31,7 @@ def test_resolve_coref(nlp_da_w_coref): # noqa F811 + "ekspertgruppe rapport." ) - resolve_coref_spans = [ + resolved_text_spans = [ "Aftalepartierne bag Rammeaftalen om plan for genåbning af Danmark blev i " + "foråret 2021 enige om at nedsætte en ekspertgruppe, en ekspertgruppe fik " + "til opgave at komme med input til den langsigtede strategi for håndtering " @@ -39,11 +39,11 @@ def test_resolve_coref(nlp_da_w_coref): # noqa F811 "en ekspertgruppe er nu klar med en ekspertgruppe rapport.", ] - doc = nlp_da_w_coref(resolve_coref_text) + doc = nlp_da_w_coref(resolved_text_text) # test for doc - assert doc._.resolve_coref == resolve_coref_text + assert doc._.resolved_text == resolved_text_text # test for spans for i, sent in enumerate(doc.sents): if sent._.coref_clusters != []: - assert sent._.resolve_coref == resolve_coref_spans[i] + assert sent._.resolved_text == resolved_text_spans[i] diff --git a/tests/test_data/test_config.toml b/tests/test_data/test_config.toml index 6ddc9cf..7e9a244 100644 --- a/tests/test_data/test_config.toml +++ b/tests/test_data/test_config.toml @@ -1,6 +1,5 @@ [base] -project_name = "test" -output_root = "output" +output_path = "output/test" language = "en" [preprocessing] @@ -15,4 +14,6 @@ some_extra_value = 1234 [docprocessing] triplet_extraction_method = "test" -[corpusprocessing] \ No newline at end of file +[corpusprocessing] + +[databasepopulation] \ No newline at end of file diff --git a/tests/test_pipelineconfig.py b/tests/test_pipelineconfig.py index 21b1d61..c6d2a38 100644 --- a/tests/test_pipelineconfig.py +++ b/tests/test_pipelineconfig.py @@ -8,6 +8,7 @@ PreProcessingConfig, DocProcessingConfig, CorpusProcessingConfig, + DatabasePopulationConfig, ) @@ -21,8 +22,7 @@ def test_config_loading(path: str): config = PipelineConfig.from_toml_file(path) assert config == PipelineConfig( base=BaseConfig( - project_name="test", - output_root="output", + output_path="output/test", language="en", ), preprocessing=PreProcessingConfig( @@ -38,6 +38,7 @@ def test_config_loading(path: str): triplet_extraction_method="test", ), corpusprocessing=CorpusProcessingConfig(), + databasepopulation=DatabasePopulationConfig(), ) diff --git a/visualizer/README.md b/visualizer/README.md index 22bfb4f..940b5c2 100644 --- a/visualizer/README.md +++ b/visualizer/README.md @@ -7,4 +7,4 @@ It works as any React app in terms of building, running etc. If you just want to 1. Ensure that you have Node >=16 and npm installed. See https://docs.npmjs.com/downloading-and-installing-node-js-and-npm. 2. From the directory `conspiracies/visualizer`, run `npm install`. 3. Then run `npm start` which will open a development server on `localhost:3000`. -4. Load in the file `graph.json` from your output via the GUI. \ No newline at end of file +4. Load in the file `graph.json` from your output via the GUI. diff --git a/visualizer/electron-main.js b/visualizer/electron-main.js index e1d363d..38213e9 100644 --- a/visualizer/electron-main.js +++ b/visualizer/electron-main.js @@ -1,7 +1,6 @@ const { app, BrowserWindow } = require("electron"); const path = require("path"); - function createWindow() { const mainWindow = new BrowserWindow(); diff --git a/visualizer/package-lock.json b/visualizer/package-lock.json index 2e97d87..d3726cc 100644 --- a/visualizer/package-lock.json +++ b/visualizer/package-lock.json @@ -14,9 +14,12 @@ "@types/jest": "^27.5.2", "@types/node": "^16.18.96", "@types/react-dom": "^18.2.24", + "draft-js": "^0.11.7", + "multi-range-slider-react": "^2.0.7", "react": "^18.2.0", "react-dom": "^18.2.0", "react-graph-vis": "^1.0.7", + "react-highlight-within-textarea": "^3.2.2", "react-router-dom": "^6.22.3", "react-scripts": "5.0.1", "react-vis-graph-wrapper": "^0.1.3", @@ -26,6 +29,7 @@ "devDependencies": { "@electron/packager": "^18.3.5", "electron": "^33.0.2", + "prettier": "^3.3.3", "serve": "^14.2.4" } }, @@ -6514,6 +6518,14 @@ "integrity": "sha512-+R08/oI0nl3vfPcqftZRpytksBXDzOUveBq/NBVx0sUp1axwzPQrKinNx5yd5sxPu8j1wIy8AfnVQ+5eFdha6Q==", "dev": true }, + "node_modules/cross-fetch": { + "version": "3.1.8", + "resolved": "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.8.tgz", + "integrity": "sha512-cvA+JwZoU0Xq+h6WkMvAUqPEYy92Obet6UdKLfW60qn99ftItKjB5T+BkyWOFWe2pUyfQ+IJHmpOTznqk1M6Kg==", + "dependencies": { + "node-fetch": "^2.6.12" + } + }, "node_modules/cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -7415,6 +7427,20 @@ "resolved": "https://registry.npmjs.org/dotenv-expand/-/dotenv-expand-5.1.0.tgz", "integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==" }, + "node_modules/draft-js": { + "version": "0.11.7", + "resolved": "https://registry.npmjs.org/draft-js/-/draft-js-0.11.7.tgz", + "integrity": "sha512-ne7yFfN4sEL82QPQEn80xnADR8/Q6ALVworbC5UOSzOvjffmYfFsr3xSZtxbIirti14R7Y33EZC5rivpLgIbsg==", + "dependencies": { + "fbjs": "^2.0.0", + "immutable": "~3.7.4", + "object-assign": "^4.1.1" + }, + "peerDependencies": { + "react": ">=0.14.0", + "react-dom": ">=0.14.0" + } + }, "node_modules/duplexer": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz", @@ -8607,6 +8633,34 @@ "bser": "2.1.1" } }, + "node_modules/fbjs": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/fbjs/-/fbjs-2.0.0.tgz", + "integrity": "sha512-8XA8ny9ifxrAWlyhAbexXcs3rRMtxWcs3M0lctLfB49jRDHiaxj+Mo0XxbwE7nKZYzgCFoq64FS+WFd4IycPPQ==", + "dependencies": { + "core-js": "^3.6.4", + "cross-fetch": "^3.0.4", + "fbjs-css-vars": "^1.0.0", + "loose-envify": "^1.0.0", + "object-assign": "^4.1.0", + "promise": "^7.1.1", + "setimmediate": "^1.0.5", + "ua-parser-js": "^0.7.18" + } + }, + "node_modules/fbjs-css-vars": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/fbjs-css-vars/-/fbjs-css-vars-1.0.2.tgz", + "integrity": "sha512-b2XGFAFdWZWg0phtAWLHCk836A1Xann+I+Dgd3Gk64MHKZO44FfoD1KxyvbSh0qZsIoXQGGlVztIY+oitJPpRQ==" + }, + "node_modules/fbjs/node_modules/promise": { + "version": "7.3.1", + "resolved": "https://registry.npmjs.org/promise/-/promise-7.3.1.tgz", + "integrity": "sha512-nolQXZ/4L+bP/UGlkfaIujX9BKxGwmQ9OT4mOt5yvy8iK1h3wqTEJCijzGANTCCl9nWjY41juyAn2K3Q1hLLTg==", + "dependencies": { + "asap": "~2.0.3" + } + }, "node_modules/fd-slicer": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/fd-slicer/-/fd-slicer-1.1.0.tgz", @@ -9789,6 +9843,14 @@ "url": "https://opencollective.com/immer" } }, + "node_modules/immutable": { + "version": "3.7.6", + "resolved": "https://registry.npmjs.org/immutable/-/immutable-3.7.6.tgz", + "integrity": "sha512-AizQPcaofEtO11RZhPPHBOJRdo/20MKQF9mBLnVkBoyHi1/zXK8fzVdnEpSV9gxqtnh6Qomfp3F0xT5qP/vThw==", + "engines": { + "node": ">=0.8.0" + } + }, "node_modules/import-fresh": { "version": "3.3.0", "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", @@ -12082,6 +12144,11 @@ "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" }, + "node_modules/multi-range-slider-react": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/multi-range-slider-react/-/multi-range-slider-react-2.0.7.tgz", + "integrity": "sha512-KRYUkatXxxYceL5ZT8xvetIN+4yTCdWszxRC6Y6Jkua+oRrWVkmBR6v3R03kosYg/QtcETBf2L1Jt+4U66DFbg==" + }, "node_modules/multicast-dns": { "version": "7.2.5", "resolved": "https://registry.npmjs.org/multicast-dns/-/multicast-dns-7.2.5.tgz", @@ -12153,6 +12220,44 @@ "tslib": "^2.0.3" } }, + "node_modules/node-fetch": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.7.0.tgz", + "integrity": "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==", + "dependencies": { + "whatwg-url": "^5.0.0" + }, + "engines": { + "node": "4.x || >=6.0.0" + }, + "peerDependencies": { + "encoding": "^0.1.0" + }, + "peerDependenciesMeta": { + "encoding": { + "optional": true + } + } + }, + "node_modules/node-fetch/node_modules/tr46": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-0.0.3.tgz", + "integrity": "sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==" + }, + "node_modules/node-fetch/node_modules/webidl-conversions": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-3.0.1.tgz", + "integrity": "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==" + }, + "node_modules/node-fetch/node_modules/whatwg-url": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", + "integrity": "sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==", + "dependencies": { + "tr46": "~0.0.3", + "webidl-conversions": "^3.0.0" + } + }, "node_modules/node-forge": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/node-forge/-/node-forge-1.3.1.tgz", @@ -14118,6 +14223,21 @@ "node": ">= 0.8.0" } }, + "node_modules/prettier": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", + "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", + "dev": true, + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, "node_modules/pretty-bytes": { "version": "5.6.0", "resolved": "https://registry.npmjs.org/pretty-bytes/-/pretty-bytes-5.6.0.tgz", @@ -14528,6 +14648,16 @@ "integrity": "sha512-FULf7fayPdpASncVy4DLh3xydlXEJJpvIELjYjNeQWYUZ9pclcpvCZSr2gkmN2FrrGcI7G/cJsIEwk5/8vfXpg==", "deprecated": "Please upgrade to version 7 or higher. Older versions may use Math.random() in certain circumstances, which is known to be problematic. See https://v8.dev/blog/math-random for details." }, + "node_modules/react-highlight-within-textarea": { + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/react-highlight-within-textarea/-/react-highlight-within-textarea-3.2.2.tgz", + "integrity": "sha512-pS+tPi6//dM8V154/0SfSqkx+0i6lKpSKazLZa7+RQjNQg0wKeCZBVkOGtxAhsVJy5KWpfIfdcpE8JpZ2Giz/g==", + "peerDependencies": { + "draft-js": ">=0.11.7", + "react": ">=0.14.0", + "react-dom": ">=0.14.0" + } + }, "node_modules/react-is": { "version": "17.0.2", "resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz", @@ -15771,6 +15901,11 @@ "node": ">= 0.4" } }, + "node_modules/setimmediate": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/setimmediate/-/setimmediate-1.0.5.tgz", + "integrity": "sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==" + }, "node_modules/setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -17143,6 +17278,31 @@ "node": ">=4.2.0" } }, + "node_modules/ua-parser-js": { + "version": "0.7.39", + "resolved": "https://registry.npmjs.org/ua-parser-js/-/ua-parser-js-0.7.39.tgz", + "integrity": "sha512-IZ6acm6RhQHNibSt7+c09hhvsKy9WUr4DVbeq9U8o71qxyYtJpQeDxQnMrVqnIFMLcQjHO0I9wgfO2vIahht4w==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/ua-parser-js" + }, + { + "type": "paypal", + "url": "https://paypal.me/faisalman" + }, + { + "type": "github", + "url": "https://github.com/sponsors/faisalman" + } + ], + "bin": { + "ua-parser-js": "script/cli.js" + }, + "engines": { + "node": "*" + } + }, "node_modules/unbox-primitive": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/unbox-primitive/-/unbox-primitive-1.0.2.tgz", diff --git a/visualizer/package.json b/visualizer/package.json index dfc6bab..745725b 100644 --- a/visualizer/package.json +++ b/visualizer/package.json @@ -9,9 +9,12 @@ "@types/jest": "^27.5.2", "@types/node": "^16.18.96", "@types/react-dom": "^18.2.24", + "draft-js": "^0.11.7", + "multi-range-slider-react": "^2.0.7", "react": "^18.2.0", "react-dom": "^18.2.0", "react-graph-vis": "^1.0.7", + "react-highlight-within-textarea": "^3.2.2", "react-router-dom": "^6.22.3", "react-scripts": "5.0.1", "react-vis-graph-wrapper": "^0.1.3", @@ -50,8 +53,9 @@ ] }, "devDependencies": { - "electron": "^33.0.2", "@electron/packager": "^18.3.5", + "electron": "^33.0.2", + "prettier": "^3.3.3", "serve": "^14.2.4" } } diff --git a/visualizer/public/index.html b/visualizer/public/index.html index 32055ae..5ddbf33 100644 --- a/visualizer/public/index.html +++ b/visualizer/public/index.html @@ -1,12 +1,9 @@ - + - + Visualizer diff --git a/visualizer/public/manifest.json b/visualizer/public/manifest.json index d0c7108..96b3f5a 100644 --- a/visualizer/public/manifest.json +++ b/visualizer/public/manifest.json @@ -1,8 +1,7 @@ { "short_name": "Visualizer", "name": "Narrative Graphs Visualizer", - "icons": [ - ], + "icons": [], "start_url": ".", "display": "standalone", "theme_color": "#000000", diff --git a/visualizer/src/App.tsx b/visualizer/src/App.tsx index fae17c9..bc4b8cd 100644 --- a/visualizer/src/App.tsx +++ b/visualizer/src/App.tsx @@ -1,27 +1,14 @@ import "./App.css"; -import {BrowserRouter, HashRouter, Link, Route, Routes} from "react-router-dom"; -import {GraphViewer} from "./graph/GraphViewer"; - -function NavBar() { - return ( -
- Graph Viewer -
- ); -} +import { GraphViewer } from "./graph/GraphViewer"; +import React from "react"; +import { ServiceContextProvider } from "./service/ServiceContextProvider"; export function App() { - // Create actual routes if/when more functionality is added to the application - return ( - - - } - /> - - - ); + return ( + + + + ); } export default App; diff --git a/visualizer/src/common/LogarithmicRangeSlider.tsx b/visualizer/src/common/LogarithmicRangeSlider.tsx new file mode 100644 index 0000000..2ca6226 --- /dev/null +++ b/visualizer/src/common/LogarithmicRangeSlider.tsx @@ -0,0 +1,107 @@ +import React, { useEffect, useMemo, useState } from "react"; +import MultiRangeSlider from "multi-range-slider-react"; + +interface LogarithmicRangeSliderProps { + min: number; // Real-world minimum value + max: number; // Real-world maximum value + minValue: number; // Current minimum value + maxValue: number; // Current maximum value + onChange: (values: { minValue: number; maxValue: number }) => void; // Callback for value changes + style?: React.CSSProperties; // Optional style prop + ruler?: boolean; // Optional ruler prop +} + +const linearToLog = ( + value: number, + minLinear: number, + maxLinear: number, + minLog: number, + maxLog: number, +): number => { + const clampedValue = Math.max(minLinear, Math.min(value, maxLinear)); + const linearRange = maxLinear - minLinear; + const logRange = Math.log(maxLog) - Math.log(minLog); + const logValue = + Math.log(minLog) + ((clampedValue - minLinear) / linearRange) * logRange; + return Math.exp(logValue); +}; + +const logToLinear = ( + value: number, + minLinear: number, + maxLinear: number, + minLog: number, + maxLog: number, +): number => { + const clampedValue = Math.max(minLog, Math.min(value, maxLog)); + const linearRange = maxLinear - minLinear; + const logRange = Math.log(maxLog) - Math.log(minLog); + const logValue = Math.log(clampedValue); + return minLinear + ((logValue - Math.log(minLog)) / logRange) * linearRange; +}; + +const LogarithmicRangeSlider: React.FC = ({ + min, + max, + minValue, + maxValue, + onChange, + ruler, + ...rest +}) => { + const [minCaption, setMinCaption] = useState(Math.round(minValue)); + const [maxCaption, setMaxCaption] = useState(Math.round(maxValue)); + useEffect(() => { + setMinCaption(minValue); + setMaxCaption(maxValue); + }, [minValue, maxValue]); + + // Slider operates on a linear scale + const linearMin = 0; + const linearMax = 100; + const realValueToLinearScale = useMemo(() => { + return (realValue: number) => + Math.round(logToLinear(realValue, linearMin, linearMax, min, max)); + }, [min, max]); + const linearScaleToRealValue = useMemo(() => { + return (linearScaleValue: number) => + Math.round(linearToLog(linearScaleValue, linearMin, linearMax, min, max)); + }, [min, max]); + + const handleSliderInput = (e: { minValue: number; maxValue: number }) => { + const realMinValue = linearScaleToRealValue(e.minValue); + const realMaxValue = linearScaleToRealValue(e.maxValue); + setMinCaption(realMinValue); + setMaxCaption(realMaxValue); + }; + const handleSliderChange = (e: { minValue: number; maxValue: number }) => { + const realMinValue = linearScaleToRealValue(e.minValue); + const realMaxValue = linearScaleToRealValue(e.maxValue); + onChange({ minValue: realMinValue, maxValue: realMaxValue }); + }; + + const labels = [ + String(min), + String(linearScaleToRealValue(25)), + String(linearScaleToRealValue(50)), + String(linearScaleToRealValue(75)), + String(max), + ]; + + return ( + + ); +}; +export default LogarithmicRangeSlider; diff --git a/visualizer/src/datasources/FileUploadComp.tsx b/visualizer/src/datasources/JsonFileUploadComp.tsx similarity index 70% rename from visualizer/src/datasources/FileUploadComp.tsx rename to visualizer/src/datasources/JsonFileUploadComp.tsx index ce52eb9..c879243 100644 --- a/visualizer/src/datasources/FileUploadComp.tsx +++ b/visualizer/src/datasources/JsonFileUploadComp.tsx @@ -1,12 +1,12 @@ import React from "react"; -interface FileUploadComponentProps { - onFileLoaded: (data: any ) => void; +interface JsonFileUploadComponentProps { + onFileLoaded: (data: any) => void; } -const FileUploadComponent: React.FC = ({ +const JsonFileUploadComponent: React.FC = ({ onFileLoaded, -}: FileUploadComponentProps) => { +}: JsonFileUploadComponentProps) => { const handleFileChange = (event: React.ChangeEvent) => { const file = event.target.files?.[0]; if (file) { @@ -29,4 +29,4 @@ const FileUploadComponent: React.FC = ({ ); }; -export default FileUploadComponent; +export default JsonFileUploadComponent; diff --git a/visualizer/src/datasources/NdjsonFileUploadComp.tsx b/visualizer/src/datasources/NdjsonFileUploadComp.tsx new file mode 100644 index 0000000..bbf9b90 --- /dev/null +++ b/visualizer/src/datasources/NdjsonFileUploadComp.tsx @@ -0,0 +1,54 @@ +import React from "react"; + +interface NdjsonFileUploadComponentProps { + onFileLoaded: (data: Generator) => void; +} + +const NdjsonFileUploadComponent: React.FC = ({ + onFileLoaded, +}: NdjsonFileUploadComponentProps) => { + const handleFileChange = (event: React.ChangeEvent) => { + const file = event.target.files?.[0]; + if (file) { + const reader = new FileReader(); + reader.onload = (e) => { + const text = e.target?.result; + if (typeof text === "string") { + // Create a generator for NDJSON parsing + const parseNDJSON = function* ( + input: string, + ): Generator { + const lines = input.split("\n"); + for (const line of lines) { + if (line.trim()) { + try { + yield JSON.parse(line); + } catch (error) { + console.error("Invalid JSON line:", line, error); + } + } + } + }; + + try { + // Test if the file is a single JSON object + JSON.parse(text); // Throws if it's not a single JSON + alert("This file is a standard JSON file. NDJSON is expected."); + } catch { + // If not, assume it's NDJSON and pass the generator + onFileLoaded(parseNDJSON(text)); + } + } + }; + reader.readAsText(file); + } + }; + + return ( +
+ +
+ ); +}; + +export default NdjsonFileUploadComponent; diff --git a/visualizer/src/docs/DocService.ts b/visualizer/src/docs/DocService.ts new file mode 100644 index 0000000..a017fe1 --- /dev/null +++ b/visualizer/src/docs/DocService.ts @@ -0,0 +1,93 @@ +export abstract class DocService { + abstract getDocData(): Map; + + getDoc(id: string): Doc | undefined { + return this.getDocData().get(id); + } + + getDocs(ids: string[]): Doc[] { + return ids + .map((id) => this.getDoc(id)) + .filter((v): v is Doc => v !== undefined); + } +} + +export interface TripletField { + text: string; + start_char: number; + start: number; + end_char: number; + end: number; +} + +export interface Triplet { + subject: TripletField; + predicate: TripletField; + object: TripletField; +} + +export interface Doc { + id: string; + text: string; + timestamp: string; + semantic_triplets: Triplet[]; +} + +export class SampleDocService extends DocService { + readonly docData: Map = new Map( + [ + { + id: "1", + text: "sample text 1", + timestamp: "", + semantic_triplets: [], + }, + { + id: "2", + text: "sample text 1", + timestamp: "", + semantic_triplets: [], + }, + { + id: "3", + text: "sample text 1", + timestamp: "", + semantic_triplets: [], + }, + ].map((d) => [d.id, d]), + ); + + getDocData(): Map { + return this.docData; + } + + getDoc(id: string): Doc | undefined { + return this.docData.get(id); + } +} + +export class FileDocService extends DocService { + readonly docData: Map; + + getDocData(): Map { + return this.docData; + } + + constructor(docData: Doc[]) { + super(); + this.docData = new Map( + docData + .filter((d) => d.semantic_triplets !== undefined) + .map((d) => [ + d.id, + { + id: d.id, + text: d.text, + timestamp: d.timestamp, + semantic_triplets: d.semantic_triplets, + }, + ]), + ); + console.log(this.docData.size); + } +} diff --git a/visualizer/src/graph/GraphFilterControlPanel.tsx b/visualizer/src/graph/GraphFilterControlPanel.tsx index 12210a1..d01eab5 100644 --- a/visualizer/src/graph/GraphFilterControlPanel.tsx +++ b/visualizer/src/graph/GraphFilterControlPanel.tsx @@ -1,67 +1,118 @@ import React from "react"; -import {GraphFilter} from "./GraphService"; -import './graph.css' - +import { GraphFilter } from "./GraphService"; +import "./graph.css"; +import LogarithmicRangeSlider from "../common/LogarithmicRangeSlider"; interface GraphFilterControlPanelProps { - graphFilter: GraphFilter; - setGraphFilter: React.Dispatch>; + graphFilter: GraphFilter; + setGraphFilter: React.Dispatch>; } -export const GraphFilterControlPanel = ({graphFilter, setGraphFilter}: GraphFilterControlPanelProps) => { +export const GraphFilterControlPanel = ({ + graphFilter, + setGraphFilter, +}: GraphFilterControlPanelProps) => { + const setMinAndMaxNodeFrequency = (min: number, max: number) => { + setGraphFilter({ + ...graphFilter, + minimumNodeFrequency: min, + maximumNodeFrequency: max, + }); + }; + const setMinAndMaxEdgeFrequency = (min: number, max: number) => { + setGraphFilter({ + ...graphFilter, + minimumEdgeFrequency: min, + maximumEdgeFrequency: max, + }); + }; - return
-
- Minimum Node Frequency: {graphFilter.minimumNodeFrequency} - - -
-
- Minimum Edge Frequency: {graphFilter.minimumEdgeFrequency} - - -
-
- Show unconnected nodes: - setGraphFilter({ - ...graphFilter, - showUnconnectedNodes: event.target.checked - })}/> -
-
- From: setGraphFilter({ - ...graphFilter, - earliestDate: event.target.valueAsDate ?? undefined - })}/> - To: setGraphFilter({ - ...graphFilter, - latestDate: event.target.valueAsDate ?? undefined - })}/> + return ( +
+
+ + Node Frequency: + +
+
+ { + setMinAndMaxNodeFrequency(e.minValue, e.maxValue); + }} + min={1} + minValue={graphFilter.minimumNodeFrequency} + maxValue={graphFilter.maximumNodeFrequency} + max={graphFilter.maximumPossibleNodeFrequency} + style={{ border: "none", boxShadow: "none", padding: "15px 10px" }} + > +
+ +
+ Edge Frequency: +
+ { + setMinAndMaxEdgeFrequency(e.minValue, e.maxValue); + }} + min={graphFilter.minimumPossibleEdgeFrequency} + minValue={graphFilter.minimumEdgeFrequency} + maxValue={graphFilter.maximumEdgeFrequency} + max={graphFilter.maximumPossibleEdgeFrequency} + style={{ border: "none", boxShadow: "none", padding: "15px 10px" }} + >
+
+
+ Show unconnected nodes: + + setGraphFilter({ + ...graphFilter, + showUnconnectedNodes: event.target.checked, + }) + } + /> +
+
+ Search nodes: + { + let value = event.target.value; + setGraphFilter({ + ...graphFilter, + labelSearch: value, + }); + }} + /> +
+
+ From: + + setGraphFilter({ + ...graphFilter, + earliestDate: event.target.valueAsDate ?? undefined, + }) + } + /> + To: + + setGraphFilter({ + ...graphFilter, + latestDate: event.target.valueAsDate ?? undefined, + }) + } + /> +
-} \ No newline at end of file + ); +}; diff --git a/visualizer/src/graph/GraphOptionsControlPanel.tsx b/visualizer/src/graph/GraphOptionsControlPanel.tsx index 3b374a4..afc9554 100644 --- a/visualizer/src/graph/GraphOptionsControlPanel.tsx +++ b/visualizer/src/graph/GraphOptionsControlPanel.tsx @@ -1,69 +1,84 @@ import React from "react"; -import './graph.css' -import {Options} from "react-vis-graph-wrapper"; - +import "./graph.css"; +import { Options } from "react-vis-graph-wrapper"; interface GraphOptionsControlPanelProps { - options: Options; - setOptions: React.Dispatch>; - + options: Options; + setOptions: React.Dispatch>; } function getSmoothEnabled(options: Options): boolean { - if (typeof options.edges?.smooth === 'boolean') { - return options.edges.smooth; - } else if (typeof options.edges?.smooth === 'object' && 'enabled' in options.edges.smooth) { - return options.edges.smooth.enabled; - } else { - return false; - } + if (typeof options.edges?.smooth === "boolean") { + return options.edges.smooth; + } else if ( + typeof options.edges?.smooth === "object" && + "enabled" in options.edges.smooth + ) { + return options.edges.smooth.enabled; + } else { + return false; + } } -export const GraphOptionsControlPanel = ({options, setOptions}: GraphOptionsControlPanelProps) => { - - - return
-
- Physics enabled: - setOptions( - { - ...options, - physics: { - ...options.physics, - enabled: event.target.checked - } - }) - }/> -
-
- Rounded edges: - setOptions( - { - ...options, - edges: { - ...options.edges, - smooth: !options.edges?.smooth - } - }) - }/> -
-
- Edge length: - setOptions( - { - ...options, - physics: { - ...options.physics, - barnesHut: { - springLength: Number(event.target.value) - } - } - }) - } - step="1"/> -
+export const GraphOptionsControlPanel = ({ + options, + setOptions, +}: GraphOptionsControlPanelProps) => { + return ( +
+
+ Physics enabled: + + setOptions({ + ...options, + physics: { + ...options.physics, + enabled: event.target.checked, + }, + }) + } + /> +
+
+ Rounded edges: + + setOptions({ + ...options, + edges: { + ...options.edges, + smooth: !options.edges?.smooth, + }, + }) + } + /> +
+
+ Edge length: + + setOptions({ + ...options, + physics: { + ...options.physics, + barnesHut: { + springLength: Number(event.target.value), + }, + }, + }) + } + step="1" + /> +
-} \ No newline at end of file + ); +}; diff --git a/visualizer/src/graph/GraphService.ts b/visualizer/src/graph/GraphService.ts index 04ba4a7..ef23b7b 100644 --- a/visualizer/src/graph/GraphService.ts +++ b/visualizer/src/graph/GraphService.ts @@ -1,189 +1,266 @@ -import {Edge, GraphData, Node} from "react-vis-graph-wrapper"; +import { Edge, GraphData, Node } from "react-vis-graph-wrapper"; export interface Stats { - frequency: number; - norm_frequency?: number; - docs?: string[]; - first_occurrence?: string; - last_occurrence?: string; - alt_labels?: string[]; + frequency: number; + norm_frequency?: number; + docs?: string[]; + first_occurrence?: string; + last_occurrence?: string; + alt_labels?: string[]; } export interface EnrichedNode extends Node { - stats: Stats; + stats: Stats; } export interface EnrichedEdge extends Edge { - stats: Stats; + stats: Stats; +} + +export interface EdgeGroup extends EnrichedEdge { + group?: EnrichedEdge[]; } export interface EnrichedGraphData extends GraphData { - nodes: EnrichedNode[]; - edges: EnrichedEdge[]; + nodes: EnrichedNode[]; + edges: EdgeGroup[]; } export class GraphFilter { - minimumNodeFrequency: number; - minimumEdgeFrequency: number; - earliestDate?: Date; - latestDate?: Date; - showUnconnectedNodes: boolean = false; - - constructor(minimumNodeFrequency: number = 1, minimumEdgeFrequency: number = 1) { - this.minimumNodeFrequency = minimumNodeFrequency; - this.minimumEdgeFrequency = minimumEdgeFrequency; - } + minimumPossibleNodeFrequency: number; + minimumNodeFrequency: number; + maximumNodeFrequency: number; + maximumPossibleNodeFrequency: number; + minimumPossibleEdgeFrequency: number; + minimumEdgeFrequency: number; + maximumEdgeFrequency: number; + maximumPossibleEdgeFrequency: number; + labelSearch: string = ""; + earliestDate?: Date; + latestDate?: Date; + showUnconnectedNodes: boolean = false; + + constructor( + minimumPossibleNodeFrequency: number, + minimumNodeFrequency: number, + maximumPossibleNodeFrequency: number, + minimumPossibleEdgeFrequency: number, + minimumEdgeFrequency: number, + maximumPossibleEdgeFrequency: number, + ) { + this.minimumPossibleNodeFrequency = minimumPossibleNodeFrequency; + this.minimumNodeFrequency = minimumNodeFrequency; + this.maximumNodeFrequency = maximumPossibleNodeFrequency; + this.maximumPossibleNodeFrequency = maximumPossibleNodeFrequency; + this.minimumPossibleEdgeFrequency = minimumPossibleEdgeFrequency; + this.minimumEdgeFrequency = minimumEdgeFrequency; + this.maximumEdgeFrequency = maximumPossibleEdgeFrequency; + this.maximumPossibleEdgeFrequency = maximumPossibleEdgeFrequency; + } } -function hasDateOverlap(node: EnrichedNode, filter: GraphFilter): boolean { - if (!node.stats.first_occurrence || !node.stats.last_occurrence) { - return true; - } - const first = new Date(node.stats.first_occurrence); - const last = new Date(node.stats.last_occurrence); - const afterEarliestDate = !filter.earliestDate - || filter.earliestDate < first - || filter.earliestDate < last; +function hasDateOverlap( + nodeOrEdge: EnrichedNode | EnrichedEdge, + filter: GraphFilter, +): boolean { + if (!nodeOrEdge.stats.first_occurrence || !nodeOrEdge.stats.last_occurrence) { + return true; + } + const first = new Date(nodeOrEdge.stats.first_occurrence); + const last = new Date(nodeOrEdge.stats.last_occurrence); + const afterEarliestDate = + !filter.earliestDate || + filter.earliestDate < first || + filter.earliestDate < last; - const beforeLatestDate = !filter.latestDate - || filter.latestDate > first || filter.latestDate > last; + const beforeLatestDate = + !filter.latestDate || filter.latestDate > first || filter.latestDate > last; - return afterEarliestDate && beforeLatestDate; + return afterEarliestDate && beforeLatestDate; } - -export function filter(filter: GraphFilter, graphData: EnrichedGraphData): EnrichedGraphData { - let nodes = graphData.nodes.filter((node: EnrichedNode) => - node.stats.frequency >= filter.minimumNodeFrequency - && hasDateOverlap(node, filter) - ); - let filteredNodes = new Set(nodes.map(node => node.id)); - let edges = graphData.edges.filter((edge: EnrichedEdge) => +export function filter( + filter: GraphFilter, + graphData: EnrichedGraphData, +): EnrichedGraphData { + let nodes = graphData.nodes.filter( + (node: EnrichedNode) => + node.stats.frequency >= filter.minimumNodeFrequency && + node.stats.frequency < filter.maximumNodeFrequency && + hasDateOverlap(node, filter), + ); + let filteredNodes = new Set(nodes.map((node) => node.id)); + let groupedEdges = graphData.edges + .filter( + (edge: EnrichedEdge) => edge.stats.frequency >= filter.minimumEdgeFrequency && - filteredNodes.has(edge.from) && filteredNodes.has(edge.to) + edge.stats.frequency < filter.maximumEdgeFrequency && + hasDateOverlap(edge, filter) && + filteredNodes.has(edge.from) && + filteredNodes.has(edge.to), + ) + .reduce( + (acc, curr) => { + const key = curr.from + "->" + curr.to; + if (!acc[key]) { + acc[key] = []; + } + acc[key].push(curr); + return acc; + }, + {} as Record, ); - let connectedNodes = new Set(edges.flatMap(edge => [edge.from, edge.to])); - if (!filter.showUnconnectedNodes) { - nodes = nodes.filter(node => connectedNodes.has(node.id)); - } - return {nodes, edges} + let edges = Object.values(groupedEdges).map((group) => { + group.sort((edge1, edge2) => edge2.stats.frequency - edge1.stats.frequency); + const representative: EnrichedEdge = group.at(0)!; + return { + ...representative, + id: representative.from + "->" + representative.to, + label: group + .slice(0, 3) + .map((e) => e.label) + .join(", "), + width: Math.log( + group.map((e) => e.stats.frequency).reduce((a, b) => a + b), + ), + group: group, + }; + }); + + let connectedNodes = new Set(edges.flatMap((edge) => [edge.from, edge.to])); + if (!filter.showUnconnectedNodes) { + nodes = nodes.filter((node) => connectedNodes.has(node.id)); + } + nodes = nodes.map((node) => ({ + ...node, + opacity: node.label?.toLowerCase().includes(filter.labelSearch) ? 1 : 0.2, + font: { + size: 14 + node.stats.frequency / 100, + }, + })); + + return { nodes, edges }; } -export abstract class GraphService { - private nodesMap: Map | null = null; +export interface DataBounds { + minNodeFrequency: number; + maxNodeFrequency: number; + maxEdgeFrequency: number; +} - abstract getGraph(): EnrichedGraphData; +export abstract class GraphService { + abstract getGraph(): EnrichedGraphData; - getSubGraph(nodeIds: Set): EnrichedGraphData { - return { - nodes: this.getGraph().nodes.filter((n: EnrichedNode) => nodeIds.has(n.id!.toString())), - edges: this.getGraph().edges - } - } - - getConnectedNodes(nodeId: string): Set { - return new Set(this.getGraph().edges.filter(edge => edge.from === nodeId || edge.to === nodeId) - .flatMap(edge => [edge.from!.toString(), edge.to!.toString()])) - } - - getNode(nodeId: string): EnrichedNode | undefined { - if (this.nodesMap === null) { - this.nodesMap = new Map( - this.getGraph().nodes.map(node => [node.id!.toString(), node]) - ) - } + getBounds(): DataBounds { + return { + minNodeFrequency: Math.min( + ...this.getGraph().nodes.map((n) => n.stats.frequency), + ), + maxNodeFrequency: Math.max( + ...this.getGraph().nodes.map((n) => n.stats.frequency), + ), + maxEdgeFrequency: Math.max( + ...this.getGraph().edges.map((n) => n.stats.frequency), + ), + }; + } - // highly inefficient linear search; overwrite for actual use - for (let node of this.getGraph().nodes) { - if (node.id === nodeId) { - return node; - } - } - return undefined; - } + getSubGraph(nodeIds: Set): EnrichedGraphData { + return { + nodes: this.getGraph().nodes.filter((n: EnrichedNode) => + nodeIds.has(n.id!.toString()), + ), + edges: this.getGraph().edges, + }; + } + getConnectedNodes(nodeId: string): Set { + return new Set( + this.getGraph() + .edges.filter((edge) => edge.from === nodeId || edge.to === nodeId) + .flatMap((edge) => [edge.from!.toString(), edge.to!.toString()]), + ); + } } - export class SampleGraphService extends GraphService { - readonly sampleGraphData: EnrichedGraphData = { - nodes: [ - { - id: "1", - label: "node 1", - stats: { - frequency: 3, - }, - }, - { - id: "2", - label: "node 2", - stats: { - frequency: 2, - }, - }, - { - id: "3", - label: "node 3", - stats: { - frequency: 2, - }, - }, - { - id: "4", - label: "node 4", - stats: { - frequency: 1, - }, - }, - ], - edges: [ - { - from: "1", - to: "2", - stats: { - frequency: 2, - }, - }, - { - from: "1", - to: "3", - stats: { - frequency: 2, - }, - }, - { - from: "1", - to: "4", - stats: { - frequency: 1, - }, - }, - { - from: "2", - to: "3", - stats: { - frequency: 2, - }, - }, - ], - }; + readonly sampleGraphData: EnrichedGraphData = { + nodes: [ + { + id: "1", + label: "node 1", + stats: { + frequency: 3, + }, + }, + { + id: "2", + label: "node 2", + stats: { + frequency: 2, + }, + }, + { + id: "3", + label: "node 3", + stats: { + frequency: 2, + }, + }, + { + id: "4", + label: "node 4", + stats: { + frequency: 1, + }, + }, + ], + edges: [ + { + from: "1", + to: "2", + stats: { + frequency: 2, + }, + }, + { + from: "1", + to: "3", + stats: { + frequency: 2, + }, + }, + { + from: "1", + to: "4", + stats: { + frequency: 1, + }, + }, + { + from: "2", + to: "3", + stats: { + frequency: 2, + }, + }, + ], + }; - getGraph(): EnrichedGraphData { - return this.sampleGraphData; - } + getGraph(): EnrichedGraphData { + return this.sampleGraphData; + } } export class FileGraphService extends GraphService { - private readonly data: EnrichedGraphData = {nodes: [], edges: []}; - - constructor(data: EnrichedGraphData) { - super(); - this.data = data; - } + private readonly data: EnrichedGraphData = { nodes: [], edges: [] }; - getGraph(): EnrichedGraphData { - return this.data; - } + constructor(data: EnrichedGraphData) { + super(); + this.data = data; + } + getGraph(): EnrichedGraphData { + return this.data; + } } diff --git a/visualizer/src/graph/GraphViewer.tsx b/visualizer/src/graph/GraphViewer.tsx index 478d0aa..333d4d0 100644 --- a/visualizer/src/graph/GraphViewer.tsx +++ b/visualizer/src/graph/GraphViewer.tsx @@ -1,132 +1,164 @@ -import React, {useEffect, useRef, useState} from "react"; -import { - EnrichedGraphData, - EnrichedNode, - FileGraphService, - filter, - GraphFilter, - GraphService, - SampleGraphService, -} from "./GraphService"; -import FileUploadComponent from "../datasources/FileUploadComp"; -import Graph, {GraphEvents, Options} from "react-vis-graph-wrapper"; -import {GraphFilterControlPanel} from "./GraphFilterControlPanel"; -import {GraphOptionsControlPanel} from "./GraphOptionsControlPanel"; -import {NodeInfo} from "./NodeInfo"; +import React, { useMemo, useState } from "react"; +import { EdgeGroup, EnrichedNode, filter, GraphFilter } from "./GraphService"; +import Graph, { GraphEvents, Options } from "react-vis-graph-wrapper"; +import { GraphFilterControlPanel } from "./GraphFilterControlPanel"; +import { GraphOptionsControlPanel } from "./GraphOptionsControlPanel"; +import { NodeInfo } from "../inspector/NodeInfo"; +import { EdgeInfo } from "../inspector/EdgeInfo"; +import { useServiceContext } from "../service/ServiceContextProvider"; +export interface GraphViewerProps {} export const GraphViewer: React.FC = () => { - let graphServiceRef = useRef(new SampleGraphService()); - const [graphData, setGraphData] = useState( - graphServiceRef.current.getGraph() - ); + const { getGraphService } = useServiceContext(); - const handleFileLoaded = (data: any) => { - graphServiceRef.current = new FileGraphService(data); - setGraphData(filter(graphFilter, graphServiceRef.current.getGraph())); - }; + const top50 = + getGraphService() + .getGraph() + .nodes.map((n) => n.stats.frequency) + .sort((a, b) => b - a) + .at(100) || 1; + let { minNodeFrequency, maxNodeFrequency, maxEdgeFrequency } = + getGraphService().getBounds(); - const [graphFilter, setGraphFilter] = useState(new GraphFilter(5, 3)) - const [selected, setSelected] = useState(new Set()) - const [selectedNode, setSelectedNode] = useState(undefined) + const [graphFilter, setGraphFilter] = useState( + new GraphFilter( + minNodeFrequency, + top50, + maxNodeFrequency, + 1, + Math.floor(top50 / 10), + maxEdgeFrequency, + ), + ); + const [subgraphNodes, setSubgraphNodes] = useState(new Set()); + const [selectedNode, setSelectedNode] = useState( + undefined, + ); + const [selectedEdge, setSelectedEdge] = useState( + undefined, + ); - useEffect( - () => { - let newGraphData: EnrichedGraphData; - if (selected.size > 0) { - newGraphData = graphServiceRef.current.getSubGraph(selected); - } else { - newGraphData = graphServiceRef.current.getGraph(); - } - setGraphData(filter(graphFilter, newGraphData)) - }, - [graphFilter, selected] - ) + const filteredGraphData = useMemo(() => { + const baseGraphData = + subgraphNodes.size > 0 + ? getGraphService().getSubGraph(subgraphNodes) + : getGraphService().getGraph(); + return filter(graphFilter, baseGraphData); + }, [getGraphService, graphFilter, subgraphNodes]); - let events: GraphEvents = { - doubleClick: ({nodes}) => { - const newSelected = new Set(selected); - nodes.forEach((element: string) => { - newSelected.delete(element); - }); - setSelected(newSelected); - }, - select: ({nodes}) => { - let newSelected: Set; - if (nodes.length > 1) { - newSelected = new Set(); - nodes.forEach((element: string) => { - newSelected.add(element); - }); - setSelected(newSelected); - } - }, - hold: ({nodes}) => { - const newSelected = new Set(selected); - nodes.forEach((element: string) => { - Array.from(graphServiceRef.current.getConnectedNodes(element)).forEach(c => newSelected.add(c)) - }); - setSelected(newSelected); - }, - selectNode: ({nodes}) => { - setSelectedNode(graphServiceRef.current.getNode(nodes[0])); - }, - deselectNode: () => { - setSelectedNode(undefined); - } + const graphDataMaps = useMemo(() => { + return { + nodesMap: new Map( + filteredGraphData.nodes.map((node) => [node.id!.toString(), node]), + ), + edgeGroupMap: new Map( + filteredGraphData.edges.map((edgeGroup) => [edgeGroup.id, edgeGroup]), + ), }; + }, [filteredGraphData]); - let [options, setOptions] = useState({ - physics: { - enabled: true, - barnesHut: { - springLength: 200 - } - }, - edges: { - smooth: false, - font: { - align: 'top' - } - } - }) - - return ( -
- -
- -
-
- - -
-
-
-
- Shift+select to show subgraph. -
-
- Double-click node to remove it. -
-
- Hold to expand from node. -
- -
- + let events: GraphEvents = { + hold: ({ nodes }) => { + const newSubgraphNodes = new Set(subgraphNodes); + nodes.forEach((element: string) => { + newSubgraphNodes.delete(element); + }); + setSubgraphNodes(newSubgraphNodes); + }, + select: ({ nodes }) => { + let newSelected: Set; + if (nodes.length > 1) { + newSelected = new Set(); + nodes.forEach((element: string) => { + newSelected.add(element); + }); + setSubgraphNodes(newSelected); + } + }, + doubleClick: ({ nodes }) => { + const newSelected = new Set(subgraphNodes); + nodes.forEach((element: string) => { + Array.from(getGraphService().getConnectedNodes(element)).forEach((c) => + newSelected.add(c), + ); + }); + setSubgraphNodes(newSelected); + }, + selectNode: ({ nodes, edges }) => { + setSelectedEdge(undefined); + setSelectedNode(graphDataMaps.nodesMap.get(nodes[0])); + }, + selectEdge: ({ nodes, edges }) => { + if (nodes.length < 1) { + setSelectedNode(undefined); + setSelectedEdge(graphDataMaps.edgeGroupMap.get(edges[0])); + } + }, + deselectNode: () => { + setSelectedNode(undefined); + }, + deselectEdge: () => { + setSelectedEdge(undefined); + }, + }; -
-
- {selectedNode && } + let [options, setOptions] = useState({ + physics: { + enabled: true, + barnesHut: { + springLength: 200, + }, + }, + edges: { + smooth: false, + font: { + align: "top", + }, + }, + }); - -
+ return ( +
+
+ + +
+
+
+
+ + Shift+mark multiple to make subgraph. + +
+
+ + Hold node to remove it. + +
+
+ + Double-click to expand from node. + +
+
- ); +
+
+ {selectedNode && } + {selectedEdge && } + +
+
+ ); }; diff --git a/visualizer/src/graph/NodeInfo.tsx b/visualizer/src/graph/NodeInfo.tsx deleted file mode 100644 index 03189c4..0000000 --- a/visualizer/src/graph/NodeInfo.tsx +++ /dev/null @@ -1,34 +0,0 @@ -import {EnrichedNode} from "./GraphService"; -import React from "react"; - -export interface NodeInfoProps { - node: EnrichedNode - className?: string; -} - -export const NodeInfo: React.FC = ({node, className}: NodeInfoProps) => { - const stats = node.stats; - return
- {node.label} -
-
-

Frequency: {stats.frequency}

-

Norm. frequency: {stats.norm_frequency?.toPrecision(3)}

- {stats.first_occurrence &&

Earliest date: {stats.first_occurrence}

} - {stats.last_occurrence &&

Latest date: {stats.last_occurrence}

} - {stats.alt_labels && -
- Labels: -
    {stats.alt_labels.map(l =>
  • {l}
  • )}
-
- } - {stats.docs && -
- Documents -
    {stats.docs.map(d =>
  • {d}
  • )}
-
- } -
- -
-} \ No newline at end of file diff --git a/visualizer/src/graph/graph.css b/visualizer/src/graph/graph.css index 3f64b8d..568600d 100644 --- a/visualizer/src/graph/graph.css +++ b/visualizer/src/graph/graph.css @@ -1,42 +1,44 @@ .padded { - padding: 5px; + padding: 5px; } .flex-container { - display: flex; - align-items: center; - margin-top: 5px; - margin-bottom: 5px; + display: flex; + align-items: center; + margin-top: 5px; + margin-bottom: 5px; } .flex-container__element { - display: flex; - align-items: center; - margin-left: 20px; - + display: flex; + align-items: center; + margin-left: 20px; } .flex-container__element:first-child { - margin-left: 0; + margin-left: 0; } .flex-container__element__sub-element { - margin: 2px; + margin: 2px; } .node-info { - position: absolute; - z-index: 3; - background: white; - border: solid 1px gray; - font-size: small; - max-width: 250px; - padding: 5px; - margin: 2px; + position: absolute; + z-index: 3; + background: white; + border: solid 1px gray; + font-size: small; + width: 20%; + max-width: 500px; + max-height: 80%; + overflow-y: scroll; + padding: 5px; + margin: 2px; } - .graph-container { - height: 80vh; - border: 1px inset; -} \ No newline at end of file + height: 80vh; + max-height: 80vh; + border: 1px inset; +} diff --git a/visualizer/src/index.css b/visualizer/src/index.css index ec2585e..4a1df4d 100644 --- a/visualizer/src/index.css +++ b/visualizer/src/index.css @@ -1,13 +1,13 @@ body { margin: 0; - font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', - 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", "Oxygen", + "Ubuntu", "Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } code { - font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', + font-family: source-code-pro, Menlo, Monaco, Consolas, "Courier New", monospace; } diff --git a/visualizer/src/index.tsx b/visualizer/src/index.tsx index cfd7784..008f05d 100644 --- a/visualizer/src/index.tsx +++ b/visualizer/src/index.tsx @@ -4,10 +4,6 @@ import "./index.css"; import App from "./App"; const root = ReactDOM.createRoot( - document.getElementById("root") as HTMLElement -); -root.render( - - - + document.getElementById("root") as HTMLElement, ); +root.render(); diff --git a/visualizer/src/inspector/DocInfo.tsx b/visualizer/src/inspector/DocInfo.tsx new file mode 100644 index 0000000..98b325d --- /dev/null +++ b/visualizer/src/inspector/DocInfo.tsx @@ -0,0 +1,127 @@ +import React, { PropsWithChildren } from "react"; +import { Doc, Triplet } from "../docs/DocService"; +import HighlightWithinTextarea from "react-highlight-within-textarea"; +import "./docinfo.css"; + +const BlueHighlight: React.FC = (props) => { + return ( + + {props.children} + + ); +}; + +const GreenHighlight: React.FC = (props) => { + return ( + + {props.children} + + ); +}; + +const RedHighlight: React.FC = (props) => { + return ( + {props.children} + ); +}; + +interface HighlightedTextProps { + text: string; + triplets: Triplet[]; + highlightLabels: string[]; +} + +const HighlightedText: React.FC = ({ + text, + triplets, + highlightLabels, +}) => { + const subjects = []; + const highlightSubjects = []; + const predicates = []; + const highlightPredicates = []; + const objects = []; + const highlightObjects = []; + + for (let triplet of triplets) { + const subject = triplet.subject; + const subjectSpan = [subject.start_char, subject.end_char]; + if (highlightLabels.indexOf(subject.text) > -1) { + highlightSubjects.push(subjectSpan); + } else { + subjects.push(subjectSpan); + } + const predicate = triplet.predicate; + const predicateSpan = [predicate.start_char, predicate.end_char]; + if (highlightLabels.indexOf(predicate.text) > -1) { + highlightPredicates.push(predicateSpan); + } else { + predicates.push(predicateSpan); + } + const object = triplet.object; + const objectSpan = [object.start_char, object.end_char]; + if (highlightLabels.indexOf(object.text) > -1) { + highlightObjects.push(objectSpan); + } else { + objects.push(objectSpan); + } + } + + return ( + + ); +}; + +export interface DocInfoProps { + document: Doc; + highlightLabels: string[]; +} + +export const DocInfo: React.FC = ({ + document, + highlightLabels, +}) => { + return ( +
+

+ {document.id} {document.timestamp} +

+ +
+ ); +}; diff --git a/visualizer/src/inspector/EdgeInfo.tsx b/visualizer/src/inspector/EdgeInfo.tsx new file mode 100644 index 0000000..f73425c --- /dev/null +++ b/visualizer/src/inspector/EdgeInfo.tsx @@ -0,0 +1,21 @@ +import { EdgeGroup } from "../graph/GraphService"; +import React from "react"; +import { StatsInfo } from "./StatsInfo"; + +export interface EdgeInfoProps { + edges: EdgeGroup; + className?: string; +} + +export const EdgeInfo: React.FC = ({ edges }: EdgeInfoProps) => { + return ( +
+ {edges.group!.map((e, i) => ( +
+ {e.label} + +
+ ))} +
+ ); +}; diff --git a/visualizer/src/inspector/NodeInfo.tsx b/visualizer/src/inspector/NodeInfo.tsx new file mode 100644 index 0000000..3785a5e --- /dev/null +++ b/visualizer/src/inspector/NodeInfo.tsx @@ -0,0 +1,21 @@ +import { EnrichedNode } from "../graph/GraphService"; +import React from "react"; +import { StatsInfo } from "./StatsInfo"; + +export interface NodeInfoProps { + node: EnrichedNode; + className?: string; +} + +export const NodeInfo: React.FC = ({ + node, + className, +}: NodeInfoProps) => { + return ( +
+ {node.label} +
+ +
+ ); +}; diff --git a/visualizer/src/inspector/StatsInfo.tsx b/visualizer/src/inspector/StatsInfo.tsx new file mode 100644 index 0000000..8676b8d --- /dev/null +++ b/visualizer/src/inspector/StatsInfo.tsx @@ -0,0 +1,58 @@ +import { DocInfo } from "./DocInfo"; +import React from "react"; +import { Stats } from "../graph/GraphService"; +import { useServiceContext } from "../service/ServiceContextProvider"; + +export interface StatsInfoProps { + label: string; + stats: Stats; +} + +export const StatsInfo: React.FC = ({ label, stats }) => { + const { getDocService } = useServiceContext(); + + return ( +
+

Frequency: {stats.frequency}

+ {/*

Norm. frequency: {stats.norm_frequency?.toPrecision(3)}

*/} + {stats.first_occurrence &&

Earliest date: {stats.first_occurrence}

} + {stats.last_occurrence &&

Latest date: {stats.last_occurrence}

} + {stats.alt_labels && ( +
+ Alternative Labels: +
    + {stats.alt_labels.map((l) => ( +
  • {l}
  • + ))} +
+
+ )} + {stats.docs && ( +
+ Documents + {!getDocService() && ( +
    + {stats.docs.map((d) => ( +
  • {d}
  • + ))} +
+ )} + + {getDocService() && ( +
+ {getDocService() + .getDocs(stats.docs) + .map((d) => ( + + ))} +
+ )} +
+ )} +
+ ); +}; diff --git a/visualizer/src/inspector/docinfo.css b/visualizer/src/inspector/docinfo.css new file mode 100644 index 0000000..f3ba462 --- /dev/null +++ b/visualizer/src/inspector/docinfo.css @@ -0,0 +1,29 @@ +.highlight-subject { + background: cyan; + opacity: 1; +} + +.subject { + background: cyan; + opacity: 0.3; +} + +.highlight-predicate { + background: lightgreen; + opacity: 1; +} + +.predicate { + background: lightgreen; + opacity: 0.3; +} + +.highlight-object { + background: yellow; + opacity: 1; +} + +.object { + background: yellow; + opacity: 0.3; +} diff --git a/visualizer/src/service/ServiceContextProvider.tsx b/visualizer/src/service/ServiceContextProvider.tsx new file mode 100644 index 0000000..b8dd35d --- /dev/null +++ b/visualizer/src/service/ServiceContextProvider.tsx @@ -0,0 +1,78 @@ +import React, { + createContext, + PropsWithChildren, + useContext, + useState, +} from "react"; +import { DocService, FileDocService } from "../docs/DocService"; +import { FileGraphService, GraphService } from "../graph/GraphService"; +import JsonFileUploadComponent from "../datasources/JsonFileUploadComp"; +import NdjsonFileUploadComponent from "../datasources/NdjsonFileUploadComp"; + +interface Services { + getGraphService: () => GraphService; + setGraphService: (service: GraphService) => void; + getDocService: () => DocService; + setDocService: (service: DocService) => void; +} + +const ServiceContext = createContext(undefined); + +export const ServiceContextProvider: React.FC = ({ + children, +}) => { + const [graphService, setGraphService] = useState( + undefined, + ); + const [docService, setDocService] = useState( + undefined, + ); + + const value: Services = { + getGraphService: () => { + if (!graphService) { + throw new Error("DocService has not been initialized!"); + } + return graphService; + }, + setGraphService: (service: GraphService) => setGraphService(service), + getDocService: () => { + if (!docService) { + throw new Error("DocService has not been initialized!"); + } + return docService; + }, + setDocService: (service: DocService) => setDocService(service), + }; + + if (!graphService || !docService) { + const handleGraphFileLoaded = (data: any) => { + setGraphService(new FileGraphService(data)); + }; + + const handleDocsFileLoaded = (data: any) => { + setDocService(new FileDocService(data)); + }; + + return ( +
+ Load graph:  + + Load documents:  + +
+ ); + } + + return ( + {children} + ); +}; + +export const useServiceContext = (): Services => { + const context = useContext(ServiceContext); + if (!context) { + throw new Error("useServiceContext must be used within a ServiceProvider"); + } + return context; +}; diff --git a/visualizer/tsconfig.json b/visualizer/tsconfig.json index a273b0c..9d379a3 100644 --- a/visualizer/tsconfig.json +++ b/visualizer/tsconfig.json @@ -1,11 +1,7 @@ { "compilerOptions": { "target": "es5", - "lib": [ - "dom", - "dom.iterable", - "esnext" - ], + "lib": ["dom", "dom.iterable", "esnext"], "allowJs": true, "skipLibCheck": true, "esModuleInterop": true, @@ -20,7 +16,5 @@ "noEmit": true, "jsx": "react-jsx" }, - "include": [ - "src" - ] + "include": ["src"] }