Skip to content

Commit

Permalink
Add Gmail Thread Tools (#159)
Browse files Browse the repository at this point in the history
# PR Description
1. This PR adds three new tools:
    - GetThread (by ID)
    - ListThreads
    - SearchThreads
2. This PR updates the return type for various Gmail tools from str to
dict.
3. This PR adds evals and tests for the added tools
  • Loading branch information
EricGustin authored Nov 20, 2024
1 parent 82afd7e commit 2798cc0
Show file tree
Hide file tree
Showing 5 changed files with 445 additions and 69 deletions.
162 changes: 137 additions & 25 deletions toolkits/google/arcade_google/tools/gmail.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import json
from email.message import EmailMessage
from email.mime.text import MIMEText
from typing import Annotated, Optional
Expand All @@ -20,6 +19,7 @@
get_sent_email_url,
parse_draft_email,
parse_email,
remove_none_values,
)


Expand All @@ -36,7 +36,7 @@ async def send_email(
recipient: Annotated[str, "The recipient of the email"],
cc: Annotated[Optional[list[str]], "CC recipients of the email"] = None,
bcc: Annotated[Optional[list[str]], "BCC recipients of the email"] = None,
) -> Annotated[str, "A confirmation message with the sent email ID and URL"]:
) -> Annotated[dict, "A dictionary containing the sent email details"]:
"""
Send an email using the Gmail API.
"""
Expand All @@ -61,7 +61,10 @@ async def send_email(

# Send the email
sent_message = service.users().messages().send(userId="me", body=email).execute()
return f"Email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"

email = parse_email(sent_message)
email["url"] = get_sent_email_url(sent_message["id"])
return email


@tool(
Expand All @@ -71,7 +74,7 @@ async def send_email(
)
async def send_draft_email(
context: ToolContext, email_id: Annotated[str, "The ID of the draft to send"]
) -> Annotated[str, "A confirmation message with the sent email ID and URL"]:
) -> Annotated[dict, "A dictionary containing the sent email details"]:
"""
Send a draft email using the Gmail API.
"""
Expand All @@ -82,10 +85,9 @@ async def send_draft_email(
# Send the draft email
sent_message = service.users().drafts().send(userId="me", body={"id": email_id}).execute()

# Construct the URL to the sent email
return (
f"Draft email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"
)
email = parse_email(sent_message)
email["url"] = get_sent_email_url(sent_message["id"])
return email


# Draft Management Tools
Expand All @@ -101,7 +103,7 @@ async def write_draft_email(
recipient: Annotated[str, "The recipient of the draft email"],
cc: Annotated[Optional[list[str]], "CC recipients of the draft email"] = None,
bcc: Annotated[Optional[list[str]], "BCC recipients of the draft email"] = None,
) -> Annotated[str, "A confirmation message with the draft email ID and URL"]:
) -> Annotated[dict, "A dictionary containing the created draft email details"]:
"""
Compose a new email draft using the Gmail API.
"""
Expand All @@ -123,9 +125,9 @@ async def write_draft_email(
draft = {"message": {"raw": raw_message}}

draft_message = service.users().drafts().create(userId="me", body=draft).execute()
return (
f"Draft email with ID {draft_message['id']} created: {get_draft_url(draft_message['id'])}"
)
email = parse_draft_email(draft_message)
email["url"] = get_draft_url(draft_message["id"])
return email


@tool(
Expand All @@ -141,7 +143,7 @@ async def update_draft_email(
recipient: Annotated[str, "The recipient of the draft email"],
cc: Annotated[Optional[list[str]], "CC recipients of the draft email"] = None,
bcc: Annotated[Optional[list[str]], "BCC recipients of the draft email"] = None,
) -> Annotated[str, "A confirmation message with the updated draft email ID and URL"]:
) -> Annotated[dict, "A dictionary containing the updated draft email details"]:
"""
Update an existing email draft using the Gmail API.
"""
Expand All @@ -166,7 +168,10 @@ async def update_draft_email(
updated_draft_message = (
service.users().drafts().update(userId="me", id=draft_email_id, body=draft).execute()
)
return f"Draft email with ID {updated_draft_message['id']} updated: {get_draft_url(updated_draft_message['id'])}"

email = parse_draft_email(updated_draft_message)
email["url"] = get_draft_url(updated_draft_message["id"])
return email


@tool(
Expand Down Expand Up @@ -198,7 +203,7 @@ async def delete_draft_email(
)
async def trash_email(
context: ToolContext, email_id: Annotated[str, "The ID of the email to trash"]
) -> Annotated[str, "A confirmation message with the trashed email ID and URL"]:
) -> Annotated[dict, "A dictionary containing the trashed email details"]:
"""
Move an email to the trash folder using the Gmail API.
"""
Expand All @@ -207,9 +212,11 @@ async def trash_email(
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))

# Trash the email
service.users().messages().trash(userId="me", id=email_id).execute()
trashed_email = service.users().messages().trash(userId="me", id=email_id).execute()

return f"Email with ID {email_id} trashed successfully: {get_email_in_trash_url(email_id)}"
email = parse_email(trashed_email)
email["url"] = get_email_in_trash_url(trashed_email["id"])
return email


# Draft Search Tools
Expand All @@ -221,7 +228,7 @@ async def trash_email(
async def list_draft_emails(
context: ToolContext,
n_drafts: Annotated[int, "Number of draft emails to read"] = 5,
) -> Annotated[str, "A JSON string containing a list of draft email details and their IDs"]:
) -> Annotated[dict, "A dictionary containing a list of draft email details"]:
"""
Lists draft emails in the user's draft mailbox using the Gmail API.
"""
Expand All @@ -245,7 +252,7 @@ async def list_draft_emails(
except Exception as e:
print(f"Error reading draft email {draft_id}: {e}")

return json.dumps({"emails": emails})
return {"emails": emails}


# Email Search Tools
Expand All @@ -263,11 +270,11 @@ async def list_emails_by_header(
date_range: Annotated[Optional[DateRange], "The date range of the email"] = None,
limit: Annotated[Optional[int], "The maximum number of emails to return"] = 25,
) -> Annotated[
str, "A JSON string containing a list of email details matching the search criteria"
dict, "A dictionary containing a list of email details matching the search criteria"
]:
"""
Search for emails by header using the Gmail API.
At least one of the following parametersMUST be provided: sender, recipient, subject, body.
At least one of the following parameters MUST be provided: sender, recipient, subject, body.
"""
if not any([sender, recipient, subject, body]):
raise RetryableToolError(
Expand All @@ -281,10 +288,10 @@ async def list_emails_by_header(
messages = fetch_messages(service, query, limit)

if not messages:
return json.dumps({"emails": []})
return {"emails": []}

emails = process_messages(service, messages)
return json.dumps({"emails": emails})
return {"emails": emails}


def process_messages(service, messages):
Expand All @@ -307,7 +314,7 @@ def process_messages(service, messages):
async def list_emails(
context: ToolContext,
n_emails: Annotated[int, "Number of emails to read"] = 5,
) -> Annotated[str, "A JSON string containing a list of email details"]:
) -> Annotated[dict, "A dictionary containing a list of email details"]:
"""
Read emails from a Gmail account and extract plain text content.
"""
Expand All @@ -329,4 +336,109 @@ async def list_emails(
except Exception as e:
print(f"Error reading email {msg['id']}: {e}")

return json.dumps({"emails": emails})
return {"emails": emails}


@tool(
requires_auth=Google(
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
)
async def search_threads(
context: ToolContext,
page_token: Annotated[
Optional[str], "Page token to retrieve a specific page of results in the list"
] = None,
max_results: Annotated[int, "The maximum number of threads to return"] = 10,
include_spam_trash: Annotated[bool, "Whether to include spam and trash in the results"] = False,
label_ids: Annotated[Optional[list[str]], "The IDs of labels to filter by"] = None,
sender: Annotated[Optional[str], "The name or email address of the sender of the email"] = None,
recipient: Annotated[Optional[str], "The name or email address of the recipient"] = None,
subject: Annotated[Optional[str], "Words to find in the subject of the email"] = None,
body: Annotated[Optional[str], "Words to find in the body of the email"] = None,
date_range: Annotated[Optional[DateRange], "The date range of the email"] = None,
) -> Annotated[dict, "A dictionary containing a list of thread details"]:
"""Search for threads in the user's mailbox"""
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))

query = (
build_query_string(sender, recipient, subject, body, date_range)
if any([sender, recipient, subject, body, date_range])
else None
)

params = {
"userId": "me",
"maxResults": min(max_results, 500),
"pageToken": page_token,
"includeSpamTrash": include_spam_trash,
"labelIds": label_ids,
"q": query,
}
params = remove_none_values(params)

threads = []
next_page_token = None
# Paginate through thread pages until we have the desired number of threads
while len(threads) < max_results:
response = service.users().threads().list(**params).execute()

threads.extend(response.get("threads", []))
next_page_token = response.get("nextPageToken")

if not next_page_token:
break

params["pageToken"] = next_page_token
params["maxResults"] = min(max_results - len(threads), 500)

return {
"threads": threads,
"num_threads": len(threads),
"next_page_token": next_page_token,
}


@tool(
requires_auth=Google(
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
)
async def list_threads(
context: ToolContext,
page_token: Annotated[
Optional[str], "Page token to retrieve a specific page of results in the list"
] = None,
max_results: Annotated[int, "The maximum number of threads to return"] = 10,
include_spam_trash: Annotated[bool, "Whether to include spam and trash in the results"] = False,
) -> Annotated[dict, "A dictionary containing a list of thread details"]:
"""List threads in the user's mailbox."""
return await search_threads(context, page_token, max_results, include_spam_trash)


@tool(
requires_auth=Google(
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
)
async def get_thread(
context: ToolContext,
thread_id: Annotated[str, "The ID of the thread to retrieve"],
metadata_headers: Annotated[
Optional[list[str]], "When given and format is METADATA, only include headers specified."
] = None,
) -> Annotated[dict, "A dictionary containing the thread details"]:
"""Get the specified thread by ID."""
params = {
"userId": "me",
"id": thread_id,
"format": "full",
"metadataHeaders": metadata_headers,
}
params = remove_none_values(params)

service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
thread = service.users().threads().get(**params).execute()
thread["messages"] = [parse_email(message) for message in thread.get("messages", [])]

return thread
4 changes: 3 additions & 1 deletion toolkits/google/arcade_google/tools/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ class SendUpdatesOptions(Enum):
EXTERNAL_ONLY = "externalOnly" # Notifications are sent to non-Google Calendar guests only.


# Utils for Google Drive tools
# ---------------------------------------------------------------------------- #
# Google Drive Models and Enums
# ---------------------------------------------------------------------------- #
class Corpora(str, Enum):
"""
Bodies of items (files/documents) to which the query applies.
Expand Down
18 changes: 10 additions & 8 deletions toolkits/google/arcade_google/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ def parse_email(email_data: dict[str, Any]) -> Optional[dict[str, str]]:
Optional[Dict[str, str]]: Parsed email details or None if parsing fails.
"""
try:
payload = email_data["payload"]
headers = {d["name"].lower(): d["value"] for d in payload["headers"]}
payload = email_data.get("payload", {})
headers = {d["name"].lower(): d["value"] for d in payload.get("headers", [])}

body_data = _get_email_body(payload)

return {
"id": email_data.get("id", ""),
"thread_id": email_data.get("threadId", ""),
"from": headers.get("from", ""),
"date": headers.get("date", ""),
"subject": headers.get("subject", "No subject"),
"subject": headers.get("subject", ""),
"body": _clean_email_body(body_data) if body_data else "",
}
except Exception as e:
Expand All @@ -110,17 +111,18 @@ def parse_draft_email(draft_email_data: dict[str, Any]) -> Optional[dict[str, st
Optional[Dict[str, str]]: Parsed draft email details or None if parsing fails.
"""
try:
message = draft_email_data["message"]
payload = message["payload"]
headers = {d["name"].lower(): d["value"] for d in payload["headers"]}
message = draft_email_data.get("message", {})
payload = message.get("payload", {})
headers = {d["name"].lower(): d["value"] for d in payload.get("headers", [])}

body_data = _get_email_body(payload)

return {
"id": draft_email_data.get("id", ""),
"thread_id": draft_email_data.get("threadId", ""),
"from": headers.get("from", ""),
"date": headers.get("internaldate", ""),
"subject": headers.get("subject", "No subject"),
"subject": headers.get("subject", ""),
"body": _clean_email_body(body_data) if body_data else "",
}
except Exception as e:
Expand Down Expand Up @@ -226,7 +228,7 @@ def _update_datetime(day: Day | None, time: TimeSlot | None, time_zone: str) ->

def build_query_string(sender, recipient, subject, body, date_range):
"""
Helper function to build a query string for Gmail list_emails_by_header tool.
Helper function to build a query string for Gmail list_emails_by_header and search_threads tools.
"""
query = []
if sender:
Expand Down
Loading

0 comments on commit 2798cc0

Please sign in to comment.