Skip to content

Commit

Permalink
[Enhancement] Implement pruning for neural sparse search (#988)
Browse files Browse the repository at this point in the history
* add impl

Signed-off-by: zhichao-aws <[email protected]>

* add UT

Signed-off-by: zhichao-aws <[email protected]>

* rename pruneType; UT

Signed-off-by: zhichao-aws <[email protected]>

* changelog

Signed-off-by: zhichao-aws <[email protected]>

* ut

Signed-off-by: zhichao-aws <[email protected]>

* add it

Signed-off-by: zhichao-aws <[email protected]>

* change on 2-phase

Signed-off-by: zhichao-aws <[email protected]>

* UT

Signed-off-by: zhichao-aws <[email protected]>

* it

Signed-off-by: zhichao-aws <[email protected]>

* rename

Signed-off-by: zhichao-aws <[email protected]>

* enhance: more detailed error message

Signed-off-by: zhichao-aws <[email protected]>

* refactor to prune and split

Signed-off-by: zhichao-aws <[email protected]>

* changelog

Signed-off-by: zhichao-aws <[email protected]>

* fix UT cov

Signed-off-by: zhichao-aws <[email protected]>

* address review comments

Signed-off-by: zhichao-aws <[email protected]>

* enlarge score diff range

Signed-off-by: zhichao-aws <[email protected]>

* address comments: check lowScores non null instead of flag

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
(cherry picked from commit e8fe284)
  • Loading branch information
zhichao-aws committed Dec 18, 2024
1 parent 6481b60 commit dde124f
Show file tree
Hide file tree
Showing 18 changed files with 1,197 additions and 140 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,22 @@
import lombok.Getter;
import lombok.Setter;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.collect.Tuple;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.search.rescore.RescorerBuilder;

import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* A SearchRequestProcessor to generate two-phase NeuralSparseQueryBuilder,
Expand All @@ -37,41 +36,37 @@ public class NeuralSparseTwoPhaseProcessor extends AbstractProcessor implements

public static final String TYPE = "neural_sparse_two_phase_processor";
private boolean enabled;
private float ratio;
private float pruneRatio;
private PruneType pruneType;
private float windowExpansion;
private int maxWindowSize;
private static final String PARAMETER_KEY = "two_phase_parameter";
private static final String RATIO_KEY = "prune_ratio";
private static final String ENABLE_KEY = "enabled";
private static final String EXPANSION_KEY = "expansion_rate";
private static final String MAX_WINDOW_SIZE_KEY = "max_window_size";
private static final boolean DEFAULT_ENABLED = true;
private static final float DEFAULT_RATIO = 0.4f;
private static final PruneType DEFAULT_PRUNE_TYPE = PruneType.MAX_RATIO;
private static final float DEFAULT_WINDOW_EXPANSION = 5.0f;
private static final int DEFAULT_MAX_WINDOW_SIZE = 10000;
private static final int DEFAULT_BASE_QUERY_SIZE = 10;
private static final int MAX_WINDOWS_SIZE_LOWER_BOUND = 50;
private static final float WINDOW_EXPANSION_LOWER_BOUND = 1.0f;
private static final float RATIO_LOWER_BOUND = 0f;
private static final float RATIO_UPPER_BOUND = 1f;

protected NeuralSparseTwoPhaseProcessor(
String tag,
String description,
boolean ignoreFailure,
boolean enabled,
float ratio,
float pruneRatio,
PruneType pruneType,
float windowExpansion,
int maxWindowSize
) {
super(tag, description, ignoreFailure);
this.enabled = enabled;
if (ratio < RATIO_LOWER_BOUND || ratio > RATIO_UPPER_BOUND) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "The two_phase_parameter.prune_ratio must be within [0, 1]. Received: %f", ratio)
);
}
this.ratio = ratio;
this.pruneRatio = pruneRatio;
this.pruneType = pruneType;
if (windowExpansion < WINDOW_EXPANSION_LOWER_BOUND) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "The two_phase_parameter.expansion_rate must >= 1.0. Received: %f", windowExpansion)
Expand All @@ -93,7 +88,7 @@ protected NeuralSparseTwoPhaseProcessor(
*/
@Override
public SearchRequest processRequest(final SearchRequest request) {
if (!enabled || ratio == 0f) {
if (!enabled || pruneRatio == 0f) {
return request;
}
QueryBuilder queryBuilder = request.source().query();
Expand All @@ -117,43 +112,6 @@ public String getType() {
return TYPE;
}

/**
* Based on ratio, split a Map into two map by the value.
*
* @param queryTokens the queryTokens map, key is the token String, value is the score.
* @param thresholdRatio The ratio that control how tokens map be split.
* @return A tuple has two element, { token map whose value above threshold, token map whose value below threshold }
*/
public static Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokensByRatioedMaxScoreAsThreshold(
final Map<String, Float> queryTokens,
final float thresholdRatio
) {
if (Objects.isNull(queryTokens)) {
throw new IllegalArgumentException("Query tokens cannot be null or empty.");
}
float max = 0f;
for (Float value : queryTokens.values()) {
max = Math.max(value, max);
}
float threshold = max * thresholdRatio;

Map<Boolean, Map<String, Float>> queryTokensByScore = queryTokens.entrySet()
.stream()
.collect(
Collectors.partitioningBy(entry -> entry.getValue() >= threshold, Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
);

Map<String, Float> highScoreTokens = queryTokensByScore.get(Boolean.TRUE);
Map<String, Float> lowScoreTokens = queryTokensByScore.get(Boolean.FALSE);
if (Objects.isNull(highScoreTokens)) {
highScoreTokens = Collections.emptyMap();
}
if (Objects.isNull(lowScoreTokens)) {
lowScoreTokens = Collections.emptyMap();
}
return Tuple.tuple(highScoreTokens, lowScoreTokens);
}

private QueryBuilder getNestedQueryBuilderFromNeuralSparseQueryBuilderMap(
final Multimap<NeuralSparseQueryBuilder, Float> queryBuilderFloatMap
) {
Expand Down Expand Up @@ -201,7 +159,10 @@ private Multimap<NeuralSparseQueryBuilder, Float> collectNeuralSparseQueryBuilde
* - Docs besides TopDocs: Score = HighScoreToken's score
* - Final TopDocs: Score = HighScoreToken's score + LowScoreToken's score
*/
NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(ratio);
NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(
pruneRatio,
pruneType
);
result.put(modifiedQueryBuilder, updatedBoost);
}
// We only support BoostQuery, BooleanQuery and NeuralSparseQuery now. For other compound query type which are not support now, will
Expand Down Expand Up @@ -248,16 +209,40 @@ public NeuralSparseTwoPhaseProcessor create(
boolean enabled = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, ENABLE_KEY, DEFAULT_ENABLED);
Map<String, Object> twoPhaseConfigMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, PARAMETER_KEY);

float ratio = DEFAULT_RATIO;
float pruneRatio = DEFAULT_RATIO;
float windowExpansion = DEFAULT_WINDOW_EXPANSION;
int maxWindowSize = DEFAULT_MAX_WINDOW_SIZE;
PruneType pruneType = DEFAULT_PRUNE_TYPE;
if (Objects.nonNull(twoPhaseConfigMap)) {
ratio = ((Number) twoPhaseConfigMap.getOrDefault(RATIO_KEY, ratio)).floatValue();
pruneRatio = ((Number) twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_RATIO_FIELD, pruneRatio)).floatValue();
windowExpansion = ((Number) twoPhaseConfigMap.getOrDefault(EXPANSION_KEY, windowExpansion)).floatValue();
maxWindowSize = ((Number) twoPhaseConfigMap.getOrDefault(MAX_WINDOW_SIZE_KEY, maxWindowSize)).intValue();
pruneType = PruneType.fromString(
twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_TYPE_FIELD, pruneType.getValue()).toString()
);
}
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Illegal prune_ratio %f for prune_type: %s. %s",
pruneRatio,
pruneType.getValue(),
PruneUtils.getValidPruneRatioDescription(pruneType)
)
);
}

return new NeuralSparseTwoPhaseProcessor(tag, description, ignoreFailure, enabled, ratio, windowExpansion, maxWindowSize);
return new NeuralSparseTwoPhaseProcessor(
tag,
description,
ignoreFailure,
enabled,
pruneRatio,
pruneType,
windowExpansion,
maxWindowSize
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import lombok.Getter;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.util.prune.PruneUtils;

/**
* This processor is used for user input data text sparse encoding processing, model_id can be used to indicate which model user use,
Expand All @@ -27,18 +30,26 @@ public final class SparseEncodingProcessor extends InferenceProcessor {

public static final String TYPE = "sparse_encoding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
@Getter
private final PruneType pruneType;
@Getter
private final float pruneRatio;

public SparseEncodingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
PruneType pruneType,
float pruneRatio,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
this.pruneType = pruneType;
this.pruneRatio = pruneRatio;
}

@Override
Expand All @@ -49,17 +60,23 @@ public void doExecute(
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
}

@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
this.modelId,
inferenceList,
ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException)
);
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
handler.accept(sparseVectors);
}, onException));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;
import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty;
import static org.opensearch.ingest.ConfigurationUtils.readDoubleProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD;
import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE;

import java.util.Locale;
import java.util.Map;

import org.opensearch.cluster.service.ClusterService;
Expand All @@ -19,6 +22,8 @@
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.util.prune.PruneUtils;
import org.opensearch.neuralsearch.util.prune.PruneType;

/**
* Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
Expand All @@ -40,7 +45,40 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
// if the field is miss, will return PruneType.None
PruneType pruneType = PruneType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD));
float pruneRatio = 0;
if (pruneType != PruneType.NONE) {
// if we have prune type, then prune ratio field must have value
// readDoubleProperty will throw exception if value is not present
pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue();
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Illegal prune_ratio %f for prune_type: %s. %s",
pruneRatio,
pruneType.getValue(),
PruneUtils.getValidPruneRatioDescription(pruneType)
)
);
}
} else if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) {
// if we don't have prune type, then prune ratio field must not have value
throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided");
}

return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService);
return new SparseEncodingProcessor(
tag,
description,
batchSize,
modelId,
fieldMap,
pruneType,
pruneRatio,
clientAccessor,
environment,
clusterService
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;

import static org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;

/**
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model
Expand Down Expand Up @@ -90,6 +90,7 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
// 2. If it's the sub query only build for two-phase, the value will be set to -1 * ratio of processor.
// Then in the DoToQuery, we can use this to determine which type are this queryBuilder.
private float twoPhasePruneRatio = 0F;
private PruneType twoPhasePruneType = PruneType.NONE;

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;

Expand Down Expand Up @@ -129,22 +130,23 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {

/**
* Copy this QueryBuilder for two phase rescorer, set the copy one's twoPhasePruneRatio to -1.
* @param ratio the parameter of the NeuralSparseTwoPhaseProcessor, control how to split the queryTokens to two phase.
* @param pruneRatio the parameter of the NeuralSparseTwoPhaseProcessor, control how to split the queryTokens to two phase.
* @return A copy NeuralSparseQueryBuilder for twoPhase, it will be added to the rescorer.
*/
public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float ratio) {
this.twoPhasePruneRatio(ratio);
public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float pruneRatio, PruneType pruneType) {
this.twoPhasePruneRatio(pruneRatio);
this.twoPhasePruneType(pruneType);
NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder().fieldName(this.fieldName)
.queryName(this.queryName)
.queryText(this.queryText)
.modelId(this.modelId)
.maxTokenScore(this.maxTokenScore)
.twoPhasePruneRatio(-1f * ratio);
.twoPhasePruneRatio(-1f * pruneRatio);
if (Objects.nonNull(this.queryTokensSupplier)) {
Map<String, Float> tokens = queryTokensSupplier.get();
// Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1,
// while those less than or equal to the threshold are stored in v2.
Tuple<Map<String, Float>, Map<String, Float>> splitTokens = splitQueryTokensByRatioedMaxScoreAsThreshold(tokens, ratio);
Tuple<Map<String, Float>, Map<String, Float>> splitTokens = PruneUtils.splitSparseVector(pruneType, pruneRatio, tokens);
this.queryTokensSupplier(() -> splitTokens.v1());
copy.queryTokensSupplier(() -> splitTokens.v2());
} else {
Expand Down Expand Up @@ -346,9 +348,10 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
ActionListener.wrap(mapResultList -> {
Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
if (Objects.nonNull(twoPhaseSharedQueryToken)) {
Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokens = splitQueryTokensByRatioedMaxScoreAsThreshold(
queryTokens,
twoPhasePruneRatio
Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokens = PruneUtils.splitSparseVector(
twoPhasePruneType,
twoPhasePruneRatio,
queryTokens
);
setOnce.set(splitQueryTokens.v1());
twoPhaseSharedQueryToken = splitQueryTokens.v2();
Expand Down
Loading

0 comments on commit dde124f

Please sign in to comment.