From bd77b6d67c0952fda44d61e9f3a942108002365c Mon Sep 17 00:00:00 2001 From: Rohan Juneja Date: Wed, 27 Mar 2024 17:33:22 -0700 Subject: [PATCH] implementation of creative limit for pathfinder --- src/index.ts | 21 ++++++++++++++------- src/inferred_mode/inferred_mode.ts | 23 +++++++++++++++++++---- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/index.ts b/src/index.ts index 7de0eb03..33e1da7f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -557,9 +557,6 @@ export default class TRAPIQueryHandler { return; } - // test - console.log("recognized pathfinder"); - // run creative mode await this._handleInferredEdges(true); const creativeResponse = this.getResponse(); @@ -569,6 +566,14 @@ export default class TRAPIQueryHandler { // restore query graph this.queryGraph.nodes[unpinnedNodeId] = unpinnedNode; intermediateEdges.forEach(([edgeId, edge]) => this.queryGraph.edges[edgeId] = edge); + creativeResponse.message.query_graph = this.queryGraph; + + // if no results then we are done + if (creativeResponse.message.results.length === 0) { + this.getResponse = () => creativeResponse; + return; + } + // set up a graph structure const kgEdge = creativeResponse.message.results[0].analyses[0].edge_bindings[mainEdgeID][0].id; @@ -694,18 +699,20 @@ export default class TRAPIQueryHandler { } } - creativeResponse.message.results = Object.values(newResultObject); - + creativeResponse.message.results = Object.values(newResultObject).sort((a, b) => (b.analyses[0].score ?? 0) - (a.analyses[0].score ?? 0)).slice(0, process.env.CREATIVE_LIMIT ? parseInt(process.env.CREATIVE_LIMIT) : 500); + creativeResponse.description = `Query processed successfully, retrieved ${creativeResponse.message.results.length} results.` + const finalNewAuxGraphs: {[id: string]: {edges: string[]}} = newAuxGraphs as any; for (const auxGraph in finalNewAuxGraphs) { finalNewAuxGraphs[auxGraph].edges = Array.from(finalNewAuxGraphs[auxGraph].edges); } Object.assign(creativeResponse.message.auxiliary_graphs, finalNewAuxGraphs); - // TODO: combine scoring information - // TODO: Fix 500 cap impl + // TODO: Add logs/debug statements in this function // TODO: formatting // TODO: move to a seperate file if this gets too big? + // TODO: test other templates + // TODO: make unit tests this.getResponse = () => creativeResponse; } diff --git a/src/inferred_mode/inferred_mode.ts b/src/inferred_mode/inferred_mode.ts index 919ae2ed..8911990f 100644 --- a/src/inferred_mode/inferred_mode.ts +++ b/src/inferred_mode/inferred_mode.ts @@ -292,6 +292,9 @@ export default class InferredQueryHandler { } }); + // modified count used for pathfinder + const pfIntermediateSet = new Set(); + // add results newResponse.message.results.forEach((result) => { const translatedResult: TrapiResult = { @@ -309,6 +312,17 @@ export default class InferredQueryHandler { ], }; + if (this.pathfinder) { + for (let [nodeID, bindings] of Object.entries(result.node_bindings)) { + if (nodeID === "creativeQuerySubject" || nodeID === "creativeQueryObject") { + continue; + } + for (const binding of bindings) { + pfIntermediateSet.add(binding.id); + } + } + } + const resultCreativeSubjectID = translatedResult.node_bindings[qEdge.subject] .map((binding) => binding.id) .join(','); @@ -436,8 +450,9 @@ export default class InferredQueryHandler { } report.querySuccess = 1; - if (Object.keys(combinedResponse.message.results).length >= this.CREATIVE_LIMIT && !report.creativeLimitHit) { - report.creativeLimitHit = Object.keys(newResponse.message.results).length; + const resSize = this.pathfinder ? pfIntermediateSet.size : Object.keys(combinedResponse.message.results).length; + if (resSize >= this.CREATIVE_LIMIT && !report.creativeLimitHit) { + report.creativeLimitHit = resSize; } span.finish(); return report; @@ -568,9 +583,9 @@ export default class InferredQueryHandler { stop = true; const message = [ `Addition of ${creativeLimitHit} results from Template ${i + 1}`, - Object.keys(combinedResponse.message.results).length === this.CREATIVE_LIMIT ? ' meets ' : ' exceeds ', + creativeLimitHit === this.CREATIVE_LIMIT ? ' meets ' : ' exceeds ', `creative result maximum of ${this.CREATIVE_LIMIT} (reaching ${ - Object.keys(combinedResponse.message.results).length + creativeLimitHit } merged). `, `Response will be truncated to top-scoring ${this.CREATIVE_LIMIT} results. Skipping remaining ${ subQueries.length - (i + 1)