Skip to content

Commit

Permalink
[ALS-7118] - Fix concept query when search eliminates results
Browse files Browse the repository at this point in the history
  • Loading branch information
Luke Sikina committed Aug 20, 2024
1 parent 9d10bfa commit c8bba41
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package edu.harvard.dbmi.avillach.dictionary.filter;
package edu.harvard.dbmi.avillach.dictionary.concept;

import edu.harvard.dbmi.avillach.dictionary.facet.Facet;
import edu.harvard.dbmi.avillach.dictionary.filter.Filter;
import edu.harvard.dbmi.avillach.dictionary.filter.QueryParamPair;
import org.springframework.data.domain.Pageable;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.stereotype.Component;
Expand All @@ -11,7 +13,7 @@
import java.util.stream.Collectors;

@Component
public class FilterQueryGenerator {
public class ConceptFilterQueryGenerator {

/**
* This generates a query that will return a list of concept_node IDs for the given filter.
Expand All @@ -29,27 +31,25 @@ public QueryParamPair generateFilterQuery(Filter filter, Pageable pageable) {
MapSqlParameterSource params = new MapSqlParameterSource();
List<String> clauses = new java.util.ArrayList<>(List.of());
if (!CollectionUtils.isEmpty(filter.facets())) {
clauses.addAll(createFacetFilter(filter.facets(), params));
clauses.addAll(createFacetFilter(filter.facets(), params, filter.search()));
}
clauses.add(createValuelessNodeFilter());
if (StringUtils.hasLength(filter.search())) {
params.addValue("search", filter.search().trim());
}
clauses.add(createValuelessNodeFilter(filter.search()));


String query = "(\n" + String.join("\n\tINTERSECT\n", clauses) + "\n)";
String havingClause = "";
if (StringUtils.hasText(filter.search())) {
String searchQuery = createSearchFilter(filter.search(), params);
query = "(" + query + "\n\tUNION \n\t" + searchQuery + ")";
havingClause = "HAVING max(rank) > 0\n";
}
String superQuery = """
WITH q AS (
%s
)
SELECT concept_node_id
FROM q
GROUP BY concept_node_id %s
GROUP BY concept_node_id
ORDER BY max(rank) DESC
""".formatted(query, havingClause);
""".formatted(query);

if (pageable.isPaged()) {
superQuery = superQuery + """
LIMIT :limit
Expand All @@ -63,47 +63,34 @@ ORDER BY max(rank) DESC
return new QueryParamPair(superQuery, params);
}

private String createValuelessNodeFilter() {
private String createValuelessNodeFilter(String search) {
String rankQuery = "0 as rank";
String rankWhere = "";
if (StringUtils.hasLength(search)) {
rankQuery = "ts_rank(searchable_fields, (phraseto_tsquery(:search)::text || ':*')::tsquery) as rank";
rankWhere = "concept_node.searchable_fields @@ (phraseto_tsquery(:search)::text || ':*')::tsquery AND";
}
// concept nodes that have no values and no min/max should not get returned by search
return """
SELECT
concept_node.concept_node_id, 0 as rank
concept_node.concept_node_id,
%s
FROM
concept_node
LEFT JOIN concept_node_meta AS continuous_min ON concept_node.concept_node_id = continuous_min.concept_node_id AND continuous_min.KEY = 'min'
LEFT JOIN concept_node_meta AS continuous_max ON concept_node.concept_node_id = continuous_max.concept_node_id AND continuous_max.KEY = 'max'
LEFT JOIN concept_node_meta AS categorical_values ON concept_node.concept_node_id = categorical_values.concept_node_id AND categorical_values.KEY = 'values'
WHERE
continuous_min.value <> '' OR
continuous_max.value <> '' OR
categorical_values.value <> ''
""";
}

private String createSearchFilter(String search, MapSqlParameterSource params) {
params.addValue("search", search);
return """
(
SELECT
concept_node.concept_node_id AS concept_node_id,
ts_rank(searchable_fields, (phraseto_tsquery(:search)::text || ':*')::tsquery) as rank
FROM
concept_node
LEFT JOIN concept_node_meta AS continuous_min ON concept_node.concept_node_id = continuous_min.concept_node_id AND continuous_min.KEY = 'min'
LEFT JOIN concept_node_meta AS continuous_max ON concept_node.concept_node_id = continuous_max.concept_node_id AND continuous_max.KEY = 'max'
LEFT JOIN concept_node_meta AS categorical_values ON concept_node.concept_node_id = categorical_values.concept_node_id AND categorical_values.KEY = 'values'
WHERE
concept_node.searchable_fields @@ (phraseto_tsquery(:search)::text || ':*')::tsquery AND
(
continuous_min.value <> '' OR
continuous_max.value <> '' OR
categorical_values.value <> ''
)
)
""";
%s
(
continuous_min.value <> '' OR
continuous_max.value <> '' OR
categorical_values.value <> ''
)
""".formatted(rankQuery, rankWhere);
}

private List<String> createFacetFilter(List<Facet> facets, MapSqlParameterSource params) {
private List<String> createFacetFilter(List<Facet> facets, MapSqlParameterSource params, String search) {
return facets.stream()
.collect(Collectors.groupingBy(Facet::category))
.entrySet().stream()
Expand All @@ -112,17 +99,28 @@ private List<String> createFacetFilter(List<Facet> facets, MapSqlParameterSource
// The templating here is to namespace the params for each facet category in the query
.addValue("facets_for_category_%s".formatted(facetsForCategory.getKey()), facetsForCategory.getValue().stream().map(Facet::name).toList())
.addValue("category_%s".formatted(facetsForCategory.getKey()), facetsForCategory.getKey());
String rankQuery = "0";
String rankWhere = "";
if (StringUtils.hasLength(search)) {
rankQuery = "ts_rank(searchable_fields, (phraseto_tsquery(:search)::text || ':*')::tsquery)";
rankWhere = "concept_node.searchable_fields @@ (phraseto_tsquery(:search)::text || ':*')::tsquery AND";
}
return """
(
SELECT
facet__concept_node.concept_node_id AS concept_node_id , 0 as rank
facet__concept_node.concept_node_id AS concept_node_id,
max(%s) as rank
FROM facet
LEFT JOIN facet__concept_node ON facet__concept_node.facet_id = facet.facet_id
LEFT JOIN facet_category ON facet_category.facet_category_id = facet.facet_category_id
JOIN facet_category ON facet_category.facet_category_id = facet.facet_category_id
JOIN concept_node ON concept_node.concept_node_id = facet__concept_node.concept_node_id
WHERE
%s
facet.name IN (:facets_for_category_%s ) AND facet_category.name = :category_%s
GROUP BY
facet__concept_node.concept_node_id
)
""".formatted(facetsForCategory.getKey(), facetsForCategory.getKey());
""".formatted(rankQuery, rankWhere, facetsForCategory.getKey(), facetsForCategory.getKey());
})
.toList();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import edu.harvard.dbmi.avillach.dictionary.concept.model.Concept;
import edu.harvard.dbmi.avillach.dictionary.filter.Filter;
import edu.harvard.dbmi.avillach.dictionary.filter.FilterQueryGenerator;
import edu.harvard.dbmi.avillach.dictionary.filter.QueryParamPair;
import edu.harvard.dbmi.avillach.dictionary.util.MapExtractor;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -22,11 +21,11 @@ public class ConceptRepository {

private final ConceptRowMapper mapper;

private final FilterQueryGenerator filterGen;
private final ConceptFilterQueryGenerator filterGen;

@Autowired
public ConceptRepository(
NamedParameterJdbcTemplate template, ConceptRowMapper mapper, FilterQueryGenerator filterGen
NamedParameterJdbcTemplate template, ConceptRowMapper mapper, ConceptFilterQueryGenerator filterGen
) {
this.template = template;
this.mapper = mapper;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package edu.harvard.dbmi.avillach.dictionary.filter;

import edu.harvard.dbmi.avillach.dictionary.concept.ConceptFilterQueryGenerator;
import edu.harvard.dbmi.avillach.dictionary.facet.Facet;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.data.domain.Pageable;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.springframework.test.context.DynamicPropertySource;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.MountableFile;

import java.util.List;

@Testcontainers
@SpringBootTest
class ConceptFilterQueryGeneratorTest {

@Container
static final PostgreSQLContainer<?> databaseContainer =
new PostgreSQLContainer<>("postgres:16")
.withReuse(true)
.withCopyFileToContainer(
MountableFile.forClasspathResource("seed.sql"), "/docker-entrypoint-initdb.d/seed.sql"
);

@DynamicPropertySource
static void mySQLProperties(DynamicPropertyRegistry registry) {
registry.add("spring.datasource.url", databaseContainer::getJdbcUrl);
registry.add("spring.datasource.username", databaseContainer::getUsername);
registry.add("spring.datasource.password", databaseContainer::getPassword);
registry.add("spring.datasource.db", databaseContainer::getDatabaseName);
}

@Autowired
ConceptFilterQueryGenerator subject;

@Autowired
NamedParameterJdbcTemplate template;

@Test
void shouldGenerateForFacetAndSearchNoMatch() {
Filter f = new Filter(List.of(new Facet("phs000007", "FHS", "", null, null, "study_ids_dataset_ids", null)), "smoke");
QueryParamPair pair = subject.generateFilterQuery(f, Pageable.unpaged());

List<Integer> actual = template.queryForList(pair.query(), pair.params(), Integer.class);
List<Integer> expected = List.of();

Assertions.assertEquals(expected, actual);
}

@Test
void shouldGenerateForFHSFacet() {
Filter f = new Filter(List.of(new Facet("phs000007", "FHS", "", null, null, "study_ids_dataset_ids", null)), "");
QueryParamPair pair = subject.generateFilterQuery(f, Pageable.unpaged());

List<Integer> actual = template.queryForList(pair.query(), pair.params(), Integer.class);
List<Integer> expected = List.of(229, 232, 235);

Assertions.assertEquals(expected, actual);
}

@Test
void shouldGenerateForFacetAndSearchMatch() {
Filter f = new Filter(List.of(new Facet("phs002715", "NSRR", "", null, null, "study_ids_dataset_ids", null)), "smoke");
QueryParamPair pair = subject.generateFilterQuery(f, Pageable.unpaged());

List<Integer> actual = template.queryForList(pair.query(), pair.params(), Integer.class);
List<Integer> expected = List.of(249);

Assertions.assertEquals(expected, actual);
}

@Test
void shouldGenerateForNSRRFacet() {
Filter f = new Filter(List.of(new Facet("phs002715", "NSRR", "", null, null, "study_ids_dataset_ids", null)), "");
QueryParamPair pair = subject.generateFilterQuery(f, Pageable.unpaged());

List<Integer> actual = template.queryForList(pair.query(), pair.params(), Integer.class);
List<Integer> expected = List.of(248, 249);

Assertions.assertEquals(expected, actual);
}
}

This file was deleted.

0 comments on commit c8bba41

Please sign in to comment.