Skip to content

Commit

Permalink
Move rewrite logic for filter to doRewrite method
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 24, 2024
1 parent b670b11 commit 7770ae3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
* Return empty results for non-existent filter fields [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
### Infrastructure
### Documentation
### Maintenance
Expand Down
17 changes: 11 additions & 6 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
Expand Down Expand Up @@ -485,7 +486,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
}

@Override
protected Query doToQuery(QueryShardContext context) throws IOException {
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);

if (mappedFieldType == null && ignoreUnmapped) {
Expand Down Expand Up @@ -600,11 +601,6 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support filters", knnEngine));
}

// rewrite filter query if it exists to avoid runtime errors in next steps of query phase
if (Objects.nonNull(filter)) {
filter = filter.rewrite(context);
}

String indexName = context.index().getName();

if (k != 0) {
Expand Down Expand Up @@ -715,4 +711,13 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException {
// rewrite filter query if it exists to avoid runtime errors in next steps of query phase
if (Objects.nonNull(filter)) {
filter = filter.rewrite(queryShardContext);
}
return super.doRewrite(queryShardContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
Expand Down Expand Up @@ -70,6 +71,8 @@ public class KNNQueryBuilderTests extends KNNTestCase {
private static final Float MIN_SCORE = 0.5f;
private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value");
private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
protected static final String TEXT_FIELD_NAME = "some_field";
protected static final String TEXT_VALUE = "some_value";

public void testInvalidK() {
float[] queryVector = { 1.0f, 1.0f };
Expand Down Expand Up @@ -1306,4 +1309,19 @@ public void testDoToQuery_whenBinaryWithInvalidDimension_thenException() throws
Exception ex = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
assertTrue(ex.getMessage(), ex.getMessage().contains("invalid dimension"));
}

@SneakyThrows
public void testDoRewrite_whenNoFilter_thenSuccessful() {
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K);
QueryBuilder rewritten = knnQueryBuilder.doRewrite(mock(QueryRewriteContext.class));
assertEquals(knnQueryBuilder, rewritten);
}

@SneakyThrows
public void testDoRewrite_whenFilterSet_thenSuccessful() {
QueryBuilder filter = QueryBuilders.termQuery(TEXT_FIELD_NAME, TEXT_VALUE);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, filter);
QueryBuilder rewritten = knnQueryBuilder.doRewrite(mock(QueryRewriteContext.class));
assertEquals(knnQueryBuilder, rewritten);
}
}

0 comments on commit 7770ae3

Please sign in to comment.