From c27fa94a1912769077631f1978b5eba127b9bc78 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Fri, 3 Jan 2025 15:47:00 -0800 Subject: [PATCH] Fixed document source and score field mismatch in sorted hybrid queries (#1043) (#1057) * Fixed mismatch between document source and score fields when sorting is enabled in hybrid query (cherry picked from commit 030e3f4d1712f2f085123546a1261be508fa35d7) Signed-off-by: Martin Gaievski Co-authored-by: Martin Gaievski --- CHANGELOG.md | 5 +- .../HybridTopFieldDocSortCollector.java | 30 ++++--- .../HybridTopFieldDocSortCollectorTests.java | 81 ++++++++++++++++++- .../HybridTopScoreDocCollectorTests.java | 4 +- 4 files changed, 101 insertions(+), 19 deletions(-) rename src/test/java/org/opensearch/neuralsearch/search/{ => collector}/HybridTopFieldDocSortCollectorTests.java (76%) rename src/test/java/org/opensearch/neuralsearch/search/{ => collector}/HybridTopScoreDocCollectorTests.java (99%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6462b9a4f..d10dae25d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,8 +21,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013)) - Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041)) ### Bug Fixes -- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) -- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040)) +- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) +- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040)) +- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java index 2e268d37b..60a82ee33 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java @@ -9,6 +9,9 @@ import java.util.Objects; import java.util.Locale; import java.util.ArrayList; + +import com.google.common.annotations.VisibleForTesting; +import lombok.AccessLevel; import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -44,8 +47,10 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl @Nullable private FieldDoc after; private FieldComparator firstComparator; - // bottom would be set to null per shard. - private FieldValueHitQueue.Entry bottom; + // the array stores bottom elements of the min heap of sorted hits for each sub query + @Getter(AccessLevel.PACKAGE) + @VisibleForTesting + private FieldValueHitQueue.Entry fieldValueLeafTrackers[]; @Getter private int totalHits; protected int docBase; @@ -65,6 +70,7 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl @Getter protected float maxScore = 0.0f; protected int[] collectedHits; + private boolean needsInitialization = true; // searchSortPartOfIndexSort is used to evaluate whether to perform index sort or not. private Boolean searchSortPartOfIndexSort = null; @@ -203,7 +209,7 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float comparators[subQueryNumber].copy(slot, doc); add(slot, doc, compoundScores[subQueryNumber], subQueryNumber, score); if (queueFull[subQueryNumber]) { - comparators[subQueryNumber].setBottom(bottom.slot); + comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot); } } else { queueFull[subQueryNumber] = true; @@ -216,9 +222,9 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float protected void collectCompetitiveHit(int doc, int subQueryNumber) throws IOException { // This hit is competitive - replace bottom element in queue & adjustTop if (numHits > 0) { - comparators[subQueryNumber].copy(bottom.slot, doc); - updateBottom(doc, compoundScores[subQueryNumber]); - comparators[subQueryNumber].setBottom(bottom.slot); + comparators[subQueryNumber].copy(fieldValueLeafTrackers[subQueryNumber].slot, doc); + updateBottom(doc, compoundScores[subQueryNumber], subQueryNumber); + comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot); } } @@ -245,7 +251,7 @@ protected boolean thresholdCheck(int doc, int subQueryNumber) throws IOException The method initializes once per search request. */ protected void initializePriorityQueuesWithComparators(LeafReaderContext context, int numberOfSubQueries) throws IOException { - if (compoundScores == null) { + if (needsInitialization) { compoundScores = new FieldValueHitQueue[numberOfSubQueries]; comparators = new LeafFieldComparator[numberOfSubQueries]; queueFull = new boolean[numberOfSubQueries]; @@ -253,6 +259,8 @@ protected void initializePriorityQueuesWithComparators(LeafReaderContext context for (int i = 0; i < numberOfSubQueries; i++) { initializeLeafFieldComparators(context, i); } + fieldValueLeafTrackers = new FieldValueHitQueue.Entry[numberOfSubQueries]; + needsInitialization = false; } if (initializeLeafComparatorsPerSegmentOnce) { for (int i = 0; i < numberOfSubQueries; i++) { @@ -369,7 +377,7 @@ private void populateResults(ScoreDoc[] results, int howMany, PriorityQueue compoundScore, int subQueryNumber, float score) { FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, docBase + doc); bottomEntry.score = score; - bottom = compoundScore.add(bottomEntry); + fieldValueLeafTrackers[subQueryNumber] = compoundScore.add(bottomEntry); // The queue is full either when totalHits == numHits (in SimpleFieldCollector), in which case // slot = totalHits - 1, or when hitsCollected == numHits (in PagingFieldCollector this is hits // on the current page) and slot = hitsCollected - 1. @@ -381,9 +389,9 @@ private void add(int slot, int doc, FieldValueHitQueue queueFull[subQueryNumber] = isQueueFull; } - private void updateBottom(int doc, FieldValueHitQueue compoundScore) { - bottom.doc = docBase + doc; - bottom = compoundScore.updateTop(); + private void updateBottom(int doc, FieldValueHitQueue compoundScore, int subQueryIndex) { + fieldValueLeafTrackers[subQueryIndex].doc = docBase + doc; + fieldValueLeafTrackers[subQueryIndex] = compoundScore.updateTop(); } private boolean canEarlyTerminate(Sort searchSort, Sort indexSort) { diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopFieldDocSortCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollectorTests.java similarity index 76% rename from src/test/java/org/opensearch/neuralsearch/search/HybridTopFieldDocSortCollectorTests.java rename to src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollectorTests.java index 3bb0e6bcd..02fbb2673 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopFieldDocSortCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollectorTests.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.search; +package org.opensearch.neuralsearch.search.collector; import java.util.ArrayList; import java.util.Arrays; @@ -24,8 +24,10 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.FieldValueHitQueue; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopFieldDocs; @@ -35,14 +37,14 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; + +import org.opensearch.common.util.io.IOUtils; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.query.HybridQueryScorer; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector; -import org.opensearch.neuralsearch.search.collector.PagingFieldCollector; -import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; public class HybridTopFieldDocSortCollectorTests extends OpenSearchQueryTestCase { static final String TEXT_FIELD_NAME = "field"; @@ -127,8 +129,13 @@ public void testSimpleFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSucce DocIdSetIterator iterator = hybridQueryScorer.iterator(); int doc = iterator.nextDoc(); + assertNull(hybridTopFieldDocSortCollector.getFieldValueLeafTrackers()); while (doc != DocIdSetIterator.NO_MORE_DOCS) { leafCollector.collect(doc); + FieldValueHitQueue.Entry[] fieldValueLeafTrackers = hybridTopFieldDocSortCollector.getFieldValueLeafTrackers(); + assertNotNull(fieldValueLeafTrackers); + assertEquals(1, fieldValueLeafTrackers.length); + assertEquals(doc, fieldValueLeafTrackers[0].doc); doc = iterator.nextDoc(); } @@ -243,4 +250,70 @@ public void testPagingFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSucce reader.close(); directory.close(); } + + public void testMultipleSubQueriesFieldValueLeafTrackers() throws Exception { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + + // Create test documents + List documents = new ArrayList<>(); + Document document1 = new Document(); + document1.add(new NumericDocValuesField("_id", 0)); // Use 0-based doc IDs + document1.add(new IntField(INT_FIELD_NAME, 100, Field.Store.YES)); + document1.add(new TextField(TEXT_FIELD_NAME, FIELD_1_VALUE, Field.Store.YES)); + documents.add(document1); + + Document document2 = new Document(); + document2.add(new NumericDocValuesField("_id", 1)); // Use 0-based doc IDs + document2.add(new IntField(INT_FIELD_NAME, 200, Field.Store.YES)); + document2.add(new TextField(TEXT_FIELD_NAME, FIELD_2_VALUE, Field.Store.YES)); + documents.add(document2); + + w.addDocuments(documents); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + // Setup collector + SortField sortField = new SortField(DOC_FIELD_NAME, SortField.Type.DOC); + HybridTopFieldDocSortCollector collector = new SimpleFieldCollector( + NUM_DOCS, + new HitsThresholdChecker(TOTAL_HITS_UP_TO), + new Sort(sortField) + ); + + Weight weight = mock(Weight.class); + collector.setWeight(weight); + LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + + // Create scorers with proper doc IDs (0-based) + int[] docIdsForQuery = new int[] { 0, 1 }; // Use 0-based doc IDs + final List scores = Arrays.asList(1.0f, 2.0f); // Fixed scores for predictability + + // Create two scorers for two sub-queries + Scorer scorer1 = scorer(docIdsForQuery, scores, fakeWeight(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + Scorer scorer2 = scorer(docIdsForQuery, scores, fakeWeight(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(weight, Arrays.asList(scorer1, scorer2)); + + leafCollector.setScorer(hybridQueryScorer); + + // Collect docs + DocIdSetIterator iterator = hybridQueryScorer.iterator(); + int doc = iterator.nextDoc(); + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + leafCollector.collect(doc); + FieldValueHitQueue.Entry[] fieldValueLeafTrackers = collector.getFieldValueLeafTrackers(); + assertNotNull(fieldValueLeafTrackers); + assertEquals(2, fieldValueLeafTrackers.length); + doc = iterator.nextDoc(); + } + + IOUtils.close(reader, w, directory); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollectorTests.java similarity index 99% rename from src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java rename to src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollectorTests.java index 1fb66d5b7..4de32f6f2 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollectorTests.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.search; +package org.opensearch.neuralsearch.search.collector; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.LeafCollector; @@ -46,7 +46,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import lombok.SneakyThrows; -import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase {