Skip to content

Commit

Permalink
Working on converting to a standalone app.
Browse files Browse the repository at this point in the history
  • Loading branch information
jzonthemtn committed Jan 1, 2025
1 parent 7a2d577 commit 2899bcf
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 116 deletions.
8 changes: 7 additions & 1 deletion opensearch-search-quality-evaluation-framework/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<url>https://www.ubisearch.dev</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.release>17</maven.compiler.release>
<maven.compiler.release>21</maven.compiler.release>
</properties>
<dependencies>
<dependency>
Expand Down Expand Up @@ -52,5 +52,11 @@
<artifactId>gson</artifactId>
<version>2.11.0</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>5.11.4</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,24 @@
import org.apache.hc.core5.http.HttpHost;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.json.JsonData;
import org.opensearch.client.json.jackson.JacksonJsonpMapper;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch._types.Refresh;
import org.opensearch.client.opensearch._types.SortOrder;
import org.opensearch.client.opensearch._types.Time;
import org.opensearch.client.opensearch._types.aggregations.Aggregate;
import org.opensearch.client.opensearch._types.aggregations.Aggregation;
import org.opensearch.client.opensearch._types.aggregations.StringTermsAggregate;
import org.opensearch.client.opensearch._types.aggregations.StringTermsBucket;
import org.opensearch.client.opensearch._types.mapping.IntegerNumberProperty;
import org.opensearch.client.opensearch._types.mapping.Property;
import org.opensearch.client.opensearch._types.mapping.TypeMapping;
import org.opensearch.client.opensearch._types.query_dsl.BoolQuery;
import org.opensearch.client.opensearch._types.query_dsl.MatchQuery;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch._types.query_dsl.RangeQuery;
import org.opensearch.client.opensearch._types.query_dsl.WrapperQuery;
import org.opensearch.client.opensearch.core.BulkRequest;
import org.opensearch.client.opensearch.core.BulkResponse;
Expand Down Expand Up @@ -493,6 +500,161 @@ public Map<String, Set<ClickthroughRate>> getClickthroughRate(final int maxRank)

}

@Override
public Map<Integer, Double> getRankAggregatedClickThrough(final int maxRank) throws Exception {

final Map<Integer, Double> rankAggregatedClickThrough = new HashMap<>();

final RangeQuery rangeQuery = RangeQuery.of(r -> r
.field("event_attributes.position.ordinal")
.lte(JsonData.of(maxRank))
);

// TODO: Is this the same as: final BucketOrder bucketOrder = BucketOrder.key(true);
final List<Map<String, SortOrder>> sort = new ArrayList<>();
sort.add(Map.of("_key", SortOrder.Asc));

final Aggregation positionsAggregator = Aggregation.of(a -> a
.terms(t -> t
.field("event_attributes.position.ordinal")
.name("By_Position")
.size(maxRank)
.order(sort)
)
);

final Aggregation actionNameAggregation = Aggregation.of(a -> a
.terms(t -> t
.field("action_name")
.name("By_Action")
.size(maxRank)
.order(sort)
)
);

final Map<String, Aggregation> aggregations = new HashMap<>();
aggregations.put("By_Position", positionsAggregator);
aggregations.put("By_Action", actionNameAggregation);

// TODO: Allow for a time period and for a specific application.
final SearchRequest searchRequest = new SearchRequest.Builder()
.index(Constants.UBI_EVENTS_INDEX_NAME)
.aggregations(aggregations)
.query(q -> q.range(rangeQuery))
.from(0)
.size(0)
.build();

final SearchResponse<Void> searchResponse = client.search(searchRequest, Void.class);

final Map<String, Aggregate> aggs = searchResponse.aggregations();
final StringTermsAggregate byAction = aggs.get("By_Action").sterms();
final List<StringTermsBucket> byActionBuckets = byAction.buckets().array();

final Map<Integer, Double> clickCounts = new HashMap<>();
final Map<Integer, Double> impressionCounts = new HashMap<>();

for (final StringTermsBucket bucket : byActionBuckets) {
System.out.println("Key: " + bucket.key() + ", Doc Count: " + bucket.docCount());

// // Handle the "impression" bucket.
// if(EVENT_IMPRESSION.equalsIgnoreCase(bucket.key())) {
//
// final Aggregate positionTerms = bucket.aggregations().get("By_Position");
//
// final Collection<? extends Terms.Bucket> positionBuckets = positionTerms.getBuckets();
//
// for(final Terms.Bucket positionBucket : positionBuckets) {
// LOGGER.debug("Inserting impression event from position {} with click count {}", positionBucket.getKey(), (double) positionBucket.getDocCount());
// impressionCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount());
// }
//
// }
//
// // Handle the "click" bucket.
// if(EVENT_CLICK.equalsIgnoreCase(bucket.key())) {
//
// final Aggregate positionTerms = actionBucket.getAggregations().get("By_Position");
// final Collection<? extends Terms.Bucket> positionBuckets = positionTerms.getBuckets();
//
// for(final Terms.Bucket positionBucket : positionBuckets) {
// LOGGER.debug("Inserting client event from position {} with click count {}", positionBucket.getKey(), (double) positionBucket.getDocCount());
// clickCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount());
// }
//
// }

}


// final Terms actionTerms = searchResponse.getAggregations().get("By_Action");
// final Collection<? extends Terms.Bucket> actionBuckets = actionTerms.getBuckets();
//
// LOGGER.debug("Aggregation query: {}", searchSourceBuilder.toString());
//
// for(final Terms.Bucket actionBucket : actionBuckets) {
//
// // Handle the "impression" bucket.
// if(EVENT_IMPRESSION.equalsIgnoreCase(actionBucket.getKey().toString())) {
//
// final Terms positionTerms = actionBucket.getAggregations().get("By_Position");
// final Collection<? extends Terms.Bucket> positionBuckets = positionTerms.getBuckets();
//
// for(final Terms.Bucket positionBucket : positionBuckets) {
// LOGGER.debug("Inserting impression event from position {} with click count {}", positionBucket.getKey(), (double) positionBucket.getDocCount());
// impressionCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount());
// }
//
// }
//
// // Handle the "click" bucket.
// if(EVENT_CLICK.equalsIgnoreCase(actionBucket.getKey().toString())) {
//
// final Terms positionTerms = actionBucket.getAggregations().get("By_Position");
// final Collection<? extends Terms.Bucket> positionBuckets = positionTerms.getBuckets();
//
// for(final Terms.Bucket positionBucket : positionBuckets) {
// LOGGER.debug("Inserting client event from position {} with click count {}", positionBucket.getKey(), (double) positionBucket.getDocCount());
// clickCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount());
// }
//
// }
//
// }

for(int rank = 0; rank < maxRank; rank++) {

if(impressionCounts.containsKey(rank)) {

if(clickCounts.containsKey(rank)) {

// Calculate the CTR by dividing the number of clicks by the number of impressions.
LOGGER.info("Position = {}, Impression Counts = {}, Click Count = {}", rank, impressionCounts.get(rank), clickCounts.get(rank));
rankAggregatedClickThrough.put(rank, clickCounts.get(rank) / impressionCounts.get(rank));

} else {

// This document has impressions but no clicks, so it's CTR is zero.
LOGGER.info("Position = {}, Impression Counts = {}, Impressions but no clicks so CTR is 0", rank, clickCounts.get(rank));
rankAggregatedClickThrough.put(rank, 0.0);

}

} else {

// No impressions so the clickthrough rate is 0.
LOGGER.info("No impressions for rank {}, so using CTR of 0", rank);
rankAggregatedClickThrough.put(rank, (double) 0);

}

}

indexRankAggregatedClickthrough(rankAggregatedClickThrough);

return rankAggregatedClickThrough;

}

private Collection<String> getQueryIdsHavingUserQuery(final String userQuery) throws Exception {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,11 @@ public abstract class SearchEngine {
*/
public abstract Map<String, Set<ClickthroughRate>> getClickthroughRate(final int maxRank) throws Exception;

/**
* Calculate the rank-aggregated click through from the UBI events.
* @return A map of positions to clickthrough rates.
* @throws IOException Thrown when a problem accessing OpenSearch.
*/
public abstract Map<Integer, Double> getRankAggregatedClickThrough(int maxRank) throws Exception;

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,14 @@
import com.google.gson.Gson;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.eval.Constants;
import org.opensearch.eval.engine.SearchEngine;
import org.opensearch.eval.judgments.clickmodel.ClickModel;
import org.opensearch.eval.judgments.queryhash.IncrementalUserQueryHash;
import org.opensearch.eval.model.ClickthroughRate;
import org.opensearch.eval.model.data.Judgment;
import org.opensearch.eval.model.ubi.event.UbiEvent;
import org.opensearch.eval.judgments.queryhash.IncrementalUserQueryHash;
import org.opensearch.eval.utils.MathUtils;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -62,7 +57,7 @@ public String calculateJudgments() throws Exception {

// Calculate and index the rank-aggregated click-through.
LOGGER.info("Beginning calculation of rank-aggregated click-through.");
final Map<Integer, Double> rankAggregatedClickThrough = getRankAggregatedClickThrough();
final Map<Integer, Double> rankAggregatedClickThrough = searchEngine.getRankAggregatedClickThrough(maxRank);
LOGGER.info("Rank-aggregated clickthrough positions: {}", rankAggregatedClickThrough.size());
showRankAggregatedClickThrough(rankAggregatedClickThrough);

Expand Down Expand Up @@ -151,108 +146,6 @@ public String calculateCoec(final Map<Integer, Double> rankAggregatedClickThroug

}


/**
* Calculate the rank-aggregated click through from the UBI events.
* @return A map of positions to clickthrough rates.
* @throws IOException Thrown when a problem accessing OpenSearch.
*/
public Map<Integer, Double> getRankAggregatedClickThrough() throws Exception {

final Map<Integer, Double> rankAggregatedClickThrough = new HashMap<>();

// TODO: Allow for a time period and for a specific application.

final QueryBuilder findRangeNumber = QueryBuilders.rangeQuery("event_attributes.position.ordinal").lte(parameters.getMaxRank());
final QueryBuilder queryBuilder = new BoolQueryBuilder().must(findRangeNumber);

// Order the aggregations by key and not by value.
final BucketOrder bucketOrder = BucketOrder.key(true);

final TermsAggregationBuilder positionsAggregator = AggregationBuilders.terms("By_Position").field("event_attributes.position.ordinal").order(bucketOrder).size(parameters.getMaxRank());
final TermsAggregationBuilder actionNameAggregation = AggregationBuilders.terms("By_Action").field("action_name").subAggregation(positionsAggregator).order(bucketOrder).size(parameters.getMaxRank());

final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.query(queryBuilder)
.aggregation(actionNameAggregation)
.from(0)
.size(0);

final SearchRequest searchRequest = new SearchRequest(Constants.UBI_EVENTS_INDEX_NAME).source(searchSourceBuilder);
final SearchResponse searchResponse = client.search(searchRequest).get();

final Map<Integer, Double> clickCounts = new HashMap<>();
final Map<Integer, Double> impressionCounts = new HashMap<>();

final Terms actionTerms = searchResponse.getAggregations().get("By_Action");
final Collection<? extends Terms.Bucket> actionBuckets = actionTerms.getBuckets();

LOGGER.debug("Aggregation query: {}", searchSourceBuilder.toString());

for(final Terms.Bucket actionBucket : actionBuckets) {

// Handle the "impression" bucket.
if(EVENT_IMPRESSION.equalsIgnoreCase(actionBucket.getKey().toString())) {

final Terms positionTerms = actionBucket.getAggregations().get("By_Position");
final Collection<? extends Terms.Bucket> positionBuckets = positionTerms.getBuckets();

for(final Terms.Bucket positionBucket : positionBuckets) {
LOGGER.debug("Inserting impression event from position {} with click count {}", positionBucket.getKey(), (double) positionBucket.getDocCount());
impressionCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount());
}

}

// Handle the "click" bucket.
if(EVENT_CLICK.equalsIgnoreCase(actionBucket.getKey().toString())) {

final Terms positionTerms = actionBucket.getAggregations().get("By_Position");
final Collection<? extends Terms.Bucket> positionBuckets = positionTerms.getBuckets();

for(final Terms.Bucket positionBucket : positionBuckets) {
LOGGER.debug("Inserting client event from position {} with click count {}", positionBucket.getKey(), (double) positionBucket.getDocCount());
clickCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount());
}

}

}

for(int rank = 0; rank < parameters.getMaxRank(); rank++) {

if(impressionCounts.containsKey(rank)) {

if(clickCounts.containsKey(rank)) {

// Calculate the CTR by dividing the number of clicks by the number of impressions.
LOGGER.info("Position = {}, Impression Counts = {}, Click Count = {}", rank, impressionCounts.get(rank), clickCounts.get(rank));
rankAggregatedClickThrough.put(rank, clickCounts.get(rank) / impressionCounts.get(rank));

} else {

// This document has impressions but no clicks, so it's CTR is zero.
LOGGER.info("Position = {}, Impression Counts = {}, Impressions but no clicks so CTR is 0", rank, clickCounts.get(rank));
rankAggregatedClickThrough.put(rank, 0.0);

}

} else {

// No impressions so the clickthrough rate is 0.
LOGGER.info("No impressions for rank {}, so using CTR of 0", rank);
rankAggregatedClickThrough.put(rank, (double) 0);

}

}

searchEngine.indexRankAggregatedClickthrough(rankAggregatedClickThrough);

return rankAggregatedClickThrough;

}

private void showJudgments(final Collection<Judgment> judgments) {

for(final Judgment judgment : judgments) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

/**
* A judgment of a search result's quality for a given query.
Expand Down Expand Up @@ -45,6 +46,21 @@ public Judgment(final String id, final String queryId, final String query, final
this.judgment = judgment;
}

/**
* Creates a new judgment.
* @param queryId The query ID for the judgment.
* @param query The query for the judgment.
* @param document The document in the judgment.
* @param judgment The judgment value.
*/
public Judgment(final String queryId, final String query, final String document, final double judgment) {
super(UUID.randomUUID().toString());
this.queryId = queryId;
this.query = query;
this.document = document;
this.judgment = judgment;
}

public String toJudgmentString() {
return queryId + ", " + query + ", " + document + ", " + MathUtils.round(judgment);
}
Expand Down
Loading

0 comments on commit 2899bcf

Please sign in to comment.