Skip to content

Commit

Permalink
#3 and #4 Working on clickthrough rates.
Browse files Browse the repository at this point in the history
Signed-off-by: jzonthemtn <[email protected]>
  • Loading branch information
jzonthemtn committed Sep 19, 2024
1 parent df58c5a commit dab5944
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ public static void main(String[] args) throws Exception {

// Calculate the rank-aggregated click-through.
final Map<Integer, Double> rankAggregatedClickThrough = openSearchEvaluationFramework.getRankAggregatedClickThrough();
// TODO: Index the <k,v> pairs in rankAggregatedClickThrough.

// Calculate the click-through rate for query/doc pairs.
final Collection<ClickthroughRate> clickthroughRates = openSearchEvaluationFramework.getClickthroughRate();
// TODO: Index the properties in each ClickthroughRate object.

}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
package org.opensearch.searchevaluationframework;

import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpHost;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkRequestBuilder;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.searchevaluationframework.model.ClickthroughRate;
import org.opensearch.searchevaluationframework.model.UbiEvent;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpHost;
import org.opensearch.action.search.ClearScrollRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -32,20 +19,25 @@
import org.opensearch.search.Scroll;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.searchevaluationframework.model.ClickthroughRate;
import org.opensearch.searchevaluationframework.model.UbiEvent;
import org.opensearch.searchevaluationframework.model.UbiSearch;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.UUID;

public class OpenSearchEvaluationFramework {

public static final String UBI_EVENTS_INDEX = "ubi_events";
public static final String UBI_QUERIES_INDEX = "ubi_queries";
public static final String INDEX_UBI_EVENTS = "ubi_events";
public static final String INDEX_UBI_QUERIES = "ubi_queries";
public static final String INDEX_RANK_AGGREGATED_CTR = "rank_aggregated_ctr";
public static final String INDEX_QUERY_DOC_CTR = "click_through_rates";

public static final String EVENT_CLICK = "click";

public static final String CLICK_EVENT = "click";

private final RestHighLevelClient client;

Expand All @@ -68,7 +60,7 @@ public Collection<ClickthroughRate> getClickthroughRate() throws IOException {
final Scroll scroll = new Scroll(TimeValue.timeValueMinutes(10L));

final SearchRequest searchRequest = Requests
.searchRequest(UBI_EVENTS_INDEX)
.searchRequest(INDEX_UBI_EVENTS)
.source(searchSourceBuilder)
.scroll(scroll);

Expand All @@ -85,7 +77,7 @@ public Collection<ClickthroughRate> getClickthroughRate() throws IOException {
final UbiEvent ubiEvent = new UbiEvent(hit);
final ClickthroughRate clickthroughRate = new ClickthroughRate(ubiEvent.getQueryId());

if (StringUtils.equalsIgnoreCase(ubiEvent.getActionName(), CLICK_EVENT)) {
if (StringUtils.equalsIgnoreCase(ubiEvent.getActionName(), EVENT_CLICK)) {
clickthroughRate.logClick();
} else {
clickthroughRate.logEvent();
Expand Down Expand Up @@ -127,7 +119,7 @@ public Map<Integer, Double> getRankAggregatedClickThrough() throws IOException {
final Scroll scroll = new Scroll(TimeValue.timeValueMinutes(10L));

final SearchRequest searchRequest = Requests
.searchRequest(UBI_EVENTS_INDEX)
.searchRequest(INDEX_UBI_EVENTS)
.source(searchSourceBuilder)
.scroll(scroll);

Expand All @@ -144,7 +136,7 @@ public Map<Integer, Double> getRankAggregatedClickThrough() throws IOException {
final UbiEvent ubiEvent = new UbiEvent(hit);

// Increment the number of clicks for the position.
if (StringUtils.equalsIgnoreCase(ubiEvent.getActionName(), CLICK_EVENT)) {
if (StringUtils.equalsIgnoreCase(ubiEvent.getActionName(), EVENT_CLICK)) {
rankAggregatedClickThrough.merge(ubiEvent.getPosition(), 1.0, Double::sum);
}

Expand Down Expand Up @@ -176,29 +168,59 @@ public Map<Integer, Double> getRankAggregatedClickThrough() throws IOException {
System.out.println("Rank-aggregated click through: " + rankAggregatedClickThrough);
System.out.println("Number of total events: " + totalEvents);

index(rankAggregatedClickThrough);

return rankAggregatedClickThrough;

}

private void index(final Collection<ClickthroughRate> clickthroughRates) throws IOException {
private void index(final Map<Integer, Double> rankAggregatedClickThrough) throws IOException {

final BulkRequest request = new BulkRequest();
if(!rankAggregatedClickThrough.isEmpty()) {

for(final ClickthroughRate clickthroughRate : clickthroughRates) {
final BulkRequest request = new BulkRequest();

final Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("query_id", clickthroughRate.getQueryId());
jsonMap.put("clicks", clickthroughRate.getClicks());
jsonMap.put("events", clickthroughRate.getEvents());
jsonMap.put("ctr", clickthroughRate.getClickthroughRate());
for (final int position : rankAggregatedClickThrough.keySet()) {

final IndexRequest indexRequest = new IndexRequest("click_through_rates").id(UUID.randomUUID().toString()).source(jsonMap);
final Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("position", position);
jsonMap.put("ctr", rankAggregatedClickThrough.get(position));

request.add(indexRequest);
final IndexRequest indexRequest = new IndexRequest(INDEX_RANK_AGGREGATED_CTR).id(UUID.randomUUID().toString()).source(jsonMap);

request.add(indexRequest);

}

client.bulk(request, RequestOptions.DEFAULT);

}

client.bulk(request, RequestOptions.DEFAULT);
}

private void index(final Collection<ClickthroughRate> clickthroughRates) throws IOException {

if(!clickthroughRates.isEmpty()) {

final BulkRequest request = new BulkRequest();

for (final ClickthroughRate clickthroughRate : clickthroughRates) {

final Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("query_id", clickthroughRate.getQueryId());
jsonMap.put("clicks", clickthroughRate.getClicks());
jsonMap.put("events", clickthroughRate.getEvents());
jsonMap.put("ctr", clickthroughRate.getClickthroughRate());

final IndexRequest indexRequest = new IndexRequest(INDEX_QUERY_DOC_CTR).id(UUID.randomUUID().toString()).source(jsonMap);

request.add(indexRequest);

}

client.bulk(request, RequestOptions.DEFAULT);

}

}

Expand Down

0 comments on commit dab5944

Please sign in to comment.