Skip to content

Commit

Permalink
Removing redundant type conversions for script scoring for hamming sp…
Browse files Browse the repository at this point in the history
…ace with binary vectors

Signed-off-by: Bansi Kasundra <[email protected]>
  • Loading branch information
kasundra07 committed Dec 23, 2024
1 parent dc369e6 commit f58ac9a
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ public float[] getValue() {
}
}

public byte[] getByteValue() {
if (!docExists) {
String errorMessage = String.format(
"One of the document doesn't have a value for field '%s'. "
+ "This can be avoided by checking if a document has a value for the field or not "
+ "by doc['%s'].size() == 0 ? 0 : {your script}",
fieldName,
fieldName
);
throw new IllegalStateException(errorMessage);
}
try {
return doGetByteValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected byte[] doGetByteValue() throws IOException {
throw new UnsupportedOperationException();
}

protected abstract float[] doGetValue() throws IOException;

@Override
Expand Down Expand Up @@ -111,6 +133,15 @@ protected float[] doGetValue() throws IOException {
}
return value;
}

@Override
public byte[] doGetByteValue() {
try {
return values.vectorValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
Expand Down Expand Up @@ -139,6 +170,15 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}

@Override
public byte[] doGetByteValue() {
try {
return values.binaryValue().bytes;
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

/**
Expand Down
35 changes: 35 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,39 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) {
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getValue());
}
}

/**
* KNNVectors with byte[] type. The query value passed in is expected to be byte[]. The fieldType of the docs
* being searched over are expected to be KNNVector type.
*/
public static class KNNByteVectorType extends KNNScoreScript<byte[]> {

public KNNByteVectorType(
Map<String, Object> params,
byte[] queryValue,
String field,
BiFunction<byte[], byte[], Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext,
IndexSearcher searcher
) throws IOException {
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
* This function called for each doc in the segment. We evaluate the score of the vector in the doc
*
* @param explanationHolder A helper to take in an explanation from a script and turn
* it into an {@link org.apache.lucene.search.Explanation}
* @return score of the vector to the query vector
*/
@Override
public double execute(ScoreScript.ExplanationHolder explanationHolder) {
KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field);
if (scriptDocValues.isEmpty()) {
return 0.0;
}
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getByteValue());
}
}
}
50 changes: 37 additions & 13 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@

import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryVectorDataType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isLongFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToBigInteger;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToByteArray;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToFloatArray;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong;

public interface KNNScoringSpace {

/**
* Return the correct scoring script for a given query. The scoring script
*
Expand Down Expand Up @@ -181,25 +184,46 @@ protected BiFunction<float[], float[], Float> getScoringMethod(final float[] pro
}
}

class Hamming extends KNNFieldSpace {
private static final Set<VectorDataType> DATA_TYPES_HAMMING = Set.of(VectorDataType.BINARY);
class Hamming implements KNNScoringSpace {
private byte[] processedQuery;
BiFunction<byte[], byte[], Float> scoringMethod;

public Hamming(Object query, MappedFieldType fieldType) {
super(query, fieldType, "hamming", DATA_TYPES_HAMMING);
if (!isKNNVectorFieldType(fieldType)) {
throw new IllegalArgumentException("Incompatible field_type for hamming space. The field type must be knn_vector.");
}
KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) fieldType;
if (!isBinaryVectorDataType(knnVectorFieldType)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Incompatible field_type for hamming space. The data type should be [BINARY] but got %s",
knnVectorFieldType.getVectorDataType()
)
);
}

this.processedQuery = parseToByteArray(
query,
KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType),
knnVectorFieldType.getVectorDataType()
);
this.scoringMethod = getHammingScoringMethod();
}

@Override
protected BiFunction<float[], float[], Float> getScoringMethod(final float[] processedQuery) {
// TODO we want to avoid converting back and forth between byte and float
return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(toByte(q), toByte(v)));
public BiFunction<byte[], byte[], Float> getHammingScoringMethod() {
return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(q, v));
}

private byte[] toByte(final float[] vector) {
byte[] bytes = new byte[vector.length];
for (int i = 0; i < vector.length; i++) {
bytes[i] = (byte) vector[i];
}
return bytes;
@Override
public ScoreScript getScoreScript(
Map<String, Object> params,
String field,
SearchLookup lookup,
LeafReaderContext ctx,
IndexSearcher searcher
) throws IOException {
return new KNNScoreScript.KNNByteVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,24 @@ public static float[] parseToFloatArray(Object object, int expectedVectorLength,
return floatArray;
}

/**
* Convert an Object to a byte array.
*
* @param object Object to be converted to a byte array
* @param expectedVectorLength int representing the expected vector length of this array.
* @return byte[] of the object
*/
public static byte[] parseToByteArray(Object object, int expectedVectorLength, VectorDataType vectorDataType) {
byte[] byteArray = convertVectorToByteArray(object, vectorDataType);
if (expectedVectorLength != byteArray.length) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalStateException(
"Object's length=" + byteArray.length + " does not match the " + "expected length=" + expectedVectorLength + "."
);
}
return byteArray;
}

/**
* Converts Object vector to primitive float[]
*
Expand All @@ -134,6 +152,29 @@ public static float[] convertVectorToPrimitive(Object vector, VectorDataType vec
return primitiveVector;
}

/**
* Converts Object vector to byte[]
*
* @param vector input vector
* @return Byte array representing the vector
*/
@SuppressWarnings("unchecked")
public static byte[] convertVectorToByteArray(Object vector, VectorDataType vectorDataType) {
byte[] byteVector = null;
if (vector != null) {
final List<Number> tmp = (List<Number>) vector;
byteVector = new byte[tmp.size()];
for (int i = 0; i < byteVector.length; i++) {
float value = tmp.get(i).floatValue();
if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) {
validateByteVectorValue(value, vectorDataType);
}
byteVector[i] = tmp.get(i).byteValue();
}
}
return byteVector;
}

/**
* Calculates the magnitude of given vector
*
Expand Down
Loading

0 comments on commit f58ac9a

Please sign in to comment.