diff --git a/src/main/java/org/opentripplanner/common/LuceneIndex.java b/src/main/java/org/opentripplanner/common/LuceneIndex.java index 8ecefac5a11..def636a198f 100644 --- a/src/main/java/org/opentripplanner/common/LuceneIndex.java +++ b/src/main/java/org/opentripplanner/common/LuceneIndex.java @@ -19,6 +19,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.util.Version; +import org.opentripplanner.model.FeedScopedId; import org.opentripplanner.model.Stop; import org.opentripplanner.gtfs.GtfsLibrary; import org.opentripplanner.profile.StopCluster; @@ -32,6 +33,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -77,11 +79,11 @@ private void index() { //directory = new RAMDirectory(); // only a little faster IndexWriterConfig config = new IndexWriterConfig(Version.LUCENE_47, analyzer).setOpenMode(OpenMode.CREATE); final IndexWriter writer = new IndexWriter(directory, config); - for (Stop station : graphIndex.stationForId.values()) { - addStation(writer, station); + for (Map.Entry station : graphIndex.stationForId.entrySet()) { + addStation(writer, station.getKey().getAgencyId(), station.getValue()); } - for (Stop stop : graphIndex.stopForId.values()) { - addStop(writer, stop); + for (Map.Entry stop : graphIndex.stopForId.entrySet()) { + addStop(writer, stop.getKey().getAgencyId(), stop.getValue()); } graphIndex.clusterStopsAsNeeded(); for (StopCluster stopCluster : graphIndex.stopClusterForId.values()) { @@ -100,9 +102,10 @@ private void index() { } } - private void addStation(IndexWriter iwriter, Stop station) throws IOException { + private void addStation(IndexWriter iwriter, String feedId, Stop station) throws IOException { Document doc = new Document(); doc.add(new TextField("name", station.getName(), Field.Store.YES)); + doc.add(new StringField("feed", feedId, Field.Store.YES)); doc.add(new DoubleField("lat", station.getLat(), Field.Store.YES)); doc.add(new DoubleField("lon", station.getLon(), Field.Store.YES)); doc.add(new StringField("id", GtfsLibrary.convertIdToString(station.getId()), Field.Store.YES)); @@ -110,9 +113,10 @@ private void addStation(IndexWriter iwriter, Stop station) throws IOException { iwriter.addDocument(doc); } - private void addStop(IndexWriter iwriter, Stop stop) throws IOException { + private void addStop(IndexWriter iwriter, String feedId, Stop stop) throws IOException { Document doc = new Document(); doc.add(new TextField("name", stop.getName(), Field.Store.YES)); + doc.add(new StringField("feed", feedId, Field.Store.YES)); if (stop.getCode() != null) { doc.add(new StringField("code", stop.getCode(), Field.Store.YES)); } @@ -167,11 +171,13 @@ public void run() { * @param stations Search for stations by their name * @param clusters Search for clusters by their name * @param corners Search for street corners using at least one of the street names + * @param maxResults Maximum amount of results to return + * @param feeds Return results only from specified feeds or from all feeds if null * @return list of results in in the format expected by GeocoderBuiltin.js in the OTP Leaflet client */ public List query(String queryString, boolean autocomplete, boolean stops, boolean stations, - boolean clusters, boolean corners) { + boolean clusters, boolean corners, int maxResults, List feeds) { /* Turn the query string into a Lucene query.*/ BooleanQuery query = new BooleanQuery(); BooleanQuery termQuery = new BooleanQuery(); @@ -226,9 +232,18 @@ public List query(String queryString, boolean autocomplete, } query.add(typeQuery, BooleanClause.Occur.MUST); } + + if (feeds != null) { + BooleanQuery feedQuery = new BooleanQuery(); + for (String feedId : feeds) { + feedQuery.add(new TermQuery(new Term("feed", feedId)), BooleanClause.Occur.SHOULD); + } + query.add(feedQuery, BooleanClause.Occur.MUST); + } + List result = Lists.newArrayList(); try { - TopScoreDocCollector collector = TopScoreDocCollector.create(10, true); + TopScoreDocCollector collector = TopScoreDocCollector.create(maxResults, true); searcher.search(query, collector); ScoreDoc[] docs = collector.topDocs().scoreDocs; for (int i = 0; i < docs.length; i++) { @@ -259,6 +274,21 @@ public List query(String queryString, boolean autocomplete, } } + /** Fetch results for the geocoder using the OTP graph for stops, clusters and street names + * + * @param queryString + * @param autocomplete Whether we should use the query string to do a prefix match + * @param stops Search for stops, either by name or stop code + * @param stations Search for stations by their name + * @param clusters Search for clusters by their name + * @param corners Search for street corners using at least one of the street names + * @return list of results in in the format expected by GeocoderBuiltin.js in the OTP Leaflet client + */ + public List query (String queryString, boolean autocomplete, boolean stations, + boolean stops, boolean clusters, boolean corners) { + return query(queryString, autocomplete, stops, stations, clusters, corners, 10, null); + } + /** Fetch results for the geocoder using the OTP graph for stops, clusters and street names * * @param queryString @@ -268,11 +298,11 @@ public List query(String queryString, boolean autocomplete, * @param corners Search for street corners using at least one of the street names * @return list of results in in the format expected by GeocoderBuiltin.js in the OTP Leaflet client * - * @deprecated Use {@link #query(String, boolean, boolean, boolean, boolean, boolean)} instead + * @deprecated Use {@link #query(String, boolean, boolean, boolean, boolean, boolean, int, List)} instead */ public List query (String queryString, boolean autocomplete, boolean stops, boolean clusters, boolean corners) { - return query(queryString, autocomplete, stops, false, clusters, corners); + return query(queryString, autocomplete, stops, false, clusters, corners, 10, null); } /** This class matches the structure of the Geocoder responses expected by the OTP client. */ diff --git a/src/main/java/org/opentripplanner/index/IndexGraphQLSchema.java b/src/main/java/org/opentripplanner/index/IndexGraphQLSchema.java index 0f8fb24c4c4..c4a0d8319b1 100644 --- a/src/main/java/org/opentripplanner/index/IndexGraphQLSchema.java +++ b/src/main/java/org/opentripplanner/index/IndexGraphQLSchema.java @@ -2708,11 +2708,21 @@ private Object getObject(String idString) { .description("Return stops with these ids") .type(new GraphQLList(Scalars.GraphQLString)) .build()) + .argument(GraphQLArgument.newArgument() + .name("feeds") + .description("List of feeds from which stops are returned. Defaults to all feeds") + .type(GraphQLList.list(GraphQLNonNull.nonNull(Scalars.GraphQLString))) + .build()) .argument(GraphQLArgument.newArgument() .name("name") .description("Query stops by this name") .type(Scalars.GraphQLString) .build()) + .argument(GraphQLArgument.newArgument() + .name("maxResults") + .description("Number of results to return when using `name` argument. Defaults to 10") + .type(Scalars.GraphQLInt) + .build()) .dataFetcher(environment -> { if ((environment.getArgument("ids") instanceof List)) { if (environment.getArguments().entrySet() @@ -2729,9 +2739,17 @@ private Object getObject(String idString) { } Stream stream; if (environment.getArgument("name") == null) { - stream = index.stopForId.values().stream(); + if (environment.getArgument("feeds") == null) { + stream = index.stopForId.values().stream(); + } else { + List feedIds = environment.getArgument("feeds"); + stream = index.stopForId.values().stream().filter(stop -> feedIds.contains(stop.getId().getAgencyId())); + } } else { - stream = index.getLuceneIndex().query(environment.getArgument("name"), true, true, false, false) + int maxResults = environment.getArgument("maxResults") != null ? environment.getArgument("maxResults") : 10; + List feeds = environment.getArgument("feeds"); + + stream = index.getLuceneIndex().query(environment.getArgument("name"), true, true, false, false, false, maxResults, feeds) .stream() .map(result -> index.stopForId.get(FeedScopedId.convertFromString(result.id))); } @@ -2964,11 +2982,21 @@ private Object getObject(String idString) { .description("Only return stations that match one of the ids in this list") .type(new GraphQLList(Scalars.GraphQLString)) .build()) + .argument(GraphQLArgument.newArgument() + .name("feeds") + .description("List of feeds from which stations are returned. Defaults to all feeds") + .type(GraphQLList.list(GraphQLNonNull.nonNull(Scalars.GraphQLString))) + .build()) .argument(GraphQLArgument.newArgument() .name("name") .description("Query stations by name") .type(Scalars.GraphQLString) .build()) + .argument(GraphQLArgument.newArgument() + .name("maxResults") + .description("Number of results to return when using `name` argument. Defaults to 10") + .type(Scalars.GraphQLInt) + .build()) .dataFetcher(environment -> { if ((environment.getArgument("ids") instanceof List)) { if (environment.getArguments().entrySet() @@ -2986,9 +3014,17 @@ private Object getObject(String idString) { Stream stream; if (environment.getArgument("name") == null) { - stream = index.stationForId.values().stream(); + if (environment.getArgument("feeds") == null) { + stream = index.stationForId.values().stream(); + } else { + List feedIds = environment.getArgument("feeds"); + stream = index.stationForId.values().stream().filter(station -> feedIds.contains(station.getId().getAgencyId())); + } } else { - stream = index.getLuceneIndex().query(environment.getArgument("name"), true, false, true, false, false) + int maxResults = environment.getArgument("maxResults") != null ? environment.getArgument("maxResults") : 10; + List feeds = environment.getArgument("feeds"); + + stream = index.getLuceneIndex().query(environment.getArgument("name"), true, false, true, false, false, maxResults, feeds) .stream() .map(result -> index.stationForId.get(FeedScopedId.convertFromString(result.id))); }