Skip to content

Commit

Permalink
Version bump and client fixes (#2017)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dev-Khant authored Nov 7, 2024
1 parent 549e5e3 commit 3731965
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
25 changes: 14 additions & 11 deletions mem0/client/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import os
import warnings
from functools import wraps
from typing import Any, Dict, List, Optional, Union
import warnings

import httpx

Expand Down Expand Up @@ -122,7 +122,7 @@ def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str,
Raises:
APIError: If the API request fails.
"""
kwargs.update({"org_name": self.organization, "project_name": self.project})
kwargs = self._prepare_params(kwargs)
payload = self._prepare_payload(messages, kwargs)
response = self.client.post("/v1/memories/", json=payload)
response.raise_for_status()
Expand Down Expand Up @@ -163,7 +163,6 @@ def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
Raises:
APIError: If the API request fails.
"""
kwargs.update({"org_name": self.organization, "project_name": self.project})
params = self._prepare_params(kwargs)
if version == "v1":
response = self.client.get(f"/{version}/memories/", params=params)
Expand Down Expand Up @@ -195,8 +194,8 @@ def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, An
APIError: If the API request fails.
"""
payload = {"query": query}
kwargs.update({"org_name": self.organization, "project_name": self.project})
payload.update({k: v for k, v in kwargs.items() if v is not None})
params = self._prepare_params(kwargs)
payload.update(params)
response = self.client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
if "metadata" in kwargs:
Expand Down Expand Up @@ -250,7 +249,6 @@ def delete_all(self, **kwargs) -> Dict[str, str]:
Raises:
APIError: If the API request fails.
"""
kwargs.update({"org_name": self.organization, "project_name": self.project})
params = self._prepare_params(kwargs)
response = self.client.delete("/v1/memories/", params=params)
response.raise_for_status()
Expand Down Expand Up @@ -278,7 +276,7 @@ def history(self, memory_id: str) -> List[Dict[str, Any]]:
@api_error_handler
def users(self) -> Dict[str, Any]:
"""Get all users, agents, and sessions for which memories exist."""
params = {"org_name": self.organization, "project_name": self.project}
params = self._prepare_params()
response = self.client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("client.users", self)
Expand All @@ -287,7 +285,7 @@ def users(self) -> Dict[str, Any]:
@api_error_handler
def delete_users(self) -> Dict[str, str]:
"""Delete all users, agents, or sessions."""
params = {"org_name": self.organization, "project_name": self.project}
params = self._prepare_params()
entities = self.users()
for entity in entities["results"]:
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
Expand Down Expand Up @@ -344,7 +342,7 @@ def _prepare_payload(
payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload

def _prepare_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Prepare query parameters for API requests.
Args:
Expand All @@ -356,6 +354,10 @@ def _prepare_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
Raises:
ValueError: If both org_id/project_id and org_name/project_name are provided.
"""

if kwargs is None:
kwargs = {}

has_new = bool(self.org_id or self.project_id)
has_old = bool(self.organization or self.project)

Expand Down Expand Up @@ -414,6 +416,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

@api_error_handler
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
kwargs = self.sync_client._prepare_params(kwargs)
payload = self.sync_client._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status()
Expand Down Expand Up @@ -488,15 +491,15 @@ async def history(self, memory_id: str) -> List[Dict[str, Any]]:

@api_error_handler
async def users(self) -> Dict[str, Any]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
params = self.sync_client._prepare_params()
response = await self.async_client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("async_client.users", self.sync_client)
return response.json()

@api_error_handler
async def delete_users(self) -> Dict[str, str]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
params = self.sync_client._prepare_params()
entities = await self.users()
for entity in entities["results"]:
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mem0ai"
version = "0.1.28"
version = "0.1.29"
description = "Long-term memory for AI Agents"
authors = ["Mem0 <[email protected]>"]
exclude = [
Expand Down

0 comments on commit 3731965

Please sign in to comment.