-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Pagination in hybrid query #963
Changes from all commits
bc73a50
7b20824
7dd0841
e0f8f4c
a8bf87b
323b8c5
9e12976
dc9020c
c3cacd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
import org.apache.lucene.search.TopFieldDocs; | ||
import org.apache.lucene.search.FieldDoc; | ||
import org.opensearch.common.lucene.search.TopDocsAndMaxScore; | ||
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; | ||
import org.opensearch.neuralsearch.processor.dto.CombineScoresDto; | ||
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; | ||
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; | ||
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; | ||
|
@@ -56,22 +56,24 @@ public class NormalizationProcessorWorkflow { | |
|
||
/** | ||
* Start execution of this workflow | ||
* @param querySearchResults input data with QuerySearchResult from multiple shards | ||
* @param normalizationTechnique technique for score normalization | ||
* @param combinationTechnique technique for score combination | ||
* @param normalizationExecuteDto contains querySearchResults input data with QuerySearchResult | ||
* from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization | ||
* combinationTechnique technique for score combination, searchPhaseContext. | ||
*/ | ||
public void execute( | ||
final List<QuerySearchResult> querySearchResults, | ||
final Optional<FetchSearchResult> fetchSearchResultOptional, | ||
final ScoreNormalizationTechnique normalizationTechnique, | ||
final ScoreCombinationTechnique combinationTechnique | ||
final ScoreCombinationTechnique combinationTechnique, | ||
final int fromValueForSingleShard | ||
) { | ||
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() | ||
.querySearchResults(querySearchResults) | ||
.fetchSearchResultOptional(fetchSearchResultOptional) | ||
.normalizationTechnique(normalizationTechnique) | ||
.combinationTechnique(combinationTechnique) | ||
.explain(false) | ||
.fromValueForSingleShard(fromValueForSingleShard) | ||
.build(); | ||
execute(request); | ||
} | ||
|
@@ -95,6 +97,8 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) | |
.scoreCombinationTechnique(request.getCombinationTechnique()) | ||
.querySearchResults(request.getQuerySearchResults()) | ||
.sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs)) | ||
.fromValueForSingleShard(request.getFromValueForSingleShard()) | ||
.isFetchResultsPresent(request.getFetchSearchResultOptional().isPresent()) | ||
.build(); | ||
|
||
// combine | ||
|
@@ -104,7 +108,12 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) | |
// post-process data | ||
log.debug("Post-process query results after score normalization and combination"); | ||
updateOriginalQueryResults(combineScoresDTO); | ||
updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds); | ||
updateOriginalFetchResults( | ||
request.getQuerySearchResults(), | ||
request.getFetchSearchResultOptional(), | ||
unprocessedDocIds, | ||
request.getFromValueForSingleShard() | ||
); | ||
} | ||
|
||
/** | ||
|
@@ -177,15 +186,29 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) | |
final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults(); | ||
final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults); | ||
final Sort sort = combineScoresDTO.getSort(); | ||
int totalScoreDocsCount = 0; | ||
for (int index = 0; index < querySearchResults.size(); index++) { | ||
QuerySearchResult querySearchResult = querySearchResults.get(index); | ||
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index); | ||
totalScoreDocsCount += updatedTopDocs.getScoreDocs().size(); | ||
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore( | ||
buildTopDocs(updatedTopDocs, sort), | ||
maxScoreForShard(updatedTopDocs, sort != null) | ||
); | ||
// Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard. | ||
// This will ensure the trimming of the results. | ||
if (combineScoresDTO.isFetchResultsPresent()) { | ||
querySearchResult.from(combineScoresDTO.getFromValueForSingleShard()); | ||
} | ||
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats()); | ||
} | ||
|
||
final int from = querySearchResults.get(0).from(); | ||
if (from > 0 && from > totalScoreDocsCount) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. first check looks redundant, can't we rely only on |
||
throw new IllegalArgumentException( | ||
String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results") | ||
); | ||
} | ||
} | ||
|
||
private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) { | ||
|
@@ -244,7 +267,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) { | |
private void updateOriginalFetchResults( | ||
final List<QuerySearchResult> querySearchResults, | ||
final Optional<FetchSearchResult> fetchSearchResultOptional, | ||
final List<Integer> docIds | ||
final List<Integer> docIds, | ||
final int fromValueForSingleShard | ||
) { | ||
if (fetchSearchResultOptional.isEmpty()) { | ||
return; | ||
|
@@ -276,14 +300,21 @@ private void updateOriginalFetchResults( | |
|
||
QuerySearchResult querySearchResult = querySearchResults.get(0); | ||
TopDocs topDocs = querySearchResult.topDocs().topDocs; | ||
|
||
// When normalization process will execute before the fetch phase, then from =0 is applicable. | ||
// When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the | ||
// search request. | ||
// iterate over the normalized/combined scores, that solves (1) and (3) | ||
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> { | ||
SearchHit[] updatedSearchHitArray = new SearchHit[topDocs.scoreDocs.length - fromValueForSingleShard]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please pull |
||
for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please change the semantic here, start from 0 and do (i + offset) when you're reading from topDocs |
||
ScoreDoc scoreDoc = topDocs.scoreDocs[i]; | ||
// get fetched hit content by doc_id | ||
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); | ||
// update score to normalized/combined value (3) | ||
searchHit.score(scoreDoc.score); | ||
return searchHit; | ||
}).toArray(SearchHit[]::new); | ||
updatedSearchHitArray[i - fromValueForSingleShard] = searchHit; | ||
} | ||
|
||
SearchHits updatedSearchHits = new SearchHits( | ||
updatedSearchHitArray, | ||
querySearchResult.getTotalHits(), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,18 @@ | |
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor.combination; | ||
package org.opensearch.neuralsearch.processor.dto; | ||
|
||
import java.util.List; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.NonNull; | ||
import org.apache.lucene.search.Sort; | ||
import org.opensearch.common.Nullable; | ||
import org.opensearch.neuralsearch.processor.CompoundTopDocs; | ||
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; | ||
import org.opensearch.search.query.QuerySearchResult; | ||
|
||
/** | ||
|
@@ -29,4 +31,6 @@ public class CombineScoresDto { | |
private List<QuerySearchResult> querySearchResults; | ||
@Nullable | ||
private Sort sort; | ||
private int fromValueForSingleShard; | ||
private boolean isFetchResultsPresent; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it doesn't look right to put this field here, it's not related to combination. please find alternative solution |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor.dto; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.NonNull; | ||
import org.opensearch.action.search.SearchPhaseContext; | ||
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; | ||
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; | ||
import org.opensearch.search.fetch.FetchSearchResult; | ||
import org.opensearch.search.query.QuerySearchResult; | ||
|
||
import java.util.List; | ||
import java.util.Optional; | ||
|
||
/** | ||
* DTO object to hold data in NormalizationProcessorWorkflow class | ||
* in NormalizationProcessorWorkflow. | ||
*/ | ||
@AllArgsConstructor | ||
@Builder | ||
@Getter | ||
public class NormalizationExecuteDto { | ||
@NonNull | ||
private List<QuerySearchResult> querySearchResults; | ||
@NonNull | ||
private Optional<FetchSearchResult> fetchSearchResultOptional; | ||
@NonNull | ||
private ScoreNormalizationTechnique normalizationTechnique; | ||
@NonNull | ||
private ScoreCombinationTechnique combinationTechnique; | ||
@NonNull | ||
private SearchPhaseContext searchPhaseContext; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
import java.util.Objects; | ||
import java.util.concurrent.Callable; | ||
|
||
import lombok.Getter; | ||
import org.apache.lucene.search.BooleanClause; | ||
import org.apache.lucene.search.BooleanQuery; | ||
import org.apache.lucene.search.IndexSearcher; | ||
|
@@ -31,20 +32,25 @@ | |
* Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual | ||
* scores for each sub-query. | ||
*/ | ||
@Getter | ||
public final class HybridQuery extends Query implements Iterable<Query> { | ||
|
||
private final List<Query> subQueries; | ||
private Integer paginationDepth; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this is not primitive int? Operating with wrapper class is potentially error prone when boxing/unboxing a null value. |
||
|
||
/** | ||
* Create new instance of hybrid query object based on collection of sub queries and filter query | ||
* @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores | ||
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is | ||
*/ | ||
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries) { | ||
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, final Integer paginationDepth) { | ||
Objects.requireNonNull(subQueries, "collection of queries must not be null"); | ||
if (subQueries.isEmpty()) { | ||
throw new IllegalArgumentException("collection of queries must not be empty"); | ||
} | ||
if (paginationDepth != null && paginationDepth == 0) { | ||
throw new IllegalArgumentException("pagination depth must not be zero"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to |
||
} | ||
if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) { | ||
this.subQueries = new ArrayList<>(subQueries); | ||
} else { | ||
|
@@ -57,10 +63,11 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ | |
} | ||
this.subQueries = modifiedSubQueries; | ||
} | ||
this.paginationDepth = paginationDepth; | ||
} | ||
|
||
public HybridQuery(final Collection<Query> subQueries) { | ||
this(subQueries, List.of()); | ||
public HybridQuery(final Collection<Query> subQueries, final Integer paginationDepth) { | ||
this(subQueries, List.of(), paginationDepth); | ||
} | ||
|
||
/** | ||
|
@@ -128,7 +135,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { | |
return super.rewrite(indexSearcher); | ||
} | ||
final List<Query> rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors); | ||
return new HybridQuery(rewrittenSubQueries); | ||
return new HybridQuery(rewrittenSubQueries, paginationDepth); | ||
} | ||
|
||
private Void rewriteQuery(Query query, HybridQueryExecutorCollector<IndexSearcher, Map.Entry<Query, Boolean>> collector) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need
final
?