From 41d5680f1abac7f625f1bbcba34eb7ca1a47bff6 Mon Sep 17 00:00:00 2001 From: Eric Gustin Date: Tue, 19 Nov 2024 15:04:47 -0800 Subject: [PATCH 1/3] Add some thread tools to Gmail --- toolkits/google/arcade_google/tools/gmail.py | 155 +++++++++--- toolkits/google/arcade_google/tools/models.py | 4 +- toolkits/google/arcade_google/tools/utils.py | 18 +- toolkits/google/evals/eval_google_gmail.py | 57 +++++ toolkits/google/tests/test_gmail.py | 230 +++++++++++++++--- 5 files changed, 392 insertions(+), 72 deletions(-) diff --git a/toolkits/google/arcade_google/tools/gmail.py b/toolkits/google/arcade_google/tools/gmail.py index 3dbe5fff..909f05fb 100644 --- a/toolkits/google/arcade_google/tools/gmail.py +++ b/toolkits/google/arcade_google/tools/gmail.py @@ -1,5 +1,4 @@ import base64 -import json from email.message import EmailMessage from email.mime.text import MIMEText from typing import Annotated, Optional @@ -15,11 +14,9 @@ DateRange, build_query_string, fetch_messages, - get_draft_url, - get_email_in_trash_url, - get_sent_email_url, parse_draft_email, parse_email, + remove_none_values, ) @@ -36,7 +33,7 @@ async def send_email( recipient: Annotated[str, "The recipient of the email"], cc: Annotated[Optional[list[str]], "CC recipients of the email"] = None, bcc: Annotated[Optional[list[str]], "BCC recipients of the email"] = None, -) -> Annotated[str, "A confirmation message with the sent email ID and URL"]: +) -> Annotated[dict, "A dictionary containing the sent email details"]: """ Send an email using the Gmail API. """ @@ -61,7 +58,8 @@ async def send_email( # Send the email sent_message = service.users().messages().send(userId="me", body=email).execute() - return f"Email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}" + + return parse_email(sent_message) @tool( @@ -71,7 +69,7 @@ async def send_email( ) async def send_draft_email( context: ToolContext, email_id: Annotated[str, "The ID of the draft to send"] -) -> Annotated[str, "A confirmation message with the sent email ID and URL"]: +) -> Annotated[dict, "A dictionary containing the sent email details"]: """ Send a draft email using the Gmail API. """ @@ -82,10 +80,7 @@ async def send_draft_email( # Send the draft email sent_message = service.users().drafts().send(userId="me", body={"id": email_id}).execute() - # Construct the URL to the sent email - return ( - f"Draft email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}" - ) + return parse_email(sent_message) # Draft Management Tools @@ -101,7 +96,7 @@ async def write_draft_email( recipient: Annotated[str, "The recipient of the draft email"], cc: Annotated[Optional[list[str]], "CC recipients of the draft email"] = None, bcc: Annotated[Optional[list[str]], "BCC recipients of the draft email"] = None, -) -> Annotated[str, "A confirmation message with the draft email ID and URL"]: +) -> Annotated[dict, "A dictionary containing the created draft email details"]: """ Compose a new email draft using the Gmail API. """ @@ -123,9 +118,7 @@ async def write_draft_email( draft = {"message": {"raw": raw_message}} draft_message = service.users().drafts().create(userId="me", body=draft).execute() - return ( - f"Draft email with ID {draft_message['id']} created: {get_draft_url(draft_message['id'])}" - ) + return parse_draft_email(draft_message) @tool( @@ -141,7 +134,7 @@ async def update_draft_email( recipient: Annotated[str, "The recipient of the draft email"], cc: Annotated[Optional[list[str]], "CC recipients of the draft email"] = None, bcc: Annotated[Optional[list[str]], "BCC recipients of the draft email"] = None, -) -> Annotated[str, "A confirmation message with the updated draft email ID and URL"]: +) -> Annotated[dict, "A dictionary containing the updated draft email details"]: """ Update an existing email draft using the Gmail API. """ @@ -166,7 +159,8 @@ async def update_draft_email( updated_draft_message = ( service.users().drafts().update(userId="me", id=draft_email_id, body=draft).execute() ) - return f"Draft email with ID {updated_draft_message['id']} updated: {get_draft_url(updated_draft_message['id'])}" + + return parse_draft_email(updated_draft_message) @tool( @@ -198,7 +192,7 @@ async def delete_draft_email( ) async def trash_email( context: ToolContext, email_id: Annotated[str, "The ID of the email to trash"] -) -> Annotated[str, "A confirmation message with the trashed email ID and URL"]: +) -> Annotated[dict, "A dictionary containing the trashed email details"]: """ Move an email to the trash folder using the Gmail API. """ @@ -207,9 +201,9 @@ async def trash_email( service = build("gmail", "v1", credentials=Credentials(context.authorization.token)) # Trash the email - service.users().messages().trash(userId="me", id=email_id).execute() + trashed_email = service.users().messages().trash(userId="me", id=email_id).execute() - return f"Email with ID {email_id} trashed successfully: {get_email_in_trash_url(email_id)}" + return parse_email(trashed_email) # Draft Search Tools @@ -221,7 +215,7 @@ async def trash_email( async def list_draft_emails( context: ToolContext, n_drafts: Annotated[int, "Number of draft emails to read"] = 5, -) -> Annotated[str, "A JSON string containing a list of draft email details and their IDs"]: +) -> Annotated[dict, "A dictionary containing a list of draft email details"]: """ Lists draft emails in the user's draft mailbox using the Gmail API. """ @@ -245,7 +239,7 @@ async def list_draft_emails( except Exception as e: print(f"Error reading draft email {draft_id}: {e}") - return json.dumps({"emails": emails}) + return {"emails": emails} # Email Search Tools @@ -263,11 +257,11 @@ async def list_emails_by_header( date_range: Annotated[Optional[DateRange], "The date range of the email"] = None, limit: Annotated[Optional[int], "The maximum number of emails to return"] = 25, ) -> Annotated[ - str, "A JSON string containing a list of email details matching the search criteria" + dict, "A dictionary containing a list of email details matching the search criteria" ]: """ Search for emails by header using the Gmail API. - At least one of the following parametersMUST be provided: sender, recipient, subject, body. + At least one of the following parameters MUST be provided: sender, recipient, subject, body. """ if not any([sender, recipient, subject, body]): raise RetryableToolError( @@ -281,10 +275,10 @@ async def list_emails_by_header( messages = fetch_messages(service, query, limit) if not messages: - return json.dumps({"emails": []}) + return {"emails": []} emails = process_messages(service, messages) - return json.dumps({"emails": emails}) + return {"emails": emails} def process_messages(service, messages): @@ -307,7 +301,7 @@ def process_messages(service, messages): async def list_emails( context: ToolContext, n_emails: Annotated[int, "Number of emails to read"] = 5, -) -> Annotated[str, "A JSON string containing a list of email details"]: +) -> Annotated[dict, "A dictionary containing a list of email details"]: """ Read emails from a Gmail account and extract plain text content. """ @@ -329,4 +323,109 @@ async def list_emails( except Exception as e: print(f"Error reading email {msg['id']}: {e}") - return json.dumps({"emails": emails}) + return {"emails": emails} + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def search_threads( + context: ToolContext, + page_token: Annotated[ + Optional[str], "Page token to retrieve a specific page of results in the list" + ] = None, + max_results: Annotated[int, "The maximum number of threads to return"] = 10, + include_spam_trash: Annotated[bool, "Whether to include spam and trash in the results"] = False, + label_ids: Annotated[Optional[list[str]], "The IDs of labels to filter by"] = None, + sender: Annotated[Optional[str], "The name or email address of the sender of the email"] = None, + recipient: Annotated[Optional[str], "The name or email address of the recipient"] = None, + subject: Annotated[Optional[str], "Words to find in the subject of the email"] = None, + body: Annotated[Optional[str], "Words to find in the body of the email"] = None, + date_range: Annotated[Optional[DateRange], "The date range of the email"] = None, +) -> Annotated[dict, "A dictionary containing a list of thread details"]: + """Search for threads in the user's mailbox""" + service = build("gmail", "v1", credentials=Credentials(context.authorization.token)) + + query = ( + build_query_string(sender, recipient, subject, body, date_range) + if any([sender, recipient, subject, body, date_range]) + else None + ) + + params = { + "userId": "me", + "maxResults": min(max_results, 500), + "pageToken": page_token, + "includeSpamTrash": include_spam_trash, + "labelIds": label_ids, + "q": query, + } + params = remove_none_values(params) + + threads = [] + next_page_token = None + # Paginate through thread pages until we have the desired number of threads + while len(threads) < max_results: + response = service.users().threads().list(**params).execute() + + threads.extend(response.get("threads", [])) + next_page_token = response.get("nextPageToken") + + if not next_page_token: + break + + params["pageToken"] = next_page_token + params["maxResults"] = min(max_results - len(threads), 500) + + return { + "threads": threads, + "num_threads": len(threads), + "next_page_token": next_page_token, + } + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def list_threads( + context: ToolContext, + page_token: Annotated[ + Optional[str], "Page token to retrieve a specific page of results in the list" + ] = None, + max_results: Annotated[int, "The maximum number of threads to return"] = 10, + include_spam_trash: Annotated[bool, "Whether to include spam and trash in the results"] = False, +) -> Annotated[dict, "A dictionary containing a list of thread details"]: + """List threads in the user's mailbox.""" + return await search_threads(context, page_token, max_results, include_spam_trash) + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def get_thread( + context: ToolContext, + thread_id: Annotated[str, "The ID of the thread to retrieve"], + metadata_headers: Annotated[ + Optional[list[str]], "When given and format is METADATA, only include headers specified." + ] = None, +) -> Annotated[dict, "A dictionary containing the thread details"]: + """Get the specified thread by ID.""" + params = { + "userId": "me", + "id": thread_id, + "format": "full", + "metadataHeaders": metadata_headers, + } + params = remove_none_values(params) + + service = build("gmail", "v1", credentials=Credentials(context.authorization.token)) + thread = service.users().threads().get(**params).execute() + thread["messages"] = [parse_email(message) for message in thread.get("messages", [])] + + return thread diff --git a/toolkits/google/arcade_google/tools/models.py b/toolkits/google/arcade_google/tools/models.py index 2130e5d4..feb6101c 100644 --- a/toolkits/google/arcade_google/tools/models.py +++ b/toolkits/google/arcade_google/tools/models.py @@ -232,7 +232,9 @@ class SendUpdatesOptions(Enum): EXTERNAL_ONLY = "externalOnly" # Notifications are sent to non-Google Calendar guests only. -# Utils for Google Drive tools +# ---------------------------------------------------------------------------- # +# Google Drive Models and Enums +# ---------------------------------------------------------------------------- # class Corpora(str, Enum): """ Bodies of items (files/documents) to which the query applies. diff --git a/toolkits/google/arcade_google/tools/utils.py b/toolkits/google/arcade_google/tools/utils.py index cf7d426b..1bc52b44 100644 --- a/toolkits/google/arcade_google/tools/utils.py +++ b/toolkits/google/arcade_google/tools/utils.py @@ -82,16 +82,17 @@ def parse_email(email_data: dict[str, Any]) -> Optional[dict[str, str]]: Optional[Dict[str, str]]: Parsed email details or None if parsing fails. """ try: - payload = email_data["payload"] - headers = {d["name"].lower(): d["value"] for d in payload["headers"]} + payload = email_data.get("payload", {}) + headers = {d["name"].lower(): d["value"] for d in payload.get("headers", [])} body_data = _get_email_body(payload) return { "id": email_data.get("id", ""), + "thread_id": email_data.get("threadId", ""), "from": headers.get("from", ""), "date": headers.get("date", ""), - "subject": headers.get("subject", "No subject"), + "subject": headers.get("subject", ""), "body": _clean_email_body(body_data) if body_data else "", } except Exception as e: @@ -110,17 +111,18 @@ def parse_draft_email(draft_email_data: dict[str, Any]) -> Optional[dict[str, st Optional[Dict[str, str]]: Parsed draft email details or None if parsing fails. """ try: - message = draft_email_data["message"] - payload = message["payload"] - headers = {d["name"].lower(): d["value"] for d in payload["headers"]} + message = draft_email_data.get("message", {}) + payload = message.get("payload", {}) + headers = {d["name"].lower(): d["value"] for d in payload.get("headers", [])} body_data = _get_email_body(payload) return { "id": draft_email_data.get("id", ""), + "thread_id": draft_email_data.get("threadId", ""), "from": headers.get("from", ""), "date": headers.get("internaldate", ""), - "subject": headers.get("subject", "No subject"), + "subject": headers.get("subject", ""), "body": _clean_email_body(body_data) if body_data else "", } except Exception as e: @@ -226,7 +228,7 @@ def _update_datetime(day: Day | None, time: TimeSlot | None, time_zone: str) -> def build_query_string(sender, recipient, subject, body, date_range): """ - Helper function to build a query string for Gmail list_emails_by_header tool. + Helper function to build a query string for Gmail list_emails_by_header and search_threads tools. """ query = [] if sender: diff --git a/toolkits/google/evals/eval_google_gmail.py b/toolkits/google/evals/eval_google_gmail.py index 420701f2..d3b9f4ac 100644 --- a/toolkits/google/evals/eval_google_gmail.py +++ b/toolkits/google/evals/eval_google_gmail.py @@ -1,7 +1,11 @@ import arcade_google from arcade_google.tools.gmail import ( + get_thread, + list_threads, + search_threads, send_email, ) +from arcade_google.tools.utils import DateRange from arcade.sdk import ToolCatalog from arcade.sdk.eval import ( @@ -57,4 +61,57 @@ def gmail_eval_suite() -> EvalSuite: ], ) + suite.add_case( + name="List threads", + user_message="Get 42 threads like right now i even wanna see the ones in my trash", + expected_tool_calls=[ + ( + list_threads, + {"max_results": 42, "include_spam_trash": True}, + ) + ], + critics=[ + BinaryCritic(critic_field="max_results", weight=0.5), + BinaryCritic(critic_field="include_spam_trash", weight=0.5), + ], + ) + + suite.add_case( + name="Search threads", + user_message="Search for threads from johndoe@example.com to janedoe@example.com about that talk about 'Arcade AI' from yesterday", + expected_tool_calls=[ + ( + search_threads, + { + "sender": "johndoe@example.com", + "recipient": "janedoe@example.com", + "body": "Arcade AI", + "date_range": DateRange.YESTERDAY, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="sender", weight=0.25), + BinaryCritic(critic_field="recipient", weight=0.25), + SimilarityCritic(critic_field="body", weight=0.25), + BinaryCritic(critic_field="date_range", weight=0.25), + ], + ) + + suite.add_case( + name="Get a thread by ID", + user_message="Get the thread r-124325435467568867667878874565464564563523424323524235242412", + expected_tool_calls=[ + ( + get_thread, + { + "thread_id": "r-124325435467568867667878874565464564563523424323524235242412", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="thread_id", weight=1.0), + ], + ) + return suite diff --git a/toolkits/google/tests/test_gmail.py b/toolkits/google/tests/test_gmail.py index 2835b999..df6e68e6 100644 --- a/toolkits/google/tests/test_gmail.py +++ b/toolkits/google/tests/test_gmail.py @@ -1,12 +1,14 @@ -import json from unittest.mock import MagicMock, patch import pytest from arcade_google.tools.gmail import ( delete_draft_email, + get_thread, list_draft_emails, list_emails, list_emails_by_header, + list_threads, + search_threads, send_draft_email, send_email, trash_email, @@ -40,8 +42,11 @@ async def test_send_email(mock_build, mock_context): recipient="test@example.com", ) - assert "Email with ID" in result - assert "sent" in result + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result # Test http error mock_service.users().messages().send().execute.side_effect = HttpError( @@ -72,8 +77,11 @@ async def test_write_draft_email(mock_build, mock_context): recipient="draft@example.com", ) - assert "Draft email with ID" in result - assert "created" in result + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result # Test http error mock_service.users().drafts().create().execute.side_effect = HttpError( @@ -105,8 +113,11 @@ async def test_update_draft_email(mock_build, mock_context): recipient="updated@example.com", ) - assert "Draft email with ID" in result - assert "updated" in result + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result # Test http error mock_service.users().drafts().update().execute.side_effect = HttpError( @@ -133,8 +144,11 @@ async def test_send_draft_email(mock_build, mock_context): # Test happy path result = await send_draft_email(context=mock_context, email_id="draft456") - assert "Draft email with ID" in result - assert "sent" in result + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result # Test http error mock_service.users().drafts().send().execute.side_effect = HttpError( @@ -226,12 +240,10 @@ async def test_get_draft_emails(mock_parse_draft_email, mock_build, mock_context # Test happy path result = await list_draft_emails(context=mock_context, n_drafts=2) - assert isinstance(result, str) - result_json = json.loads(result) - assert isinstance(result_json, dict) - assert "emails" in result_json - assert len(result_json["emails"]) == 1 - assert all("id" in draft and "subject" in draft for draft in result_json["emails"]) + assert isinstance(result, dict) + assert "emails" in result + assert len(result["emails"]) == 1 + assert all("id" in draft and "subject" in draft for draft in result["emails"]) # Test http error mock_service.users().drafts().list().execute.side_effect = HttpError( @@ -301,12 +313,10 @@ async def test_search_emails_by_header(mock_parse_email, mock_build, mock_contex # Test happy path result = await list_emails_by_header(context=mock_context, sender="noreply@github.com", limit=2) - assert isinstance(result, str) - result_json = json.loads(result) - assert isinstance(result_json, dict) - assert "emails" in result_json - assert len(result_json["emails"]) == 2 - assert all("id" in email and "subject" in email for email in result_json["emails"]) + assert isinstance(result, dict) + assert "emails" in result + assert len(result["emails"]) == 2 + assert all("id" in email and "subject" in email for email in result["emails"]) # Test http error mock_service.users().messages().list().execute.side_effect = HttpError( @@ -375,16 +385,13 @@ async def test_get_emails(mock_parse_email, mock_build, mock_context): # Test happy path result = await list_emails(context=mock_context, n_emails=1) - # Assert the result - assert isinstance(result, str) - result_json = json.loads(result) - assert isinstance(result_json, dict) - assert "emails" in result_json - assert len(result_json["emails"]) == 1 - assert "id" in result_json["emails"][0] - assert "subject" in result_json["emails"][0] - assert "date" in result_json["emails"][0] - assert "body" in result_json["emails"][0] + assert isinstance(result, dict) + assert "emails" in result + assert len(result["emails"]) == 1 + assert "id" in result["emails"][0] + assert "subject" in result["emails"][0] + assert "date" in result["emails"][0] + assert "body" in result["emails"][0] # Test http error mock_service.users().messages().list().execute.side_effect = HttpError( @@ -406,10 +413,11 @@ async def test_trash_email(mock_build, mock_context): email_id = "123456" result = await trash_email(context=mock_context, email_id=email_id) - assert ( - f"Email with ID {email_id} trashed successfully: https://mail.google.com/mail/u/0/#trash/{email_id}" - == result - ) + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result # Test http error mock_service.users().messages().trash().execute.side_effect = HttpError( @@ -419,3 +427,155 @@ async def test_trash_email(mock_build, mock_context): with pytest.raises(ToolExecutionError): await trash_email(context=mock_context, email_id="nonexistent_email") + + +@pytest.mark.asyncio +@patch("arcade_google.tools.gmail.build") +async def test_search_threads(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Setup mock response data + mock_threads_list_response = { + "threads": [ + { + "id": "thread1", + "snippet": "Thread snippet 1", + }, + { + "id": "thread2", + "snippet": "Thread snippet 2", + }, + ], + "nextPageToken": "next_token_123", + "resultSizeEstimate": 2, + } + + # Mock the Gmail API threads().list() method + mock_service.users().threads().list().execute.return_value = mock_threads_list_response + + # Test happy path + result = await search_threads( + context=mock_context, + sender="test@example.com", + max_results=2, + ) + + assert isinstance(result, dict) + assert "threads" in result + assert len(result["threads"]) == 2 + assert result["threads"][0]["id"] == "thread1" + assert "next_page_token" in result + + # Test error handling + mock_service.users().threads().list().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await search_threads( + context=mock_context, + sender="test@example.com", + max_results=2, + ) + + +@pytest.mark.asyncio +@patch("arcade_google.tools.gmail.build") +async def test_list_threads(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Setup mock response data + mock_threads_list_response = { + "threads": [ + { + "id": "thread1", + "snippet": "Thread snippet 1", + }, + { + "id": "thread2", + "snippet": "Thread snippet 2", + }, + ], + "nextPageToken": "next_token_123", + "resultSizeEstimate": 2, + } + + # Mock the Gmail API threads().list() method + mock_service.users().threads().list().execute.return_value = mock_threads_list_response + + # Test happy path + result = await list_threads( + context=mock_context, + max_results=2, + ) + + assert isinstance(result, dict) + assert "threads" in result + assert len(result["threads"]) == 2 + assert result["threads"][0]["id"] == "thread1" + assert "next_page_token" in result + + # Test error handling + mock_service.users().threads().list().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await list_threads( + context=mock_context, + max_results=2, + ) + + +@pytest.mark.asyncio +@patch("arcade_google.tools.gmail.build") +async def test_get_thread(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Setup mock response data + mock_thread_get_response = { + "id": "thread1", + "messages": [ + { + "id": "message1", + "snippet": "Message snippet 1", + }, + { + "id": "message2", + "snippet": "Message snippet 2", + }, + ], + } + + # Mock the Gmail API threads().get() method + mock_service.users().threads().get().execute.return_value = mock_thread_get_response + + # Test happy path + result = await get_thread( + context=mock_context, + thread_id="thread1", + ) + + assert isinstance(result, dict) + assert "id" in result + assert result["id"] == "thread1" + assert "messages" in result + assert len(result["messages"]) == 2 + assert result["messages"][0]["id"] == "message1" + + # Test error handling + mock_service.users().threads().get().execute.side_effect = HttpError( + resp=MagicMock(status=404), + content=b'{"error": {"message": "Thread not found"}}', + ) + + with pytest.raises(ToolExecutionError): + await get_thread( + context=mock_context, + thread_id="invalid_thread", + ) From f002c8a22b0547f96ba2aa6fdd83cf69117d1a6d Mon Sep 17 00:00:00 2001 From: Eric Gustin Date: Tue, 19 Nov 2024 15:17:14 -0800 Subject: [PATCH 2/3] Add back urls --- toolkits/google/arcade_google/tools/gmail.py | 23 +++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/toolkits/google/arcade_google/tools/gmail.py b/toolkits/google/arcade_google/tools/gmail.py index 909f05fb..53bd74c3 100644 --- a/toolkits/google/arcade_google/tools/gmail.py +++ b/toolkits/google/arcade_google/tools/gmail.py @@ -14,6 +14,9 @@ DateRange, build_query_string, fetch_messages, + get_draft_url, + get_email_in_trash_url, + get_sent_email_url, parse_draft_email, parse_email, remove_none_values, @@ -59,7 +62,9 @@ async def send_email( # Send the email sent_message = service.users().messages().send(userId="me", body=email).execute() - return parse_email(sent_message) + email = parse_email(sent_message) + email["url"] = get_sent_email_url(sent_message["id"]) + return email @tool( @@ -80,7 +85,9 @@ async def send_draft_email( # Send the draft email sent_message = service.users().drafts().send(userId="me", body={"id": email_id}).execute() - return parse_email(sent_message) + email = parse_email(sent_message) + email["url"] = get_sent_email_url(sent_message["id"]) + return email # Draft Management Tools @@ -118,7 +125,9 @@ async def write_draft_email( draft = {"message": {"raw": raw_message}} draft_message = service.users().drafts().create(userId="me", body=draft).execute() - return parse_draft_email(draft_message) + email = parse_draft_email(draft_message) + email["url"] = get_draft_url(draft_message["id"]) + return email @tool( @@ -160,7 +169,9 @@ async def update_draft_email( service.users().drafts().update(userId="me", id=draft_email_id, body=draft).execute() ) - return parse_draft_email(updated_draft_message) + email = parse_draft_email(updated_draft_message) + email["url"] = get_draft_url(updated_draft_message["id"]) + return email @tool( @@ -203,7 +214,9 @@ async def trash_email( # Trash the email trashed_email = service.users().messages().trash(userId="me", id=email_id).execute() - return parse_email(trashed_email) + email = parse_email(trashed_email) + email["url"] = get_email_in_trash_url(trashed_email["id"]) + return email # Draft Search Tools From bf19ee79036ae772b65cead246e68dd8c6ae5336 Mon Sep 17 00:00:00 2001 From: Eric Gustin Date: Wed, 20 Nov 2024 11:24:51 -0800 Subject: [PATCH 3/3] Add eval --- toolkits/google/evals/eval_google_gmail.py | 45 +++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/toolkits/google/evals/eval_google_gmail.py b/toolkits/google/evals/eval_google_gmail.py index d3b9f4ac..ede0fcc6 100644 --- a/toolkits/google/evals/eval_google_gmail.py +++ b/toolkits/google/evals/eval_google_gmail.py @@ -62,7 +62,7 @@ def gmail_eval_suite() -> EvalSuite: ) suite.add_case( - name="List threads", + name="Simple list threads", user_message="Get 42 threads like right now i even wanna see the ones in my trash", expected_tool_calls=[ ( @@ -76,6 +76,49 @@ def gmail_eval_suite() -> EvalSuite: ], ) + history = [ + {"role": "user", "content": "list 1 thread"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_X8V5Hw9iJ3wfB8WMZf8omAMi", + "type": "function", + "function": {"name": "Google_ListThreads", "arguments": '{"max_results":1}'}, + } + ], + }, + { + "role": "tool", + "content": '{"next_page_token":"10321400718999360131","num_threads":1,"threads":[{"historyId":"61691","id":"1934a8f8deccb749","snippet":"Hi Joe, I hope this email finds you well. Thank you for being a part of our community."}]}', + "tool_call_id": "call_X8V5Hw9iJ3wfB8WMZf8omAMi", + "name": "Google_ListThreads", + }, + { + "role": "assistant", + "content": "Here is one email thread:\n\n- **Snippet:** Hi Joe, I hope this email finds you well. Thank you for being a part of our community.\n- **Thread ID:** 1934a8f8deccb749\n- **History ID:** 61691", + }, + ] + suite.add_case( + name="List threads with history", + user_message="Get the next 5 threads", + additional_messages=history, + expected_tool_calls=[ + ( + list_threads, + { + "max_results": 5, + "page_token": "10321400718999360131", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="max_results", weight=0.2), + BinaryCritic(critic_field="page_token", weight=0.8), + ], + ) + suite.add_case( name="Search threads", user_message="Search for threads from johndoe@example.com to janedoe@example.com about that talk about 'Arcade AI' from yesterday",