diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 664a4de3e9..06a239033c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -24,6 +24,7 @@ import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; @@ -485,7 +486,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio } @Override - protected Query doToQuery(QueryShardContext context) { + protected Query doToQuery(QueryShardContext context) throws IOException { MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName); if (mappedFieldType == null && ignoreUnmapped) { @@ -600,6 +601,11 @@ protected Query doToQuery(QueryShardContext context) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support filters", knnEngine)); } + // rewrite filter query if it exists + if (Objects.nonNull(filter)) { + filter = filter.rewrite(context); + } + String indexName = context.index().getName(); if (k != 0) { diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index b287396558..070ea0fbb3 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; import org.apache.lucene.search.FloatVectorSimilarityQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; @@ -485,6 +486,7 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + @SneakyThrows public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -518,6 +520,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th ); } + @SneakyThrows public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -540,6 +543,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS assertTrue(query.toString().contains("resultSimilarity=" + 0.5f)); } + @SneakyThrows public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; @@ -602,6 +606,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + @SneakyThrows public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; @@ -655,6 +660,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + @SneakyThrows public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; @@ -774,6 +780,7 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception { assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } + @SneakyThrows public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -802,6 +809,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); } + @SneakyThrows public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -828,6 +836,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); } + @SneakyThrows public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { // Given float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -904,6 +913,7 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + @SneakyThrows public void testDoToQuery_FromModel() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); @@ -938,6 +948,7 @@ public void testDoToQuery_FromModel() { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + @SneakyThrows public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -979,6 +990,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + @SneakyThrows public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -1233,6 +1245,7 @@ public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowExcepti expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + @SneakyThrows public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() { KNNMethodContext knnMethodContext = new KNNMethodContext( KNNEngine.FAISS,