From 3561092828300136af20629c8815be3fb9e8b26e Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Thu, 21 Nov 2024 08:16:39 +0100 Subject: [PATCH 01/29] Caching embeddings during clustering + fewer clusters by default --- config/eschatology.toml | 3 +- config/template.toml | 5 +- .../corpusprocessing/clustering.py | 48 ++++++++++++------- src/conspiracies/pipeline/config.py | 2 +- src/conspiracies/pipeline/pipeline.py | 1 + 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/config/eschatology.toml b/config/eschatology.toml index 7aa4c7b..bf7d81b 100644 --- a/config/eschatology.toml +++ b/config/eschatology.toml @@ -16,4 +16,5 @@ batch_size = 5 prefer_gpu_for_coref = false [corpusprocessing] -enabled = true \ No newline at end of file +enabled = true +dimensions = 50 \ No newline at end of file diff --git a/config/template.toml b/config/template.toml index cd76efd..6d7ae49 100644 --- a/config/template.toml +++ b/config/template.toml @@ -27,6 +27,5 @@ dimensions = 100 # leave out to skip dimensionality reduction n_neighbors = 15 # used for dimensionality reduction [corpusprocessing.thresholds] # leave out for automatic estimation -min_cluster_size = 3 # unused if auto_thresholds is true -min_samples = 3 # unused if auto_thresholds is true -min_topic_size = 5 # unused if auto_thresholds is true \ No newline at end of file +min_cluster_size = 3 +min_samples = 3 \ No newline at end of file diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index 05deb3e..46b7dd9 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -1,4 +1,6 @@ +import os from collections import defaultdict +from pathlib import Path from typing import List, Callable, Any, Hashable, Dict import networkx @@ -6,7 +8,6 @@ from hdbscan import HDBSCAN from pydantic import BaseModel from sentence_transformers import SentenceTransformer -from sklearn.preprocessing import StandardScaler from umap import UMAP from conspiracies.common.modelchoice import ModelChoice @@ -45,6 +46,7 @@ def __init__( min_cluster_size: int = 5, min_samples: int = 3, embedding_model: str = None, + cache_location: Path = None, ): self.language = language self.n_dimensions = n_dimensions @@ -52,6 +54,9 @@ def __init__( self.min_cluster_size = min_cluster_size self.min_samples = min_samples self._embedding_model = embedding_model + self.cache_location = cache_location + if self.cache_location is not None: + os.makedirs(self.cache_location, exist_ok=True) def _get_embedding_model(self): # figure out embedding model if not given explicitly @@ -100,14 +105,31 @@ def _combine_clusters( def _cluster( self, fields: List[TripletField], + cache_filename: str, ): - model = self._get_embedding_model() - print("Creating embeddings:") - embeddings = model.encode( - [field.text for field in fields], - show_progress_bar=True, - ) - embeddings = StandardScaler().fit_transform(embeddings) + if ( + self.cache_location + and Path(self.cache_location, f"embeddings-{cache_filename}.npy").exists() + ): + print( + "Reusing cached embeddings! Delete cache if this is not supposed to happen.", + ) + embeddings = np.load( + Path(self.cache_location, f"embeddings-{cache_filename}.npy"), + ) + else: + model = self._get_embedding_model() + print("Creating embeddings:") + embeddings = model.encode( + [field.text for field in fields], + normalize_embeddings=True, + show_progress_bar=True, + ) + if self.cache_location: + np.save( + Path(self.cache_location, f"embeddings-{cache_filename}.npy"), + embeddings, + ) if self.n_dimensions is not None: print("Reducing embedding space") @@ -138,12 +160,6 @@ def _cluster( get_combine_key=lambda t: t[0].text, ) - # too risky with false positives from this - # merged = self._combine_clusters( - # merged, - # get_combine_key=lambda t: t[0].head, - # ) - # sort by how "prototypical" a member is in the cluster for cluster in merged: mean = np.mean(np.stack([t[1] for t in cluster]), axis=0) @@ -167,9 +183,9 @@ def create_mappings(self, triplets: List[Triplet]) -> Mappings: predicates = [triplet.predicate for triplet in triplets] print("Creating mappings for entities") - entity_clusters = self._cluster(entities) + entity_clusters = self._cluster(entities, "entities") print("Creating mappings for predicates") - predicate_clusters = self._cluster(predicates) + predicate_clusters = self._cluster(predicates, "predicates") mappings = Mappings( entities=self._mapping_to_first_member(entity_clusters), diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index 6458eb8..ad96f23 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -35,7 +35,7 @@ class ClusteringThresholds(BaseModel): @classmethod def estimate_from_n_triplets(cls, n_triplets: int): - factor = n_triplets / 1000 + factor = n_triplets / 500 thresholds = cls( min_cluster_size=max(int(factor + 1), 2), min_samples=int(factor + 1), diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index 0da29cf..b3e4bae 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -114,6 +114,7 @@ def corpusprocessing(self): n_neighbors=self.config.corpusprocessing.n_neighbors, min_cluster_size=thresholds.min_cluster_size, min_samples=thresholds.min_samples, + cache_location=self.output_path / "cache", ) mappings = clustering.create_mappings(triplets) with open(self.output_path / "mappings.json", "w") as out: From 4763c00b5deeaad8c9d9d8131800f4c3827b3caf Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Thu, 21 Nov 2024 09:33:08 +0100 Subject: [PATCH 02/29] visualizer: prettier and switching double-click and hold ops --- visualizer/package-lock.json | 16 + visualizer/package.json | 3 +- .../src/graph/GraphFilterControlPanel.tsx | 163 +++++---- visualizer/src/graph/GraphService.ts | 315 +++++++++--------- visualizer/src/graph/GraphViewer.tsx | 264 ++++++++------- visualizer/src/graph/NodeInfo.tsx | 70 ++-- 6 files changed, 470 insertions(+), 361 deletions(-) diff --git a/visualizer/package-lock.json b/visualizer/package-lock.json index 2e97d87..2b5f041 100644 --- a/visualizer/package-lock.json +++ b/visualizer/package-lock.json @@ -26,6 +26,7 @@ "devDependencies": { "@electron/packager": "^18.3.5", "electron": "^33.0.2", + "prettier": "^3.3.3", "serve": "^14.2.4" } }, @@ -14118,6 +14119,21 @@ "node": ">= 0.8.0" } }, + "node_modules/prettier": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", + "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", + "dev": true, + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, "node_modules/pretty-bytes": { "version": "5.6.0", "resolved": "https://registry.npmjs.org/pretty-bytes/-/pretty-bytes-5.6.0.tgz", diff --git a/visualizer/package.json b/visualizer/package.json index dfc6bab..e6e9fa3 100644 --- a/visualizer/package.json +++ b/visualizer/package.json @@ -50,8 +50,9 @@ ] }, "devDependencies": { - "electron": "^33.0.2", "@electron/packager": "^18.3.5", + "electron": "^33.0.2", + "prettier": "^3.3.3", "serve": "^14.2.4" } } diff --git a/visualizer/src/graph/GraphFilterControlPanel.tsx b/visualizer/src/graph/GraphFilterControlPanel.tsx index 12210a1..bb93162 100644 --- a/visualizer/src/graph/GraphFilterControlPanel.tsx +++ b/visualizer/src/graph/GraphFilterControlPanel.tsx @@ -1,67 +1,108 @@ import React from "react"; -import {GraphFilter} from "./GraphService"; -import './graph.css' - +import { GraphFilter } from "./GraphService"; +import "./graph.css"; interface GraphFilterControlPanelProps { - graphFilter: GraphFilter; - setGraphFilter: React.Dispatch>; + graphFilter: GraphFilter; + setGraphFilter: React.Dispatch>; } -export const GraphFilterControlPanel = ({graphFilter, setGraphFilter}: GraphFilterControlPanelProps) => { - - return
-
- Minimum Node Frequency: {graphFilter.minimumNodeFrequency} - - -
-
- Minimum Edge Frequency: {graphFilter.minimumEdgeFrequency} - - -
-
- Show unconnected nodes: - setGraphFilter({ - ...graphFilter, - showUnconnectedNodes: event.target.checked - })}/> -
-
- From: setGraphFilter({ - ...graphFilter, - earliestDate: event.target.valueAsDate ?? undefined - })}/> - To: setGraphFilter({ - ...graphFilter, - latestDate: event.target.valueAsDate ?? undefined - })}/> -
+export const GraphFilterControlPanel = ({ + graphFilter, + setGraphFilter, +}: GraphFilterControlPanelProps) => { + return ( +
+
+ + Minimum Node Frequency: {graphFilter.minimumNodeFrequency} + + + +
+
+ Minimum Edge Frequency: {graphFilter.minimumEdgeFrequency} + + +
+
+ Show unconnected nodes: + + setGraphFilter({ + ...graphFilter, + showUnconnectedNodes: event.target.checked, + }) + } + /> +
+
+ From:{" "} + + setGraphFilter({ + ...graphFilter, + earliestDate: event.target.valueAsDate ?? undefined, + }) + } + /> + To:{" "} + + setGraphFilter({ + ...graphFilter, + latestDate: event.target.valueAsDate ?? undefined, + }) + } + /> +
-} \ No newline at end of file + ); +}; diff --git a/visualizer/src/graph/GraphService.ts b/visualizer/src/graph/GraphService.ts index 04ba4a7..7cbb905 100644 --- a/visualizer/src/graph/GraphService.ts +++ b/visualizer/src/graph/GraphService.ts @@ -1,189 +1,200 @@ -import {Edge, GraphData, Node} from "react-vis-graph-wrapper"; +import { Edge, GraphData, Node } from "react-vis-graph-wrapper"; export interface Stats { - frequency: number; - norm_frequency?: number; - docs?: string[]; - first_occurrence?: string; - last_occurrence?: string; - alt_labels?: string[]; + frequency: number; + norm_frequency?: number; + docs?: string[]; + first_occurrence?: string; + last_occurrence?: string; + alt_labels?: string[]; } export interface EnrichedNode extends Node { - stats: Stats; + stats: Stats; } export interface EnrichedEdge extends Edge { - stats: Stats; + stats: Stats; } export interface EnrichedGraphData extends GraphData { - nodes: EnrichedNode[]; - edges: EnrichedEdge[]; + nodes: EnrichedNode[]; + edges: EnrichedEdge[]; } export class GraphFilter { - minimumNodeFrequency: number; - minimumEdgeFrequency: number; - earliestDate?: Date; - latestDate?: Date; - showUnconnectedNodes: boolean = false; - - constructor(minimumNodeFrequency: number = 1, minimumEdgeFrequency: number = 1) { - this.minimumNodeFrequency = minimumNodeFrequency; - this.minimumEdgeFrequency = minimumEdgeFrequency; - } + minimumNodeFrequency: number; + minimumEdgeFrequency: number; + earliestDate?: Date; + latestDate?: Date; + showUnconnectedNodes: boolean = false; + + constructor( + minimumNodeFrequency: number = 1, + minimumEdgeFrequency: number = 1, + ) { + this.minimumNodeFrequency = minimumNodeFrequency; + this.minimumEdgeFrequency = minimumEdgeFrequency; + } } function hasDateOverlap(node: EnrichedNode, filter: GraphFilter): boolean { - if (!node.stats.first_occurrence || !node.stats.last_occurrence) { - return true; - } - const first = new Date(node.stats.first_occurrence); - const last = new Date(node.stats.last_occurrence); - const afterEarliestDate = !filter.earliestDate - || filter.earliestDate < first - || filter.earliestDate < last; - - const beforeLatestDate = !filter.latestDate - || filter.latestDate > first || filter.latestDate > last; - - return afterEarliestDate && beforeLatestDate; + if (!node.stats.first_occurrence || !node.stats.last_occurrence) { + return true; + } + const first = new Date(node.stats.first_occurrence); + const last = new Date(node.stats.last_occurrence); + const afterEarliestDate = + !filter.earliestDate || + filter.earliestDate < first || + filter.earliestDate < last; + + const beforeLatestDate = + !filter.latestDate || filter.latestDate > first || filter.latestDate > last; + + return afterEarliestDate && beforeLatestDate; } - -export function filter(filter: GraphFilter, graphData: EnrichedGraphData): EnrichedGraphData { - let nodes = graphData.nodes.filter((node: EnrichedNode) => - node.stats.frequency >= filter.minimumNodeFrequency - && hasDateOverlap(node, filter) - ); - let filteredNodes = new Set(nodes.map(node => node.id)); - let edges = graphData.edges.filter((edge: EnrichedEdge) => - edge.stats.frequency >= filter.minimumEdgeFrequency && - filteredNodes.has(edge.from) && filteredNodes.has(edge.to) - ); - let connectedNodes = new Set(edges.flatMap(edge => [edge.from, edge.to])); - if (!filter.showUnconnectedNodes) { - nodes = nodes.filter(node => connectedNodes.has(node.id)); - } - return {nodes, edges} +export function filter( + filter: GraphFilter, + graphData: EnrichedGraphData, +): EnrichedGraphData { + let nodes = graphData.nodes.filter( + (node: EnrichedNode) => + node.stats.frequency >= filter.minimumNodeFrequency && + hasDateOverlap(node, filter), + ); + let filteredNodes = new Set(nodes.map((node) => node.id)); + let edges = graphData.edges.filter( + (edge: EnrichedEdge) => + edge.stats.frequency >= filter.minimumEdgeFrequency && + filteredNodes.has(edge.from) && + filteredNodes.has(edge.to), + ); + let connectedNodes = new Set(edges.flatMap((edge) => [edge.from, edge.to])); + if (!filter.showUnconnectedNodes) { + nodes = nodes.filter((node) => connectedNodes.has(node.id)); + } + return { nodes, edges }; } export abstract class GraphService { - private nodesMap: Map | null = null; + private nodesMap: Map | null = null; - abstract getGraph(): EnrichedGraphData; + abstract getGraph(): EnrichedGraphData; - getSubGraph(nodeIds: Set): EnrichedGraphData { - return { - nodes: this.getGraph().nodes.filter((n: EnrichedNode) => nodeIds.has(n.id!.toString())), - edges: this.getGraph().edges - } - } + getSubGraph(nodeIds: Set): EnrichedGraphData { + return { + nodes: this.getGraph().nodes.filter((n: EnrichedNode) => + nodeIds.has(n.id!.toString()), + ), + edges: this.getGraph().edges, + }; + } - getConnectedNodes(nodeId: string): Set { - return new Set(this.getGraph().edges.filter(edge => edge.from === nodeId || edge.to === nodeId) - .flatMap(edge => [edge.from!.toString(), edge.to!.toString()])) - } + getConnectedNodes(nodeId: string): Set { + return new Set( + this.getGraph() + .edges.filter((edge) => edge.from === nodeId || edge.to === nodeId) + .flatMap((edge) => [edge.from!.toString(), edge.to!.toString()]), + ); + } - getNode(nodeId: string): EnrichedNode | undefined { - if (this.nodesMap === null) { - this.nodesMap = new Map( - this.getGraph().nodes.map(node => [node.id!.toString(), node]) - ) - } - - // highly inefficient linear search; overwrite for actual use - for (let node of this.getGraph().nodes) { - if (node.id === nodeId) { - return node; - } - } - return undefined; + getNode(nodeId: string): EnrichedNode | undefined { + if (this.nodesMap === null) { + this.nodesMap = new Map( + this.getGraph().nodes.map((node) => [node.id!.toString(), node]), + ); } + // highly inefficient linear search; overwrite for actual use + for (let node of this.getGraph().nodes) { + if (node.id === nodeId) { + return node; + } + } + return undefined; + } } - export class SampleGraphService extends GraphService { - readonly sampleGraphData: EnrichedGraphData = { - nodes: [ - { - id: "1", - label: "node 1", - stats: { - frequency: 3, - }, - }, - { - id: "2", - label: "node 2", - stats: { - frequency: 2, - }, - }, - { - id: "3", - label: "node 3", - stats: { - frequency: 2, - }, - }, - { - id: "4", - label: "node 4", - stats: { - frequency: 1, - }, - }, - ], - edges: [ - { - from: "1", - to: "2", - stats: { - frequency: 2, - }, - }, - { - from: "1", - to: "3", - stats: { - frequency: 2, - }, - }, - { - from: "1", - to: "4", - stats: { - frequency: 1, - }, - }, - { - from: "2", - to: "3", - stats: { - frequency: 2, - }, - }, - ], - }; - - getGraph(): EnrichedGraphData { - return this.sampleGraphData; - } + readonly sampleGraphData: EnrichedGraphData = { + nodes: [ + { + id: "1", + label: "node 1", + stats: { + frequency: 3, + }, + }, + { + id: "2", + label: "node 2", + stats: { + frequency: 2, + }, + }, + { + id: "3", + label: "node 3", + stats: { + frequency: 2, + }, + }, + { + id: "4", + label: "node 4", + stats: { + frequency: 1, + }, + }, + ], + edges: [ + { + from: "1", + to: "2", + stats: { + frequency: 2, + }, + }, + { + from: "1", + to: "3", + stats: { + frequency: 2, + }, + }, + { + from: "1", + to: "4", + stats: { + frequency: 1, + }, + }, + { + from: "2", + to: "3", + stats: { + frequency: 2, + }, + }, + ], + }; + + getGraph(): EnrichedGraphData { + return this.sampleGraphData; + } } export class FileGraphService extends GraphService { - private readonly data: EnrichedGraphData = {nodes: [], edges: []}; + private readonly data: EnrichedGraphData = { nodes: [], edges: [] }; - constructor(data: EnrichedGraphData) { - super(); - this.data = data; - } - - getGraph(): EnrichedGraphData { - return this.data; - } + constructor(data: EnrichedGraphData) { + super(); + this.data = data; + } + getGraph(): EnrichedGraphData { + return this.data; + } } diff --git a/visualizer/src/graph/GraphViewer.tsx b/visualizer/src/graph/GraphViewer.tsx index 478d0aa..9c13026 100644 --- a/visualizer/src/graph/GraphViewer.tsx +++ b/visualizer/src/graph/GraphViewer.tsx @@ -1,132 +1,158 @@ -import React, {useEffect, useRef, useState} from "react"; +import React, { useEffect, useRef, useState } from "react"; import { - EnrichedGraphData, - EnrichedNode, - FileGraphService, - filter, - GraphFilter, - GraphService, - SampleGraphService, + EnrichedGraphData, + EnrichedNode, + FileGraphService, + filter, + GraphFilter, + GraphService, + SampleGraphService, } from "./GraphService"; import FileUploadComponent from "../datasources/FileUploadComp"; -import Graph, {GraphEvents, Options} from "react-vis-graph-wrapper"; -import {GraphFilterControlPanel} from "./GraphFilterControlPanel"; -import {GraphOptionsControlPanel} from "./GraphOptionsControlPanel"; -import {NodeInfo} from "./NodeInfo"; - +import Graph, { GraphEvents, Options } from "react-vis-graph-wrapper"; +import { GraphFilterControlPanel } from "./GraphFilterControlPanel"; +import { GraphOptionsControlPanel } from "./GraphOptionsControlPanel"; +import { NodeInfo } from "./NodeInfo"; export const GraphViewer: React.FC = () => { - let graphServiceRef = useRef(new SampleGraphService()); - const [graphData, setGraphData] = useState( - graphServiceRef.current.getGraph() - ); - - const handleFileLoaded = (data: any) => { - graphServiceRef.current = new FileGraphService(data); - setGraphData(filter(graphFilter, graphServiceRef.current.getGraph())); - }; - - const [graphFilter, setGraphFilter] = useState(new GraphFilter(5, 3)) - const [selected, setSelected] = useState(new Set()) - const [selectedNode, setSelectedNode] = useState(undefined) + let graphServiceRef = useRef(new SampleGraphService()); + const [graphData, setGraphData] = useState( + graphServiceRef.current.getGraph(), + ); - useEffect( - () => { - let newGraphData: EnrichedGraphData; - if (selected.size > 0) { - newGraphData = graphServiceRef.current.getSubGraph(selected); - } else { - newGraphData = graphServiceRef.current.getGraph(); - } - setGraphData(filter(graphFilter, newGraphData)) - }, - [graphFilter, selected] - ) + const handleFileLoaded = (data: any) => { + if (data === "SAMPLE") { + graphServiceRef.current = new SampleGraphService(); + setGraphFilter(new GraphFilter(1, 0)); + } else { + graphServiceRef.current = new FileGraphService(data); + const top50 = + graphServiceRef.current + .getGraph() + .nodes.map((n) => n.stats.frequency) + .sort() + .reverse() + .at(100) || 1; + setGraphFilter(new GraphFilter(top50, Math.floor(top50 / 2))); + } + setGraphData(filter(graphFilter, graphServiceRef.current.getGraph())); + }; - let events: GraphEvents = { - doubleClick: ({nodes}) => { - const newSelected = new Set(selected); - nodes.forEach((element: string) => { - newSelected.delete(element); - }); - setSelected(newSelected); - }, - select: ({nodes}) => { - let newSelected: Set; - if (nodes.length > 1) { - newSelected = new Set(); - nodes.forEach((element: string) => { - newSelected.add(element); - }); - setSelected(newSelected); - } - }, - hold: ({nodes}) => { - const newSelected = new Set(selected); - nodes.forEach((element: string) => { - Array.from(graphServiceRef.current.getConnectedNodes(element)).forEach(c => newSelected.add(c)) - }); - setSelected(newSelected); - }, - selectNode: ({nodes}) => { - setSelectedNode(graphServiceRef.current.getNode(nodes[0])); - }, - deselectNode: () => { - setSelectedNode(undefined); - } - }; + const [graphFilter, setGraphFilter] = useState(new GraphFilter(5, 3)); + const [selected, setSelected] = useState(new Set()); + const [selectedNode, setSelectedNode] = useState( + undefined, + ); - let [options, setOptions] = useState({ - physics: { - enabled: true, - barnesHut: { - springLength: 200 - } - }, - edges: { - smooth: false, - font: { - align: 'top' - } - } - }) + useEffect(() => { + let newGraphData: EnrichedGraphData; + if (selected.size > 0) { + newGraphData = graphServiceRef.current.getSubGraph(selected); + } else { + newGraphData = graphServiceRef.current.getGraph(); + } + setGraphData(filter(graphFilter, newGraphData)); + }, [graphFilter, selected]); - return ( -
+ let events: GraphEvents = { + hold: ({ nodes }) => { + const newSelected = new Set(selected); + nodes.forEach((element: string) => { + newSelected.delete(element); + }); + setSelected(newSelected); + }, + select: ({ nodes }) => { + let newSelected: Set; + if (nodes.length > 1) { + newSelected = new Set(); + nodes.forEach((element: string) => { + newSelected.add(element); + }); + setSelected(newSelected); + } + }, + doubleClick: ({ nodes }) => { + const newSelected = new Set(selected); + nodes.forEach((element: string) => { + Array.from(graphServiceRef.current.getConnectedNodes(element)).forEach( + (c) => newSelected.add(c), + ); + }); + setSelected(newSelected); + }, + selectNode: ({ nodes }) => { + setSelectedNode(graphServiceRef.current.getNode(nodes[0])); + }, + deselectNode: () => { + setSelectedNode(undefined); + }, + }; -
- -
-
- - -
-
-
-
- Shift+select to show subgraph. -
-
- Double-click node to remove it. -
-
- Hold to expand from node. -
- -
+ let [options, setOptions] = useState({ + physics: { + enabled: true, + barnesHut: { + springLength: 200, + }, + }, + edges: { + smooth: false, + font: { + align: "top", + }, + }, + }); - -
-
- {selectedNode && } - - -
+ return ( +
+
+ + +
+
+
+ + +
+
+
+
+ + Shift+mark multiple to make subgraph. + +
+
+ + Hold node to remove it. + +
+
+ + Double-click to expand from node. + +
+
- ); +
+
+ {selectedNode && } + + +
+
+ ); }; diff --git a/visualizer/src/graph/NodeInfo.tsx b/visualizer/src/graph/NodeInfo.tsx index 03189c4..eaccbdb 100644 --- a/visualizer/src/graph/NodeInfo.tsx +++ b/visualizer/src/graph/NodeInfo.tsx @@ -1,34 +1,48 @@ -import {EnrichedNode} from "./GraphService"; +import { EnrichedNode } from "./GraphService"; import React from "react"; export interface NodeInfoProps { - node: EnrichedNode - className?: string; + node: EnrichedNode; + className?: string; } -export const NodeInfo: React.FC = ({node, className}: NodeInfoProps) => { - const stats = node.stats; - return
- {node.label} -
-
-

Frequency: {stats.frequency}

-

Norm. frequency: {stats.norm_frequency?.toPrecision(3)}

- {stats.first_occurrence &&

Earliest date: {stats.first_occurrence}

} - {stats.last_occurrence &&

Latest date: {stats.last_occurrence}

} - {stats.alt_labels && -
- Labels: -
    {stats.alt_labels.map(l =>
  • {l}
  • )}
-
- } - {stats.docs && -
- Documents -
    {stats.docs.map(d =>
  • {d}
  • )}
-
- } -
- +export const NodeInfo: React.FC = ({ + node, + className, +}: NodeInfoProps) => { + const stats = node.stats; + return ( +
+ {node.label} +
+
+

Frequency: {stats.frequency}

+

Norm. frequency: {stats.norm_frequency?.toPrecision(3)}

+ {stats.first_occurrence && ( +

Earliest date: {stats.first_occurrence}

+ )} + {stats.last_occurrence &&

Latest date: {stats.last_occurrence}

} + {stats.alt_labels && ( +
+ Labels: +
    + {stats.alt_labels.map((l) => ( +
  • {l}
  • + ))} +
+
+ )} + {stats.docs && ( +
+ Documents +
    + {stats.docs.map((d) => ( +
  • {d}
  • + ))} +
+
+ )} +
-} \ No newline at end of file + ); +}; From 0f862e755509640069c8d2a3341c5af0774c119f Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Thu, 21 Nov 2024 09:40:38 +0100 Subject: [PATCH 03/29] More caching during clustering --- .../corpusprocessing/clustering.py | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index 46b7dd9..3c7a230 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -107,15 +107,13 @@ def _cluster( fields: List[TripletField], cache_filename: str, ): - if ( - self.cache_location - and Path(self.cache_location, f"embeddings-{cache_filename}.npy").exists() - ): + emb_cache = f"embeddings-{cache_filename}.npy" + if self.cache_location and Path(self.cache_location, emb_cache).exists(): print( "Reusing cached embeddings! Delete cache if this is not supposed to happen.", ) embeddings = np.load( - Path(self.cache_location, f"embeddings-{cache_filename}.npy"), + Path(self.cache_location, emb_cache), ) else: model = self._get_embedding_model() @@ -127,19 +125,40 @@ def _cluster( ) if self.cache_location: np.save( - Path(self.cache_location, f"embeddings-{cache_filename}.npy"), + Path(self.cache_location, emb_cache), embeddings, ) if self.n_dimensions is not None: - print("Reducing embedding space") - reducer = UMAP(n_components=self.n_dimensions, n_neighbors=self.n_neighbors) - embeddings = reducer.fit_transform(embeddings) - - print("Clustering ...") + reduced_emb_cache = ( + f"embeddings-{cache_filename}-red{self.n_dimensions}.npy" + ) + if ( + self.cache_location + and Path(self.cache_location, reduced_emb_cache).exists() + ): + print( + "Reusing cached reduced embeddings! Delete cache if this is not supposed to happen.", + ) + embeddings = np.load(Path(self.cache_location, reduced_emb_cache)) + else: + print("Reducing embedding space ...") + reducer = UMAP( + n_components=self.n_dimensions, + n_neighbors=self.n_neighbors, + ) + embeddings = reducer.fit_transform(embeddings) + if self.cache_location: + np.save( + Path(self.cache_location, reduced_emb_cache), + embeddings, + ) + + print("Clustering ... (Delete cache to ensure recalculation)") hdbscan_model = HDBSCAN( min_cluster_size=self.min_cluster_size, min_samples=self.min_samples, + memory=str(self.cache_location), ) hdbscan_model.fit(embeddings) From 6dd1d9591899f58ebe443e34e67c123304358b57 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Thu, 21 Nov 2024 12:53:34 +0100 Subject: [PATCH 04/29] Removing the long tail of entities upfront and reducing cluster sizes --- src/conspiracies/corpusprocessing/clustering.py | 1 + src/conspiracies/corpusprocessing/triplet.py | 17 +++++++++++++++++ src/conspiracies/pipeline/config.py | 9 ++++++--- src/conspiracies/pipeline/pipeline.py | 12 ++++++++---- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index 3c7a230..10fd98b 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -157,6 +157,7 @@ def _cluster( print("Clustering ... (Delete cache to ensure recalculation)") hdbscan_model = HDBSCAN( min_cluster_size=self.min_cluster_size, + max_cluster_size=20, # somewhat arbitrary number, mostly to avoid mega clusters that suck up everything min_samples=self.min_samples, memory=str(self.cache_location), ) diff --git a/src/conspiracies/corpusprocessing/triplet.py b/src/conspiracies/corpusprocessing/triplet.py index 5d87b17..fdc065b 100644 --- a/src/conspiracies/corpusprocessing/triplet.py +++ b/src/conspiracies/corpusprocessing/triplet.py @@ -1,4 +1,5 @@ import json +from collections import Counter from datetime import datetime from pathlib import Path from typing import Optional, Set, Iterator, Iterable, List, Union @@ -52,6 +53,22 @@ def filter_on_stopwords( if not triplet.has_blacklist_match(stopwords) ] + @staticmethod + def filter_on_entity_label_frequency( + triplets: Iterable["Triplet"], + min_frequency: int, + ): + entity_label_counter = Counter( + f.text for triplet in triplets for f in (triplet.subject, triplet.object) + ) + filtered = [ + triplet + for triplet in triplets + if entity_label_counter[triplet.subject.text] >= min_frequency + and entity_label_counter[triplet.subject.text] >= min_frequency + ] + return filtered + @classmethod def from_annotated_docs(cls, path: Path) -> Iterator["Triplet"]: return ( diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index ad96f23..b06d12c 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -1,3 +1,4 @@ +import math from typing import Any import toml @@ -29,14 +30,16 @@ class DocProcessingConfig(StepConfig): prefer_gpu_for_coref: bool = False -class ClusteringThresholds(BaseModel): +class Thresholds(BaseModel): + min_label_occurrence: int min_cluster_size: int min_samples: int @classmethod def estimate_from_n_triplets(cls, n_triplets: int): - factor = n_triplets / 500 + factor = n_triplets / 1000 thresholds = cls( + min_label_occurrence=math.floor(math.log10(n_triplets)) - 1, min_cluster_size=max(int(factor + 1), 2), min_samples=int(factor + 1), ) @@ -47,7 +50,7 @@ class CorpusProcessingConfig(StepConfig): dimensions: int = None n_neighbors: int = 15 embedding_model: str = None - thresholds: ClusteringThresholds = None + thresholds: Thresholds = None class PipelineConfig(BaseModel): diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index b3e4bae..eabccaf 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -9,7 +9,7 @@ from conspiracies.corpusprocessing.triplet import Triplet from conspiracies.docprocessing.docprocessor import DocProcessor from conspiracies.document import Document -from conspiracies.pipeline.config import PipelineConfig, ClusteringThresholds +from conspiracies.pipeline.config import PipelineConfig, Thresholds from conspiracies.preprocessing.csv import CsvPreprocessor from conspiracies.preprocessing.infomedia import InfoMediaPreprocessor from conspiracies.preprocessing.preprocessor import Preprocessor @@ -101,12 +101,16 @@ def corpusprocessing(self): print("Collecting triplets.") triplets = Triplet.from_annotated_docs(self.output_path / "annotations.ndjson") triplets = Triplet.filter_on_stopwords(triplets, self.config.base.language) - Triplet.write_jsonl(self.output_path / "triplets.ndjson", triplets) - if self.config.corpusprocessing.thresholds is None: - thresholds = ClusteringThresholds.estimate_from_n_triplets(len(triplets)) + thresholds = Thresholds.estimate_from_n_triplets(len(triplets)) else: thresholds = self.config.corpusprocessing.thresholds + triplets = Triplet.filter_on_entity_label_frequency( + triplets, + thresholds.min_label_occurrence, + ) + Triplet.write_jsonl(self.output_path / "triplets.ndjson", triplets) + print("Clustering entities and predicates to create mappings.") clustering = Clustering( language=self.config.base.language, From f0578d8bd54b78c36d166cb02be5af09b879f3b3 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Mon, 25 Nov 2024 11:38:56 +0100 Subject: [PATCH 05/29] Without embedding-based clustering, only label overlap --- .../corpusprocessing/clustering.py | 97 +++++++++++++++++-- src/conspiracies/corpusprocessing/triplet.py | 12 ++- src/conspiracies/pipeline/config.py | 7 +- tests/test_clustering.py | 17 ++++ 4 files changed, 119 insertions(+), 14 deletions(-) diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index 10fd98b..991598a 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -1,7 +1,8 @@ +import math import os -from collections import defaultdict +from collections import defaultdict, Counter from pathlib import Path -from typing import List, Callable, Any, Hashable, Dict +from typing import List, Callable, Any, Hashable, Dict, Union import networkx import numpy as np @@ -118,8 +119,15 @@ def _cluster( else: model = self._get_embedding_model() print("Creating embeddings:") + + counter = Counter((field.text for field in fields)) + condensed = [ + field + for field, count in counter.items() + for _ in range(math.ceil(count / 1000)) + ] embeddings = model.encode( - [field.text for field in fields], + condensed, normalize_embeddings=True, show_progress_bar=True, ) @@ -171,7 +179,7 @@ def _cluster( hdbscan_model.probabilities_, ): # skip noise and low confidence - if label == -1 or probability < 0.1: + if label == -1 or probability < 0.5: continue clusters[label].append((field, embedding)) @@ -189,11 +197,69 @@ def _cluster( return [[t[0] for t in cluster] for cluster in merged] @staticmethod - def _mapping_to_first_member(clusters: List[List[TripletField]]) -> Dict[str, str]: + def _cluster_via_normalization( + labels: List[str], + top: Union[int, float] = 1.0, + restrictive_labels=True, + ) -> List[List[str]]: + counter = Counter((label for label in labels)) + if isinstance(top, float): + top = int(top * len(counter)) + + norm_map = { + label: " " + + label.lower() + + " " # surrounding spaces avoids matches like evil <-> devil + for label in counter.keys() + } + cluster_map = { + label: [] + for label, count in counter.most_common(top) + # FIXME: hack due to lack of NER and lemmas at the time of writing + if not restrictive_labels + or len(label) >= 4 + and label[0].isupper() + or len(label.split()) > 1 + } + + for label in counter.keys(): + norm_label = norm_map[label] + matches = [ + substring + for substring in cluster_map.keys() + if norm_map[substring] in norm_label + ] + if not matches: + continue + + best_match = min( + matches, + key=lambda substring: len(norm_map[substring]), + ) + if best_match != label: + cluster_map[best_match].append(label) + + clusters = [ + [main_label] + alt_labels + for main_label, alt_labels in cluster_map.items() + if alt_labels + ] + return clusters + + @staticmethod + def _mapping_to_first_member( + clusters: List[List[TripletField | str]], + ) -> Dict[str, str]: + def get_text(member: TripletField | str): + if isinstance(member, TripletField): + return member.text + else: + return member + return { - member: cluster[0].text + member: get_text(cluster[0]) for cluster in clusters - for member in set(member.text for member in cluster) + for member in set(get_text(member) for member in cluster) } def create_mappings(self, triplets: List[Triplet]) -> Mappings: @@ -202,10 +268,23 @@ def create_mappings(self, triplets: List[Triplet]) -> Mappings: entities = subjects + objects predicates = [triplet.predicate for triplet in triplets] + # FIXME: clustering gets way to aggressive for many triplets + # print("Creating mappings for entities") + # entity_clusters = self._cluster(entities, "entities") + # print("Creating mappings for predicates") + # predicate_clusters = self._cluster(predicates, "predicates") + print("Creating mappings for entities") - entity_clusters = self._cluster(entities, "entities") + entity_clusters = self._cluster_via_normalization( + [e.text for e in entities], + 0.2, + ) print("Creating mappings for predicates") - predicate_clusters = self._cluster(predicates, "predicates") + predicate_clusters = self._cluster_via_normalization( + [p.text for p in predicates], + top=0.2, + restrictive_labels=False, + ) mappings = Mappings( entities=self._mapping_to_first_member(entity_clusters), diff --git a/src/conspiracies/corpusprocessing/triplet.py b/src/conspiracies/corpusprocessing/triplet.py index fdc065b..af809ba 100644 --- a/src/conspiracies/corpusprocessing/triplet.py +++ b/src/conspiracies/corpusprocessing/triplet.py @@ -1,5 +1,5 @@ import json -from collections import Counter +from collections import Counter, defaultdict from datetime import datetime from pathlib import Path from typing import Optional, Set, Iterator, Iterable, List, Union @@ -57,15 +57,23 @@ def filter_on_stopwords( def filter_on_entity_label_frequency( triplets: Iterable["Triplet"], min_frequency: int, + min_doc_frequency: int = 1, ): entity_label_counter = Counter( f.text for triplet in triplets for f in (triplet.subject, triplet.object) ) + docs = defaultdict(set) + for triplet in triplets: + for f in (triplet.subject, triplet.object): + docs[f.text].add(triplet.doc) + doc_frequency = {label: len(docs) for label, docs in docs.items()} + filtered = [ triplet for triplet in triplets if entity_label_counter[triplet.subject.text] >= min_frequency - and entity_label_counter[triplet.subject.text] >= min_frequency + and entity_label_counter[triplet.object.text] >= min_frequency + and doc_frequency[triplet.subject.text] >= min_doc_frequency ] return filtered diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index b06d12c..f0425a4 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -37,11 +37,12 @@ class Thresholds(BaseModel): @classmethod def estimate_from_n_triplets(cls, n_triplets: int): - factor = n_triplets / 1000 + # factor = n_triplets / 10_000 thresholds = cls( min_label_occurrence=math.floor(math.log10(n_triplets)) - 1, - min_cluster_size=max(int(factor + 1), 2), - min_samples=int(factor + 1), + min_label_doc_freq=2, + min_cluster_size=2, + min_samples=2, ) return thresholds diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 92cf144..2ba48d1 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -29,3 +29,20 @@ def test_tuples_with_second_element_as_combine_key(self): Clustering._combine_clusters(clusters, get_combine_key=lambda x: x[1]) == expected ) + + +def test_cluster_by_normalization(): + labels = [ + "popular label", + "popular label", + "popular label 2", + "another label", + "another label", + "yet another label", + "a third label", + ] + clusters = Clustering._cluster_via_normalization(labels, top=2) + assert clusters == [ + ["popular label", "popular label 2"], + ["another label", "yet another label"], + ] From 0a916daeeeec2bb96e5f522a4ff8276e9e52fd40 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 26 Nov 2024 09:25:46 +0100 Subject: [PATCH 06/29] Better frequency slider for nodes and edges on visualizer --- visualizer/package-lock.json | 6 + visualizer/package.json | 1 + .../src/common/LogarithmicRangeSlider.tsx | 107 ++++++++++++++++++ .../src/graph/GraphFilterControlPanel.tsx | 106 +++++++++-------- visualizer/src/graph/GraphService.ts | 50 +++++++- visualizer/src/graph/GraphViewer.tsx | 65 +++++------ visualizer/src/graph/NodeInfo.tsx | 4 +- visualizer/src/index.tsx | 2 - 8 files changed, 252 insertions(+), 89 deletions(-) create mode 100644 visualizer/src/common/LogarithmicRangeSlider.tsx diff --git a/visualizer/package-lock.json b/visualizer/package-lock.json index 2b5f041..16425f6 100644 --- a/visualizer/package-lock.json +++ b/visualizer/package-lock.json @@ -14,6 +14,7 @@ "@types/jest": "^27.5.2", "@types/node": "^16.18.96", "@types/react-dom": "^18.2.24", + "multi-range-slider-react": "^2.0.7", "react": "^18.2.0", "react-dom": "^18.2.0", "react-graph-vis": "^1.0.7", @@ -12083,6 +12084,11 @@ "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" }, + "node_modules/multi-range-slider-react": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/multi-range-slider-react/-/multi-range-slider-react-2.0.7.tgz", + "integrity": "sha512-KRYUkatXxxYceL5ZT8xvetIN+4yTCdWszxRC6Y6Jkua+oRrWVkmBR6v3R03kosYg/QtcETBf2L1Jt+4U66DFbg==" + }, "node_modules/multicast-dns": { "version": "7.2.5", "resolved": "https://registry.npmjs.org/multicast-dns/-/multicast-dns-7.2.5.tgz", diff --git a/visualizer/package.json b/visualizer/package.json index e6e9fa3..c64ff2e 100644 --- a/visualizer/package.json +++ b/visualizer/package.json @@ -9,6 +9,7 @@ "@types/jest": "^27.5.2", "@types/node": "^16.18.96", "@types/react-dom": "^18.2.24", + "multi-range-slider-react": "^2.0.7", "react": "^18.2.0", "react-dom": "^18.2.0", "react-graph-vis": "^1.0.7", diff --git a/visualizer/src/common/LogarithmicRangeSlider.tsx b/visualizer/src/common/LogarithmicRangeSlider.tsx new file mode 100644 index 0000000..b766506 --- /dev/null +++ b/visualizer/src/common/LogarithmicRangeSlider.tsx @@ -0,0 +1,107 @@ +import React, { useEffect, useMemo, useState } from "react"; +import MultiRangeSlider from "multi-range-slider-react"; + +interface LogarithmicRangeSliderProps { + min: number; // Real-world minimum value + max: number; // Real-world maximum value + minValue: number; // Current minimum value + maxValue: number; // Current maximum value + onChange: (values: { minValue: number; maxValue: number }) => void; // Callback for value changes + style?: React.CSSProperties; // Optional style prop + ruler?: boolean; // Optional ruler prop +} + +const linearToLog = ( + value: number, + minLinear: number, + maxLinear: number, + minLog: number, + maxLog: number, +): number => { + const clampedValue = Math.max(minLinear, Math.min(value, maxLinear)); + const linearRange = maxLinear - minLinear; + const logRange = Math.log(maxLog) - Math.log(minLog); + const logValue = + Math.log(minLog) + ((clampedValue - minLinear) / linearRange) * logRange; + return Math.exp(logValue); +}; + +const logToLinear = ( + value: number, + minLinear: number, + maxLinear: number, + minLog: number, + maxLog: number, +): number => { + const clampedValue = Math.max(minLog, Math.min(value, maxLog)); + const linearRange = maxLinear - minLinear; + const logRange = Math.log(maxLog) - Math.log(minLog); + const logValue = Math.log(clampedValue); + return minLinear + ((logValue - Math.log(minLog)) / logRange) * linearRange; +}; + +const LogarithmicRangeSlider: React.FC = ({ + min, + max, + minValue, + maxValue, + onChange, + ruler, + ...rest +}) => { + const [minCaption, setMinCaption] = useState(Math.round(minValue)); + const [maxCaption, setMaxCaption] = useState(Math.round(maxValue)); + useEffect(() => { + setMinCaption(minValue); + setMaxCaption(maxValue); + }, [minValue, maxValue]); + + // Slider operates on a linear scale + const linearMin = 0; + const linearMax = 100; + const realValueToLinearScale = useMemo(() => { + return (realValue: number) => + Math.round(logToLinear(realValue, linearMin, linearMax, min, max)); + }, [min, max]); + const linearScaleToRealValue = useMemo(() => { + return (linearScaleValue: number) => + Math.round(linearToLog(linearScaleValue, linearMin, linearMax, min, max)); + }, [min, max]); + + const handleSliderInput = (e: { minValue: number; maxValue: number }) => { + const realMinValue = linearScaleToRealValue(e.minValue); + const realMaxValue = linearScaleToRealValue(e.maxValue); + setMinCaption(realMinValue); + setMaxCaption(realMaxValue); + }; + const handleSliderChange = (e: { minValue: number; maxValue: number }) => { + const realMinValue = linearScaleToRealValue(e.minValue); + const realMaxValue = linearScaleToRealValue(e.maxValue); + onChange({ minValue: realMinValue, maxValue: realMaxValue }); + }; + + const labels = [ + String(min), + String(linearScaleToRealValue(25)), + String(linearScaleToRealValue(50)), + String(linearScaleToRealValue(75)), + String(max), + ]; + + return ( + + ); +}; +export default LogarithmicRangeSlider; diff --git a/visualizer/src/graph/GraphFilterControlPanel.tsx b/visualizer/src/graph/GraphFilterControlPanel.tsx index bb93162..d01eab5 100644 --- a/visualizer/src/graph/GraphFilterControlPanel.tsx +++ b/visualizer/src/graph/GraphFilterControlPanel.tsx @@ -1,6 +1,7 @@ import React from "react"; import { GraphFilter } from "./GraphService"; import "./graph.css"; +import LogarithmicRangeSlider from "../common/LogarithmicRangeSlider"; interface GraphFilterControlPanelProps { graphFilter: GraphFilter; @@ -11,59 +12,55 @@ export const GraphFilterControlPanel = ({ graphFilter, setGraphFilter, }: GraphFilterControlPanelProps) => { + const setMinAndMaxNodeFrequency = (min: number, max: number) => { + setGraphFilter({ + ...graphFilter, + minimumNodeFrequency: min, + maximumNodeFrequency: max, + }); + }; + const setMinAndMaxEdgeFrequency = (min: number, max: number) => { + setGraphFilter({ + ...graphFilter, + minimumEdgeFrequency: min, + maximumEdgeFrequency: max, + }); + }; + return (
- Minimum Node Frequency: {graphFilter.minimumNodeFrequency} + Node Frequency: - -
+
+ { + setMinAndMaxNodeFrequency(e.minValue, e.maxValue); + }} + min={1} + minValue={graphFilter.minimumNodeFrequency} + maxValue={graphFilter.maximumNodeFrequency} + max={graphFilter.maximumPossibleNodeFrequency} + style={{ border: "none", boxShadow: "none", padding: "15px 10px" }} + > +
+
- Minimum Edge Frequency: {graphFilter.minimumEdgeFrequency} - - + Edge Frequency: +
+ { + setMinAndMaxEdgeFrequency(e.minValue, e.maxValue); + }} + min={graphFilter.minimumPossibleEdgeFrequency} + minValue={graphFilter.minimumEdgeFrequency} + maxValue={graphFilter.maximumEdgeFrequency} + max={graphFilter.maximumPossibleEdgeFrequency} + style={{ border: "none", boxShadow: "none", padding: "15px 10px" }} + > +
Show unconnected nodes: @@ -80,7 +77,20 @@ export const GraphFilterControlPanel = ({ />
- From:{" "} + Search nodes: + { + let value = event.target.value; + setGraphFilter({ + ...graphFilter, + labelSearch: value, + }); + }} + /> +
+
+ From: - To:{" "} + To: node.stats.frequency >= filter.minimumNodeFrequency && + node.stats.frequency < filter.maximumNodeFrequency && hasDateOverlap(node, filter), ); let filteredNodes = new Set(nodes.map((node) => node.id)); let edges = graphData.edges.filter( (edge: EnrichedEdge) => edge.stats.frequency >= filter.minimumEdgeFrequency && + edge.stats.frequency < filter.maximumEdgeFrequency && filteredNodes.has(edge.from) && filteredNodes.has(edge.to), ); @@ -75,14 +94,41 @@ export function filter( if (!filter.showUnconnectedNodes) { nodes = nodes.filter((node) => connectedNodes.has(node.id)); } + nodes = nodes.map((node) => ({ + ...node, + ...(node.label?.toLowerCase().includes(filter.labelSearch) + ? { opacity: 1 } + : { opacity: 0.2 }), + })); + return { nodes, edges }; } +export interface DataBounds { + minNodeFrequency: number; + maxNodeFrequency: number; + maxEdgeFrequency: number; +} + export abstract class GraphService { private nodesMap: Map | null = null; abstract getGraph(): EnrichedGraphData; + getBounds(): DataBounds { + return { + minNodeFrequency: Math.min( + ...this.getGraph().nodes.map((n) => n.stats.frequency), + ), + maxNodeFrequency: Math.max( + ...this.getGraph().nodes.map((n) => n.stats.frequency), + ), + maxEdgeFrequency: Math.max( + ...this.getGraph().edges.map((n) => n.stats.frequency), + ), + }; + } + getSubGraph(nodeIds: Set): EnrichedGraphData { return { nodes: this.getGraph().nodes.filter((n: EnrichedNode) => diff --git a/visualizer/src/graph/GraphViewer.tsx b/visualizer/src/graph/GraphViewer.tsx index 9c13026..dfcd07f 100644 --- a/visualizer/src/graph/GraphViewer.tsx +++ b/visualizer/src/graph/GraphViewer.tsx @@ -1,6 +1,5 @@ -import React, { useEffect, useRef, useState } from "react"; +import React, { useMemo, useRef, useState } from "react"; import { - EnrichedGraphData, EnrichedNode, FileGraphService, filter, @@ -16,42 +15,43 @@ import { NodeInfo } from "./NodeInfo"; export const GraphViewer: React.FC = () => { let graphServiceRef = useRef(new SampleGraphService()); - const [graphData, setGraphData] = useState( - graphServiceRef.current.getGraph(), - ); const handleFileLoaded = (data: any) => { - if (data === "SAMPLE") { - graphServiceRef.current = new SampleGraphService(); - setGraphFilter(new GraphFilter(1, 0)); - } else { - graphServiceRef.current = new FileGraphService(data); - const top50 = - graphServiceRef.current - .getGraph() - .nodes.map((n) => n.stats.frequency) - .sort() - .reverse() - .at(100) || 1; - setGraphFilter(new GraphFilter(top50, Math.floor(top50 / 2))); - } - setGraphData(filter(graphFilter, graphServiceRef.current.getGraph())); + graphServiceRef.current = new FileGraphService(data); + const top50 = + graphServiceRef.current + .getGraph() + .nodes.map((n) => n.stats.frequency) + .sort((a, b) => b - a) + .at(100) || 1; + let { minNodeFrequency, maxNodeFrequency, maxEdgeFrequency } = + graphServiceRef.current.getBounds(); + setGraphFilter( + new GraphFilter( + minNodeFrequency, + top50, + maxNodeFrequency, + 1, + Math.floor(top50 / 10), + maxEdgeFrequency, + ), + ); }; - const [graphFilter, setGraphFilter] = useState(new GraphFilter(5, 3)); + const [graphFilter, setGraphFilter] = useState( + new GraphFilter(1, 1, 10, 1, 1, 10), + ); const [selected, setSelected] = useState(new Set()); const [selectedNode, setSelectedNode] = useState( undefined, ); - useEffect(() => { - let newGraphData: EnrichedGraphData; - if (selected.size > 0) { - newGraphData = graphServiceRef.current.getSubGraph(selected); - } else { - newGraphData = graphServiceRef.current.getGraph(); - } - setGraphData(filter(graphFilter, newGraphData)); + const filteredGraphData = useMemo(() => { + const baseGraphData = + selected.size > 0 + ? graphServiceRef.current.getSubGraph(selected) + : graphServiceRef.current.getGraph(); + return filter(graphFilter, baseGraphData); }, [graphFilter, selected]); let events: GraphEvents = { @@ -108,7 +108,6 @@ export const GraphViewer: React.FC = () => {
-

@@ -138,9 +137,6 @@ export const GraphViewer: React.FC = () => {
{selectedNode && } - - +
); diff --git a/visualizer/src/graph/NodeInfo.tsx b/visualizer/src/graph/NodeInfo.tsx index eaccbdb..1e3e71f 100644 --- a/visualizer/src/graph/NodeInfo.tsx +++ b/visualizer/src/graph/NodeInfo.tsx @@ -17,14 +17,14 @@ export const NodeInfo: React.FC = ({

Frequency: {stats.frequency}

-

Norm. frequency: {stats.norm_frequency?.toPrecision(3)}

+ {/*

Norm. frequency: {stats.norm_frequency?.toPrecision(3)}

*/} {stats.first_occurrence && (

Earliest date: {stats.first_occurrence}

)} {stats.last_occurrence &&

Latest date: {stats.last_occurrence}

} {stats.alt_labels && (
- Labels: + Alternative Labels:
    {stats.alt_labels.map((l) => (
  • {l}
  • diff --git a/visualizer/src/index.tsx b/visualizer/src/index.tsx index cfd7784..77c175d 100644 --- a/visualizer/src/index.tsx +++ b/visualizer/src/index.tsx @@ -7,7 +7,5 @@ const root = ReactDOM.createRoot( document.getElementById("root") as HTMLElement ); root.render( - - ); From 7a85b493dd3281604ecb12bc9cce332e48eec09a Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 26 Nov 2024 09:34:32 +0100 Subject: [PATCH 07/29] Also filtering edges on dates in visualizer --- visualizer/src/graph/GraphService.ts | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/visualizer/src/graph/GraphService.ts b/visualizer/src/graph/GraphService.ts index b1e1a95..9cede8a 100644 --- a/visualizer/src/graph/GraphService.ts +++ b/visualizer/src/graph/GraphService.ts @@ -55,12 +55,15 @@ export class GraphFilter { } } -function hasDateOverlap(node: EnrichedNode, filter: GraphFilter): boolean { - if (!node.stats.first_occurrence || !node.stats.last_occurrence) { +function hasDateOverlap( + nodeOrEdge: EnrichedNode | EnrichedEdge, + filter: GraphFilter, +): boolean { + if (!nodeOrEdge.stats.first_occurrence || !nodeOrEdge.stats.last_occurrence) { return true; } - const first = new Date(node.stats.first_occurrence); - const last = new Date(node.stats.last_occurrence); + const first = new Date(nodeOrEdge.stats.first_occurrence); + const last = new Date(nodeOrEdge.stats.last_occurrence); const afterEarliestDate = !filter.earliestDate || filter.earliestDate < first || @@ -87,6 +90,7 @@ export function filter( (edge: EnrichedEdge) => edge.stats.frequency >= filter.minimumEdgeFrequency && edge.stats.frequency < filter.maximumEdgeFrequency && + hasDateOverlap(edge, filter) && filteredNodes.has(edge.from) && filteredNodes.has(edge.to), ); From 791422645a4c09dfd3706b3ba0b45050babbe9b6 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 26 Nov 2024 10:43:58 +0100 Subject: [PATCH 08/29] Trying multiprocess offloading to GPU --- config/eschatology.toml | 6 +++--- config/template.toml | 4 +++- src/conspiracies/docprocessing/docprocessor.py | 7 +++++++ src/conspiracies/pipeline/config.py | 1 + src/conspiracies/pipeline/pipeline.py | 1 + 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/config/eschatology.toml b/config/eschatology.toml index bf7d81b..a31d8d8 100644 --- a/config/eschatology.toml +++ b/config/eschatology.toml @@ -13,8 +13,8 @@ timestamp_column = "timestamp" [docprocessing] enabled = true batch_size = 5 -prefer_gpu_for_coref = false +prefer_gpu_for_coref = true +n_process = 2 # might make sense for GPU offloading [corpusprocessing] -enabled = true -dimensions = 50 \ No newline at end of file +enabled = true \ No newline at end of file diff --git a/config/template.toml b/config/template.toml index 6d7ae49..9515e1d 100644 --- a/config/template.toml +++ b/config/template.toml @@ -12,13 +12,14 @@ metadata_fields = ["*"] [preprocessing.extra] # specific extra arguments for your preprocessor, e.g. context length for tweets or -# or field specification for CSVs +# field specification for CSVs [docprocessing] enabled = true batch_size = 25 continue_from_last = true triplet_extraction_method = "multi2oie/prompting" +n_process = 1 # can be set to 2 or more for multiprocess ofloading to GPU; otherwise might not make sense [corpusprocessing] enabled = true @@ -27,5 +28,6 @@ dimensions = 100 # leave out to skip dimensionality reduction n_neighbors = 15 # used for dimensionality reduction [corpusprocessing.thresholds] # leave out for automatic estimation +min_label_occurrence = 3 min_cluster_size = 3 min_samples = 3 \ No newline at end of file diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index 1fa8e07..baffa81 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -86,10 +86,15 @@ def __init__( batch_size=25, triplet_extraction_method="multi2oie", prefer_gpu_for_coref: bool = False, + n_process: int = 1, ): self.language = language self.batch_size = batch_size self.prefer_gpu_for_coref = prefer_gpu_for_coref + self.n_process = n_process + if n_process > 1: + # multiprocessing and torch with multiple threads result in a deadlock, therefore: + torch.set_num_threads(1) self.coref_pipeline = self._build_coref_pipeline() self.triplet_extraction_component = triplet_extraction_method self.triplet_extraction_pipeline = self._build_triplet_extraction_pipeline() @@ -114,6 +119,7 @@ def process_docs( ((text_with_context(src_doc), src_doc) for src_doc in docs), batch_size=self.batch_size, as_tuples=True, + n_process=self.n_process, ) with_triplets = self.triplet_extraction_pipeline.pipe( @@ -123,6 +129,7 @@ def process_docs( ), batch_size=self.batch_size, as_tuples=True, + n_process=self.n_process, ) docs_to_jsonl( diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index f0425a4..dba39a1 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -28,6 +28,7 @@ class DocProcessingConfig(StepConfig): continue_from_last: bool = True triplet_extraction_method: str = "multi2oie" prefer_gpu_for_coref: bool = False + n_process: int = 1 class Thresholds(BaseModel): diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index eabccaf..c8d9f73 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -81,6 +81,7 @@ def _get_docprocessor(self) -> DocProcessor: batch_size=self.config.docprocessing.batch_size, triplet_extraction_method=self.config.docprocessing.triplet_extraction_method, prefer_gpu_for_coref=self.config.docprocessing.prefer_gpu_for_coref, + n_process=self.config.docprocessing.n_process, ) def docprocessing(self, continue_from_last=False): From 9588527714f805b7992a49a596319897f852207c Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 26 Nov 2024 11:40:55 +0100 Subject: [PATCH 09/29] First steps towards storing binary files of SpaCy docs --- .../docprocessing/docprocessor.py | 52 ++++++++++++++++--- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index baffa81..f26ffbf 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -1,10 +1,12 @@ +import json import os +from glob import glob from pathlib import Path -from typing import Iterable +from typing import Iterable, Tuple, Iterator import spacy import torch -from jsonlines import jsonlines +from spacy.tokens import DocBin, Doc from tqdm import tqdm from conspiracies import docs_to_jsonl @@ -87,6 +89,7 @@ def __init__( triplet_extraction_method="multi2oie", prefer_gpu_for_coref: bool = False, n_process: int = 1, + doc_bin_size: int = 100, ): self.language = language self.batch_size = batch_size @@ -95,10 +98,32 @@ def __init__( if n_process > 1: # multiprocessing and torch with multiple threads result in a deadlock, therefore: torch.set_num_threads(1) + self.doc_bin_size = doc_bin_size self.coref_pipeline = self._build_coref_pipeline() self.triplet_extraction_component = triplet_extraction_method self.triplet_extraction_pipeline = self._build_triplet_extraction_pipeline() + @staticmethod + def _set_user_data_on_docs(docs: Iterator[Tuple[Doc, Document]]) -> Iterator[Doc]: + for doc, src_doc in docs: + # FIXME: this is kind of stupid, but with old pydantic this will have to work for now. + doc.user_data = json.loads(src_doc.json()) + yield doc + + def _store_doc_bins(self, docs: Iterator[Doc], output_path: Path): + output_dir = Path(os.path.dirname(output_path)) / "spacy_docs" + output_dir.mkdir(parents=True, exist_ok=True) + + size = self.doc_bin_size + doc_bin = DocBin(store_user_data=True) + for i, doc in enumerate(docs, start=1): + doc_bin.add(doc) + if i % size == 0: + with open(output_dir / f"{i//size}.bin", "wb") as f: + f.write(doc_bin.to_bytes()) + doc_bin = DocBin(store_user_data=True) + yield doc + def process_docs( self, docs: Iterable[Document], @@ -106,10 +131,19 @@ def process_docs( continue_from_last=False, ): if continue_from_last and os.path.exists(output_path): - with jsonlines.open(output_path) as annotated_docs: - already_processed = { - annotated_doc["id"] for annotated_doc in annotated_docs - } + already_processed = set() + + # FIXME: paths should be given elsewhere and not be inferred like this + for bin_file in glob( + (Path(os.path.dirname(output_path)) / "spacy_docs").as_posix() + + "/*.bin", + ): + with open(bin_file, "rb") as bytes_data: + doc_bin = DocBin().from_bytes(bytes_data.read()) + for doc in doc_bin.get_docs(self.triplet_extraction_pipeline.vocab): + id_ = doc.user_data["id"] + already_processed.add(id_) + print(f"Skipping {len(already_processed)} processed docs.") docs = (doc for doc in docs if doc.id not in already_processed) @@ -132,8 +166,12 @@ def process_docs( n_process=self.n_process, ) + with_user_data = self._set_user_data_on_docs(with_triplets) + + stored = self._store_doc_bins(with_user_data, output_path) + docs_to_jsonl( - tqdm(d for d in with_triplets), + tqdm(stored), output_path, append=continue_from_last, ) From 1a0f4e7a41b8d03a0287e9d23e47565d6871bf6d Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 26 Nov 2024 13:37:38 +0100 Subject: [PATCH 10/29] Trying fastcoref's LingMessCoref instead of AllenNLP which is sluggishly slow --- config/eschatology.toml | 6 +-- docs/tutorials/overview.ipynb | 2 +- paper/extract_triplets_newspapers.py | 2 +- paper/extract_triplets_tweets.py | 4 +- .../docprocessing/coref/coref_component.py | 12 +++--- .../docprocessing/docprocessor.py | 39 +++++++++++++------ tests/test_coref_comp.py | 12 +++--- 7 files changed, 47 insertions(+), 30 deletions(-) diff --git a/config/eschatology.toml b/config/eschatology.toml index a31d8d8..6ac362f 100644 --- a/config/eschatology.toml +++ b/config/eschatology.toml @@ -2,7 +2,7 @@ language = "en" [preprocessing] -enabled = true +enabled = false doc_type = "csv" [preprocessing.extra] @@ -12,9 +12,9 @@ timestamp_column = "timestamp" [docprocessing] enabled = true -batch_size = 5 +batch_size = 50 prefer_gpu_for_coref = true -n_process = 2 # might make sense for GPU offloading +n_process = 1 [corpusprocessing] enabled = true \ No newline at end of file diff --git a/docs/tutorials/overview.ipynb b/docs/tutorials/overview.ipynb index affc822..29ff5d0 100644 --- a/docs/tutorials/overview.ipynb +++ b/docs/tutorials/overview.ipynb @@ -71,7 +71,7 @@ " assert isinstance(sent._.coref_clusters[0], tuple)\n", " assert isinstance(sent._.coref_clusters[0][0], int)\n", " assert isinstance(sent._.coref_clusters[0][1], Span)\n", - " sent._.resolve_coref # get resolved coref" + " sent._.resolved_text # get resolved coref" ] }, { diff --git a/paper/extract_triplets_newspapers.py b/paper/extract_triplets_newspapers.py index 95ded68..745d85f 100644 --- a/paper/extract_triplets_newspapers.py +++ b/paper/extract_triplets_newspapers.py @@ -88,7 +88,7 @@ def process_file( # Resolve coreference coref_docs = nlp_coref.pipe(normalized_article) - resolved_docs = (d._.resolve_coref for d in coref_docs) + resolved_docs = (d._.resolved_text for d in coref_docs) # Extract relations docs = nlp.pipe(resolved_docs) diff --git a/paper/extract_triplets_tweets.py b/paper/extract_triplets_tweets.py index 77bcc8a..503835d 100644 --- a/paper/extract_triplets_tweets.py +++ b/paper/extract_triplets_tweets.py @@ -99,7 +99,7 @@ def concat_resolve_unconcat_contexts(file_path: str): coref_nlp = build_coref_pipeline() coref_docs = coref_nlp.pipe(context_tweets) - resolved_docs = (d._.resolve_coref for d in coref_docs) + resolved_docs = (d._.resolved_text for d in coref_docs) resolved_tweets = (tweet_from_context_text(tweet) for tweet in resolved_docs) return resolved_tweets @@ -240,7 +240,7 @@ def prompt_gpt3( for i, batch in enumerate(batch_generator(concatenated_tweets, batch_size)): start = time.time() coref_docs = coref_nlp.pipe(batch) - resolved_docs = (d._.resolve_coref for d in coref_docs) + resolved_docs = (d._.resolved_text for d in coref_docs) resolved_target_tweets = ( tweet_from_context_text(tweet) for tweet in resolved_docs ) diff --git a/src/conspiracies/docprocessing/coref/coref_component.py b/src/conspiracies/docprocessing/coref/coref_component.py index c486913..dd4ca1c 100644 --- a/src/conspiracies/docprocessing/coref/coref_component.py +++ b/src/conspiracies/docprocessing/coref/coref_component.py @@ -40,10 +40,10 @@ def __init__( ) # Register custom extension on the Doc and Span - if not Doc.has_extension("resolve_coref"): - Doc.set_extension("resolve_coref", getter=self.resolve_coref_doc) - if not Span.has_extension("resolve_coref"): - Span.set_extension("resolve_coref", getter=self.resolve_coref_span) + if not Doc.has_extension("resolved_text"): + Doc.set_extension("resolved_text", getter=self.resolved_text_doc) + if not Span.has_extension("resolved_text"): + Span.set_extension("resolved_text", getter=self.resolved_text_span) if not Doc.has_extension("coref_clusters"): Doc.set_extension("coref_clusters", default=list()) if not Span.has_extension("coref_clusters"): @@ -51,7 +51,7 @@ def __init__( if not Span.has_extension("antecedent"): Span.set_extension("antecedent", default=None) - def resolve_coref_doc(self, doc: Doc) -> str: + def resolved_text_doc(self, doc: Doc) -> str: """Resolve the coreference clusters by replacing each entity with the antecedent. The antecedent is the first entity that appears in the cluster. This is for the whole doc. @@ -73,7 +73,7 @@ def resolve_coref_doc(self, doc: Doc) -> str: resolved[i] = "" return "".join(resolved) - def resolve_coref_span(self, sent: Span) -> str: + def resolved_text_span(self, sent: Span) -> str: """Resolve the coreference clusters by replacing each entity with the antecedent. The antecedent is the first entity that appears in the cluster. This is for the the sent. diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index f26ffbf..b44809a 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -5,6 +5,7 @@ from typing import Iterable, Tuple, Iterator import spacy +from fastcoref.spacy_component import spacy_component # noqa import torch from spacy.tokens import DocBin, Doc from tqdm import tqdm @@ -18,14 +19,30 @@ class DocProcessor: def _build_coref_pipeline(self): nlp_coref = spacy.blank(self.language) nlp_coref.add_pipe("sentencizer") - nlp_coref.add_pipe( - "allennlp_coref", - config={ - "device": ( - 0 if self.prefer_gpu_for_coref and torch.cuda.is_available() else -1 - ), - }, - ) + if self.language == "en": + nlp_coref.add_pipe( + "fastcoref", + config={ + "enable_progress_bar": False, + "model_architecture": "LingMessCoref", + "device": ( + "cuda:0" + if self.prefer_gpu_for_coref and torch.cuda.is_available() + else "cpu" + ), + }, + ) + elif self.language == "da": + nlp_coref.add_pipe( + "allennlp_coref", + config={ + "device": ( + 0 + if self.prefer_gpu_for_coref and torch.cuda.is_available() + else -1 + ), + }, + ) def warn_error(proc_name, proc, docs, e): print( @@ -119,7 +136,7 @@ def _store_doc_bins(self, docs: Iterator[Doc], output_path: Path): for i, doc in enumerate(docs, start=1): doc_bin.add(doc) if i % size == 0: - with open(output_dir / f"{i//size}.bin", "wb") as f: + with open(output_dir / f"{i // size}.bin", "wb") as f: f.write(doc_bin.to_bytes()) doc_bin = DocBin(store_user_data=True) yield doc @@ -151,14 +168,14 @@ def process_docs( # extreme memory pressure, hence the small batch size coref_resolved_docs = self.coref_pipeline.pipe( ((text_with_context(src_doc), src_doc) for src_doc in docs), - batch_size=self.batch_size, + batch_size=5, as_tuples=True, n_process=self.n_process, ) with_triplets = self.triplet_extraction_pipeline.pipe( ( - (remove_context(doc._.resolve_coref), src_doc) + (remove_context(doc._.resolved_text), src_doc) for doc, src_doc in coref_resolved_docs ), batch_size=self.batch_size, diff --git a/tests/test_coref_comp.py b/tests/test_coref_comp.py index 80c10d0..38b90b6 100644 --- a/tests/test_coref_comp.py +++ b/tests/test_coref_comp.py @@ -22,8 +22,8 @@ def test_coref_clusters(nlp_da_w_coref): # noqa F811 assert isinstance(sent._.coref_clusters[0][1], Span) -def test_resolve_coref(nlp_da_w_coref): # noqa F811 - resolve_coref_text = ( +def test_resolved_text(nlp_da_w_coref): # noqa F811 + resolved_text_text = ( "Aftalepartierne bag Rammeaftalen om plan for genÃ¥bning af Danmark blev i" + " forÃ¥ret 2021 enige om at nedsætte en ekspertgruppe, en ekspertgruppe fik " + "til opgave at komme med input til den langsigtede strategi for hÃ¥ndtering " @@ -31,7 +31,7 @@ def test_resolve_coref(nlp_da_w_coref): # noqa F811 + "ekspertgruppe rapport." ) - resolve_coref_spans = [ + resolved_text_spans = [ "Aftalepartierne bag Rammeaftalen om plan for genÃ¥bning af Danmark blev i " + "forÃ¥ret 2021 enige om at nedsætte en ekspertgruppe, en ekspertgruppe fik " + "til opgave at komme med input til den langsigtede strategi for hÃ¥ndtering " @@ -39,11 +39,11 @@ def test_resolve_coref(nlp_da_w_coref): # noqa F811 "en ekspertgruppe er nu klar med en ekspertgruppe rapport.", ] - doc = nlp_da_w_coref(resolve_coref_text) + doc = nlp_da_w_coref(resolved_text_text) # test for doc - assert doc._.resolve_coref == resolve_coref_text + assert doc._.resolved_text == resolved_text_text # test for spans for i, sent in enumerate(doc.sents): if sent._.coref_clusters != []: - assert sent._.resolve_coref == resolve_coref_spans[i] + assert sent._.resolved_text == resolved_text_spans[i] From 43c4786a9749e6b79f21099f12cd6a807e354909 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 26 Nov 2024 13:39:20 +0100 Subject: [PATCH 11/29] Adding fastcoref to dependencies --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9cb42b8..e6e5ab4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,8 @@ dependencies = [ "sentence-transformers", "stop-words", "bs4", - "toml" + "toml", + "fastcoref" ] [project.license] From 3d9e4b488ecd0de0924596bc406117dce261e252 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 26 Nov 2024 20:19:47 +0100 Subject: [PATCH 12/29] Adding a error recovering wrapper of the fastcoref component --- pyproject.toml | 1 + .../docprocessing/coref/safefastcoref.py | 69 +++++++++++++++++++ .../docprocessing/docprocessor.py | 15 ++-- 3 files changed, 75 insertions(+), 10 deletions(-) create mode 100644 src/conspiracies/docprocessing/coref/safefastcoref.py diff --git a/pyproject.toml b/pyproject.toml index e6e5ab4..20540db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ content-type = "text/markdown" "prompt_relation_extraction" = "conspiracies.docprocessing.relationextraction.gptprompting:create_prompt_relation_extraction_component" "relation_extractor" = "conspiracies.docprocessing.relationextraction.multi2oie:make_relation_extractor" "allennlp_coref" = "conspiracies.docprocessing.coref:create_coref_component" +"safe_fastcoref" = "conspiracies.docprocessing.coref.safefastcoref:create_safe_fastcoref" "heads_extraction" = "conspiracies.docprocessing.headwordextraction:create_headwords_component" diff --git a/src/conspiracies/docprocessing/coref/safefastcoref.py b/src/conspiracies/docprocessing/coref/safefastcoref.py new file mode 100644 index 0000000..47f13e0 --- /dev/null +++ b/src/conspiracies/docprocessing/coref/safefastcoref.py @@ -0,0 +1,69 @@ +from fastcoref.spacy_component import FastCorefResolver +from spacy.language import Language +from spacy.pipeline import Pipe +from typing import Iterable + +import logging + +logging.getLogger("fastcoref").setLevel(logging.WARNING) + + +class SafeFastCoref(Pipe): + def __init__(self, component): + """Initialize the wrapper with the original component.""" + self.component = component + + def pipe(self, stream: Iterable, batch_size: int = 128): + """Wrap the pipe method of the component.""" + as_list = list(stream) + try: + yield from self.component.pipe(as_list, batch_size=batch_size) + except Exception as e: + # Log the error and return the unprocessed documents + logging.error(f"Error in SafeFastCoref pipe: {e}") + for doc in stream: + yield doc # Return the original document + + def __call__(self, doc): + """Wrap the __call__ method of the component.""" + try: + return self.component(doc) + except Exception as e: + # Log the error and return the original document + logging.error(f"Error in SafeFastCoref __call__: {e}") + return doc + + +@Language.factory( + "safe_fastcoref", + assigns=["doc._.resolved_text", "doc._.coref_clusters"], + default_config={ + "model_architecture": "FCoref", # FCoref or LingMessCoref + "model_path": "biu-nlp/f-coref", # You can specify your own trained model path + "device": None, # "cuda" or "cpu" None defaults to cuda + "max_tokens_in_batch": 10000, + "enable_progress_bar": True, + }, +) +def create_safe_fastcoref( + nlp, + name, + model_architecture: str, + model_path: str, + device, + max_tokens_in_batch: int, + enable_progress_bar: bool, +): + """Factory method to create the SafeFastCoref component.""" + # Create the original FastCorefResolver with the given configuration + fastcoref_component = FastCorefResolver( + nlp=nlp, + name=name, + model_architecture=model_architecture, + model_path=model_path, + device=device, + max_tokens_in_batch=max_tokens_in_batch, + enable_progress_bar=enable_progress_bar, + ) + # Wrap it with SafeFastCoref + return SafeFastCoref(fastcoref_component) diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index b44809a..f1ba1dc 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -5,7 +5,6 @@ from typing import Iterable, Tuple, Iterator import spacy -from fastcoref.spacy_component import spacy_component # noqa import torch from spacy.tokens import DocBin, Doc from tqdm import tqdm @@ -21,15 +20,10 @@ def _build_coref_pipeline(self): nlp_coref.add_pipe("sentencizer") if self.language == "en": nlp_coref.add_pipe( - "fastcoref", + "safe_fastcoref", config={ "enable_progress_bar": False, "model_architecture": "LingMessCoref", - "device": ( - "cuda:0" - if self.prefer_gpu_for_coref and torch.cuda.is_available() - else "cpu" - ), }, ) elif self.language == "da": @@ -124,7 +118,7 @@ def __init__( def _set_user_data_on_docs(docs: Iterator[Tuple[Doc, Document]]) -> Iterator[Doc]: for doc, src_doc in docs: # FIXME: this is kind of stupid, but with old pydantic this will have to work for now. - doc.user_data = json.loads(src_doc.json()) + doc.user_data["doc_metadata"] = json.loads(src_doc.json()) yield doc def _store_doc_bins(self, docs: Iterator[Doc], output_path: Path): @@ -158,7 +152,7 @@ def process_docs( with open(bin_file, "rb") as bytes_data: doc_bin = DocBin().from_bytes(bytes_data.read()) for doc in doc_bin.get_docs(self.triplet_extraction_pipeline.vocab): - id_ = doc.user_data["id"] + id_ = doc.user_data["doc_metadata"]["id"] already_processed.add(id_) print(f"Skipping {len(already_processed)} processed docs.") @@ -168,9 +162,10 @@ def process_docs( # extreme memory pressure, hence the small batch size coref_resolved_docs = self.coref_pipeline.pipe( ((text_with_context(src_doc), src_doc) for src_doc in docs), - batch_size=5, + batch_size=self.batch_size, as_tuples=True, n_process=self.n_process, + component_cfg={"fastcoref": {"resolve_text": True}}, ) with_triplets = self.triplet_extraction_pipeline.pipe( From 0f2a0d615b208ad3f914069f58b1007faeb0e8e9 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 26 Nov 2024 20:47:31 +0100 Subject: [PATCH 13/29] Configurable DocBin size and specified cuda for fastcoref component --- src/conspiracies/docprocessing/docprocessor.py | 5 +++++ src/conspiracies/pipeline/config.py | 1 + src/conspiracies/pipeline/pipeline.py | 1 + 3 files changed, 7 insertions(+) diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index f1ba1dc..4c5bf1c 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -24,6 +24,11 @@ def _build_coref_pipeline(self): config={ "enable_progress_bar": False, "model_architecture": "LingMessCoref", + "device": ( + "cuda" + if self.prefer_gpu_for_coref and torch.cuda.is_available() + else "cpu" + ), }, ) elif self.language == "da": diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index dba39a1..f167791 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -29,6 +29,7 @@ class DocProcessingConfig(StepConfig): triplet_extraction_method: str = "multi2oie" prefer_gpu_for_coref: bool = False n_process: int = 1 + doc_bin_size: int = 100 class Thresholds(BaseModel): diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index c8d9f73..e7b0265 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -82,6 +82,7 @@ def _get_docprocessor(self) -> DocProcessor: triplet_extraction_method=self.config.docprocessing.triplet_extraction_method, prefer_gpu_for_coref=self.config.docprocessing.prefer_gpu_for_coref, n_process=self.config.docprocessing.n_process, + doc_bin_size=self.config.docprocessing.doc_bin_size, ) def docprocessing(self, continue_from_last=False): From 65b1d9a99474fc494fc80cd12719cb693400c6bb Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Tue, 26 Nov 2024 21:00:04 +0100 Subject: [PATCH 14/29] Fixing stream collected to list from debugging --- src/conspiracies/docprocessing/coref/safefastcoref.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/conspiracies/docprocessing/coref/safefastcoref.py b/src/conspiracies/docprocessing/coref/safefastcoref.py index 47f13e0..f8dce09 100644 --- a/src/conspiracies/docprocessing/coref/safefastcoref.py +++ b/src/conspiracies/docprocessing/coref/safefastcoref.py @@ -15,9 +15,8 @@ def __init__(self, component): def pipe(self, stream: Iterable, batch_size: int = 128): """Wrap the pipe method of the component.""" - as_list = list(stream) try: - yield from self.component.pipe(as_list, batch_size=batch_size) + yield from self.component.pipe(stream, batch_size=batch_size) except Exception as e: # Log the error and return the unprocessed documents logging.error(f"Error in SafeFastCoref pipe: {e}") From babfc4b9fc51386e6df742f6c780b84a8d5e94b5 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Wed, 27 Nov 2024 09:15:37 +0100 Subject: [PATCH 15/29] Output path for pipeline instead of project name --- src/conspiracies/pipeline/config.py | 3 +-- src/conspiracies/pipeline/pipeline.py | 5 ++--- src/conspiracies/run.py | 4 ++-- tests/test_data/test_config.toml | 3 +-- tests/test_pipelineconfig.py | 3 +-- 5 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index f167791..e5dfd05 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -6,8 +6,7 @@ class BaseConfig(BaseModel): - project_name: str - output_root: str = "output" + output_path: str language: str diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index e7b0265..2162ccc 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -23,12 +23,11 @@ class Pipeline: def __init__(self, config: PipelineConfig): - self.project_name = config.base.project_name self.input_path = Path(config.preprocessing.input_path) + self.output_path = Path(config.base.output_path) + os.makedirs(self.output_path, exist_ok=True) self.config = config print("Initialized Pipeline with config:", config) - self.output_path = Path(self.config.base.output_root, self.project_name) - os.makedirs(self.output_path, exist_ok=True) def run(self): if self.config.preprocessing.enabled: diff --git a/src/conspiracies/run.py b/src/conspiracies/run.py index d859b7e..b2d0c15 100644 --- a/src/conspiracies/run.py +++ b/src/conspiracies/run.py @@ -8,7 +8,7 @@ if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( - "project_name", + "output_path", nargs="?", default=None, help="Name of your project under which various output files will be output" @@ -49,7 +49,7 @@ logging.getLogger().setLevel(args.root_log_level) cli_args = { - "base.project_name": args.project_name, + "base.output_path": args.output_path, "base.language": args.language, "preprocessing.input_path": args.input_path, "preprocessing.n_docs": args.n_docs, diff --git a/tests/test_data/test_config.toml b/tests/test_data/test_config.toml index 6ddc9cf..133c7ff 100644 --- a/tests/test_data/test_config.toml +++ b/tests/test_data/test_config.toml @@ -1,6 +1,5 @@ [base] -project_name = "test" -output_root = "output" +output_path = "output/test" language = "en" [preprocessing] diff --git a/tests/test_pipelineconfig.py b/tests/test_pipelineconfig.py index 21b1d61..d79f4b7 100644 --- a/tests/test_pipelineconfig.py +++ b/tests/test_pipelineconfig.py @@ -21,8 +21,7 @@ def test_config_loading(path: str): config = PipelineConfig.from_toml_file(path) assert config == PipelineConfig( base=BaseConfig( - project_name="test", - output_root="output", + output_path="output/test", language="en", ), preprocessing=PreProcessingConfig( From 47b4f478342b5c31e57138cb062b6a24005ec6c0 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Wed, 27 Nov 2024 09:27:56 +0100 Subject: [PATCH 16/29] Resolving text in safe fastcoref component --- src/conspiracies/docprocessing/coref/safefastcoref.py | 10 +++++++--- src/conspiracies/docprocessing/docprocessor.py | 1 - 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/conspiracies/docprocessing/coref/safefastcoref.py b/src/conspiracies/docprocessing/coref/safefastcoref.py index f8dce09..27348e7 100644 --- a/src/conspiracies/docprocessing/coref/safefastcoref.py +++ b/src/conspiracies/docprocessing/coref/safefastcoref.py @@ -9,14 +9,18 @@ class SafeFastCoref(Pipe): - def __init__(self, component): + def __init__(self, component: FastCorefResolver): """Initialize the wrapper with the original component.""" self.component = component def pipe(self, stream: Iterable, batch_size: int = 128): """Wrap the pipe method of the component.""" try: - yield from self.component.pipe(stream, batch_size=batch_size) + yield from self.component.pipe( + stream, + batch_size=batch_size, + resolve_text=True, + ) except Exception as e: # Log the error and return the unprocessed documents logging.error(f"Error in SafeFastCoref pipe: {e}") @@ -26,7 +30,7 @@ def pipe(self, stream: Iterable, batch_size: int = 128): def __call__(self, doc): """Wrap the __call__ method of the component.""" try: - return self.component(doc) + return self.component(doc, resolve_text=True) except Exception as e: # Log the error and return the original document logging.error(f"Error in SafeFastCoref __call__: {e}") diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index 4c5bf1c..8573a8d 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -170,7 +170,6 @@ def process_docs( batch_size=self.batch_size, as_tuples=True, n_process=self.n_process, - component_cfg={"fastcoref": {"resolve_text": True}}, ) with_triplets = self.triplet_extraction_pipeline.pipe( From 28782f7335f36f7c63c240a06a5f237f560b27d4 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Wed, 27 Nov 2024 10:48:12 +0100 Subject: [PATCH 17/29] More proper continuation of already processed docs --- .../docprocessing/docprocessor.py | 57 ++++++++++++------- src/conspiracies/run.py | 7 +++ 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index 8573a8d..8873f18 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -23,7 +23,6 @@ def _build_coref_pipeline(self): "safe_fastcoref", config={ "enable_progress_bar": False, - "model_architecture": "LingMessCoref", "device": ( "cuda" if self.prefer_gpu_for_coref and torch.cuda.is_available() @@ -127,41 +126,54 @@ def _set_user_data_on_docs(docs: Iterator[Tuple[Doc, Document]]) -> Iterator[Doc yield doc def _store_doc_bins(self, docs: Iterator[Doc], output_path: Path): + # FIXME: paths should be given elsewhere and not be inferred like this output_dir = Path(os.path.dirname(output_path)) / "spacy_docs" output_dir.mkdir(parents=True, exist_ok=True) + prev_doc_bins = glob( + (Path(os.path.dirname(output_path)) / "spacy_docs").as_posix() + "/*.bin", + ) + start_from = ( + max(int(os.path.basename(doc).replace(".bin", "")) for doc in prev_doc_bins) + if prev_doc_bins + else 0 + ) + size = self.doc_bin_size doc_bin = DocBin(store_user_data=True) - for i, doc in enumerate(docs, start=1): + for i, doc in enumerate(docs, start=start_from + 1): doc_bin.add(doc) if i % size == 0: - with open(output_dir / f"{i // size}.bin", "wb") as f: + with open(output_dir / f"{i}.bin", "wb") as f: f.write(doc_bin.to_bytes()) doc_bin = DocBin(store_user_data=True) yield doc + def _read_doc_bins(self, output_path: Path): + # FIXME: paths should be given elsewhere and not be inferred like this + count = 0 + for bin_file in glob( + (Path(os.path.dirname(output_path)) / "spacy_docs").as_posix() + "/*.bin", + ): + with open(bin_file, "rb") as bytes_data: + doc_bin = DocBin().from_bytes(bytes_data.read()) + for doc in doc_bin.get_docs(self.triplet_extraction_pipeline.vocab): + count += 1 + yield doc + print(f"Read {count} previously processed docs.") + def process_docs( self, docs: Iterable[Document], output_path: Path, continue_from_last=False, ): - if continue_from_last and os.path.exists(output_path): - already_processed = set() - - # FIXME: paths should be given elsewhere and not be inferred like this - for bin_file in glob( - (Path(os.path.dirname(output_path)) / "spacy_docs").as_posix() - + "/*.bin", - ): - with open(bin_file, "rb") as bytes_data: - doc_bin = DocBin().from_bytes(bytes_data.read()) - for doc in doc_bin.get_docs(self.triplet_extraction_pipeline.vocab): - id_ = doc.user_data["doc_metadata"]["id"] - already_processed.add(id_) - - print(f"Skipping {len(already_processed)} processed docs.") - docs = (doc for doc in docs if doc.id not in already_processed) + + if continue_from_last: + print( + "Reading previously processed documents! Disable 'continue_from_last' to avoid this.'", + ) + docs_to_jsonl(self._read_doc_bins(output_path), output_path) # The coreference pipeline tends to choke on too large batches because of an # extreme memory pressure, hence the small batch size @@ -184,10 +196,13 @@ def process_docs( with_user_data = self._set_user_data_on_docs(with_triplets) - stored = self._store_doc_bins(with_user_data, output_path) + docs_to_output = tqdm( + self._store_doc_bins(with_user_data, output_path), + desc="Processing documents", + ) docs_to_jsonl( - tqdm(stored), + docs_to_output, output_path, append=continue_from_last, ) diff --git a/src/conspiracies/run.py b/src/conspiracies/run.py index b2d0c15..a747f50 100644 --- a/src/conspiracies/run.py +++ b/src/conspiracies/run.py @@ -61,4 +61,11 @@ config = PipelineConfig.default_with_extra_config(cli_args) pipeline = Pipeline(config) + + logging.basicConfig( + level=logging.DEBUG, + filename=config.base.output_path + "/logfile", + filemode="w+", + ) + pipeline.run() From 244e18f8b2f57ce44d2bfd59f2d25487e61db455 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Wed, 27 Nov 2024 13:59:45 +0100 Subject: [PATCH 18/29] fixing bug in SafeFastCoref which led to empty texts --- .../docprocessing/coref/safefastcoref.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/conspiracies/docprocessing/coref/safefastcoref.py b/src/conspiracies/docprocessing/coref/safefastcoref.py index 27348e7..993aada 100644 --- a/src/conspiracies/docprocessing/coref/safefastcoref.py +++ b/src/conspiracies/docprocessing/coref/safefastcoref.py @@ -5,6 +5,8 @@ import logging +from spacy.util import minibatch + logging.getLogger("fastcoref").setLevel(logging.WARNING) @@ -15,17 +17,19 @@ def __init__(self, component: FastCorefResolver): def pipe(self, stream: Iterable, batch_size: int = 128): """Wrap the pipe method of the component.""" - try: - yield from self.component.pipe( - stream, - batch_size=batch_size, - resolve_text=True, - ) - except Exception as e: - # Log the error and return the unprocessed documents - logging.error(f"Error in SafeFastCoref pipe: {e}") - for doc in stream: - yield doc # Return the original document + for mb in minibatch(stream, size=batch_size): + try: + yield from self.component.pipe( + mb, + batch_size=batch_size, + resolve_text=True, + ) + except Exception as e: + # Log the error and return the unprocessed documents + logging.error(f"Error in SafeFastCoref pipe: {e}") + for doc in mb: + doc._.resolved_text = doc.text + yield doc # Return the original document def __call__(self, doc): """Wrap the __call__ method of the component.""" From 1f1238203a1628000e5c88fe069b831d20420223 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Wed, 27 Nov 2024 18:14:36 +0100 Subject: [PATCH 19/29] Union notation instead of | for compatibility with Python 3.9 --- src/conspiracies/corpusprocessing/clustering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index 991598a..587e126 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -248,9 +248,9 @@ def _cluster_via_normalization( @staticmethod def _mapping_to_first_member( - clusters: List[List[TripletField | str]], + clusters: List[List[Union[TripletField, str]]], ) -> Dict[str, str]: - def get_text(member: TripletField | str): + def get_text(member: Union[TripletField, str]): if isinstance(member, TripletField): return member.text else: From c3804e3d5dd92e7f20d83e402c340a27a16168c4 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Wed, 4 Dec 2024 10:13:09 +0100 Subject: [PATCH 20/29] Fixing reading/writing DocBins, fixing safe fastcoref pipe and introducing subclustering --- config/eschatology.toml | 2 +- .../corpusprocessing/clustering.py | 65 +++++++++++-------- .../docprocessing/coref/safefastcoref.py | 30 ++++++--- src/conspiracies/docprocessing/doc_utils.py | 3 + .../docprocessing/docprocessor.py | 54 ++++++++++----- .../multi2oie/multi2oie_component.py | 2 +- src/conspiracies/run.py | 4 +- 7 files changed, 102 insertions(+), 58 deletions(-) diff --git a/config/eschatology.toml b/config/eschatology.toml index 6ac362f..7cc9800 100644 --- a/config/eschatology.toml +++ b/config/eschatology.toml @@ -2,7 +2,7 @@ language = "en" [preprocessing] -enabled = false +enabled = true doc_type = "csv" [preprocessing.extra] diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index 587e126..2d585d3 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -9,6 +9,7 @@ from hdbscan import HDBSCAN from pydantic import BaseModel from sentence_transformers import SentenceTransformer +from tqdm import tqdm from umap import UMAP from conspiracies.common.modelchoice import ModelChoice @@ -103,24 +104,27 @@ def _combine_clusters( return merged_clusters - def _cluster( + def _cluster_via_embeddings( self, - fields: List[TripletField], - cache_filename: str, + labels: List[str], + cache_name: str = None, + show_progress: bool = True, ): - emb_cache = f"embeddings-{cache_filename}.npy" - if self.cache_location and Path(self.cache_location, emb_cache).exists(): + emb_cache = ( + Path(self.cache_location, f"embeddings-{cache_name}.npy") + if self.cache_location and cache_name + else None + ) + if emb_cache and emb_cache.exists(): print( "Reusing cached embeddings! Delete cache if this is not supposed to happen.", ) - embeddings = np.load( - Path(self.cache_location, emb_cache), - ) + embeddings = np.load(emb_cache) else: model = self._get_embedding_model() print("Creating embeddings:") - counter = Counter((field.text for field in fields)) + counter = Counter((field for field in labels)) condensed = [ field for field, count in counter.items() @@ -129,26 +133,25 @@ def _cluster( embeddings = model.encode( condensed, normalize_embeddings=True, - show_progress_bar=True, + show_progress_bar=show_progress, ) - if self.cache_location: - np.save( - Path(self.cache_location, emb_cache), - embeddings, - ) + if emb_cache: + np.save(emb_cache, embeddings) if self.n_dimensions is not None: reduced_emb_cache = ( - f"embeddings-{cache_filename}-red{self.n_dimensions}.npy" + Path( + self.cache_location, + f"embeddings-{cache_name}-red{self.n_dimensions}.npy", + ) + if self.cache_location and cache_name + else None ) - if ( - self.cache_location - and Path(self.cache_location, reduced_emb_cache).exists() - ): + if reduced_emb_cache and reduced_emb_cache.exists(): print( "Reusing cached reduced embeddings! Delete cache if this is not supposed to happen.", ) - embeddings = np.load(Path(self.cache_location, reduced_emb_cache)) + embeddings = np.load(reduced_emb_cache) else: print("Reducing embedding space ...") reducer = UMAP( @@ -157,23 +160,19 @@ def _cluster( ) embeddings = reducer.fit_transform(embeddings) if self.cache_location: - np.save( - Path(self.cache_location, reduced_emb_cache), - embeddings, - ) + np.save(reduced_emb_cache, embeddings) print("Clustering ... (Delete cache to ensure recalculation)") hdbscan_model = HDBSCAN( min_cluster_size=self.min_cluster_size, max_cluster_size=20, # somewhat arbitrary number, mostly to avoid mega clusters that suck up everything min_samples=self.min_samples, - memory=str(self.cache_location), ) hdbscan_model.fit(embeddings) clusters = defaultdict(list) for field, embedding, label, probability in zip( - fields, + labels, embeddings, hdbscan_model.labels_, hdbscan_model.probabilities_, @@ -185,7 +184,7 @@ def _cluster( merged = self._combine_clusters( list(clusters.values()), - get_combine_key=lambda t: t[0].text, + get_combine_key=lambda t: t[0], ) # sort by how "prototypical" a member is in the cluster @@ -279,6 +278,16 @@ def create_mappings(self, triplets: List[Triplet]) -> Mappings: [e.text for e in entities], 0.2, ) + entity_clusters = [ + sub_cluster + for cluster in tqdm(entity_clusters, desc="Creating sub-clusters") + for sub_cluster in ( + self._cluster_via_embeddings(cluster, show_progress=False) + if len(cluster) > 10 + else [cluster] + ) + ] + print("Creating mappings for predicates") predicate_clusters = self._cluster_via_normalization( [p.text for p in predicates], diff --git a/src/conspiracies/docprocessing/coref/safefastcoref.py b/src/conspiracies/docprocessing/coref/safefastcoref.py index 993aada..318ca03 100644 --- a/src/conspiracies/docprocessing/coref/safefastcoref.py +++ b/src/conspiracies/docprocessing/coref/safefastcoref.py @@ -6,7 +6,10 @@ import logging from spacy.util import minibatch +from datasets.utils.logging import disable_progress_bar + +disable_progress_bar() # annoying progress bar per batch logging.getLogger("fastcoref").setLevel(logging.WARNING) @@ -19,17 +22,24 @@ def pipe(self, stream: Iterable, batch_size: int = 128): """Wrap the pipe method of the component.""" for mb in minibatch(stream, size=batch_size): try: - yield from self.component.pipe( - mb, - batch_size=batch_size, - resolve_text=True, + # The pipe method can fail on one document in a loop and thereby fail on all docs in that + # minibatch. However, it is made as a generator and may not show before long after the first + # documents have passed through the whole pipeline. Therefore, the minibatch is processed fully + # and then yielded. If it fails, they will be processed individually. + annotated = list( + self.component.pipe( + mb, + batch_size=batch_size, + resolve_text=True, + ), ) except Exception as e: # Log the error and return the unprocessed documents - logging.error(f"Error in SafeFastCoref pipe: {e}") - for doc in mb: - doc._.resolved_text = doc.text - yield doc # Return the original document + logging.error( + f"Error in SafeFastCoref pipe: {e}. Trying documents individually", + ) + annotated = [self(d) for d in mb] + yield from annotated def __call__(self, doc): """Wrap the __call__ method of the component.""" @@ -38,6 +48,8 @@ def __call__(self, doc): except Exception as e: # Log the error and return the original document logging.error(f"Error in SafeFastCoref __call__: {e}") + doc._.coref_clusters = [] + doc._.resolved_text = doc.text return doc @@ -49,7 +61,7 @@ def __call__(self, doc): "model_path": "biu-nlp/f-coref", # You can specify your own trained model path "device": None, # "cuda" or "cpu" None defaults to cuda "max_tokens_in_batch": 10000, - "enable_progress_bar": True, + "enable_progress_bar": False, }, ) def create_safe_fastcoref( diff --git a/src/conspiracies/docprocessing/doc_utils.py b/src/conspiracies/docprocessing/doc_utils.py index 48e1733..f3b916c 100644 --- a/src/conspiracies/docprocessing/doc_utils.py +++ b/src/conspiracies/docprocessing/doc_utils.py @@ -29,6 +29,9 @@ def _doc_to_json( timestamp = src_doc.timestamp.isoformat() else: raise TypeError(f"Unexpected input type {type(doc[1])}") + elif "doc_metadata" in doc.user_data: + id_ = doc.user_data["doc_metadata"]["id"] + timestamp = doc.user_data["doc_metadata"]["timestamp"] else: id_ = None timestamp = None diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index 8873f18..0d7c7cb 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -1,4 +1,5 @@ import json +import logging import os from glob import glob from pathlib import Path @@ -22,7 +23,6 @@ def _build_coref_pipeline(self): nlp_coref.add_pipe( "safe_fastcoref", config={ - "enable_progress_bar": False, "device": ( "cuda" if self.prefer_gpu_for_coref and torch.cuda.is_available() @@ -118,14 +118,7 @@ def __init__( self.triplet_extraction_component = triplet_extraction_method self.triplet_extraction_pipeline = self._build_triplet_extraction_pipeline() - @staticmethod - def _set_user_data_on_docs(docs: Iterator[Tuple[Doc, Document]]) -> Iterator[Doc]: - for doc, src_doc in docs: - # FIXME: this is kind of stupid, but with old pydantic this will have to work for now. - doc.user_data["doc_metadata"] = json.loads(src_doc.json()) - yield doc - - def _store_doc_bins(self, docs: Iterator[Doc], output_path: Path): + def _store_doc_bins(self, docs: Iterator[Tuple[Doc, Document]], output_path: Path): # FIXME: paths should be given elsewhere and not be inferred like this output_dir = Path(os.path.dirname(output_path)) / "spacy_docs" output_dir.mkdir(parents=True, exist_ok=True) @@ -141,14 +134,25 @@ def _store_doc_bins(self, docs: Iterator[Doc], output_path: Path): size = self.doc_bin_size doc_bin = DocBin(store_user_data=True) - for i, doc in enumerate(docs, start=start_from + 1): + at_doc = start_from + for doc, src_doc in docs: + at_doc += 1 + + # FIXME: this conversion is kind of stupid, but with old pydantic this will have to work for now. + doc.user_data["doc_metadata"] = json.loads(src_doc.json()) + doc_bin.add(doc) - if i % size == 0: - with open(output_dir / f"{i}.bin", "wb") as f: + if at_doc % size == 0: + with open(output_dir / f"{at_doc}.bin", "wb") as f: f.write(doc_bin.to_bytes()) doc_bin = DocBin(store_user_data=True) yield doc + if len(doc_bin) > 0: + # write final doc bin if any docs are left + with open(output_dir / f"{at_doc}.bin", "wb") as f: + f.write(doc_bin.to_bytes()) + def _read_doc_bins(self, output_path: Path): # FIXME: paths should be given elsewhere and not be inferred like this count = 0 @@ -159,8 +163,8 @@ def _read_doc_bins(self, output_path: Path): doc_bin = DocBin().from_bytes(bytes_data.read()) for doc in doc_bin.get_docs(self.triplet_extraction_pipeline.vocab): count += 1 - yield doc - print(f"Read {count} previously processed docs.") + src_doc = Document(**doc.user_data["doc_metadata"]) + yield doc, src_doc def process_docs( self, @@ -173,7 +177,23 @@ def process_docs( print( "Reading previously processed documents! Disable 'continue_from_last' to avoid this.'", ) - docs_to_jsonl(self._read_doc_bins(output_path), output_path) + processed_ids = set() + + def check_processed_ids_and_pass_on(): + for doc, src_doc in tqdm( + self._read_doc_bins(output_path), + desc="Reading previously processed docs", + ): + if src_doc.id in processed_ids: + logging.warning(f"Duplicate processed document: {src_doc.id}") + continue + processed_ids.add(src_doc.id) + yield doc, src_doc + + docs_to_jsonl(check_processed_ids_and_pass_on(), output_path) + + print(f"Read {len(processed_ids)} previously processed docs.") + docs = (d for d in docs if d.id not in processed_ids) # The coreference pipeline tends to choke on too large batches because of an # extreme memory pressure, hence the small batch size @@ -194,10 +214,8 @@ def process_docs( n_process=self.n_process, ) - with_user_data = self._set_user_data_on_docs(with_triplets) - docs_to_output = tqdm( - self._store_doc_bins(with_user_data, output_path), + self._store_doc_bins(with_triplets, output_path), desc="Processing documents", ) diff --git a/src/conspiracies/docprocessing/relationextraction/multi2oie/multi2oie_component.py b/src/conspiracies/docprocessing/relationextraction/multi2oie/multi2oie_component.py index 0574572..e3b4c57 100644 --- a/src/conspiracies/docprocessing/relationextraction/multi2oie/multi2oie_component.py +++ b/src/conspiracies/docprocessing/relationextraction/multi2oie/multi2oie_component.py @@ -89,7 +89,7 @@ def set_annotations(self, doc: Iterable[Doc], predictions: Dict) -> None: try: self.do_set_annotations(doc, predictions) except Exception as e: - self.logger.exception(e) + self.logger.error(e) def do_set_annotations(self, doc: Iterable[Doc], predictions: Dict) -> None: # get nested list of indices above confidence threshold diff --git a/src/conspiracies/run.py b/src/conspiracies/run.py index a747f50..83f02cc 100644 --- a/src/conspiracies/run.py +++ b/src/conspiracies/run.py @@ -65,7 +65,9 @@ logging.basicConfig( level=logging.DEBUG, filename=config.base.output_path + "/logfile", - filemode="w+", + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + filemode="a+", ) + logging.info("Running pipeline...") pipeline.run() From 004a8ed9c6088b0a6bfc5ceea32ed70941cbfcbc Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Wed, 4 Dec 2024 11:19:12 +0100 Subject: [PATCH 21/29] Deduplicating docs in doc bins --- config/eschatology.toml | 2 +- .../corpusprocessing/clustering.py | 4 +- .../docprocessing/docprocessor.py | 56 ++++++++++++++----- 3 files changed, 44 insertions(+), 18 deletions(-) diff --git a/config/eschatology.toml b/config/eschatology.toml index 7cc9800..6ac362f 100644 --- a/config/eschatology.toml +++ b/config/eschatology.toml @@ -2,7 +2,7 @@ language = "en" [preprocessing] -enabled = true +enabled = false doc_type = "csv" [preprocessing.extra] diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index 2d585d3..9863182 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -162,10 +162,10 @@ def _cluster_via_embeddings( if self.cache_location: np.save(reduced_emb_cache, embeddings) - print("Clustering ... (Delete cache to ensure recalculation)") hdbscan_model = HDBSCAN( min_cluster_size=self.min_cluster_size, - max_cluster_size=20, # somewhat arbitrary number, mostly to avoid mega clusters that suck up everything + max_cluster_size=self.min_cluster_size + * 10, # somewhat arbitrary, mostly to avoid mega clusters that suck up everything min_samples=self.min_samples, ) hdbscan_model.fit(embeddings) diff --git a/src/conspiracies/docprocessing/docprocessor.py b/src/conspiracies/docprocessing/docprocessor.py index 0d7c7cb..b8b5d39 100644 --- a/src/conspiracies/docprocessing/docprocessor.py +++ b/src/conspiracies/docprocessing/docprocessor.py @@ -117,6 +117,7 @@ def __init__( self.coref_pipeline = self._build_coref_pipeline() self.triplet_extraction_component = triplet_extraction_method self.triplet_extraction_pipeline = self._build_triplet_extraction_pipeline() + self.deduplicate_processed_docs = False def _store_doc_bins(self, docs: Iterator[Tuple[Doc, Document]], output_path: Path): # FIXME: paths should be given elsewhere and not be inferred like this @@ -166,32 +167,57 @@ def _read_doc_bins(self, output_path: Path): src_doc = Document(**doc.user_data["doc_metadata"]) yield doc, src_doc + def _read_deduplicated_doc_bins( + self, + output_path: Path, + processed_ids: set[str] = None, + ): + if processed_ids is None: + processed_ids = set() + + for doc, src_doc in tqdm( + self._read_doc_bins(output_path), + desc="Reading previously processed docs", + ): + if src_doc.id in processed_ids: + logging.warning(f"Duplicate processed document: {src_doc.id}") + continue + processed_ids.add(src_doc.id) + yield doc, src_doc + + def deduplicate_doc_bins(self, output_path: Path): + spacy_docs = Path(os.path.dirname(output_path)) / "spacy_docs" + old_docs = Path(os.path.dirname(output_path)) / ".old" / "spacy_docs" + old_docs.mkdir(parents=True, exist_ok=True) + orig_dir = spacy_docs.rename(old_docs) + deduplicated = spacy_docs + deduplicated.mkdir() + for _ in self._store_doc_bins( + self._read_deduplicated_doc_bins(orig_dir), + deduplicated, + ): + pass + def process_docs( self, docs: Iterable[Document], output_path: Path, continue_from_last=False, ): - + if self.deduplicate_processed_docs: + self.deduplicate_doc_bins(output_path) if continue_from_last: print( "Reading previously processed documents! Disable 'continue_from_last' to avoid this.'", ) processed_ids = set() - - def check_processed_ids_and_pass_on(): - for doc, src_doc in tqdm( - self._read_doc_bins(output_path), - desc="Reading previously processed docs", - ): - if src_doc.id in processed_ids: - logging.warning(f"Duplicate processed document: {src_doc.id}") - continue - processed_ids.add(src_doc.id) - yield doc, src_doc - - docs_to_jsonl(check_processed_ids_and_pass_on(), output_path) - + docs_to_jsonl( + self._read_deduplicated_doc_bins( + output_path, + processed_ids=processed_ids, + ), + output_path, + ) print(f"Read {len(processed_ids)} previously processed docs.") docs = (d for d in docs if d.id not in processed_ids) From c655792692e6938660eaa9ff4936627a49c52fba Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Wed, 4 Dec 2024 12:51:06 +0100 Subject: [PATCH 22/29] Big hackathon for the visualizer commenced --- visualizer/src/graph/EdgeInfo.tsx | 53 +++++++++++++++++++ visualizer/src/graph/GraphService.ts | 77 ++++++++++++++++++++-------- visualizer/src/graph/GraphViewer.tsx | 49 +++++++++++++----- 3 files changed, 147 insertions(+), 32 deletions(-) create mode 100644 visualizer/src/graph/EdgeInfo.tsx diff --git a/visualizer/src/graph/EdgeInfo.tsx b/visualizer/src/graph/EdgeInfo.tsx new file mode 100644 index 0000000..2b5a430 --- /dev/null +++ b/visualizer/src/graph/EdgeInfo.tsx @@ -0,0 +1,53 @@ +import {EdgeGroup} from "./GraphService"; +import React from "react"; + +export interface EdgeInfoProps { + edges: EdgeGroup; + className?: string; +} + +export const EdgeInfo: React.FC = ({ edges }: EdgeInfoProps) => { + return ( +
    + {edges.group!.map((e, i) => ( +
    + {e.label} +
    +

    Frequency: {e.stats.frequency}

    + {/*

    Norm. frequency: {e.stats.norm_frequency?.toPrecision(3)}

    */} + {e.stats.first_occurrence && ( +

    Earliest date: {e.stats.first_occurrence}

    + )} + {e.stats.last_occurrence && ( +

    Latest date: {e.stats.last_occurrence}

    + )} + {e.stats.alt_labels && ( +
    + Alternative Labels: +
      + {e.stats.alt_labels.map((l) => ( +
    • {l}
    • + ))} +
    +
    + )} + {e.stats.docs && ( +
    + Documents +
      + {e.stats.docs.map((d) => ( +
    • {d}
    • + ))} +
    +
    + )} +
    + {i < edges.group!.length - 1 &&
    } +
    + ))} +
    + ); +}; diff --git a/visualizer/src/graph/GraphService.ts b/visualizer/src/graph/GraphService.ts index 9cede8a..a4d3e16 100644 --- a/visualizer/src/graph/GraphService.ts +++ b/visualizer/src/graph/GraphService.ts @@ -17,9 +17,13 @@ export interface EnrichedEdge extends Edge { stats: Stats; } +export interface EdgeGroup extends EnrichedEdge { + group?: EnrichedEdge[]; +} + export interface EnrichedGraphData extends GraphData { nodes: EnrichedNode[]; - edges: EnrichedEdge[]; + edges: EdgeGroup[]; } export class GraphFilter { @@ -86,23 +90,52 @@ export function filter( hasDateOverlap(node, filter), ); let filteredNodes = new Set(nodes.map((node) => node.id)); - let edges = graphData.edges.filter( - (edge: EnrichedEdge) => - edge.stats.frequency >= filter.minimumEdgeFrequency && - edge.stats.frequency < filter.maximumEdgeFrequency && - hasDateOverlap(edge, filter) && - filteredNodes.has(edge.from) && - filteredNodes.has(edge.to), - ); + let groupedEdges = graphData.edges + .filter( + (edge: EnrichedEdge) => + edge.stats.frequency >= filter.minimumEdgeFrequency && + edge.stats.frequency < filter.maximumEdgeFrequency && + hasDateOverlap(edge, filter) && + filteredNodes.has(edge.from) && + filteredNodes.has(edge.to), + ) + .reduce( + (acc, curr) => { + const key = curr.from + "->" + curr.to; + if (!acc[key]) { + acc[key] = []; + } + acc[key].push(curr); + return acc; + }, + {} as Record, + ); + let edges = Object.values(groupedEdges).map((group) => { + group.sort((edge1, edge2) => edge2.stats.frequency - edge1.stats.frequency); + const representative: EnrichedEdge = group.at(0)!; + return { + ...representative, + id: representative.from + '->' + representative.to, + label: group + .slice(0, 3) + .map((e) => e.label) + .join(", "), + width: + Math.log(group.map((e) => e.stats.frequency).reduce((a, b) => a + b)), + group: group + }; + }); + let connectedNodes = new Set(edges.flatMap((edge) => [edge.from, edge.to])); if (!filter.showUnconnectedNodes) { nodes = nodes.filter((node) => connectedNodes.has(node.id)); } nodes = nodes.map((node) => ({ ...node, - ...(node.label?.toLowerCase().includes(filter.labelSearch) - ? { opacity: 1 } - : { opacity: 0.2 }), + opacity: node.label?.toLowerCase().includes(filter.labelSearch) ? 1 : 0.2, + font: { + size: 14 + node.stats.frequency / 100 + } })); return { nodes, edges }; @@ -116,6 +149,7 @@ export interface DataBounds { export abstract class GraphService { private nodesMap: Map | null = null; + private edgesMap: Map | null = null; abstract getGraph(): EnrichedGraphData; @@ -156,15 +190,18 @@ export abstract class GraphService { this.getGraph().nodes.map((node) => [node.id!.toString(), node]), ); } - - // highly inefficient linear search; overwrite for actual use - for (let node of this.getGraph().nodes) { - if (node.id === nodeId) { - return node; - } - } - return undefined; + return this.nodesMap.get(nodeId); } + + // getEdges(edgeFromAndTo: string): EnrichedEdge[] | undefined { + // if (this.edgesMap === null) { + // this.edgesMap = new Map( + // this.getGraph().edges.map((node) => [node.id!.toString(), node]), + // ); + // } + // + // return undefined; + // } } export class SampleGraphService extends GraphService { diff --git a/visualizer/src/graph/GraphViewer.tsx b/visualizer/src/graph/GraphViewer.tsx index dfcd07f..819f96e 100644 --- a/visualizer/src/graph/GraphViewer.tsx +++ b/visualizer/src/graph/GraphViewer.tsx @@ -1,5 +1,6 @@ import React, { useMemo, useRef, useState } from "react"; import { + EdgeGroup, EnrichedNode, FileGraphService, filter, @@ -12,6 +13,7 @@ import Graph, { GraphEvents, Options } from "react-vis-graph-wrapper"; import { GraphFilterControlPanel } from "./GraphFilterControlPanel"; import { GraphOptionsControlPanel } from "./GraphOptionsControlPanel"; import { NodeInfo } from "./NodeInfo"; +import { EdgeInfo } from "./EdgeInfo"; export const GraphViewer: React.FC = () => { let graphServiceRef = useRef(new SampleGraphService()); @@ -41,26 +43,40 @@ export const GraphViewer: React.FC = () => { const [graphFilter, setGraphFilter] = useState( new GraphFilter(1, 1, 10, 1, 1, 10), ); - const [selected, setSelected] = useState(new Set()); + const [subgraphNodes, setSubgraphNodes] = useState(new Set()); const [selectedNode, setSelectedNode] = useState( undefined, ); + const [selectedEdge, setSelectedEdge] = useState( + undefined, + ); const filteredGraphData = useMemo(() => { const baseGraphData = - selected.size > 0 - ? graphServiceRef.current.getSubGraph(selected) + subgraphNodes.size > 0 + ? graphServiceRef.current.getSubGraph(subgraphNodes) : graphServiceRef.current.getGraph(); return filter(graphFilter, baseGraphData); - }, [graphFilter, selected]); + }, [graphFilter, subgraphNodes]); + + const graphDataMaps = useMemo(() => { + return { + nodesMap: new Map( + filteredGraphData.nodes.map((node) => [node.id!.toString(), node]), + ), + edgeGroupMap: new Map( + filteredGraphData.edges.map((edgeGroup) => [edgeGroup.id, edgeGroup]), + ), + }; + }, [filteredGraphData]); let events: GraphEvents = { hold: ({ nodes }) => { - const newSelected = new Set(selected); + const newSubgraphNodes = new Set(subgraphNodes); nodes.forEach((element: string) => { - newSelected.delete(element); + newSubgraphNodes.delete(element); }); - setSelected(newSelected); + setSubgraphNodes(newSubgraphNodes); }, select: ({ nodes }) => { let newSelected: Set; @@ -69,24 +85,32 @@ export const GraphViewer: React.FC = () => { nodes.forEach((element: string) => { newSelected.add(element); }); - setSelected(newSelected); + setSubgraphNodes(newSelected); } }, doubleClick: ({ nodes }) => { - const newSelected = new Set(selected); + const newSelected = new Set(subgraphNodes); nodes.forEach((element: string) => { Array.from(graphServiceRef.current.getConnectedNodes(element)).forEach( (c) => newSelected.add(c), ); }); - setSelected(newSelected); + setSubgraphNodes(newSelected); }, selectNode: ({ nodes }) => { - setSelectedNode(graphServiceRef.current.getNode(nodes[0])); + setSelectedEdge(undefined); + setSelectedNode(graphDataMaps.nodesMap.get(nodes[0])); + }, + selectEdge: ({ edges }) => { + setSelectedNode(undefined); + setSelectedEdge(graphDataMaps.edgeGroupMap.get(edges[0])); }, deselectNode: () => { setSelectedNode(undefined); }, + deselectEdge: () => { + setSelectedEdge(undefined); + }, }; let [options, setOptions] = useState({ @@ -137,7 +161,7 @@ export const GraphViewer: React.FC = () => {
{selectedNode && } + {selectedEdge && }
From 7bb4b7c81b0221a01c1cd3bb62d4583b0155861e Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Thu, 5 Dec 2024 11:44:58 +0100 Subject: [PATCH 23/29] Hacking away for visualizer: now showing docs in a very ugly way --- .pre-commit-config.yaml | 5 ++ visualizer/src/App.tsx | 29 ++----- ...eUploadComp.tsx => JsonFileUploadComp.tsx} | 10 +-- .../src/datasources/NdjsonFileUploadComp.tsx | 54 ++++++++++++ visualizer/src/docs/DocService.ts | 71 ++++++++++++++++ visualizer/src/graph/GraphService.ts | 35 ++------ visualizer/src/graph/GraphViewer.tsx | 82 ++++++++----------- visualizer/src/graph/NodeInfo.tsx | 48 ----------- visualizer/src/graph/graph.css | 48 +++++------ visualizer/src/inspector/DocInfo.tsx | 18 ++++ .../src/{graph => inspector}/EdgeInfo.tsx | 7 +- visualizer/src/inspector/NodeInfo.tsx | 21 +++++ visualizer/src/inspector/StatsInfo.tsx | 53 ++++++++++++ .../src/service/ServiceContextProvider.tsx | 78 ++++++++++++++++++ 14 files changed, 381 insertions(+), 178 deletions(-) rename visualizer/src/datasources/{FileUploadComp.tsx => JsonFileUploadComp.tsx} (70%) create mode 100644 visualizer/src/datasources/NdjsonFileUploadComp.tsx create mode 100644 visualizer/src/docs/DocService.ts delete mode 100644 visualizer/src/graph/NodeInfo.tsx create mode 100644 visualizer/src/inspector/DocInfo.tsx rename visualizer/src/{graph => inspector}/EdgeInfo.tsx (90%) create mode 100644 visualizer/src/inspector/NodeInfo.tsx create mode 100644 visualizer/src/inspector/StatsInfo.tsx create mode 100644 visualizer/src/service/ServiceContextProvider.tsx diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 40d3d76..5960b93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,8 @@ repos: rev: v0.5.7 hooks: - id: ruff + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v3.1.0 # Use the sha or tag you want to point at + hooks: + - id: prettier diff --git a/visualizer/src/App.tsx b/visualizer/src/App.tsx index fae17c9..bc4b8cd 100644 --- a/visualizer/src/App.tsx +++ b/visualizer/src/App.tsx @@ -1,27 +1,14 @@ import "./App.css"; -import {BrowserRouter, HashRouter, Link, Route, Routes} from "react-router-dom"; -import {GraphViewer} from "./graph/GraphViewer"; - -function NavBar() { - return ( -
- Graph Viewer -
- ); -} +import { GraphViewer } from "./graph/GraphViewer"; +import React from "react"; +import { ServiceContextProvider } from "./service/ServiceContextProvider"; export function App() { - // Create actual routes if/when more functionality is added to the application - return ( - - - } - /> - - - ); + return ( + + + + ); } export default App; diff --git a/visualizer/src/datasources/FileUploadComp.tsx b/visualizer/src/datasources/JsonFileUploadComp.tsx similarity index 70% rename from visualizer/src/datasources/FileUploadComp.tsx rename to visualizer/src/datasources/JsonFileUploadComp.tsx index ce52eb9..c879243 100644 --- a/visualizer/src/datasources/FileUploadComp.tsx +++ b/visualizer/src/datasources/JsonFileUploadComp.tsx @@ -1,12 +1,12 @@ import React from "react"; -interface FileUploadComponentProps { - onFileLoaded: (data: any ) => void; +interface JsonFileUploadComponentProps { + onFileLoaded: (data: any) => void; } -const FileUploadComponent: React.FC = ({ +const JsonFileUploadComponent: React.FC = ({ onFileLoaded, -}: FileUploadComponentProps) => { +}: JsonFileUploadComponentProps) => { const handleFileChange = (event: React.ChangeEvent) => { const file = event.target.files?.[0]; if (file) { @@ -29,4 +29,4 @@ const FileUploadComponent: React.FC = ({ ); }; -export default FileUploadComponent; +export default JsonFileUploadComponent; diff --git a/visualizer/src/datasources/NdjsonFileUploadComp.tsx b/visualizer/src/datasources/NdjsonFileUploadComp.tsx new file mode 100644 index 0000000..bbf9b90 --- /dev/null +++ b/visualizer/src/datasources/NdjsonFileUploadComp.tsx @@ -0,0 +1,54 @@ +import React from "react"; + +interface NdjsonFileUploadComponentProps { + onFileLoaded: (data: Generator) => void; +} + +const NdjsonFileUploadComponent: React.FC = ({ + onFileLoaded, +}: NdjsonFileUploadComponentProps) => { + const handleFileChange = (event: React.ChangeEvent) => { + const file = event.target.files?.[0]; + if (file) { + const reader = new FileReader(); + reader.onload = (e) => { + const text = e.target?.result; + if (typeof text === "string") { + // Create a generator for NDJSON parsing + const parseNDJSON = function* ( + input: string, + ): Generator { + const lines = input.split("\n"); + for (const line of lines) { + if (line.trim()) { + try { + yield JSON.parse(line); + } catch (error) { + console.error("Invalid JSON line:", line, error); + } + } + } + }; + + try { + // Test if the file is a single JSON object + JSON.parse(text); // Throws if it's not a single JSON + alert("This file is a standard JSON file. NDJSON is expected."); + } catch { + // If not, assume it's NDJSON and pass the generator + onFileLoaded(parseNDJSON(text)); + } + } + }; + reader.readAsText(file); + } + }; + + return ( +
+ +
+ ); +}; + +export default NdjsonFileUploadComponent; diff --git a/visualizer/src/docs/DocService.ts b/visualizer/src/docs/DocService.ts new file mode 100644 index 0000000..1d3731a --- /dev/null +++ b/visualizer/src/docs/DocService.ts @@ -0,0 +1,71 @@ +export abstract class DocService { + abstract getDocData(): Map; + + getDoc(id: string): Doc | undefined { + return this.getDocData().get(id); + } + + getDocs(ids: string[]): Doc[] { + return ids + .map((id) => this.getDoc(id)) + .filter((v): v is Doc => v !== undefined); + } +} + +export interface Doc { + id: string; + text: string; + timestamp: string; +} + +export class SampleDocService extends DocService { + readonly docData: Map = new Map( + [ + { + id: "1", + text: "sample text 1", + timestamp: "", + }, + { + id: "2", + text: "sample text 1", + timestamp: "", + }, + { + id: "3", + text: "sample text 1", + timestamp: "", + }, + ].map((d) => [d.id, d]), + ); + + getDocData(): Map { + return this.docData; + } + + getDoc(id: string): Doc | undefined { + return this.docData.get(id); + } +} + +export class FileDocService extends DocService { + readonly docData: Map; + + getDocData(): Map { + return this.docData; + } + + constructor(docData: Doc[]) { + super(); + this.docData = new Map( + docData.map((d) => [ + d.id, + { + id: d.id, + text: d.text, + timestamp: d.timestamp, + }, + ]), + ); + } +} diff --git a/visualizer/src/graph/GraphService.ts b/visualizer/src/graph/GraphService.ts index a4d3e16..ef23b7b 100644 --- a/visualizer/src/graph/GraphService.ts +++ b/visualizer/src/graph/GraphService.ts @@ -115,14 +115,15 @@ export function filter( const representative: EnrichedEdge = group.at(0)!; return { ...representative, - id: representative.from + '->' + representative.to, + id: representative.from + "->" + representative.to, label: group .slice(0, 3) .map((e) => e.label) .join(", "), - width: - Math.log(group.map((e) => e.stats.frequency).reduce((a, b) => a + b)), - group: group + width: Math.log( + group.map((e) => e.stats.frequency).reduce((a, b) => a + b), + ), + group: group, }; }); @@ -134,8 +135,8 @@ export function filter( ...node, opacity: node.label?.toLowerCase().includes(filter.labelSearch) ? 1 : 0.2, font: { - size: 14 + node.stats.frequency / 100 - } + size: 14 + node.stats.frequency / 100, + }, })); return { nodes, edges }; @@ -148,9 +149,6 @@ export interface DataBounds { } export abstract class GraphService { - private nodesMap: Map | null = null; - private edgesMap: Map | null = null; - abstract getGraph(): EnrichedGraphData; getBounds(): DataBounds { @@ -183,25 +181,6 @@ export abstract class GraphService { .flatMap((edge) => [edge.from!.toString(), edge.to!.toString()]), ); } - - getNode(nodeId: string): EnrichedNode | undefined { - if (this.nodesMap === null) { - this.nodesMap = new Map( - this.getGraph().nodes.map((node) => [node.id!.toString(), node]), - ); - } - return this.nodesMap.get(nodeId); - } - - // getEdges(edgeFromAndTo: string): EnrichedEdge[] | undefined { - // if (this.edgesMap === null) { - // this.edgesMap = new Map( - // this.getGraph().edges.map((node) => [node.id!.toString(), node]), - // ); - // } - // - // return undefined; - // } } export class SampleGraphService extends GraphService { diff --git a/visualizer/src/graph/GraphViewer.tsx b/visualizer/src/graph/GraphViewer.tsx index 819f96e..043ee87 100644 --- a/visualizer/src/graph/GraphViewer.tsx +++ b/visualizer/src/graph/GraphViewer.tsx @@ -1,47 +1,35 @@ -import React, { useMemo, useRef, useState } from "react"; -import { - EdgeGroup, - EnrichedNode, - FileGraphService, - filter, - GraphFilter, - GraphService, - SampleGraphService, -} from "./GraphService"; -import FileUploadComponent from "../datasources/FileUploadComp"; +import React, { useMemo, useState } from "react"; +import { EdgeGroup, EnrichedNode, filter, GraphFilter } from "./GraphService"; import Graph, { GraphEvents, Options } from "react-vis-graph-wrapper"; import { GraphFilterControlPanel } from "./GraphFilterControlPanel"; import { GraphOptionsControlPanel } from "./GraphOptionsControlPanel"; -import { NodeInfo } from "./NodeInfo"; -import { EdgeInfo } from "./EdgeInfo"; +import { NodeInfo } from "../inspector/NodeInfo"; +import { EdgeInfo } from "../inspector/EdgeInfo"; +import { useServiceContext } from "../service/ServiceContextProvider"; + +export interface GraphViewerProps {} export const GraphViewer: React.FC = () => { - let graphServiceRef = useRef(new SampleGraphService()); + const { getGraphService, getDocService } = useServiceContext(); - const handleFileLoaded = (data: any) => { - graphServiceRef.current = new FileGraphService(data); - const top50 = - graphServiceRef.current - .getGraph() - .nodes.map((n) => n.stats.frequency) - .sort((a, b) => b - a) - .at(100) || 1; - let { minNodeFrequency, maxNodeFrequency, maxEdgeFrequency } = - graphServiceRef.current.getBounds(); - setGraphFilter( - new GraphFilter( - minNodeFrequency, - top50, - maxNodeFrequency, - 1, - Math.floor(top50 / 10), - maxEdgeFrequency, - ), - ); - }; + const top50 = + getGraphService() + .getGraph() + .nodes.map((n) => n.stats.frequency) + .sort((a, b) => b - a) + .at(100) || 1; + let { minNodeFrequency, maxNodeFrequency, maxEdgeFrequency } = + getGraphService().getBounds(); const [graphFilter, setGraphFilter] = useState( - new GraphFilter(1, 1, 10, 1, 1, 10), + new GraphFilter( + minNodeFrequency, + top50, + maxNodeFrequency, + 1, + Math.floor(top50 / 10), + maxEdgeFrequency, + ), ); const [subgraphNodes, setSubgraphNodes] = useState(new Set()); const [selectedNode, setSelectedNode] = useState( @@ -54,8 +42,8 @@ export const GraphViewer: React.FC = () => { const filteredGraphData = useMemo(() => { const baseGraphData = subgraphNodes.size > 0 - ? graphServiceRef.current.getSubGraph(subgraphNodes) - : graphServiceRef.current.getGraph(); + ? getGraphService().getSubGraph(subgraphNodes) + : getGraphService().getGraph(); return filter(graphFilter, baseGraphData); }, [graphFilter, subgraphNodes]); @@ -91,19 +79,21 @@ export const GraphViewer: React.FC = () => { doubleClick: ({ nodes }) => { const newSelected = new Set(subgraphNodes); nodes.forEach((element: string) => { - Array.from(graphServiceRef.current.getConnectedNodes(element)).forEach( - (c) => newSelected.add(c), + Array.from(getGraphService().getConnectedNodes(element)).forEach((c) => + newSelected.add(c), ); }); setSubgraphNodes(newSelected); }, - selectNode: ({ nodes }) => { + selectNode: ({ nodes, edges }) => { setSelectedEdge(undefined); setSelectedNode(graphDataMaps.nodesMap.get(nodes[0])); }, - selectEdge: ({ edges }) => { - setSelectedNode(undefined); - setSelectedEdge(graphDataMaps.edgeGroupMap.get(edges[0])); + selectEdge: ({ nodes, edges }) => { + if (nodes.length < 1) { + setSelectedNode(undefined); + setSelectedEdge(graphDataMaps.edgeGroupMap.get(edges[0])); + } }, deselectNode: () => { setSelectedNode(undefined); @@ -130,10 +120,6 @@ export const GraphViewer: React.FC = () => { return (
-
- -
-
= ({ - node, - className, -}: NodeInfoProps) => { - const stats = node.stats; - return ( -
- {node.label} -
-
-

Frequency: {stats.frequency}

- {/*

Norm. frequency: {stats.norm_frequency?.toPrecision(3)}

*/} - {stats.first_occurrence && ( -

Earliest date: {stats.first_occurrence}

- )} - {stats.last_occurrence &&

Latest date: {stats.last_occurrence}

} - {stats.alt_labels && ( -
- Alternative Labels: -
    - {stats.alt_labels.map((l) => ( -
  • {l}
  • - ))} -
-
- )} - {stats.docs && ( -
- Documents -
    - {stats.docs.map((d) => ( -
  • {d}
  • - ))} -
-
- )} -
-
- ); -}; diff --git a/visualizer/src/graph/graph.css b/visualizer/src/graph/graph.css index 3f64b8d..568600d 100644 --- a/visualizer/src/graph/graph.css +++ b/visualizer/src/graph/graph.css @@ -1,42 +1,44 @@ .padded { - padding: 5px; + padding: 5px; } .flex-container { - display: flex; - align-items: center; - margin-top: 5px; - margin-bottom: 5px; + display: flex; + align-items: center; + margin-top: 5px; + margin-bottom: 5px; } .flex-container__element { - display: flex; - align-items: center; - margin-left: 20px; - + display: flex; + align-items: center; + margin-left: 20px; } .flex-container__element:first-child { - margin-left: 0; + margin-left: 0; } .flex-container__element__sub-element { - margin: 2px; + margin: 2px; } .node-info { - position: absolute; - z-index: 3; - background: white; - border: solid 1px gray; - font-size: small; - max-width: 250px; - padding: 5px; - margin: 2px; + position: absolute; + z-index: 3; + background: white; + border: solid 1px gray; + font-size: small; + width: 20%; + max-width: 500px; + max-height: 80%; + overflow-y: scroll; + padding: 5px; + margin: 2px; } - .graph-container { - height: 80vh; - border: 1px inset; -} \ No newline at end of file + height: 80vh; + max-height: 80vh; + border: 1px inset; +} diff --git a/visualizer/src/inspector/DocInfo.tsx b/visualizer/src/inspector/DocInfo.tsx new file mode 100644 index 0000000..401414c --- /dev/null +++ b/visualizer/src/inspector/DocInfo.tsx @@ -0,0 +1,18 @@ +import React from "react"; +import { Doc } from "../docs/DocService"; + +export interface DocInfoProps { + document: Doc; +} + +export const DocInfo: React.FC = ({ document }) => { + return ( +
+

{document.id}

+ {document.timestamp} +

{document.text}

+
+ ); +}; diff --git a/visualizer/src/graph/EdgeInfo.tsx b/visualizer/src/inspector/EdgeInfo.tsx similarity index 90% rename from visualizer/src/graph/EdgeInfo.tsx rename to visualizer/src/inspector/EdgeInfo.tsx index 2b5a430..bd06bdb 100644 --- a/visualizer/src/graph/EdgeInfo.tsx +++ b/visualizer/src/inspector/EdgeInfo.tsx @@ -1,4 +1,4 @@ -import {EdgeGroup} from "./GraphService"; +import { EdgeGroup } from "../graph/GraphService"; import React from "react"; export interface EdgeInfoProps { @@ -8,10 +8,7 @@ export interface EdgeInfoProps { export const EdgeInfo: React.FC = ({ edges }: EdgeInfoProps) => { return ( -
+
{edges.group!.map((e, i) => (
{e.label} diff --git a/visualizer/src/inspector/NodeInfo.tsx b/visualizer/src/inspector/NodeInfo.tsx new file mode 100644 index 0000000..631fc12 --- /dev/null +++ b/visualizer/src/inspector/NodeInfo.tsx @@ -0,0 +1,21 @@ +import { EnrichedNode } from "../graph/GraphService"; +import React from "react"; +import { StatsInfo } from "./StatsInfo"; + +export interface NodeInfoProps { + node: EnrichedNode; + className?: string; +} + +export const NodeInfo: React.FC = ({ + node, + className, +}: NodeInfoProps) => { + return ( +
+ {node.label} +
+ +
+ ); +}; diff --git a/visualizer/src/inspector/StatsInfo.tsx b/visualizer/src/inspector/StatsInfo.tsx new file mode 100644 index 0000000..e9b75e9 --- /dev/null +++ b/visualizer/src/inspector/StatsInfo.tsx @@ -0,0 +1,53 @@ +import { DocInfo } from "./DocInfo"; +import React from "react"; +import { Stats } from "../graph/GraphService"; +import { useServiceContext } from "../service/ServiceContextProvider"; + +export interface StatsInfoProps { + stats: Stats; +} + +export const StatsInfo: React.FC = ({ stats }) => { + const { getDocService } = useServiceContext(); + + return ( +
+

Frequency: {stats.frequency}

+ {/*

Norm. frequency: {stats.norm_frequency?.toPrecision(3)}

*/} + {stats.first_occurrence &&

Earliest date: {stats.first_occurrence}

} + {stats.last_occurrence &&

Latest date: {stats.last_occurrence}

} + {stats.alt_labels && ( +
+ Alternative Labels: +
    + {stats.alt_labels.map((l) => ( +
  • {l}
  • + ))} +
+
+ )} + {stats.docs && ( +
+ Documents + {!getDocService() && ( +
    + {stats.docs.map((d) => ( +
  • {d}
  • + ))} +
+ )} + + {getDocService() && ( +
+ {getDocService() + .getDocs(stats.docs) + .map((d) => ( + + ))} +
+ )} +
+ )} +
+ ); +}; diff --git a/visualizer/src/service/ServiceContextProvider.tsx b/visualizer/src/service/ServiceContextProvider.tsx new file mode 100644 index 0000000..b8dd35d --- /dev/null +++ b/visualizer/src/service/ServiceContextProvider.tsx @@ -0,0 +1,78 @@ +import React, { + createContext, + PropsWithChildren, + useContext, + useState, +} from "react"; +import { DocService, FileDocService } from "../docs/DocService"; +import { FileGraphService, GraphService } from "../graph/GraphService"; +import JsonFileUploadComponent from "../datasources/JsonFileUploadComp"; +import NdjsonFileUploadComponent from "../datasources/NdjsonFileUploadComp"; + +interface Services { + getGraphService: () => GraphService; + setGraphService: (service: GraphService) => void; + getDocService: () => DocService; + setDocService: (service: DocService) => void; +} + +const ServiceContext = createContext(undefined); + +export const ServiceContextProvider: React.FC = ({ + children, +}) => { + const [graphService, setGraphService] = useState( + undefined, + ); + const [docService, setDocService] = useState( + undefined, + ); + + const value: Services = { + getGraphService: () => { + if (!graphService) { + throw new Error("DocService has not been initialized!"); + } + return graphService; + }, + setGraphService: (service: GraphService) => setGraphService(service), + getDocService: () => { + if (!docService) { + throw new Error("DocService has not been initialized!"); + } + return docService; + }, + setDocService: (service: DocService) => setDocService(service), + }; + + if (!graphService || !docService) { + const handleGraphFileLoaded = (data: any) => { + setGraphService(new FileGraphService(data)); + }; + + const handleDocsFileLoaded = (data: any) => { + setDocService(new FileDocService(data)); + }; + + return ( +
+ Load graph:  + + Load documents:  + +
+ ); + } + + return ( + {children} + ); +}; + +export const useServiceContext = (): Services => { + const context = useContext(ServiceContext); + if (!context) { + throw new Error("useServiceContext must be used within a ServiceProvider"); + } + return context; +}; From 0ea45b5cd27143239382a94838729fab3476c0b3 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Fri, 6 Dec 2024 11:15:05 +0100 Subject: [PATCH 24/29] Only targeting visualizer files with prettier --- .pre-commit-config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5960b93..c2efc26 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,6 +23,7 @@ repos: - id: ruff - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.1.0 # Use the sha or tag you want to point at + rev: v3.1.0 hooks: - id: prettier + files: "visualizer/.*" From 4a86104f402138b4a084821455bf82c61a2f9cb8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:21:42 +0000 Subject: [PATCH 25/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- visualizer/README.md | 2 +- visualizer/electron-main.js | 1 - visualizer/public/index.html | 7 +- visualizer/public/manifest.json | 3 +- .../src/common/LogarithmicRangeSlider.tsx | 26 ++-- .../src/graph/GraphOptionsControlPanel.tsx | 135 ++++++++++-------- visualizer/src/index.css | 6 +- visualizer/src/index.tsx | 6 +- visualizer/tsconfig.json | 10 +- 9 files changed, 99 insertions(+), 97 deletions(-) diff --git a/visualizer/README.md b/visualizer/README.md index 22bfb4f..940b5c2 100644 --- a/visualizer/README.md +++ b/visualizer/README.md @@ -7,4 +7,4 @@ It works as any React app in terms of building, running etc. If you just want to 1. Ensure that you have Node >=16 and npm installed. See https://docs.npmjs.com/downloading-and-installing-node-js-and-npm. 2. From the directory `conspiracies/visualizer`, run `npm install`. 3. Then run `npm start` which will open a development server on `localhost:3000`. -4. Load in the file `graph.json` from your output via the GUI. \ No newline at end of file +4. Load in the file `graph.json` from your output via the GUI. diff --git a/visualizer/electron-main.js b/visualizer/electron-main.js index e1d363d..38213e9 100644 --- a/visualizer/electron-main.js +++ b/visualizer/electron-main.js @@ -1,7 +1,6 @@ const { app, BrowserWindow } = require("electron"); const path = require("path"); - function createWindow() { const mainWindow = new BrowserWindow(); diff --git a/visualizer/public/index.html b/visualizer/public/index.html index 32055ae..5ddbf33 100644 --- a/visualizer/public/index.html +++ b/visualizer/public/index.html @@ -1,12 +1,9 @@ - + - + Visualizer diff --git a/visualizer/public/manifest.json b/visualizer/public/manifest.json index d0c7108..96b3f5a 100644 --- a/visualizer/public/manifest.json +++ b/visualizer/public/manifest.json @@ -1,8 +1,7 @@ { "short_name": "Visualizer", "name": "Narrative Graphs Visualizer", - "icons": [ - ], + "icons": [], "start_url": ".", "display": "standalone", "theme_color": "#000000", diff --git a/visualizer/src/common/LogarithmicRangeSlider.tsx b/visualizer/src/common/LogarithmicRangeSlider.tsx index b766506..2ca6226 100644 --- a/visualizer/src/common/LogarithmicRangeSlider.tsx +++ b/visualizer/src/common/LogarithmicRangeSlider.tsx @@ -89,19 +89,19 @@ const LogarithmicRangeSlider: React.FC = ({ ]; return ( - + ); }; export default LogarithmicRangeSlider; diff --git a/visualizer/src/graph/GraphOptionsControlPanel.tsx b/visualizer/src/graph/GraphOptionsControlPanel.tsx index 3b374a4..afc9554 100644 --- a/visualizer/src/graph/GraphOptionsControlPanel.tsx +++ b/visualizer/src/graph/GraphOptionsControlPanel.tsx @@ -1,69 +1,84 @@ import React from "react"; -import './graph.css' -import {Options} from "react-vis-graph-wrapper"; - +import "./graph.css"; +import { Options } from "react-vis-graph-wrapper"; interface GraphOptionsControlPanelProps { - options: Options; - setOptions: React.Dispatch>; - + options: Options; + setOptions: React.Dispatch>; } function getSmoothEnabled(options: Options): boolean { - if (typeof options.edges?.smooth === 'boolean') { - return options.edges.smooth; - } else if (typeof options.edges?.smooth === 'object' && 'enabled' in options.edges.smooth) { - return options.edges.smooth.enabled; - } else { - return false; - } + if (typeof options.edges?.smooth === "boolean") { + return options.edges.smooth; + } else if ( + typeof options.edges?.smooth === "object" && + "enabled" in options.edges.smooth + ) { + return options.edges.smooth.enabled; + } else { + return false; + } } -export const GraphOptionsControlPanel = ({options, setOptions}: GraphOptionsControlPanelProps) => { - - - return
-
- Physics enabled: - setOptions( - { - ...options, - physics: { - ...options.physics, - enabled: event.target.checked - } - }) - }/> -
-
- Rounded edges: - setOptions( - { - ...options, - edges: { - ...options.edges, - smooth: !options.edges?.smooth - } - }) - }/> -
-
- Edge length: - setOptions( - { - ...options, - physics: { - ...options.physics, - barnesHut: { - springLength: Number(event.target.value) - } - } - }) - } - step="1"/> -
+export const GraphOptionsControlPanel = ({ + options, + setOptions, +}: GraphOptionsControlPanelProps) => { + return ( +
+
+ Physics enabled: + + setOptions({ + ...options, + physics: { + ...options.physics, + enabled: event.target.checked, + }, + }) + } + /> +
+
+ Rounded edges: + + setOptions({ + ...options, + edges: { + ...options.edges, + smooth: !options.edges?.smooth, + }, + }) + } + /> +
+
+ Edge length: + + setOptions({ + ...options, + physics: { + ...options.physics, + barnesHut: { + springLength: Number(event.target.value), + }, + }, + }) + } + step="1" + /> +
-} \ No newline at end of file + ); +}; diff --git a/visualizer/src/index.css b/visualizer/src/index.css index ec2585e..4a1df4d 100644 --- a/visualizer/src/index.css +++ b/visualizer/src/index.css @@ -1,13 +1,13 @@ body { margin: 0; - font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', - 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", "Oxygen", + "Ubuntu", "Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } code { - font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', + font-family: source-code-pro, Menlo, Monaco, Consolas, "Courier New", monospace; } diff --git a/visualizer/src/index.tsx b/visualizer/src/index.tsx index 77c175d..008f05d 100644 --- a/visualizer/src/index.tsx +++ b/visualizer/src/index.tsx @@ -4,8 +4,6 @@ import "./index.css"; import App from "./App"; const root = ReactDOM.createRoot( - document.getElementById("root") as HTMLElement -); -root.render( - + document.getElementById("root") as HTMLElement, ); +root.render(); diff --git a/visualizer/tsconfig.json b/visualizer/tsconfig.json index a273b0c..9d379a3 100644 --- a/visualizer/tsconfig.json +++ b/visualizer/tsconfig.json @@ -1,11 +1,7 @@ { "compilerOptions": { "target": "es5", - "lib": [ - "dom", - "dom.iterable", - "esnext" - ], + "lib": ["dom", "dom.iterable", "esnext"], "allowJs": true, "skipLibCheck": true, "esModuleInterop": true, @@ -20,7 +16,5 @@ "noEmit": true, "jsx": "react-jsx" }, - "include": [ - "src" - ] + "include": ["src"] } From a07c2cc4ac851ba1907fdfd8a3e7a857857651c9 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Fri, 6 Dec 2024 11:23:42 +0100 Subject: [PATCH 26/29] Showing documents with triplets marked up --- .../relationextraction/data_classes.py | 2 + visualizer/package-lock.json | 138 ++++++++++++++++++ visualizer/package.json | 2 + visualizer/src/docs/DocService.ts | 38 ++++- visualizer/src/graph/GraphViewer.tsx | 4 +- visualizer/src/inspector/DocInfo.tsx | 121 ++++++++++++++- visualizer/src/inspector/EdgeInfo.tsx | 33 +---- visualizer/src/inspector/NodeInfo.tsx | 2 +- visualizer/src/inspector/StatsInfo.tsx | 9 +- visualizer/src/inspector/docinfo.css | 29 ++++ 10 files changed, 328 insertions(+), 50 deletions(-) create mode 100644 visualizer/src/inspector/docinfo.css diff --git a/src/conspiracies/docprocessing/relationextraction/data_classes.py b/src/conspiracies/docprocessing/relationextraction/data_classes.py index 5370451..4ac5276 100644 --- a/src/conspiracies/docprocessing/relationextraction/data_classes.py +++ b/src/conspiracies/docprocessing/relationextraction/data_classes.py @@ -264,7 +264,9 @@ def span_to_json(span: Union[Span, Doc]) -> Dict[str, Any]: span = span[:] return { "text": span.text, + "start_char": span.start_char, "start": span.start, + "end_char": span.end_char, "end": span.end, } diff --git a/visualizer/package-lock.json b/visualizer/package-lock.json index 16425f6..d3726cc 100644 --- a/visualizer/package-lock.json +++ b/visualizer/package-lock.json @@ -14,10 +14,12 @@ "@types/jest": "^27.5.2", "@types/node": "^16.18.96", "@types/react-dom": "^18.2.24", + "draft-js": "^0.11.7", "multi-range-slider-react": "^2.0.7", "react": "^18.2.0", "react-dom": "^18.2.0", "react-graph-vis": "^1.0.7", + "react-highlight-within-textarea": "^3.2.2", "react-router-dom": "^6.22.3", "react-scripts": "5.0.1", "react-vis-graph-wrapper": "^0.1.3", @@ -6516,6 +6518,14 @@ "integrity": "sha512-+R08/oI0nl3vfPcqftZRpytksBXDzOUveBq/NBVx0sUp1axwzPQrKinNx5yd5sxPu8j1wIy8AfnVQ+5eFdha6Q==", "dev": true }, + "node_modules/cross-fetch": { + "version": "3.1.8", + "resolved": "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.8.tgz", + "integrity": "sha512-cvA+JwZoU0Xq+h6WkMvAUqPEYy92Obet6UdKLfW60qn99ftItKjB5T+BkyWOFWe2pUyfQ+IJHmpOTznqk1M6Kg==", + "dependencies": { + "node-fetch": "^2.6.12" + } + }, "node_modules/cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -7417,6 +7427,20 @@ "resolved": "https://registry.npmjs.org/dotenv-expand/-/dotenv-expand-5.1.0.tgz", "integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==" }, + "node_modules/draft-js": { + "version": "0.11.7", + "resolved": "https://registry.npmjs.org/draft-js/-/draft-js-0.11.7.tgz", + "integrity": "sha512-ne7yFfN4sEL82QPQEn80xnADR8/Q6ALVworbC5UOSzOvjffmYfFsr3xSZtxbIirti14R7Y33EZC5rivpLgIbsg==", + "dependencies": { + "fbjs": "^2.0.0", + "immutable": "~3.7.4", + "object-assign": "^4.1.1" + }, + "peerDependencies": { + "react": ">=0.14.0", + "react-dom": ">=0.14.0" + } + }, "node_modules/duplexer": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz", @@ -8609,6 +8633,34 @@ "bser": "2.1.1" } }, + "node_modules/fbjs": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/fbjs/-/fbjs-2.0.0.tgz", + "integrity": "sha512-8XA8ny9ifxrAWlyhAbexXcs3rRMtxWcs3M0lctLfB49jRDHiaxj+Mo0XxbwE7nKZYzgCFoq64FS+WFd4IycPPQ==", + "dependencies": { + "core-js": "^3.6.4", + "cross-fetch": "^3.0.4", + "fbjs-css-vars": "^1.0.0", + "loose-envify": "^1.0.0", + "object-assign": "^4.1.0", + "promise": "^7.1.1", + "setimmediate": "^1.0.5", + "ua-parser-js": "^0.7.18" + } + }, + "node_modules/fbjs-css-vars": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/fbjs-css-vars/-/fbjs-css-vars-1.0.2.tgz", + "integrity": "sha512-b2XGFAFdWZWg0phtAWLHCk836A1Xann+I+Dgd3Gk64MHKZO44FfoD1KxyvbSh0qZsIoXQGGlVztIY+oitJPpRQ==" + }, + "node_modules/fbjs/node_modules/promise": { + "version": "7.3.1", + "resolved": "https://registry.npmjs.org/promise/-/promise-7.3.1.tgz", + "integrity": "sha512-nolQXZ/4L+bP/UGlkfaIujX9BKxGwmQ9OT4mOt5yvy8iK1h3wqTEJCijzGANTCCl9nWjY41juyAn2K3Q1hLLTg==", + "dependencies": { + "asap": "~2.0.3" + } + }, "node_modules/fd-slicer": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/fd-slicer/-/fd-slicer-1.1.0.tgz", @@ -9791,6 +9843,14 @@ "url": "https://opencollective.com/immer" } }, + "node_modules/immutable": { + "version": "3.7.6", + "resolved": "https://registry.npmjs.org/immutable/-/immutable-3.7.6.tgz", + "integrity": "sha512-AizQPcaofEtO11RZhPPHBOJRdo/20MKQF9mBLnVkBoyHi1/zXK8fzVdnEpSV9gxqtnh6Qomfp3F0xT5qP/vThw==", + "engines": { + "node": ">=0.8.0" + } + }, "node_modules/import-fresh": { "version": "3.3.0", "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", @@ -12160,6 +12220,44 @@ "tslib": "^2.0.3" } }, + "node_modules/node-fetch": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.7.0.tgz", + "integrity": "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==", + "dependencies": { + "whatwg-url": "^5.0.0" + }, + "engines": { + "node": "4.x || >=6.0.0" + }, + "peerDependencies": { + "encoding": "^0.1.0" + }, + "peerDependenciesMeta": { + "encoding": { + "optional": true + } + } + }, + "node_modules/node-fetch/node_modules/tr46": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-0.0.3.tgz", + "integrity": "sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==" + }, + "node_modules/node-fetch/node_modules/webidl-conversions": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-3.0.1.tgz", + "integrity": "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==" + }, + "node_modules/node-fetch/node_modules/whatwg-url": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", + "integrity": "sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==", + "dependencies": { + "tr46": "~0.0.3", + "webidl-conversions": "^3.0.0" + } + }, "node_modules/node-forge": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/node-forge/-/node-forge-1.3.1.tgz", @@ -14550,6 +14648,16 @@ "integrity": "sha512-FULf7fayPdpASncVy4DLh3xydlXEJJpvIELjYjNeQWYUZ9pclcpvCZSr2gkmN2FrrGcI7G/cJsIEwk5/8vfXpg==", "deprecated": "Please upgrade to version 7 or higher. Older versions may use Math.random() in certain circumstances, which is known to be problematic. See https://v8.dev/blog/math-random for details." }, + "node_modules/react-highlight-within-textarea": { + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/react-highlight-within-textarea/-/react-highlight-within-textarea-3.2.2.tgz", + "integrity": "sha512-pS+tPi6//dM8V154/0SfSqkx+0i6lKpSKazLZa7+RQjNQg0wKeCZBVkOGtxAhsVJy5KWpfIfdcpE8JpZ2Giz/g==", + "peerDependencies": { + "draft-js": ">=0.11.7", + "react": ">=0.14.0", + "react-dom": ">=0.14.0" + } + }, "node_modules/react-is": { "version": "17.0.2", "resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz", @@ -15793,6 +15901,11 @@ "node": ">= 0.4" } }, + "node_modules/setimmediate": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/setimmediate/-/setimmediate-1.0.5.tgz", + "integrity": "sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==" + }, "node_modules/setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -17165,6 +17278,31 @@ "node": ">=4.2.0" } }, + "node_modules/ua-parser-js": { + "version": "0.7.39", + "resolved": "https://registry.npmjs.org/ua-parser-js/-/ua-parser-js-0.7.39.tgz", + "integrity": "sha512-IZ6acm6RhQHNibSt7+c09hhvsKy9WUr4DVbeq9U8o71qxyYtJpQeDxQnMrVqnIFMLcQjHO0I9wgfO2vIahht4w==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/ua-parser-js" + }, + { + "type": "paypal", + "url": "https://paypal.me/faisalman" + }, + { + "type": "github", + "url": "https://github.com/sponsors/faisalman" + } + ], + "bin": { + "ua-parser-js": "script/cli.js" + }, + "engines": { + "node": "*" + } + }, "node_modules/unbox-primitive": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/unbox-primitive/-/unbox-primitive-1.0.2.tgz", diff --git a/visualizer/package.json b/visualizer/package.json index c64ff2e..745725b 100644 --- a/visualizer/package.json +++ b/visualizer/package.json @@ -9,10 +9,12 @@ "@types/jest": "^27.5.2", "@types/node": "^16.18.96", "@types/react-dom": "^18.2.24", + "draft-js": "^0.11.7", "multi-range-slider-react": "^2.0.7", "react": "^18.2.0", "react-dom": "^18.2.0", "react-graph-vis": "^1.0.7", + "react-highlight-within-textarea": "^3.2.2", "react-router-dom": "^6.22.3", "react-scripts": "5.0.1", "react-vis-graph-wrapper": "^0.1.3", diff --git a/visualizer/src/docs/DocService.ts b/visualizer/src/docs/DocService.ts index 1d3731a..a017fe1 100644 --- a/visualizer/src/docs/DocService.ts +++ b/visualizer/src/docs/DocService.ts @@ -12,10 +12,25 @@ export abstract class DocService { } } +export interface TripletField { + text: string; + start_char: number; + start: number; + end_char: number; + end: number; +} + +export interface Triplet { + subject: TripletField; + predicate: TripletField; + object: TripletField; +} + export interface Doc { id: string; text: string; timestamp: string; + semantic_triplets: Triplet[]; } export class SampleDocService extends DocService { @@ -25,16 +40,19 @@ export class SampleDocService extends DocService { id: "1", text: "sample text 1", timestamp: "", + semantic_triplets: [], }, { id: "2", text: "sample text 1", timestamp: "", + semantic_triplets: [], }, { id: "3", text: "sample text 1", timestamp: "", + semantic_triplets: [], }, ].map((d) => [d.id, d]), ); @@ -58,14 +76,18 @@ export class FileDocService extends DocService { constructor(docData: Doc[]) { super(); this.docData = new Map( - docData.map((d) => [ - d.id, - { - id: d.id, - text: d.text, - timestamp: d.timestamp, - }, - ]), + docData + .filter((d) => d.semantic_triplets !== undefined) + .map((d) => [ + d.id, + { + id: d.id, + text: d.text, + timestamp: d.timestamp, + semantic_triplets: d.semantic_triplets, + }, + ]), ); + console.log(this.docData.size); } } diff --git a/visualizer/src/graph/GraphViewer.tsx b/visualizer/src/graph/GraphViewer.tsx index 043ee87..333d4d0 100644 --- a/visualizer/src/graph/GraphViewer.tsx +++ b/visualizer/src/graph/GraphViewer.tsx @@ -10,7 +10,7 @@ import { useServiceContext } from "../service/ServiceContextProvider"; export interface GraphViewerProps {} export const GraphViewer: React.FC = () => { - const { getGraphService, getDocService } = useServiceContext(); + const { getGraphService } = useServiceContext(); const top50 = getGraphService() @@ -45,7 +45,7 @@ export const GraphViewer: React.FC = () => { ? getGraphService().getSubGraph(subgraphNodes) : getGraphService().getGraph(); return filter(graphFilter, baseGraphData); - }, [graphFilter, subgraphNodes]); + }, [getGraphService, graphFilter, subgraphNodes]); const graphDataMaps = useMemo(() => { return { diff --git a/visualizer/src/inspector/DocInfo.tsx b/visualizer/src/inspector/DocInfo.tsx index 401414c..98b325d 100644 --- a/visualizer/src/inspector/DocInfo.tsx +++ b/visualizer/src/inspector/DocInfo.tsx @@ -1,18 +1,127 @@ -import React from "react"; -import { Doc } from "../docs/DocService"; +import React, { PropsWithChildren } from "react"; +import { Doc, Triplet } from "../docs/DocService"; +import HighlightWithinTextarea from "react-highlight-within-textarea"; +import "./docinfo.css"; + +const BlueHighlight: React.FC = (props) => { + return ( + + {props.children} + + ); +}; + +const GreenHighlight: React.FC = (props) => { + return ( + + {props.children} + + ); +}; + +const RedHighlight: React.FC = (props) => { + return ( + {props.children} + ); +}; + +interface HighlightedTextProps { + text: string; + triplets: Triplet[]; + highlightLabels: string[]; +} + +const HighlightedText: React.FC = ({ + text, + triplets, + highlightLabels, +}) => { + const subjects = []; + const highlightSubjects = []; + const predicates = []; + const highlightPredicates = []; + const objects = []; + const highlightObjects = []; + + for (let triplet of triplets) { + const subject = triplet.subject; + const subjectSpan = [subject.start_char, subject.end_char]; + if (highlightLabels.indexOf(subject.text) > -1) { + highlightSubjects.push(subjectSpan); + } else { + subjects.push(subjectSpan); + } + const predicate = triplet.predicate; + const predicateSpan = [predicate.start_char, predicate.end_char]; + if (highlightLabels.indexOf(predicate.text) > -1) { + highlightPredicates.push(predicateSpan); + } else { + predicates.push(predicateSpan); + } + const object = triplet.object; + const objectSpan = [object.start_char, object.end_char]; + if (highlightLabels.indexOf(object.text) > -1) { + highlightObjects.push(objectSpan); + } else { + objects.push(objectSpan); + } + } + + return ( + + ); +}; export interface DocInfoProps { document: Doc; + highlightLabels: string[]; } -export const DocInfo: React.FC = ({ document }) => { +export const DocInfo: React.FC = ({ + document, + highlightLabels, +}) => { return (
-

{document.id}

- {document.timestamp} -

{document.text}

+

+ {document.id} {document.timestamp} +

+
); }; diff --git a/visualizer/src/inspector/EdgeInfo.tsx b/visualizer/src/inspector/EdgeInfo.tsx index bd06bdb..f73425c 100644 --- a/visualizer/src/inspector/EdgeInfo.tsx +++ b/visualizer/src/inspector/EdgeInfo.tsx @@ -1,5 +1,6 @@ import { EdgeGroup } from "../graph/GraphService"; import React from "react"; +import { StatsInfo } from "./StatsInfo"; export interface EdgeInfoProps { edges: EdgeGroup; @@ -12,37 +13,7 @@ export const EdgeInfo: React.FC = ({ edges }: EdgeInfoProps) => { {edges.group!.map((e, i) => (
{e.label} -
-

Frequency: {e.stats.frequency}

- {/*

Norm. frequency: {e.stats.norm_frequency?.toPrecision(3)}

*/} - {e.stats.first_occurrence && ( -

Earliest date: {e.stats.first_occurrence}

- )} - {e.stats.last_occurrence && ( -

Latest date: {e.stats.last_occurrence}

- )} - {e.stats.alt_labels && ( -
- Alternative Labels: -
    - {e.stats.alt_labels.map((l) => ( -
  • {l}
  • - ))} -
-
- )} - {e.stats.docs && ( -
- Documents -
    - {e.stats.docs.map((d) => ( -
  • {d}
  • - ))} -
-
- )} -
- {i < edges.group!.length - 1 &&
} +
))}
diff --git a/visualizer/src/inspector/NodeInfo.tsx b/visualizer/src/inspector/NodeInfo.tsx index 631fc12..3785a5e 100644 --- a/visualizer/src/inspector/NodeInfo.tsx +++ b/visualizer/src/inspector/NodeInfo.tsx @@ -15,7 +15,7 @@ export const NodeInfo: React.FC = ({
{node.label}
- +
); }; diff --git a/visualizer/src/inspector/StatsInfo.tsx b/visualizer/src/inspector/StatsInfo.tsx index e9b75e9..8676b8d 100644 --- a/visualizer/src/inspector/StatsInfo.tsx +++ b/visualizer/src/inspector/StatsInfo.tsx @@ -4,10 +4,11 @@ import { Stats } from "../graph/GraphService"; import { useServiceContext } from "../service/ServiceContextProvider"; export interface StatsInfoProps { + label: string; stats: Stats; } -export const StatsInfo: React.FC = ({ stats }) => { +export const StatsInfo: React.FC = ({ label, stats }) => { const { getDocService } = useServiceContext(); return ( @@ -42,7 +43,11 @@ export const StatsInfo: React.FC = ({ stats }) => { {getDocService() .getDocs(stats.docs) .map((d) => ( - + ))}
)} diff --git a/visualizer/src/inspector/docinfo.css b/visualizer/src/inspector/docinfo.css new file mode 100644 index 0000000..f3ba462 --- /dev/null +++ b/visualizer/src/inspector/docinfo.css @@ -0,0 +1,29 @@ +.highlight-subject { + background: cyan; + opacity: 1; +} + +.subject { + background: cyan; + opacity: 0.3; +} + +.highlight-predicate { + background: lightgreen; + opacity: 1; +} + +.predicate { + background: lightgreen; + opacity: 0.3; +} + +.highlight-object { + background: yellow; + opacity: 1; +} + +.object { + background: yellow; + opacity: 0.3; +} From 585437387d0897fa27177d880c6b86d207ba7b00 Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Mon, 9 Dec 2024 12:55:29 +0100 Subject: [PATCH 27/29] Introducing database for entities, relations, triplets and documents --- config/eschatology.toml | 8 +- pyproject.toml | 3 +- .../corpusprocessing/clustering.py | 1 - src/conspiracies/corpusprocessing/triplet.py | 2 + src/conspiracies/database/__init__.py | 0 src/conspiracies/database/engine.py | 23 ++++ src/conspiracies/database/models.py | 103 ++++++++++++++++++ src/conspiracies/pipeline/config.py | 5 + src/conspiracies/pipeline/pipeline.py | 72 +++++++++++- 9 files changed, 211 insertions(+), 6 deletions(-) create mode 100644 src/conspiracies/database/__init__.py create mode 100644 src/conspiracies/database/engine.py create mode 100644 src/conspiracies/database/models.py diff --git a/config/eschatology.toml b/config/eschatology.toml index 6ac362f..d66f4f2 100644 --- a/config/eschatology.toml +++ b/config/eschatology.toml @@ -11,10 +11,14 @@ text_column = "body" timestamp_column = "timestamp" [docprocessing] -enabled = true +enabled = false batch_size = 50 prefer_gpu_for_coref = true n_process = 1 [corpusprocessing] -enabled = true \ No newline at end of file +enabled = false + +[databasepopulation] +enabled = true +clear_and_write = true \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 20540db..ad9e2b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,8 @@ dependencies = [ "stop-words", "bs4", "toml", - "fastcoref" + "fastcoref", + "sqlalchemy" ] [project.license] diff --git a/src/conspiracies/corpusprocessing/clustering.py b/src/conspiracies/corpusprocessing/clustering.py index 9863182..caa5b94 100644 --- a/src/conspiracies/corpusprocessing/clustering.py +++ b/src/conspiracies/corpusprocessing/clustering.py @@ -122,7 +122,6 @@ def _cluster_via_embeddings( embeddings = np.load(emb_cache) else: model = self._get_embedding_model() - print("Creating embeddings:") counter = Counter((field for field in labels)) condensed = [ diff --git a/src/conspiracies/corpusprocessing/triplet.py b/src/conspiracies/corpusprocessing/triplet.py index af809ba..8d2d9b3 100644 --- a/src/conspiracies/corpusprocessing/triplet.py +++ b/src/conspiracies/corpusprocessing/triplet.py @@ -12,6 +12,8 @@ class TripletField(BaseModel): text: str + start_char: Optional[int] + end_char: Optional[int] head: Optional[str] def clear_head_if_blacklist_match(self, blacklist: Set[str]): diff --git a/src/conspiracies/database/__init__.py b/src/conspiracies/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/conspiracies/database/engine.py b/src/conspiracies/database/engine.py new file mode 100644 index 0000000..d027632 --- /dev/null +++ b/src/conspiracies/database/engine.py @@ -0,0 +1,23 @@ +import logging +from pathlib import Path + +from sqlalchemy import create_engine, Engine +from sqlalchemy.orm import Session + +from conspiracies.database.models import Base + +logging.getLogger("sqlalchemy").setLevel("WARNING") + + +def get_engine(filepath: Path): + engine = create_engine("sqlite:///" + filepath.as_posix(), echo=True) + return engine + + +def setup_database(engine: Engine): + Base.metadata.create_all(engine) + + +def get_session(engine: Engine = None) -> Session: + session = Session(bind=engine) + return session diff --git a/src/conspiracies/database/models.py b/src/conspiracies/database/models.py new file mode 100644 index 0000000..81db72a --- /dev/null +++ b/src/conspiracies/database/models.py @@ -0,0 +1,103 @@ +from sqlalchemy import ( + Column, + Integer, + String, + ForeignKey, + Text, + DateTime, +) +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class EntityOrm(Base): + __tablename__ = "entities" + id = Column(Integer, primary_key=True, autoincrement=True) + label = Column(String, nullable=False, index=True) + supernode_id = Column(Integer, ForeignKey("entities.id"), nullable=True) + + # Relationships + supernode = relationship( + "EntityOrm", + back_populates="subnodes", + remote_side="EntityOrm.id", + foreign_keys="EntityOrm.supernode_id", + ) + subnodes = relationship("EntityOrm", back_populates="supernode") + + +class RelationOrm(Base): + __tablename__ = "relations" + id = Column(Integer, primary_key=True, autoincrement=True) + label = Column(String, nullable=False, index=True) + + +class TripletOrm(Base): + __tablename__ = "triplets" + id = Column(Integer, primary_key=True, autoincrement=True) + doc_id = Column(Integer, ForeignKey("docs.id"), nullable=False) + subject_entity_id = Column(Integer, ForeignKey("entities.id"), nullable=False) + predicate_relation_id = Column(Integer, ForeignKey("relations.id"), nullable=False) + object_entity_id = Column(Integer, ForeignKey("entities.id"), nullable=False) + subj_span_start = Column(Integer, nullable=True) + subj_span_end = Column(Integer, nullable=True) + pred_span_start = Column(Integer, nullable=True) + pred_span_end = Column(Integer, nullable=True) + obj_span_start = Column(Integer, nullable=True) + obj_span_end = Column(Integer, nullable=True) + + # Relationships + subject_entity = relationship( + "EntityOrm", + foreign_keys="TripletOrm.subject_entity_id", + ) + predicate_relation = relationship( + "RelationOrm", + foreign_keys="TripletOrm.predicate_relation_id", + ) + object_entity = relationship( + "EntityOrm", + foreign_keys="TripletOrm.object_entity_id", + ) + document = relationship("DocumentOrm", back_populates="triplets") + + # TODO: this should be here, but sometimes we see duplicates. Why? + # __table_args__ = (UniqueConstraint( + # 'doc_id', + # 'subject_entity_id', + # 'predicate_relation_id', + # 'object_entity_id', + # name='unique_triplet_constraint' + # ),) + + +class DocumentOrm(Base): + __tablename__ = "docs" + id = Column(Integer, primary_key=True, autoincrement=True) + text = Column(Text, nullable=False) + orig_text = Column(Text, nullable=True) + timestamp = Column(DateTime) + + # Relationships + triplets = relationship("TripletOrm", back_populates="document") + + +def get_or_create_entity(label, session): + """Fetch an entity by label, or create it if it doesn't exist.""" + entity = session.query(EntityOrm).filter_by(label=label).first() + if not entity: + entity = EntityOrm(label=label) + session.add(entity) + session.flush() # Get the ID immediately + return entity.id + + +def get_or_create_relation(label, session): + """Fetch a relation by label, or create it if it doesn't exist.""" + relation = session.query(RelationOrm).filter_by(label=label).first() + if not relation: + relation = RelationOrm(label=label) + session.add(relation) + session.flush() # Get the ID immediately + return relation.id diff --git a/src/conspiracies/pipeline/config.py b/src/conspiracies/pipeline/config.py index e5dfd05..9328e59 100644 --- a/src/conspiracies/pipeline/config.py +++ b/src/conspiracies/pipeline/config.py @@ -48,6 +48,10 @@ def estimate_from_n_triplets(cls, n_triplets: int): return thresholds +class DatabasePopulationConfig(StepConfig): + clear_and_write: bool = False + + class CorpusProcessingConfig(StepConfig): dimensions: int = None n_neighbors: int = 15 @@ -60,6 +64,7 @@ class PipelineConfig(BaseModel): preprocessing: PreProcessingConfig docprocessing: DocProcessingConfig corpusprocessing: CorpusProcessingConfig + databasepopulation: DatabasePopulationConfig @staticmethod def update_nested_dict(d: dict[str, Any], path: str, value: Any) -> None: diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index 2162ccc..17713de 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -1,12 +1,19 @@ import json import os +from datetime import datetime from pathlib import Path - from conspiracies.common.fileutils import iter_lines_of_files from conspiracies.corpusprocessing.aggregation import TripletAggregator -from conspiracies.corpusprocessing.clustering import Clustering +from conspiracies.corpusprocessing.clustering import Clustering, Mappings from conspiracies.corpusprocessing.triplet import Triplet +from conspiracies.database.engine import get_engine, setup_database, get_session +from conspiracies.database.models import ( + get_or_create_entity, + get_or_create_relation, + TripletOrm, + DocumentOrm, +) from conspiracies.docprocessing.docprocessor import DocProcessor from conspiracies.document import Document from conspiracies.pipeline.config import PipelineConfig, Thresholds @@ -43,6 +50,9 @@ def run(self): if self.config.corpusprocessing.enabled: self.corpusprocessing() + if self.config.databasepopulation.enabled: + self.databasepopulation() + def _get_preprocessor(self) -> Preprocessor: config = self.config.preprocessing doc_type = config.doc_type.lower() @@ -151,3 +161,61 @@ def corpusprocessing(self): edges, save=self.output_path / "graph.png", ) + + def databasepopulation(self): + if self.config.databasepopulation.clear_and_write: + if os.path.exists(self.output_path / "database.db"): + print("Removing old database.") + os.remove(self.output_path / "database.db") + + print("Populating database.") + engine = get_engine(self.output_path / "database.db") + setup_database(engine) + session = get_session(engine) + + with open(self.output_path / "mappings.json") as mappings_file: + mappings = Mappings(**json.load(mappings_file)) + + with open(self.output_path / "triplets.ndjson") as triplets_file: + for line in triplets_file: + triplet = Triplet(**json.loads(line)) + subject_id = get_or_create_entity( + mappings.map_entity(triplet.subject.text), + session, + ) + relation_id = get_or_create_relation( + mappings.map_predicate(triplet.predicate.text), + session, + ) + object_id = get_or_create_entity( + mappings.map_entity(triplet.object.text), + session, + ) + + triplet_orm = TripletOrm( + doc_id=int(triplet.doc), + subject_entity_id=subject_id, + predicate_relation_id=relation_id, + object_entity_id=object_id, + subj_span_start=triplet.subject.start_char, + subj_span_end=triplet.subject.end_char, + pred_span_start=triplet.predicate.start_char, + pred_span_end=triplet.predicate.end_char, + obj_span_start=triplet.object.start_char, + obj_span_end=triplet.object.end_char, + ) + session.add(triplet_orm) + + session.commit() + + for doc in ( + json.loads(line) + for line in iter_lines_of_files(self.output_path / "annotations.ndjson") + ): + doc_orm = DocumentOrm( + id=doc["id"], + text=doc["text"], + timestamp=datetime.fromisoformat(doc["timestamp"]), + ) + session.add(doc_orm) + session.commit() From 5c44abf7c610b1dbd968b9ec99501acae731438c Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Mon, 9 Dec 2024 19:21:21 +0100 Subject: [PATCH 28/29] Fixing failing test after last commit --- tests/test_data/test_config.toml | 4 +++- tests/test_pipelineconfig.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_data/test_config.toml b/tests/test_data/test_config.toml index 133c7ff..7e9a244 100644 --- a/tests/test_data/test_config.toml +++ b/tests/test_data/test_config.toml @@ -14,4 +14,6 @@ some_extra_value = 1234 [docprocessing] triplet_extraction_method = "test" -[corpusprocessing] \ No newline at end of file +[corpusprocessing] + +[databasepopulation] \ No newline at end of file diff --git a/tests/test_pipelineconfig.py b/tests/test_pipelineconfig.py index d79f4b7..c6d2a38 100644 --- a/tests/test_pipelineconfig.py +++ b/tests/test_pipelineconfig.py @@ -8,6 +8,7 @@ PreProcessingConfig, DocProcessingConfig, CorpusProcessingConfig, + DatabasePopulationConfig, ) @@ -37,6 +38,7 @@ def test_config_loading(path: str): triplet_extraction_method="test", ), corpusprocessing=CorpusProcessingConfig(), + databasepopulation=DatabasePopulationConfig(), ) From 3156970dbd859136415bcabaa4aa3c97abc5934d Mon Sep 17 00:00:00 2001 From: Kasper Fyhn Date: Mon, 9 Dec 2024 20:00:39 +0100 Subject: [PATCH 29/29] Changing DB models slightly and optimizing DB population --- src/conspiracies/database/engine.py | 5 +- src/conspiracies/database/models.py | 87 +++++++++++++++++++-------- src/conspiracies/pipeline/pipeline.py | 46 +++++++++----- 3 files changed, 93 insertions(+), 45 deletions(-) diff --git a/src/conspiracies/database/engine.py b/src/conspiracies/database/engine.py index d027632..288bb8c 100644 --- a/src/conspiracies/database/engine.py +++ b/src/conspiracies/database/engine.py @@ -1,4 +1,3 @@ -import logging from pathlib import Path from sqlalchemy import create_engine, Engine @@ -6,11 +5,9 @@ from conspiracies.database.models import Base -logging.getLogger("sqlalchemy").setLevel("WARNING") - def get_engine(filepath: Path): - engine = create_engine("sqlite:///" + filepath.as_posix(), echo=True) + engine = create_engine("sqlite:///" + filepath.as_posix()) return engine diff --git a/src/conspiracies/database/models.py b/src/conspiracies/database/models.py index 81db72a..cf6b5d1 100644 --- a/src/conspiracies/database/models.py +++ b/src/conspiracies/database/models.py @@ -6,7 +6,7 @@ Text, DateTime, ) -from sqlalchemy.orm import declarative_base, relationship +from sqlalchemy.orm import declarative_base, relationship, Session Base = declarative_base() @@ -31,15 +31,26 @@ class RelationOrm(Base): __tablename__ = "relations" id = Column(Integer, primary_key=True, autoincrement=True) label = Column(String, nullable=False, index=True) + subject_id = Column(Integer, ForeignKey("entities.id"), nullable=True) + object_id = Column(Integer, ForeignKey("entities.id"), nullable=True) + + subject = relationship( + "EntityOrm", + foreign_keys="RelationOrm.subject_id", + ) + object = relationship( + "EntityOrm", + foreign_keys="RelationOrm.object_id", + ) class TripletOrm(Base): __tablename__ = "triplets" id = Column(Integer, primary_key=True, autoincrement=True) doc_id = Column(Integer, ForeignKey("docs.id"), nullable=False) - subject_entity_id = Column(Integer, ForeignKey("entities.id"), nullable=False) - predicate_relation_id = Column(Integer, ForeignKey("relations.id"), nullable=False) - object_entity_id = Column(Integer, ForeignKey("entities.id"), nullable=False) + subject_id = Column(Integer, ForeignKey("entities.id"), nullable=False) + relation_id = Column(Integer, ForeignKey("relations.id"), nullable=False) + object_id = Column(Integer, ForeignKey("entities.id"), nullable=False) subj_span_start = Column(Integer, nullable=True) subj_span_end = Column(Integer, nullable=True) pred_span_start = Column(Integer, nullable=True) @@ -50,17 +61,21 @@ class TripletOrm(Base): # Relationships subject_entity = relationship( "EntityOrm", - foreign_keys="TripletOrm.subject_entity_id", + foreign_keys="TripletOrm.subject_id", ) predicate_relation = relationship( "RelationOrm", - foreign_keys="TripletOrm.predicate_relation_id", + foreign_keys="TripletOrm.relation_id", ) object_entity = relationship( "EntityOrm", - foreign_keys="TripletOrm.object_entity_id", + foreign_keys="TripletOrm.object_id", + ) + document = relationship( + "DocumentOrm", + foreign_keys="TripletOrm.doc_id", + back_populates="triplets", ) - document = relationship("DocumentOrm", back_populates="triplets") # TODO: this should be here, but sometimes we see duplicates. Why? # __table_args__ = (UniqueConstraint( @@ -83,21 +98,41 @@ class DocumentOrm(Base): triplets = relationship("TripletOrm", back_populates="document") -def get_or_create_entity(label, session): - """Fetch an entity by label, or create it if it doesn't exist.""" - entity = session.query(EntityOrm).filter_by(label=label).first() - if not entity: - entity = EntityOrm(label=label) - session.add(entity) - session.flush() # Get the ID immediately - return entity.id - - -def get_or_create_relation(label, session): - """Fetch a relation by label, or create it if it doesn't exist.""" - relation = session.query(RelationOrm).filter_by(label=label).first() - if not relation: - relation = RelationOrm(label=label) - session.add(relation) - session.flush() # Get the ID immediately - return relation.id +class ModelLookupCache: + + def __init__(self, session: Session): + self._entities = {e.label: e for e in session.query(EntityOrm).all()} + self._relations = { + (int(r.subject_id), str(r.label), int(r.object_id)): r # noqa + for r in session.query(RelationOrm).all() + } + + def get_or_create_entity(self, label, session): + """Fetch an entity by label, or create it if it doesn't exist.""" + entity = self._entities.get(label, None) + if entity is None: + entity = EntityOrm(label=label) + session.add(entity) + session.flush() # Get the ID immediately + self._entities[label] = entity # noqa + return entity.id + + def get_or_create_relation( + self, + subject_id: int, + object_id: int, + label: str, + session: Session, + ): + """Fetch a relation by label, or create it if it doesn't exist.""" + relation = self._relations.get((subject_id, label, object_id), None) + if relation is None: + relation = RelationOrm( + label=label, + subject_id=subject_id, + object_id=object_id, + ) + session.add(relation) + session.flush() # Get the ID immediately + self._relations[(subject_id, label, object_id)] = relation # noqa + return relation.id diff --git a/src/conspiracies/pipeline/pipeline.py b/src/conspiracies/pipeline/pipeline.py index 17713de..1774605 100644 --- a/src/conspiracies/pipeline/pipeline.py +++ b/src/conspiracies/pipeline/pipeline.py @@ -3,16 +3,17 @@ from datetime import datetime from pathlib import Path +from tqdm import tqdm + from conspiracies.common.fileutils import iter_lines_of_files from conspiracies.corpusprocessing.aggregation import TripletAggregator from conspiracies.corpusprocessing.clustering import Clustering, Mappings from conspiracies.corpusprocessing.triplet import Triplet from conspiracies.database.engine import get_engine, setup_database, get_session from conspiracies.database.models import ( - get_or_create_entity, - get_or_create_relation, TripletOrm, DocumentOrm, + ModelLookupCache, ) from conspiracies.docprocessing.docprocessor import DocProcessor from conspiracies.document import Document @@ -177,26 +178,30 @@ def databasepopulation(self): mappings = Mappings(**json.load(mappings_file)) with open(self.output_path / "triplets.ndjson") as triplets_file: - for line in triplets_file: + cache = ModelLookupCache(session) + bulk = [] + for line in tqdm(triplets_file, desc="Writing triplets to database"): triplet = Triplet(**json.loads(line)) - subject_id = get_or_create_entity( + subject_id = cache.get_or_create_entity( mappings.map_entity(triplet.subject.text), session, ) - relation_id = get_or_create_relation( - mappings.map_predicate(triplet.predicate.text), + object_id = cache.get_or_create_entity( + mappings.map_entity(triplet.object.text), session, ) - object_id = get_or_create_entity( - mappings.map_entity(triplet.object.text), + relation_id = cache.get_or_create_relation( + subject_id, + object_id, + mappings.map_predicate(triplet.predicate.text), session, ) triplet_orm = TripletOrm( doc_id=int(triplet.doc), - subject_entity_id=subject_id, - predicate_relation_id=relation_id, - object_entity_id=object_id, + subject_id=subject_id, + relation_id=relation_id, + object_id=object_id, subj_span_start=triplet.subject.start_char, subj_span_end=triplet.subject.end_char, pred_span_start=triplet.predicate.start_char, @@ -204,18 +209,29 @@ def databasepopulation(self): obj_span_start=triplet.object.start_char, obj_span_end=triplet.object.end_char, ) - session.add(triplet_orm) - + bulk.append(triplet_orm) + if len(bulk) >= 500: + session.bulk_save_objects(bulk) + bulk.clear() + session.bulk_save_objects(bulk) + bulk.clear() session.commit() for doc in ( json.loads(line) - for line in iter_lines_of_files(self.output_path / "annotations.ndjson") + for line in tqdm( + iter_lines_of_files(self.output_path / "annotations.ndjson"), + desc="Writing documents to database", + ) ): doc_orm = DocumentOrm( id=doc["id"], text=doc["text"], timestamp=datetime.fromisoformat(doc["timestamp"]), ) - session.add(doc_orm) + bulk.append(doc_orm) + if len(bulk) >= 500: + session.bulk_save_objects(bulk) + bulk.clear() + session.bulk_save_objects(bulk) session.commit()