Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tools Hot Reload #115

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions arcade/arcade/actor/core/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import logging
import os
import time
import uuid
from datetime import datetime
from typing import Any, Callable, ClassVar

Expand All @@ -13,13 +15,16 @@
CallToolComponent,
CatalogComponent,
HealthCheckComponent,
ToolStatusComponent,
)
from arcade.core.catalog import ToolCatalog, Toolkit
from arcade.core.executor import ToolExecutor
from arcade.core.schema import (
ToolCallRequest,
ToolCallResponse,
ToolDefinition,
ToolStatusRequest,
ToolStatusResponse,
)

logger = logging.getLogger(__name__)
Expand All @@ -37,6 +42,7 @@
CatalogComponent,
CallToolComponent,
HealthCheckComponent,
ToolStatusComponent,
)

def __init__(
Expand All @@ -47,6 +53,7 @@
If no secret is provided, the actor will use the ARCADE_ACTOR_SECRET environment variable.
"""
self.catalog = ToolCatalog()
self._update_catalog_uuid()
self.disable_auth = disable_auth
if disable_auth:
logger.warning(
Expand All @@ -60,6 +67,9 @@
"tool_call", "requests", "Total number of tools called"
)

def _update_catalog_uuid(self) -> None:
self.uuid = str(uuid.uuid4())

def _set_secret(self, secret: str | None, disable_auth: bool) -> str:
if disable_auth:
return ""
Expand All @@ -77,6 +87,10 @@
"No secret provided for actor. Set the ARCADE_ACTOR_SECRET environment variable."
)

def clear_catalog(self) -> None:
self.catalog = ToolCatalog()
self._update_catalog_uuid()

Check warning on line 92 in arcade/arcade/actor/core/base.py

View check run for this annotation

Codecov / codecov/patch

arcade/arcade/actor/core/base.py#L91-L92

Added lines #L91 - L92 were not covered by tests

def get_catalog(self) -> list[ToolDefinition]:
"""
Get the catalog as a list of ToolDefinitions.
Expand All @@ -88,12 +102,14 @@
Register a tool to the catalog.
"""
self.catalog.add_tool(tool, toolkit_name)
self._update_catalog_uuid()

Check warning on line 105 in arcade/arcade/actor/core/base.py

View check run for this annotation

Codecov / codecov/patch

arcade/arcade/actor/core/base.py#L105

Added line #L105 was not covered by tests

def register_toolkit(self, toolkit: Toolkit) -> None:
"""
Register a toolkit to the catalog.
"""
self.catalog.add_toolkit(toolkit)
self._update_catalog_uuid()

Check warning on line 112 in arcade/arcade/actor/core/base.py

View check run for this annotation

Codecov / codecov/patch

arcade/arcade/actor/core/base.py#L112

Added line #L112 was not covered by tests

async def call_tool(self, tool_request: ToolCallRequest) -> ToolCallResponse:
"""
Expand Down Expand Up @@ -169,6 +185,24 @@
output=output,
)

async def tool_status(self, tool_request: ToolStatusRequest) -> ToolStatusResponse:
if tool_request.uuid != self.uuid:
return ToolStatusResponse(uuid=self.uuid)
# If no update, wait for changes
try:
await self._wait_for_catalog_update(tool_request.uuid)
return ToolStatusResponse(uuid=self.uuid)

Check warning on line 194 in arcade/arcade/actor/core/base.py

View check run for this annotation

Codecov / codecov/patch

arcade/arcade/actor/core/base.py#L194

Added line #L194 was not covered by tests
except asyncio.TimeoutError:
# If no update after timeout, return current status
return ToolStatusResponse(uuid=self.uuid)

async def _wait_for_catalog_update(self, uuid: str, timeout: float = 30.0) -> None:
start_time = asyncio.get_event_loop().time()
while self.uuid == uuid:
if asyncio.get_event_loop().time() - start_time > timeout:
raise asyncio.TimeoutError("Timeout waiting for catalog update")
await asyncio.sleep(1)

def health_check(self) -> dict[str, Any]:
"""
Provide a health check that serves as a heartbeat of actor health.
Expand Down
15 changes: 14 additions & 1 deletion arcade/arcade/actor/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

from pydantic import BaseModel

from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition
from arcade.core.schema import (
ToolCallRequest,
ToolCallResponse,
ToolDefinition,
ToolStatusRequest,
ToolStatusResponse,
)


class RequestData(BaseModel):
Expand Down Expand Up @@ -56,6 +62,13 @@
"""
pass

@abstractmethod
async def tool_status(self, request: ToolStatusRequest) -> ToolStatusResponse:
"""
Send a request to get the last uuid of the toolkits
"""
pass

Check warning on line 70 in arcade/arcade/actor/core/common.py

View check run for this annotation

Codecov / codecov/patch

arcade/arcade/actor/core/common.py#L70

Added line #L70 was not covered by tests

@abstractmethod
def health_check(self) -> dict[str, Any]:
"""
Expand Down
29 changes: 28 additions & 1 deletion arcade/arcade/actor/core/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from opentelemetry import trace

from arcade.actor.core.common import Actor, ActorComponent, RequestData, Router
from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition
from arcade.core.schema import (
ToolCallRequest,
ToolCallResponse,
ToolDefinition,
ToolStatusRequest,
ToolStatusResponse,
)


class CatalogComponent(ActorComponent):
Expand Down Expand Up @@ -46,6 +52,27 @@
return await self.actor.call_tool(call_tool_request)


class ToolStatusComponent(ActorComponent):
def __init__(self, actor: Actor) -> None:
self.actor = actor

Check warning on line 57 in arcade/arcade/actor/core/components.py

View check run for this annotation

Codecov / codecov/patch

arcade/arcade/actor/core/components.py#L57

Added line #L57 was not covered by tests

def register(self, router: Router) -> None:
"""
Register the tool status check route with the router.
"""
router.add_route("tools/status", self, method="POST")

Check warning on line 63 in arcade/arcade/actor/core/components.py

View check run for this annotation

Codecov / codecov/patch

arcade/arcade/actor/core/components.py#L63

Added line #L63 was not covered by tests

async def __call__(self, request: RequestData) -> ToolStatusResponse:
"""
Handle long-polling requests for tool status updates.
"""
tracer = trace.get_tracer(__name__)
with tracer.start_as_current_span("ToolStatusCheck"):
call_tool_status_data = request.body_json
call_status_request = ToolStatusRequest.model_validate(call_tool_status_data)
return await self.actor.tool_status(call_status_request)

Check warning on line 73 in arcade/arcade/actor/core/components.py

View check run for this annotation

Codecov / codecov/patch

arcade/arcade/actor/core/components.py#L69-L73

Added lines #L69 - L73 were not covered by tests


class HealthCheckComponent(ActorComponent):
def __init__(self, actor: Actor) -> None:
self.actor = actor
Expand Down
14 changes: 14 additions & 0 deletions arcade/arcade/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import sys
import threading
from contextlib import asynccontextmanager
from typing import Any

Expand All @@ -24,6 +25,7 @@
)

from arcade.actor.fastapi.actor import FastAPIActor
from arcade.cli.watcher import ToolkitWatcher
from arcade.core.toolkit import Toolkit


Expand Down Expand Up @@ -127,6 +129,16 @@ def serve_default_actor(
for toolkit in toolkits:
actor.register_toolkit(toolkit)

shutdown_event = threading.Event()

toolkit_watcher = ToolkitWatcher(toolkits, actor, shutdown_event)

def run_polling() -> None:
asyncio.run(toolkit_watcher.start())

polling_thread = threading.Thread(target=run_polling, daemon=True)
polling_thread.start()

logger.info("Starting FastAPI server...")

class CustomUvicornServer(uvicorn.Server):
Expand Down Expand Up @@ -154,4 +166,6 @@ async def serve() -> None:
finally:
if enable_otel:
otel_handler.shutdown()
shutdown_event.set()
polling_thread.join(timeout=5)
logger.debug("Server shutdown complete.")
61 changes: 61 additions & 0 deletions arcade/arcade/cli/watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import asyncio
import threading

from loguru import logger

from arcade.actor.fastapi.actor import FastAPIActor
from arcade.core.schema import FullyQualifiedName
from arcade.core.toolkit import Toolkit


class ToolkitWatcher:
def __init__(
self,
initial: list[Toolkit],
actor: FastAPIActor,
shutdown_event: threading.Event,
):
self.current_tools: list[FullyQualifiedName] = self._list_tools(initial)
self.actor = actor
self.shutdown_event = shutdown_event

async def start(self, interval: int = 1) -> None:
while not self.shutdown_event.is_set():
try:
new_toolkits = Toolkit.find_all_arcade_toolkits()
new_tools = self._list_tools(new_toolkits)
if new_tools != self.current_tools:
logger.info("Toolkit changes detected. Updating actor's catalog...")

for tool in new_tools:
if tool not in self.current_tools:
logger.info(f"New tool added: {tool}")

for tool in self.current_tools:
if tool not in new_tools:
logger.info(f"Toolkit removed: {tool}")

self.actor.clear_catalog()
for toolkit in new_toolkits:
self.actor.register_toolkit(toolkit)

self.current_tools = new_tools

logger.info("Actor's catalog has been updated.")
else:
pass

except Exception:
logger.exception("Error while polling toolkits")

await asyncio.sleep(interval)

def _list_tools(self, toolkits: list[Toolkit]) -> list[FullyQualifiedName]:
tools_list = []
for toolkit in toolkits:
for _, tools in toolkit.tools.items():
if len(tools) != 0:
tools_list.extend([
FullyQualifiedName(tool, toolkit.name, toolkit.version) for tool in tools
])
return tools_list
14 changes: 14 additions & 0 deletions arcade/arcade/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ class ToolCallRequest(BaseModel):
"""The context for the tool invocation."""


class ToolStatusRequest(BaseModel):
"""The request to check for updates for tools."""

uuid: str
"""The UUID to compare against the current toolkits UUID"""


class ToolCallError(BaseModel):
"""The error that occurred during the tool invocation."""

Expand Down Expand Up @@ -317,3 +324,10 @@ class ToolCallResponse(BaseModel):
"""Whether the tool invocation was successful."""
output: ToolCallOutput | None = None
"""The output of the tool invocation."""


class ToolStatusResponse(BaseModel):
"""The response to a status invocation."""

uuid: str
"""The current UUID of the registered toolkits."""
67 changes: 67 additions & 0 deletions arcade/tests/actor/core/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import asyncio
import os
from unittest import mock

import pytest

from arcade.actor.core.base import BaseActor
from arcade.core.schema import ToolStatusRequest, ToolStatusResponse


@pytest.mark.asyncio
@mock.patch.dict(os.environ, {"ARCADE_ACTOR_SECRET": "test-secret"})
class TestBaseActor:
async def test_tool_status_different_uuid(self):
actor = BaseActor()
actor.uuid = "actor-uuid"
request = ToolStatusRequest(uuid="different-uuid")

response = await actor.tool_status(request)

assert isinstance(response, ToolStatusResponse)
assert response.uuid == actor.uuid

async def test_tool_status_timeout(self):
actor = BaseActor()
actor.uuid = "actor-uuid"
request = ToolStatusRequest(uuid="actor-uuid")

# Mock _wait_for_catalog_update to simulate timeout
original_wait_for_catalog_update = actor._wait_for_catalog_update

async def mock_wait_for_catalog_update(uuid: str, timeout: float = 0.5):
await asyncio.sleep(timeout)
raise asyncio.TimeoutError()

actor._wait_for_catalog_update = mock_wait_for_catalog_update

response = await actor.tool_status(request)

assert isinstance(response, ToolStatusResponse)
assert response.uuid == actor.uuid

# Cleanup
actor._wait_for_catalog_update = original_wait_for_catalog_update

async def test_wait_for_catalog_update_update_happens(self):
actor = BaseActor()
actor.uuid = "actor-uuid"

wait_task = asyncio.create_task(actor._wait_for_catalog_update("actor-uuid", timeout=1.0))

# Simulate catalog update after 0.1 second
await asyncio.sleep(0.1)
actor.uuid = "new-actor-uuid"

await wait_task

assert True

async def test_wait_for_catalog_update_timeout(self):
actor = BaseActor()
actor.uuid = "actor-uuid"

wait_task = asyncio.create_task(actor._wait_for_catalog_update("actor-uuid", timeout=0.5))

with pytest.raises(asyncio.TimeoutError):
await wait_task
Loading