Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate explainability for hybrid query into RRF processor #1037

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import java.util.Optional;

/**
* Base class for all score hybridization processors. This class is responsible for executing the score hybridization process.
* It is a pipeline processor that is executed after the query phase and before the fetch phase.
*/
public abstract class AbstractScoreHybridizationProcessor implements SearchPhaseResultsProcessor {
/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult
* @param searchPhaseContext
* @param requestContextOptional
* @param <Result>
*/
abstract <Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ public SearchResponse processResponse(
);
}
// Create and set final explanation combining all components
Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When sorting by a field other than score:

  1. searchHit.score will be Float.NaN
  2. score in the search hit and in actual response will be null

We can't pass null as a value for the explanation object. Therefore, I've set it to 0.0 in these cases. This ensures that we always have a valid numeric value for the explanation, even when the score isn't the primary sorting factor.

Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
finalScore,
// combination level explanation is always a single detail
combinationExplanation.getScoreDetails().get(0).getValue(),
normalizedExplanation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

import lombok.AllArgsConstructor;
Expand All @@ -33,7 +31,7 @@
*/
@Log4j2
@AllArgsConstructor
public class NormalizationProcessor implements SearchPhaseResultsProcessor {
public class NormalizationProcessor extends AbstractScoreHybridizationProcessor {
public static final String TYPE = "normalization-processor";

private final String tag;
Expand All @@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {
private final ScoreCombinationTechnique combinationTechnique;
private final NormalizationProcessorWorkflow normalizationWorkflow;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
<Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
import java.util.Objects;
import java.util.Optional;

import com.google.common.annotations.VisibleForTesting;
import lombok.Getter;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.query.QuerySearchResult;

import org.opensearch.action.search.QueryPhaseResultConsumer;
Expand All @@ -25,7 +26,6 @@
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
Expand All @@ -39,7 +39,7 @@
*/
@Log4j2
@AllArgsConstructor
public class RRFProcessor implements SearchPhaseResultsProcessor {
public class RRFProcessor extends AbstractScoreHybridizationProcessor {
public static final String TYPE = "score-ranker-processor";

@Getter
Expand All @@ -57,25 +57,28 @@ public class RRFProcessor implements SearchPhaseResultsProcessor {
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
<Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
) {
if (shouldSkipProcessor(searchPhaseResult)) {
log.debug("Query results are not compatible with RRF processor");
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);

boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain())
&& searchPhaseContext.getRequest().source().explain();
// make data transfer object to pass in, execute will get object with 4 or 5 fields, depending
// on coming from NormalizationProcessor or RRFProcessor
NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(false)
.explain(explain)
.pipelineProcessingContext(requestContextOptional.orElse(null))
.build();
normalizationWorkflow.execute(normalizationExecuteDTO);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,39 @@

import lombok.ToString;
import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import java.util.Map;
import java.util.List;
import java.util.Objects;

import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique;

@Log4j2
/**
* Abstracts combination of scores based on reciprocal rank fusion algorithm
*/
@ToString(onlyExplicitlyIncluded = true)
public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique {
public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";

// Not currently using weights for RRF, no need to modify or verify these params
public RRFScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those parameters were never used

public RRFScoreCombinationTechnique() {}

@Override
public float combine(final float[] scores) {
if (Objects.isNull(scores)) {
throw new IllegalArgumentException("scores array cannot be null");
}
float sumScores = 0.0f;
for (float score : scores) {
sumScores += score;
}
return sumScores;
}

@Override
public String describe() {
return describeCombinationTechnique(TECHNIQUE_NAME, List.of());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class ScoreCombinationFactory {
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil),
RRFScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new RRFScoreCombinationTechnique(params, scoreCombinationUtil)
params -> new RRFScoreCombinationTechnique()
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,34 @@

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Locale;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.stream.IntStream;

import org.apache.commons.lang3.Range;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import lombok.ToString;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;

import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization;

/**
* Abstracts calculation of rank scores for each document returned as part of
* reciprocal rank fusion. Rank scores are summed across subqueries in combination classes.
*/
@ToString(onlyExplicitlyIncluded = true)
public class RRFNormalizationTechnique implements ScoreNormalizationTechnique {
public class RRFNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";
public static final int DEFAULT_RANK_CONSTANT = 60;
Expand Down Expand Up @@ -58,21 +65,49 @@ public RRFNormalizationTechnique(final Map<String, Object> params, final ScoreNo
public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
final List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (TopDocs topDocs : topDocsPerSubQuery) {
int docsCountPerSubQuery = topDocs.scoreDocs.length;
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
for (int j = 0; j < docsCountPerSubQuery; j++) {
// using big decimal approach to minimize error caused by floating point ops
// score = 1.f / (float) (rankConstant + j + 1))
scoreDocs[j].score = BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + j + 1), 10, RoundingMode.HALF_UP)
.floatValue();
}
}
processTopDocs(compoundQueryTopDocs, (docId, score) -> {});
}
}

@Override
public String describe() {
return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant);
}

@Override
public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();

for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
processTopDocs(
compoundQueryTopDocs,
(docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score)
);
}

return getDocIdAtQueryForNormalization(normalizedScores, this);
}

private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer<DocIdAtSearchShard, Float> scoreProcessor) {
if (Objects.isNull(compoundQueryTopDocs)) {
return;
}

compoundQueryTopDocs.getTopDocs().forEach(topDocs -> {
IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> {
float normalizedScore = calculateNormalizedScore(position);
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(
topDocs.scoreDocs[position].doc,
compoundQueryTopDocs.getSearchShard()
);
scoreProcessor.accept(docIdAtSearchShard, normalizedScore);
topDocs.scoreDocs[position].score = normalizedScore;
});
});
}

private float calculateNormalizedScore(int position) {
return BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + position + 1), 10, RoundingMode.HALF_UP).floatValue();
}

private int getRankConstant(final Map<String, Object> params) {
Expand All @@ -96,7 +131,7 @@ private void validateRankConstant(final int rankConstant) {
}
}

public static int getParamAsInteger(final Map<String, Object> parameters, final String fieldName) {
private static int getParamAsInteger(final Map<String, Object> parameters, final String fieldName) {
try {
return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName)));
} catch (NumberFormatException e) {
Expand Down
Loading
Loading