Skip to content

Commit

Permalink
Merge pull request #332 from HSLdevcom/stop_query_filters
Browse files Browse the repository at this point in the history
Filtering stop queries by feed ID
  • Loading branch information
optionsome authored Apr 1, 2020
2 parents f81dc82 + cbbda17 commit 4325c2e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 14 deletions.
50 changes: 40 additions & 10 deletions src/main/java/org/opentripplanner/common/LuceneIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.Version;
import org.opentripplanner.model.FeedScopedId;
import org.opentripplanner.model.Stop;
import org.opentripplanner.gtfs.GtfsLibrary;
import org.opentripplanner.profile.StopCluster;
Expand All @@ -32,6 +33,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -77,11 +79,11 @@ private void index() {
//directory = new RAMDirectory(); // only a little faster
IndexWriterConfig config = new IndexWriterConfig(Version.LUCENE_47, analyzer).setOpenMode(OpenMode.CREATE);
final IndexWriter writer = new IndexWriter(directory, config);
for (Stop station : graphIndex.stationForId.values()) {
addStation(writer, station);
for (Map.Entry<FeedScopedId, Stop> station : graphIndex.stationForId.entrySet()) {
addStation(writer, station.getKey().getAgencyId(), station.getValue());
}
for (Stop stop : graphIndex.stopForId.values()) {
addStop(writer, stop);
for (Map.Entry<FeedScopedId, Stop> stop : graphIndex.stopForId.entrySet()) {
addStop(writer, stop.getKey().getAgencyId(), stop.getValue());
}
graphIndex.clusterStopsAsNeeded();
for (StopCluster stopCluster : graphIndex.stopClusterForId.values()) {
Expand All @@ -100,19 +102,21 @@ private void index() {
}
}

private void addStation(IndexWriter iwriter, Stop station) throws IOException {
private void addStation(IndexWriter iwriter, String feedId, Stop station) throws IOException {
Document doc = new Document();
doc.add(new TextField("name", station.getName(), Field.Store.YES));
doc.add(new StringField("feed", feedId, Field.Store.YES));
doc.add(new DoubleField("lat", station.getLat(), Field.Store.YES));
doc.add(new DoubleField("lon", station.getLon(), Field.Store.YES));
doc.add(new StringField("id", GtfsLibrary.convertIdToString(station.getId()), Field.Store.YES));
doc.add(new StringField("category", Category.STATION.name(), Field.Store.YES));
iwriter.addDocument(doc);
}

private void addStop(IndexWriter iwriter, Stop stop) throws IOException {
private void addStop(IndexWriter iwriter, String feedId, Stop stop) throws IOException {
Document doc = new Document();
doc.add(new TextField("name", stop.getName(), Field.Store.YES));
doc.add(new StringField("feed", feedId, Field.Store.YES));
if (stop.getCode() != null) {
doc.add(new StringField("code", stop.getCode(), Field.Store.YES));
}
Expand Down Expand Up @@ -167,11 +171,13 @@ public void run() {
* @param stations Search for stations by their name
* @param clusters Search for clusters by their name
* @param corners Search for street corners using at least one of the street names
* @param maxResults Maximum amount of results to return
* @param feeds Return results only from specified feeds or from all feeds if null
* @return list of results in in the format expected by GeocoderBuiltin.js in the OTP Leaflet client
*/
public List<LuceneResult> query(String queryString, boolean autocomplete,
boolean stops, boolean stations,
boolean clusters, boolean corners) {
boolean clusters, boolean corners, int maxResults, List<String> feeds) {
/* Turn the query string into a Lucene query.*/
BooleanQuery query = new BooleanQuery();
BooleanQuery termQuery = new BooleanQuery();
Expand Down Expand Up @@ -226,9 +232,18 @@ public List<LuceneResult> query(String queryString, boolean autocomplete,
}
query.add(typeQuery, BooleanClause.Occur.MUST);
}

if (feeds != null) {
BooleanQuery feedQuery = new BooleanQuery();
for (String feedId : feeds) {
feedQuery.add(new TermQuery(new Term("feed", feedId)), BooleanClause.Occur.SHOULD);
}
query.add(feedQuery, BooleanClause.Occur.MUST);
}

List<LuceneResult> result = Lists.newArrayList();
try {
TopScoreDocCollector collector = TopScoreDocCollector.create(10, true);
TopScoreDocCollector collector = TopScoreDocCollector.create(maxResults, true);
searcher.search(query, collector);
ScoreDoc[] docs = collector.topDocs().scoreDocs;
for (int i = 0; i < docs.length; i++) {
Expand Down Expand Up @@ -259,6 +274,21 @@ public List<LuceneResult> query(String queryString, boolean autocomplete,
}
}

/** Fetch results for the geocoder using the OTP graph for stops, clusters and street names
*
* @param queryString
* @param autocomplete Whether we should use the query string to do a prefix match
* @param stops Search for stops, either by name or stop code
* @param stations Search for stations by their name
* @param clusters Search for clusters by their name
* @param corners Search for street corners using at least one of the street names
* @return list of results in in the format expected by GeocoderBuiltin.js in the OTP Leaflet client
*/
public List<LuceneResult> query (String queryString, boolean autocomplete, boolean stations,
boolean stops, boolean clusters, boolean corners) {
return query(queryString, autocomplete, stops, stations, clusters, corners, 10, null);
}

/** Fetch results for the geocoder using the OTP graph for stops, clusters and street names
*
* @param queryString
Expand All @@ -268,11 +298,11 @@ public List<LuceneResult> query(String queryString, boolean autocomplete,
* @param corners Search for street corners using at least one of the street names
* @return list of results in in the format expected by GeocoderBuiltin.js in the OTP Leaflet client
*
* @deprecated Use {@link #query(String, boolean, boolean, boolean, boolean, boolean)} instead
* @deprecated Use {@link #query(String, boolean, boolean, boolean, boolean, boolean, int, List)} instead
*/
public List<LuceneResult> query (String queryString, boolean autocomplete,
boolean stops, boolean clusters, boolean corners) {
return query(queryString, autocomplete, stops, false, clusters, corners);
return query(queryString, autocomplete, stops, false, clusters, corners, 10, null);
}

/** This class matches the structure of the Geocoder responses expected by the OTP client. */
Expand Down
44 changes: 40 additions & 4 deletions src/main/java/org/opentripplanner/index/IndexGraphQLSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -2708,11 +2708,21 @@ private Object getObject(String idString) {
.description("Return stops with these ids")
.type(new GraphQLList(Scalars.GraphQLString))
.build())
.argument(GraphQLArgument.newArgument()
.name("feeds")
.description("List of feeds from which stops are returned. Defaults to all feeds")
.type(GraphQLList.list(GraphQLNonNull.nonNull(Scalars.GraphQLString)))
.build())
.argument(GraphQLArgument.newArgument()
.name("name")
.description("Query stops by this name")
.type(Scalars.GraphQLString)
.build())
.argument(GraphQLArgument.newArgument()
.name("maxResults")
.description("Number of results to return when using `name` argument. Defaults to 10")
.type(Scalars.GraphQLInt)
.build())
.dataFetcher(environment -> {
if ((environment.getArgument("ids") instanceof List)) {
if (environment.getArguments().entrySet()
Expand All @@ -2729,9 +2739,17 @@ private Object getObject(String idString) {
}
Stream<Stop> stream;
if (environment.getArgument("name") == null) {
stream = index.stopForId.values().stream();
if (environment.getArgument("feeds") == null) {
stream = index.stopForId.values().stream();
} else {
List<String> feedIds = environment.getArgument("feeds");
stream = index.stopForId.values().stream().filter(stop -> feedIds.contains(stop.getId().getAgencyId()));
}
} else {
stream = index.getLuceneIndex().query(environment.getArgument("name"), true, true, false, false)
int maxResults = environment.getArgument("maxResults") != null ? environment.getArgument("maxResults") : 10;
List<String> feeds = environment.getArgument("feeds");

stream = index.getLuceneIndex().query(environment.getArgument("name"), true, true, false, false, false, maxResults, feeds)
.stream()
.map(result -> index.stopForId.get(FeedScopedId.convertFromString(result.id)));
}
Expand Down Expand Up @@ -2964,11 +2982,21 @@ private Object getObject(String idString) {
.description("Only return stations that match one of the ids in this list")
.type(new GraphQLList(Scalars.GraphQLString))
.build())
.argument(GraphQLArgument.newArgument()
.name("feeds")
.description("List of feeds from which stations are returned. Defaults to all feeds")
.type(GraphQLList.list(GraphQLNonNull.nonNull(Scalars.GraphQLString)))
.build())
.argument(GraphQLArgument.newArgument()
.name("name")
.description("Query stations by name")
.type(Scalars.GraphQLString)
.build())
.argument(GraphQLArgument.newArgument()
.name("maxResults")
.description("Number of results to return when using `name` argument. Defaults to 10")
.type(Scalars.GraphQLInt)
.build())
.dataFetcher(environment -> {
if ((environment.getArgument("ids") instanceof List)) {
if (environment.getArguments().entrySet()
Expand All @@ -2986,9 +3014,17 @@ private Object getObject(String idString) {

Stream<Stop> stream;
if (environment.getArgument("name") == null) {
stream = index.stationForId.values().stream();
if (environment.getArgument("feeds") == null) {
stream = index.stationForId.values().stream();
} else {
List<String> feedIds = environment.getArgument("feeds");
stream = index.stationForId.values().stream().filter(station -> feedIds.contains(station.getId().getAgencyId()));
}
} else {
stream = index.getLuceneIndex().query(environment.getArgument("name"), true, false, true, false, false)
int maxResults = environment.getArgument("maxResults") != null ? environment.getArgument("maxResults") : 10;
List<String> feeds = environment.getArgument("feeds");

stream = index.getLuceneIndex().query(environment.getArgument("name"), true, false, true, false, false, maxResults, feeds)
.stream()
.map(result -> index.stationForId.get(FeedScopedId.convertFromString(result.id)));
}
Expand Down

0 comments on commit 4325c2e

Please sign in to comment.