Skip to content

Commit

Permalink
Fixed document source and score field mismatch in sorted hybrid queri…
Browse files Browse the repository at this point in the history
…es (#1043) (#1057)

* Fixed mismatch between document source and score fields when sorting is enabled in hybrid query


(cherry picked from commit 030e3f4)

Signed-off-by: Martin Gaievski <[email protected]>
Co-authored-by: Martin Gaievski <[email protected]>
  • Loading branch information
1 parent 51ddbb2 commit c27fa94
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 19 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import java.util.Objects;
import java.util.Locale;
import java.util.ArrayList;

import com.google.common.annotations.VisibleForTesting;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -44,8 +47,10 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl
@Nullable
private FieldDoc after;
private FieldComparator<?> firstComparator;
// bottom would be set to null per shard.
private FieldValueHitQueue.Entry bottom;
// the array stores bottom elements of the min heap of sorted hits for each sub query
@Getter(AccessLevel.PACKAGE)
@VisibleForTesting
private FieldValueHitQueue.Entry fieldValueLeafTrackers[];
@Getter
private int totalHits;
protected int docBase;
Expand All @@ -65,6 +70,7 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl
@Getter
protected float maxScore = 0.0f;
protected int[] collectedHits;
private boolean needsInitialization = true;

// searchSortPartOfIndexSort is used to evaluate whether to perform index sort or not.
private Boolean searchSortPartOfIndexSort = null;
Expand Down Expand Up @@ -203,7 +209,7 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float
comparators[subQueryNumber].copy(slot, doc);
add(slot, doc, compoundScores[subQueryNumber], subQueryNumber, score);
if (queueFull[subQueryNumber]) {
comparators[subQueryNumber].setBottom(bottom.slot);
comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot);
}
} else {
queueFull[subQueryNumber] = true;
Expand All @@ -216,9 +222,9 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float
protected void collectCompetitiveHit(int doc, int subQueryNumber) throws IOException {
// This hit is competitive - replace bottom element in queue & adjustTop
if (numHits > 0) {
comparators[subQueryNumber].copy(bottom.slot, doc);
updateBottom(doc, compoundScores[subQueryNumber]);
comparators[subQueryNumber].setBottom(bottom.slot);
comparators[subQueryNumber].copy(fieldValueLeafTrackers[subQueryNumber].slot, doc);
updateBottom(doc, compoundScores[subQueryNumber], subQueryNumber);
comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot);
}
}

Expand All @@ -245,14 +251,16 @@ protected boolean thresholdCheck(int doc, int subQueryNumber) throws IOException
The method initializes once per search request.
*/
protected void initializePriorityQueuesWithComparators(LeafReaderContext context, int numberOfSubQueries) throws IOException {
if (compoundScores == null) {
if (needsInitialization) {
compoundScores = new FieldValueHitQueue[numberOfSubQueries];
comparators = new LeafFieldComparator[numberOfSubQueries];
queueFull = new boolean[numberOfSubQueries];
collectedHits = new int[numberOfSubQueries];
for (int i = 0; i < numberOfSubQueries; i++) {
initializeLeafFieldComparators(context, i);
}
fieldValueLeafTrackers = new FieldValueHitQueue.Entry[numberOfSubQueries];
needsInitialization = false;
}
if (initializeLeafComparatorsPerSegmentOnce) {
for (int i = 0; i < numberOfSubQueries; i++) {
Expand Down Expand Up @@ -369,7 +377,7 @@ private void populateResults(ScoreDoc[] results, int howMany, PriorityQueue<Fiel
private void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryNumber, float score) {
FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, docBase + doc);
bottomEntry.score = score;
bottom = compoundScore.add(bottomEntry);
fieldValueLeafTrackers[subQueryNumber] = compoundScore.add(bottomEntry);
// The queue is full either when totalHits == numHits (in SimpleFieldCollector), in which case
// slot = totalHits - 1, or when hitsCollected == numHits (in PagingFieldCollector this is hits
// on the current page) and slot = hitsCollected - 1.
Expand All @@ -381,9 +389,9 @@ private void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry>
queueFull[subQueryNumber] = isQueueFull;
}

private void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore) {
bottom.doc = docBase + doc;
bottom = compoundScore.updateTop();
private void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryIndex) {
fieldValueLeafTrackers[subQueryIndex].doc = docBase + doc;
fieldValueLeafTrackers[subQueryIndex] = compoundScore.updateTop();
}

private boolean canEarlyTerminate(Sort searchSort, Sort indexSort) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;
package org.opensearch.neuralsearch.search.collector;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -24,8 +24,10 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopFieldDocs;
Expand All @@ -35,14 +37,14 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import org.opensearch.common.util.io.IOUtils;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.query.HybridQueryScorer;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector;
import org.opensearch.neuralsearch.search.collector.PagingFieldCollector;
import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

public class HybridTopFieldDocSortCollectorTests extends OpenSearchQueryTestCase {
static final String TEXT_FIELD_NAME = "field";
Expand Down Expand Up @@ -127,8 +129,13 @@ public void testSimpleFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSucce
DocIdSetIterator iterator = hybridQueryScorer.iterator();

int doc = iterator.nextDoc();
assertNull(hybridTopFieldDocSortCollector.getFieldValueLeafTrackers());
while (doc != DocIdSetIterator.NO_MORE_DOCS) {
leafCollector.collect(doc);
FieldValueHitQueue.Entry[] fieldValueLeafTrackers = hybridTopFieldDocSortCollector.getFieldValueLeafTrackers();
assertNotNull(fieldValueLeafTrackers);
assertEquals(1, fieldValueLeafTrackers.length);
assertEquals(doc, fieldValueLeafTrackers[0].doc);
doc = iterator.nextDoc();
}

Expand Down Expand Up @@ -243,4 +250,70 @@ public void testPagingFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSucce
reader.close();
directory.close();
}

public void testMultipleSubQueriesFieldValueLeafTrackers() throws Exception {
final Directory directory = newDirectory();
final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));

// Create test documents
List<Document> documents = new ArrayList<>();
Document document1 = new Document();
document1.add(new NumericDocValuesField("_id", 0)); // Use 0-based doc IDs
document1.add(new IntField(INT_FIELD_NAME, 100, Field.Store.YES));
document1.add(new TextField(TEXT_FIELD_NAME, FIELD_1_VALUE, Field.Store.YES));
documents.add(document1);

Document document2 = new Document();
document2.add(new NumericDocValuesField("_id", 1)); // Use 0-based doc IDs
document2.add(new IntField(INT_FIELD_NAME, 200, Field.Store.YES));
document2.add(new TextField(TEXT_FIELD_NAME, FIELD_2_VALUE, Field.Store.YES));
documents.add(document2);

w.addDocuments(documents);
w.commit();

DirectoryReader reader = DirectoryReader.open(w);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);

QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

// Setup collector
SortField sortField = new SortField(DOC_FIELD_NAME, SortField.Type.DOC);
HybridTopFieldDocSortCollector collector = new SimpleFieldCollector(
NUM_DOCS,
new HitsThresholdChecker(TOTAL_HITS_UP_TO),
new Sort(sortField)
);

Weight weight = mock(Weight.class);
collector.setWeight(weight);
LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext);

// Create scorers with proper doc IDs (0-based)
int[] docIdsForQuery = new int[] { 0, 1 }; // Use 0-based doc IDs
final List<Float> scores = Arrays.asList(1.0f, 2.0f); // Fixed scores for predictability

// Create two scorers for two sub-queries
Scorer scorer1 = scorer(docIdsForQuery, scores, fakeWeight(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)));
Scorer scorer2 = scorer(docIdsForQuery, scores, fakeWeight(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)));

HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(weight, Arrays.asList(scorer1, scorer2));

leafCollector.setScorer(hybridQueryScorer);

// Collect docs
DocIdSetIterator iterator = hybridQueryScorer.iterator();
int doc = iterator.nextDoc();
while (doc != DocIdSetIterator.NO_MORE_DOCS) {
leafCollector.collect(doc);
FieldValueHitQueue.Entry[] fieldValueLeafTrackers = collector.getFieldValueLeafTrackers();
assertNotNull(fieldValueLeafTrackers);
assertEquals(2, fieldValueLeafTrackers.length);
doc = iterator.nextDoc();
}

IOUtils.close(reader, w, directory);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;
package org.opensearch.neuralsearch.search.collector;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafCollector;
Expand Down Expand Up @@ -46,7 +46,7 @@
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;

import lombok.SneakyThrows;
import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase {

Expand Down

0 comments on commit c27fa94

Please sign in to comment.