Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Pagination in hybrid query #963

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Pagination in Hybrid query ([#963](https://github.com/opensearch-project/neural-search/pull/963))
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

import org.opensearch.index.query.MatchQueryBuilder;

import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public final class MinClusterVersionUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0;

// Note this minimal version will act as a override
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
Expand All @@ -38,6 +39,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY);
}

public static boolean isClusterOnOrAfterMinReqVersion(String key) {
Version version;
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,23 @@ private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWo
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);

boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain())
&& searchPhaseContext.getRequest().source().explain();

int fromValueForSingleShard = -1;
if (searchPhaseContext.getNumShards() == 1 && fetchSearchResult.isPresent()) {
fromValueForSingleShard = searchPhaseContext.getRequest().source().from();
}

NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(explain)
.pipelineProcessingContext(requestContextOptional.orElse(null))
.fromValueForSingleShard(fromValueForSingleShard)
.build();
normalizationWorkflow.execute(request);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.FieldDoc;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.dto.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
Expand Down Expand Up @@ -56,22 +56,24 @@ public class NormalizationProcessorWorkflow {

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
* @param normalizationExecuteDto contains querySearchResults input data with QuerySearchResult
* from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization
* combinationTechnique technique for score combination, searchPhaseContext.
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
final ScoreCombinationTechnique combinationTechnique,
final int fromValueForSingleShard
) {
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResultOptional)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(false)
.fromValueForSingleShard(fromValueForSingleShard)
.build();
execute(request);
}
Expand All @@ -95,6 +97,8 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)
.scoreCombinationTechnique(request.getCombinationTechnique())
.querySearchResults(request.getQuerySearchResults())
.sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs))
.fromValueForSingleShard(request.getFromValueForSingleShard())
.isFetchResultsPresent(request.getFetchSearchResultOptional().isPresent())
.build();

// combine
Expand All @@ -104,7 +108,12 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(combineScoresDTO);
updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds);
updateOriginalFetchResults(
request.getQuerySearchResults(),
request.getFetchSearchResultOptional(),
unprocessedDocIds,
request.getFromValueForSingleShard()
);
}

/**
Expand Down Expand Up @@ -177,15 +186,29 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO)
final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults);
final Sort sort = combineScoresDTO.getSort();
int totalScoreDocsCount = 0;
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
totalScoreDocsCount += updatedTopDocs.getScoreDocs().size();
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(
buildTopDocs(updatedTopDocs, sort),
maxScoreForShard(updatedTopDocs, sort != null)
);
// Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard.
// This will ensure the trimming of the results.
if (combineScoresDTO.isFetchResultsPresent()) {
querySearchResult.from(combineScoresDTO.getFromValueForSingleShard());
}
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
}

final int from = querySearchResults.get(0).from();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need final?

if (from > 0 && from > totalScoreDocsCount) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first check looks redundant, can't we rely only on from > totalScoreDocsCount?

throw new IllegalArgumentException(
String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results")
);
}
}

private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) {
Expand Down Expand Up @@ -244,7 +267,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
final List<Integer> docIds,
final int fromValueForSingleShard
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand Down Expand Up @@ -276,14 +300,21 @@ private void updateOriginalFetchResults(

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;

// When normalization process will execute before the fetch phase, then from =0 is applicable.
// When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the
// search request.
// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
SearchHit[] updatedSearchHitArray = new SearchHit[topDocs.scoreDocs.length - fromValueForSingleShard];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please pull topDocs.scoreDocs.length - fromValueForSingleShard expression to a variable and give it meaningful name

for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please change the semantic here, start from 0 and do (i + offset) when you're reading from topDocs

ScoreDoc scoreDoc = topDocs.scoreDocs[i];
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
updatedSearchHitArray[i - fromValueForSingleShard] = searchHit;
}

SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ public class NormalizationProcessorWorkflowExecuteRequest {
final ScoreCombinationTechnique combinationTechnique;
boolean explain;
final PipelineProcessingContext pipelineProcessingContext;
int fromValueForSingleShard;
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.dto.CombineScoresDto;

/**
* Abstracts combination of scores in query search results.
Expand Down Expand Up @@ -70,14 +71,10 @@ public class ScoreCombiner {
public void combineScores(final CombineScoresDto combineScoresDTO) {
// iterate over results from each shard. Every CompoundTopDocs object has results from
// multiple sub queries, doc ids may repeat for each sub query results
ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique();
Sort sort = combineScoresDTO.getSort();
combineScoresDTO.getQueryTopDocs()
.forEach(
compoundQueryTopDocs -> combineShardScores(
combineScoresDTO.getScoreCombinationTechnique(),
compoundQueryTopDocs,
combineScoresDTO.getSort()
)
);
.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort));
}

private void combineShardScores(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.combination;
package org.opensearch.neuralsearch.processor.dto;

import java.util.List;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.apache.lucene.search.Sort;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.search.query.QuerySearchResult;

/**
Expand All @@ -29,4 +31,6 @@ public class CombineScoresDto {
private List<QuerySearchResult> querySearchResults;
@Nullable
private Sort sort;
private int fromValueForSingleShard;
private boolean isFetchResultsPresent;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't look right to put this field here, it's not related to combination. please find alternative solution

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.dto;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

import java.util.List;
import java.util.Optional;

/**
* DTO object to hold data in NormalizationProcessorWorkflow class
* in NormalizationProcessorWorkflow.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizationExecuteDto {
@NonNull
private List<QuerySearchResult> querySearchResults;
@NonNull
private Optional<FetchSearchResult> fetchSearchResultOptional;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
@NonNull
private ScoreCombinationTechnique combinationTechnique;
@NonNull
private SearchPhaseContext searchPhaseContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Objects;
import java.util.concurrent.Callable;

import lombok.Getter;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand All @@ -31,20 +32,25 @@
* Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual
* scores for each sub-query.
*/
@Getter
public final class HybridQuery extends Query implements Iterable<Query> {

private final List<Query> subQueries;
private Integer paginationDepth;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is not primitive int? Operating with wrapper class is potentially error prone when boxing/unboxing a null value.


/**
* Create new instance of hybrid query object based on collection of sub queries and filter query
* @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is
*/
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries) {
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, final Integer paginationDepth) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
if (subQueries.isEmpty()) {
throw new IllegalArgumentException("collection of queries must not be empty");
}
if (paginationDepth != null && paginationDepth == 0) {
throw new IllegalArgumentException("pagination depth must not be zero");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to pagination_depth in error message to signify that this is a parameter

}
if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) {
this.subQueries = new ArrayList<>(subQueries);
} else {
Expand All @@ -57,10 +63,11 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ
}
this.subQueries = modifiedSubQueries;
}
this.paginationDepth = paginationDepth;
}

public HybridQuery(final Collection<Query> subQueries) {
this(subQueries, List.of());
public HybridQuery(final Collection<Query> subQueries, final Integer paginationDepth) {
this(subQueries, List.of(), paginationDepth);
}

/**
Expand Down Expand Up @@ -128,7 +135,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return super.rewrite(indexSearcher);
}
final List<Query> rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors);
return new HybridQuery(rewrittenSubQueries);
return new HybridQuery(rewrittenSubQueries, paginationDepth);
}

private Void rewriteQuery(Query query, HybridQueryExecutorCollector<IndexSearcher, Map.Entry<Query, Boolean>> collector) {
Expand Down
Loading
Loading