Skip to content

Commit

Permalink
Python: Adding Vector Search to the In Memory collection (#9574)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Adds vectorized search and text search to the In Memory connector.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Tao Chen <[email protected]>
  • Loading branch information
eavanvalkenburg and TaoChenOSU authored Nov 7, 2024
1 parent daafde4 commit 1db12c4
Show file tree
Hide file tree
Showing 12 changed files with 267 additions and 76 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/python-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ jobs:
- name: Install the project
run: uv sync --all-extras --dev
- uses: pre-commit/[email protected]
name: Run Pre-Commit Hooks
with:
extra_args: --config python/.pre-commit-config.yaml --all-files
- name: Run Mypy
run: uv run mypy -p semantic_kernel --config-file mypy.ini
- name: Minimize uv cache
run: uv cache prune --ci
11 changes: 1 addition & 10 deletions python/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,10 @@ repos:
- id: ruff-format
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.4.29
rev: 0.4.30
hooks:
# Update the uv lockfile
- id: uv-lock
- repo: local
hooks:
- id: mypy
files: ^python/semantic_kernel/
name: mypy
entry: bash -c 'cd python && uv run mypy -p semantic_kernel --config-file mypy.ini'
language: system
types: [python]
pass_filenames: true
- repo: https://github.com/PyCQA/bandit
rev: 1.7.8
hooks:
Expand Down
9 changes: 5 additions & 4 deletions python/.vscode/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@
"command": "uv",
"args": [
"run",
"pre-commit",
"run",
"-a",
"mypy"
"mypy",
"-p",
"semantic_kernel",
"--config-file",
"mypy.ini"
],
"problemMatcher": {
"owner": "python",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
from semantic_kernel.exceptions import MemoryConnectorException, MemoryConnectorInitializationError
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
from semantic_kernel.utils.experimental_decorator import experimental_class

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -304,39 +303,6 @@ def _build_filter_string(self, search_filter: VectorSearchFilter) -> str:
filter_string = filter_string[:-5]
return filter_string

@staticmethod
def _default_parameter_metadata() -> list[KernelParameterMetadata]:
"""Default parameter metadata for text search functions.
This function should be overridden when necessary.
"""
return [
KernelParameterMetadata(
name="query",
description="What to search for.",
type="str",
is_required=False,
default_value="*",
type_object=str,
),
KernelParameterMetadata(
name="count",
description="Number of results to return.",
type="int",
is_required=False,
default_value=2,
type_object=int,
),
KernelParameterMetadata(
name="skip",
description="Number of results to skip.",
type="int",
is_required=False,
default_value=0,
type_object=int,
),
]

@override
def _get_record_from_result(self, result: dict[str, Any]) -> dict[str, Any]:
return result
Expand Down
21 changes: 21 additions & 0 deletions python/semantic_kernel/connectors/memory/in_memory/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft. All rights reserved.


from collections.abc import Callable
from typing import Any

from numpy import dot
from scipy.spatial.distance import cityblock, cosine, euclidean, hamming, sqeuclidean

from semantic_kernel.data.const import DistanceFunction

DISTANCE_FUNCTION_MAP: dict[DistanceFunction | str, Callable[..., Any]] = {
DistanceFunction.COSINE_DISTANCE: cosine,
DistanceFunction.COSINE_SIMILARITY: cosine,
DistanceFunction.EUCLIDEAN_DISTANCE: euclidean,
DistanceFunction.EUCLIDEAN_SQUARED_DISTANCE: sqeuclidean,
DistanceFunction.MANHATTAN: cityblock,
DistanceFunction.HAMMING: hamming,
DistanceFunction.DOT_PROD: dot,
"default": cosine,
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,45 @@
# Copyright (c) Microsoft. All rights reserved.

import sys
from collections.abc import Mapping, Sequence
from collections.abc import AsyncIterable, Callable, Mapping, Sequence
from typing import Any, ClassVar, TypeVar

from pydantic import Field

from semantic_kernel.data.filter_clauses.any_tags_equal_to_filter_clause import AnyTagsEqualTo
from semantic_kernel.data.filter_clauses.equal_to_filter_clause import EqualTo

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover

from pydantic import Field

from semantic_kernel.connectors.memory.in_memory.const import DISTANCE_FUNCTION_MAP
from semantic_kernel.data.const import DistanceFunction
from semantic_kernel.data.filter_clauses.filter_clause_base import FilterClauseBase
from semantic_kernel.data.kernel_search_results import KernelSearchResults
from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
from semantic_kernel.data.record_definition.vector_store_record_fields import (
VectorStoreRecordVectorField,
)
from semantic_kernel.data.vector_search.vector_search import VectorSearchBase
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
from semantic_kernel.exceptions import VectorSearchExecutionException, VectorStoreModelValidationError
from semantic_kernel.kernel_types import OneOrMany

KEY_TYPES = str | int | float

TModel = TypeVar("TModel")

IN_MEMORY_SCORE_KEY = "in_memory_search_score"


class InMemoryVectorCollection(VectorStoreRecordCollection[KEY_TYPES, TModel]):
class InMemoryVectorCollection(
VectorSearchBase[KEY_TYPES, TModel], VectorTextSearchMixin[TModel], VectorizedSearchMixin[TModel]
):
"""In Memory Collection."""

inner_storage: dict[KEY_TYPES, dict] = Field(default_factory=dict)
Expand All @@ -39,6 +58,12 @@ def __init__(
collection_name=collection_name,
)

def _validate_data_model(self):
"""Check if the In Memory Score key is not used."""
super()._validate_data_model()
if IN_MEMORY_SCORE_KEY in self.data_model_definition.field_names:
raise VectorStoreModelValidationError(f"Field name '{IN_MEMORY_SCORE_KEY}' is reserved for internal use.")

@override
async def _inner_delete(self, keys: Sequence[KEY_TYPES], **kwargs: Any) -> None:
for key in keys:
Expand Down Expand Up @@ -74,3 +99,139 @@ async def delete_collection(self, **kwargs: Any) -> None:
@override
async def does_collection_exist(self, **kwargs: Any) -> bool:
return True

@override
async def _inner_search(
self,
options: VectorSearchOptions | None = None,
search_text: str | None = None,
vectorizable_text: str | None = None,
vector: list[float | int] | None = None,
**kwargs: Any,
) -> KernelSearchResults[VectorSearchResult[TModel]]:
"""Inner search method."""
if search_text:
return await self._inner_search_text(search_text, options, **kwargs)
if vector:
if not options:
raise VectorSearchExecutionException("Options must be provided for vector search.")
return await self._inner_search_vectorized(vector, options, **kwargs)
raise VectorSearchExecutionException("Search text or vector must be provided.")

async def _inner_search_text(
self,
search_text: str,
options: VectorSearchOptions | None = None,
**kwargs: Any,
) -> KernelSearchResults[VectorSearchResult[TModel]]:
"""Inner search method."""
return_records: dict[KEY_TYPES, float] = {}
for key, record in self._get_filtered_records(options).items():
if self._should_add_text_search(search_text, record):
return_records[key] = 1.0
if return_records:
return KernelSearchResults(
results=self._get_vector_search_results_from_results(
self._generate_return_list(return_records, options)
),
total_count=len(return_records) if options and options.include_total_count else None,
)
return KernelSearchResults(results=None)

async def _inner_search_vectorized(
self,
vector: list[float | int],
options: VectorSearchOptions,
**kwargs: Any,
) -> KernelSearchResults[VectorSearchResult[TModel]]:
return_records: dict[KEY_TYPES, float] = {}
if not options.vector_field_name:
raise ValueError("Vector field name must be provided in options for vector search.")
field = options.vector_field_name
assert isinstance(self.data_model_definition.fields.get(field), VectorStoreRecordVectorField) # nosec
distance_metric = self.data_model_definition.fields.get(field).distance_function or "default" # type: ignore
distance_func = DISTANCE_FUNCTION_MAP[distance_metric]

for key, record in self._get_filtered_records(options).items():
if vector and field is not None:
return_records[key] = self._calculate_vector_similarity(
vector,
record[field],
distance_func,
invert_score=distance_metric == DistanceFunction.COSINE_SIMILARITY,
)
if distance_metric in [DistanceFunction.COSINE_SIMILARITY, DistanceFunction.DOT_PROD]:
sorted_records = dict(sorted(return_records.items(), key=lambda item: item[1], reverse=True))
else:
sorted_records = dict(sorted(return_records.items(), key=lambda item: item[1]))
if sorted_records:
return KernelSearchResults(
results=self._get_vector_search_results_from_results(
self._generate_return_list(sorted_records, options)
),
total_count=len(return_records) if options and options.include_total_count else None,
)
return KernelSearchResults(results=None)

async def _generate_return_list(
self, return_records: dict[KEY_TYPES, float], options: VectorSearchOptions | None
) -> AsyncIterable[dict]:
top = 3 if not options else options.top
skip = 0 if not options else options.skip
returned = 0
for idx, key in enumerate(return_records.keys()):
if idx >= skip:
returned += 1
rec = self.inner_storage[key]
rec[IN_MEMORY_SCORE_KEY] = return_records[key]
yield rec
if returned >= top:
break

def _get_filtered_records(self, options: VectorSearchOptions | None) -> dict[KEY_TYPES, dict]:
if options and options.filter:
for filter in options.filter.filters:
return {key: record for key, record in self.inner_storage.items() if self._apply_filter(record, filter)}
return self.inner_storage

def _should_add_text_search(self, search_text: str, record: dict) -> bool:
for field in self.data_model_definition.fields.values():
if not isinstance(field, VectorStoreRecordVectorField) and search_text in record.get(field.name, ""):
return True
return False

def _calculate_vector_similarity(
self,
search_vector: list[float | int],
record_vector: list[float | int],
distance_func: Callable,
invert_score: bool = False,
) -> float:
calc = distance_func(record_vector, search_vector)
if invert_score:
return 1.0 - float(calc)
return float(calc)

@staticmethod
def _apply_filter(record: dict[str, Any], filter: FilterClauseBase) -> bool:
match filter:
case EqualTo():
value = record.get(filter.field_name)
if not value:
return False
return value.lower() == filter.value.lower()
case AnyTagsEqualTo():
tag_list = record.get(filter.field_name)
if not tag_list:
return False
if not isinstance(tag_list, list):
tag_list = [tag_list]
return filter.value in tag_list
case _:
return True

def _get_record_from_result(self, result: Any) -> Any:
return result

def _get_score_from_result(self, result: Any) -> float | None:
return result.get(IN_MEMORY_SCORE_KEY)
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from pydantic import Field

from semantic_kernel.data.filter_clauses.filter_clause_base import FilterClauseBase
from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.experimental_decorator import experimental_class


@experimental_class
class AnyTagsEqualTo(FilterClauseBase, KernelBaseModel):
"""A filter clause for a any tags equals comparison."""
class AnyTagsEqualTo(FilterClauseBase):
"""A filter clause for a any tags equals comparison.
filter_clause_type: str = Field("any_tags_equal_to", init=False) # type: ignore
Args:
field_name: The name of the field containing the list of tags.
value: The value to compare against the list of tags.
"""

field_name: str
value: str
filter_clause_type: str = Field("any_tags_equal_to", init=False) # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
from pydantic import Field

from semantic_kernel.data.filter_clauses.filter_clause_base import FilterClauseBase
from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.experimental_decorator import experimental_class


@experimental_class
class EqualTo(FilterClauseBase, KernelBaseModel):
"""A filter clause for an equals comparison."""
class EqualTo(FilterClauseBase):
"""A filter clause for an equals comparison.
filter_clause_type: str = Field("equal_to", init=False) # type: ignore
Args:
field_name: The name of the field to compare.
value: The value to compare against the field.
"""

field_name: str
value: str
filter_clause_type: str = Field("equal_to", init=False) # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


from abc import ABC
from typing import Any

from pydantic import Field

Expand All @@ -14,3 +15,5 @@ class FilterClauseBase(ABC, KernelBaseModel):
"""A base for all filter clauses."""

filter_clause_type: str = Field("FilterClauseBase", init=False) # type: ignore
field_name: str
value: Any
Empty file.
11 changes: 0 additions & 11 deletions python/semantic_kernel/search/const.py

This file was deleted.

Loading

0 comments on commit 1db12c4

Please sign in to comment.