Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store user_id to vectordb #1988

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions mem0/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

import httpx

from mem0.memory.setup import get_user_id, setup_config
from mem0.memory.telemetry import capture_client_event
from mem0.memory.setup import get_user_id
from mem0.memory.telemetry import AnonymousTelemetry, capture_client_event

logger = logging.getLogger(__name__)

# Setup user config
setup_config()
telemetry = AnonymousTelemetry(vector_store_provider="qdrant")


class APIError(Exception):
Expand Down Expand Up @@ -83,7 +82,7 @@ def __init__(
timeout=60,
)
self._validate_api_key()
capture_client_event("client.init", self)
capture_client_event(telemetry, "client.init", self)

def _validate_api_key(self):
"""Validate the API key by making a test request."""
Expand Down Expand Up @@ -113,7 +112,7 @@ def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str,
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event("client.add", self, {"keys": list(kwargs.keys())})
capture_client_event(telemetry, "client.add", self, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand All @@ -131,7 +130,7 @@ def get(self, memory_id: str) -> Dict[str, Any]:
"""
response = self.client.get(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("client.get", self, {"memory_id": memory_id})
capture_client_event(telemetry, "client.get", self, {"memory_id": memory_id})
return response.json()

@api_error_handler
Expand All @@ -158,6 +157,7 @@ def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event(
telemetry,
"client.get_all",
self,
{"api_version": version, "keys": list(kwargs.keys())},
Expand Down Expand Up @@ -186,7 +186,9 @@ def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, An
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event("client.search", self, {"api_version": version, "keys": list(kwargs.keys())})
capture_client_event(
telemetry, "client.search", self, {"api_version": version, "keys": list(kwargs.keys())}
)
return response.json()

@api_error_handler
Expand All @@ -199,7 +201,7 @@ def update(self, memory_id: str, data: str) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The response from the server.
"""
capture_client_event("client.update", self, {"memory_id": memory_id})
capture_client_event(telemetry, "client.update", self, {"memory_id": memory_id})
response = self.client.put(f"/v1/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
return response.json()
Expand All @@ -219,7 +221,7 @@ def delete(self, memory_id: str) -> Dict[str, Any]:
"""
response = self.client.delete(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("client.delete", self, {"memory_id": memory_id})
capture_client_event(telemetry, "client.delete", self, {"memory_id": memory_id})
return response.json()

@api_error_handler
Expand All @@ -239,7 +241,7 @@ def delete_all(self, **kwargs) -> Dict[str, str]:
params = self._prepare_params(kwargs)
response = self.client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("client.delete_all", self, {"keys": list(kwargs.keys())})
capture_client_event(telemetry, "client.delete_all", self, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
Expand All @@ -257,7 +259,7 @@ def history(self, memory_id: str) -> List[Dict[str, Any]]:
"""
response = self.client.get(f"/v1/memories/{memory_id}/history/")
response.raise_for_status()
capture_client_event("client.history", self, {"memory_id": memory_id})
capture_client_event(telemetry, "client.history", self, {"memory_id": memory_id})
return response.json()

@api_error_handler
Expand All @@ -266,7 +268,7 @@ def users(self) -> Dict[str, Any]:
params = {"org_name": self.organization, "project_name": self.project}
response = self.client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("client.users", self)
capture_client_event(telemetry, "client.users", self)
return response.json()

@api_error_handler
Expand All @@ -278,7 +280,7 @@ def delete_users(self) -> Dict[str, str]:
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()

capture_client_event("client.delete_users", self)
capture_client_event(telemetry, "client.delete_users", self)
return {"message": "All users, agents, and sessions deleted."}

@api_error_handler
Expand All @@ -297,7 +299,7 @@ def reset(self) -> Dict[str, str]:
# This will also delete the memories
self.delete_users()

capture_client_event("client.reset", self)
capture_client_event(telemetry, "client.reset", self)
return {"message": "Client reset successful. All users and memories deleted."}

def chat(self):
Expand Down Expand Up @@ -372,14 +374,14 @@ async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dic
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event("async_client.add", self.sync_client, {"keys": list(kwargs.keys())})
capture_client_event(telemetry, "async_client.add", self.sync_client, {"keys": list(kwargs.keys())})
return response.json()

@api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.get", self.sync_client, {"memory_id": memory_id})
capture_client_event(telemetry, "async_client.get", self.sync_client, {"memory_id": memory_id})
return response.json()

@api_error_handler
Expand All @@ -393,7 +395,10 @@ async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event(
"async_client.get_all", self.sync_client, {"api_version": version, "keys": list(kwargs.keys())}
telemetry,
"async_client.get_all",
self.sync_client,
{"api_version": version, "keys": list(kwargs.keys())},
)
return response.json()

Expand All @@ -406,45 +411,50 @@ async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[s
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event(
"async_client.search", self.sync_client, {"api_version": version, "keys": list(kwargs.keys())}
telemetry,
"async_client.search",
self.sync_client,
{"api_version": version, "keys": list(kwargs.keys())},
)
return response.json()

@api_error_handler
async def update(self, memory_id: str, data: str) -> Dict[str, Any]:
response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
capture_client_event("async_client.update", self.sync_client, {"memory_id": memory_id})
capture_client_event(telemetry, "async_client.update", self.sync_client, {"memory_id": memory_id})
return response.json()

@api_error_handler
async def delete(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.delete(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.delete", self.sync_client, {"memory_id": memory_id})
capture_client_event(telemetry, "async_client.delete", self.sync_client, {"memory_id": memory_id})
return response.json()

@api_error_handler
async def delete_all(self, **kwargs) -> Dict[str, str]:
params = self.sync_client._prepare_params(kwargs)
response = await self.async_client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_all", self.sync_client, {"keys": list(kwargs.keys())})
capture_client_event(
telemetry, "async_client.delete_all", self.sync_client, {"keys": list(kwargs.keys())}
)
return response.json()

@api_error_handler
async def history(self, memory_id: str) -> List[Dict[str, Any]]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/history/")
response.raise_for_status()
capture_client_event("async_client.history", self.sync_client, {"memory_id": memory_id})
capture_client_event(telemetry, "async_client.history", self.sync_client, {"memory_id": memory_id})
return response.json()

@api_error_handler
async def users(self) -> Dict[str, Any]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
response = await self.async_client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("async_client.users", self.sync_client)
capture_client_event(telemetry, "async_client.users", self.sync_client)
return response.json()

@api_error_handler
Expand All @@ -454,13 +464,13 @@ async def delete_users(self) -> Dict[str, str]:
for entity in entities["results"]:
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_users", self.sync_client)
capture_client_event(telemetry, "async_client.delete_users", self.sync_client)
return {"message": "All users, agents, and sessions deleted."}

@api_error_handler
async def reset(self) -> Dict[str, str]:
await self.delete_users()
capture_client_event("async_client.reset", self.sync_client)
capture_client_event(telemetry, "async_client.reset", self.sync_client)
return {"message": "Client reset successful. All users and memories deleted."}

async def chat(self):
Expand Down
32 changes: 15 additions & 17 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.telemetry import AnonymousTelemetry, capture_event
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory

# Setup user config
setup_config()

logger = logging.getLogger(__name__)


Expand All @@ -47,7 +43,8 @@ def __init__(self, config: MemoryConfig = MemoryConfig()):
self.graph = MemoryGraph(self.config)
self.enable_graph = True

capture_event("mem0.init", self)
self.telemetry = AnonymousTelemetry(self.config.vector_store.provider, self.config.vector_store)
capture_event(self.telemetry, "mem0.init", self)

@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
Expand Down Expand Up @@ -233,7 +230,7 @@ def _add_to_vector_store(self, messages, metadata, filters):
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")

capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})
capture_event(self.telemetry, "mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})

return returned_memories

Expand Down Expand Up @@ -263,7 +260,7 @@ def get(self, memory_id):
Returns:
dict: Retrieved memory.
"""
capture_event("mem0.get", self, {"memory_id": memory_id})
capture_event(self.telemetry, "mem0.get", self, {"memory_id": memory_id})
memory = self.vector_store.get(vector_id=memory_id)
if not memory:
return None
Expand Down Expand Up @@ -312,7 +309,7 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
if run_id:
filters["run_id"] = run_id

capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})
capture_event(self.telemetry, "mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})

with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
Expand Down Expand Up @@ -403,6 +400,7 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")

capture_event(
self.telemetry,
"mem0.search",
self,
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
Expand Down Expand Up @@ -485,7 +483,7 @@ def update(self, memory_id, data):
Returns:
dict: Updated memory.
"""
capture_event("mem0.update", self, {"memory_id": memory_id})
capture_event(self.telemetry, "mem0.update", self, {"memory_id": memory_id})

existing_embeddings = {data: self.embedding_model.embed(data)}

Expand All @@ -499,7 +497,7 @@ def delete(self, memory_id):
Args:
memory_id (str): ID of the memory to delete.
"""
capture_event("mem0.delete", self, {"memory_id": memory_id})
capture_event(self.telemetry, "mem0.delete", self, {"memory_id": memory_id})
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}

Expand All @@ -525,7 +523,7 @@ def delete_all(self, user_id=None, agent_id=None, run_id=None):
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)

capture_event("mem0.delete_all", self, {"keys": list(filters.keys())})
capture_event(self.telemetry, "mem0.delete_all", self, {"keys": list(filters.keys())})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory(memory.id)
Expand All @@ -547,7 +545,7 @@ def history(self, memory_id):
Returns:
list: List of changes for the memory.
"""
capture_event("mem0.history", self, {"memory_id": memory_id})
capture_event(self.telemetry, "mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)

def _create_memory(self, data, existing_embeddings, metadata=None):
Expand All @@ -568,7 +566,7 @@ def _create_memory(self, data, existing_embeddings, metadata=None):
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
capture_event(self.telemetry, "mem0._create_memory", self, {"memory_id": memory_id})
return memory_id

def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
Expand Down Expand Up @@ -611,7 +609,7 @@ def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
capture_event("mem0._update_memory", self, {"memory_id": memory_id})
capture_event(self.telemetry, "mem0._update_memory", self, {"memory_id": memory_id})
return memory_id

def _delete_memory(self, memory_id):
Expand All @@ -620,7 +618,7 @@ def _delete_memory(self, memory_id):
prev_value = existing_memory.payload["data"]
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
capture_event("mem0._delete_memory", self, {"memory_id": memory_id})
capture_event(self.telemetry, "mem0._delete_memory", self, {"memory_id": memory_id})
return memory_id

def reset(self):
Expand All @@ -633,7 +631,7 @@ def reset(self):
self.config.vector_store.provider, self.config.vector_store.config
)
self.db.reset()
capture_event("mem0.reset", self)
capture_event(self.telemetry, "mem0.reset", self)

def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")
Loading
Loading