From de3c28ce14404ca20e6650f42316775034269a74 Mon Sep 17 00:00:00 2001 From: sdreyer Date: Fri, 18 Oct 2024 12:48:02 -0700 Subject: [PATCH 1/8] Working --- arcade/arcade/actor/core/base.py | 35 ++++++++++++++ arcade/arcade/actor/core/common.py | 15 +++++- arcade/arcade/actor/core/components.py | 31 +++++++++++- arcade/arcade/cli/serve.py | 67 ++++++++++++++++++++++++++ arcade/arcade/core/catalog.py | 4 +- arcade/arcade/core/schema.py | 14 ++++++ 6 files changed, 162 insertions(+), 4 deletions(-) diff --git a/arcade/arcade/actor/core/base.py b/arcade/arcade/actor/core/base.py index 523aa6f0..0ef03ee7 100644 --- a/arcade/arcade/actor/core/base.py +++ b/arcade/arcade/actor/core/base.py @@ -1,6 +1,8 @@ +import asyncio import logging import os import time +import uuid from datetime import datetime from typing import Any, Callable, ClassVar @@ -13,6 +15,7 @@ CallToolComponent, CatalogComponent, HealthCheckComponent, + ToolStatusComponent, ) from arcade.core.catalog import ToolCatalog, Toolkit from arcade.core.executor import ToolExecutor @@ -20,6 +23,8 @@ ToolCallRequest, ToolCallResponse, ToolDefinition, + ToolStatusRequest, + ToolStatusResponse, ) logger = logging.getLogger(__name__) @@ -37,6 +42,7 @@ class BaseActor(Actor): CatalogComponent, CallToolComponent, HealthCheckComponent, + ToolStatusComponent, ) def __init__( @@ -47,6 +53,7 @@ def __init__( If no secret is provided, the actor will use the ARCADE_ACTOR_SECRET environment variable. """ self.catalog = ToolCatalog() + self._update_catalog_uuid() self.disable_auth = disable_auth if disable_auth: logger.warning( @@ -60,6 +67,9 @@ def __init__( "tool_call", "requests", "Total number of tools called" ) + def _update_catalog_uuid(self) -> None: + self.uuid = str(uuid.uuid4()) + def _set_secret(self, secret: str | None, disable_auth: bool) -> str: if disable_auth: return "" @@ -77,6 +87,10 @@ def _set_secret(self, secret: str | None, disable_auth: bool) -> str: "No secret provided for actor. Set the ARCADE_ACTOR_SECRET environment variable." ) + def new_catalog(self) -> None: + self.catalog = ToolCatalog() + self._update_catalog_uuid() + def get_catalog(self) -> list[ToolDefinition]: """ Get the catalog as a list of ToolDefinitions. @@ -88,12 +102,15 @@ def register_tool(self, tool: Callable, toolkit_name: str) -> None: Register a tool to the catalog. """ self.catalog.add_tool(tool, toolkit_name) + self._update_catalog_uuid() def register_toolkit(self, toolkit: Toolkit) -> None: """ Register a toolkit to the catalog. """ self.catalog.add_toolkit(toolkit) + self._update_catalog_uuid() + print(self.uuid) async def call_tool(self, tool_request: ToolCallRequest) -> ToolCallResponse: """ @@ -169,6 +186,24 @@ async def call_tool(self, tool_request: ToolCallRequest) -> ToolCallResponse: output=output, ) + async def tool_status(self, tool_request: ToolStatusRequest) -> ToolStatusResponse: + if tool_request.uuid != self.uuid: + return ToolStatusResponse(uuid=self.uuid) + # If no update, wait for changes + try: + await self._wait_for_catalog_update(tool_request.uuid) + return ToolStatusResponse(uuid=self.uuid) + except asyncio.TimeoutError: + # If no update after timeout, return current status + return ToolStatusResponse(uuid=self.uuid) + + async def _wait_for_catalog_update(self, uuid: str, timeout: float = 30.0) -> None: + start_time = asyncio.get_event_loop().time() + while self.uuid == uuid: + if asyncio.get_event_loop().time() - start_time > timeout: + raise asyncio.TimeoutError("Timeout waiting for catalog update") + await asyncio.sleep(1) + def health_check(self) -> dict[str, Any]: """ Provide a health check that serves as a heartbeat of actor health. diff --git a/arcade/arcade/actor/core/common.py b/arcade/arcade/actor/core/common.py index 4bcb6657..e50337eb 100644 --- a/arcade/arcade/actor/core/common.py +++ b/arcade/arcade/actor/core/common.py @@ -3,7 +3,13 @@ from pydantic import BaseModel -from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition +from arcade.core.schema import ( + ToolCallRequest, + ToolCallResponse, + ToolDefinition, + ToolStatusRequest, + ToolStatusResponse, +) class RequestData(BaseModel): @@ -56,6 +62,13 @@ async def call_tool(self, request: ToolCallRequest) -> ToolCallResponse: """ pass + @abstractmethod + async def tool_status(self, request: ToolStatusRequest) -> ToolStatusResponse: + """ + Send a request to get the last time tools were updated + """ + pass + @abstractmethod def health_check(self) -> dict[str, Any]: """ diff --git a/arcade/arcade/actor/core/components.py b/arcade/arcade/actor/core/components.py index 106e94a9..f1543c93 100644 --- a/arcade/arcade/actor/core/components.py +++ b/arcade/arcade/actor/core/components.py @@ -1,9 +1,16 @@ +import time from typing import Any from opentelemetry import trace from arcade.actor.core.common import Actor, ActorComponent, RequestData, Router -from arcade.core.schema import ToolCallRequest, ToolCallResponse, ToolDefinition +from arcade.core.schema import ( + ToolCallRequest, + ToolCallResponse, + ToolDefinition, + ToolStatusRequest, + ToolStatusResponse, +) class CatalogComponent(ActorComponent): @@ -46,6 +53,28 @@ async def __call__(self, request: RequestData) -> ToolCallResponse: return await self.actor.call_tool(call_tool_request) +class ToolStatusComponent(ActorComponent): + def __init__(self, actor: Actor) -> None: + self.actor = actor + self.last_update_time = time.time() + + def register(self, router: Router) -> None: + """ + Register the tool status check route with the router. + """ + router.add_route("tools/status", self, method="POST") + + async def __call__(self, request: RequestData) -> ToolStatusResponse: + """ + Handle long-polling requests for tool status updates. + """ + tracer = trace.get_tracer(__name__) + with tracer.start_as_current_span("ToolStatusCheck"): + call_tool_status_data = request.body_json + call_status_request = ToolStatusRequest.model_validate(call_tool_status_data) + return await self.actor.tool_status(call_status_request) + + class HealthCheckComponent(ActorComponent): def __init__(self, actor: Actor) -> None: self.actor = actor diff --git a/arcade/arcade/cli/serve.py b/arcade/arcade/cli/serve.py index 5ac66b73..8624d363 100644 --- a/arcade/arcade/cli/serve.py +++ b/arcade/arcade/cli/serve.py @@ -2,11 +2,13 @@ import logging import os import sys +import threading from contextlib import asynccontextmanager from typing import Any from loguru import logger +from arcade.core.schema import FullyQualifiedName from arcade.core.telemetry import OTELHandler try: @@ -80,6 +82,59 @@ async def lifespan(app: fastapi.FastAPI): # type: ignore[no-untyped-def] logger.debug("Lifespan cancelled.") +class ToolkitWatcher: + def __init__( + self, + initial: list[Toolkit], + actor: FastAPIActor, + shutdown_event: threading.Event, + ): + self.current_tools: list[FullyQualifiedName] = self._list_tools(initial) + self.actor = actor + self.shutdown_event = shutdown_event + + async def start(self, interval: int = 1) -> None: + while not self.shutdown_event.is_set(): + try: + new_toolkits = Toolkit.find_all_arcade_toolkits() + new_tools = self._list_tools(new_toolkits) + if new_tools != self.current_tools: + logger.info("Toolkit changes detected. Updating actor's catalog...") + + for tool in new_tools: + if tool not in self.current_tools: + logger.info(f"New tool added: {tool}") + + for tool in self.current_tools: + if tool not in new_tools: + logger.info(f"Toolkit removed: {tool}") + + self.actor.new_catalog() + for toolkit in new_toolkits: + self.actor.register_toolkit(toolkit) + + self.current_tools = new_tools + + logger.info("Actor's catalog has been updated.") + else: + pass + + except Exception: + logger.exception("Error while polling toolkits") + + await asyncio.sleep(interval) + + def _list_tools(self, toolkits: list[Toolkit]) -> list[FullyQualifiedName]: + tools_list = [] + for toolkit in toolkits: + for _, tools in toolkit.tools.items(): + if len(tools) != 0: + tools_list.extend([ + FullyQualifiedName(tool, toolkit.name, toolkit.version) for tool in tools + ]) + return tools_list + + def serve_default_actor( host: str = "127.0.0.1", port: int = 8002, @@ -127,6 +182,16 @@ def serve_default_actor( for toolkit in toolkits: actor.register_toolkit(toolkit) + shutdown_event = threading.Event() + + toolkit_watcher = ToolkitWatcher(toolkits, actor, shutdown_event) + + def run_polling() -> None: + asyncio.run(toolkit_watcher.start()) + + polling_thread = threading.Thread(target=run_polling, daemon=True) + polling_thread.start() + logger.info("Starting FastAPI server...") class CustomUvicornServer(uvicorn.Server): @@ -154,4 +219,6 @@ async def serve() -> None: finally: if enable_otel: otel_handler.shutdown() + shutdown_event.set() + polling_thread.join(timeout=5) logger.debug("Server shutdown complete.") diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index 496e8284..a9451a4b 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -384,7 +384,7 @@ def create_output_definition(func: Callable) -> ToolOutput: ) if hasattr(return_type, "__metadata__"): - description = return_type.__metadata__[0] if return_type.__metadata__ else None # type: ignore[assignment] + description = return_type.__metadata__[0] if return_type.__metadata__ else None return_type = return_type.__origin__ # Unwrap Optional types @@ -542,7 +542,7 @@ def get_wire_type_info(_type: type) -> WireTypeInfo: # Special case: Enum can be enumerated on the wire elif issubclass(type_to_check, Enum): is_enum = True - enum_values = [e.value for e in type_to_check] # type: ignore[union-attr] + enum_values = [e.value for e in type_to_check] return WireTypeInfo(wire_type, inner_wire_type, enum_values if is_enum else None) diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index 7fa3772c..29e8439a 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -252,6 +252,13 @@ class ToolCallRequest(BaseModel): """The context for the tool invocation.""" +class ToolStatusRequest(BaseModel): + """The request to check for updates for tools.""" + + uuid: str + """The timestamp to compare the last update against""" + + class ToolCallError(BaseModel): """The error that occurred during the tool invocation.""" @@ -317,3 +324,10 @@ class ToolCallResponse(BaseModel): """Whether the tool invocation was successful.""" output: ToolCallOutput | None = None """The output of the tool invocation.""" + + +class ToolStatusResponse(BaseModel): + """The response to a status invocation.""" + + uuid: str + """The timestamp when the tools were last updated.""" From 044f6a34b6bdd855e431b4eb1b64f6106324ed97 Mon Sep 17 00:00:00 2001 From: sdreyer Date: Fri, 18 Oct 2024 12:55:50 -0700 Subject: [PATCH 2/8] cleanup --- arcade/arcade/actor/core/base.py | 1 - arcade/arcade/actor/core/common.py | 2 +- arcade/arcade/actor/core/components.py | 2 -- arcade/arcade/core/schema.py | 4 ++-- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/arcade/arcade/actor/core/base.py b/arcade/arcade/actor/core/base.py index 0ef03ee7..ce11b400 100644 --- a/arcade/arcade/actor/core/base.py +++ b/arcade/arcade/actor/core/base.py @@ -110,7 +110,6 @@ def register_toolkit(self, toolkit: Toolkit) -> None: """ self.catalog.add_toolkit(toolkit) self._update_catalog_uuid() - print(self.uuid) async def call_tool(self, tool_request: ToolCallRequest) -> ToolCallResponse: """ diff --git a/arcade/arcade/actor/core/common.py b/arcade/arcade/actor/core/common.py index e50337eb..ad8efb3a 100644 --- a/arcade/arcade/actor/core/common.py +++ b/arcade/arcade/actor/core/common.py @@ -65,7 +65,7 @@ async def call_tool(self, request: ToolCallRequest) -> ToolCallResponse: @abstractmethod async def tool_status(self, request: ToolStatusRequest) -> ToolStatusResponse: """ - Send a request to get the last time tools were updated + Send a request to get the last uuid of the toolkits """ pass diff --git a/arcade/arcade/actor/core/components.py b/arcade/arcade/actor/core/components.py index f1543c93..1288e33a 100644 --- a/arcade/arcade/actor/core/components.py +++ b/arcade/arcade/actor/core/components.py @@ -1,4 +1,3 @@ -import time from typing import Any from opentelemetry import trace @@ -56,7 +55,6 @@ async def __call__(self, request: RequestData) -> ToolCallResponse: class ToolStatusComponent(ActorComponent): def __init__(self, actor: Actor) -> None: self.actor = actor - self.last_update_time = time.time() def register(self, router: Router) -> None: """ diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index 29e8439a..48e06e24 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -256,7 +256,7 @@ class ToolStatusRequest(BaseModel): """The request to check for updates for tools.""" uuid: str - """The timestamp to compare the last update against""" + """The UUID to compare against the current toolkits UUID""" class ToolCallError(BaseModel): @@ -330,4 +330,4 @@ class ToolStatusResponse(BaseModel): """The response to a status invocation.""" uuid: str - """The timestamp when the tools were last updated.""" + """The current UUID of the registered toolkits.""" From f7ccdfdd7c12a28b34441e79408df851873c4a6d Mon Sep 17 00:00:00 2001 From: sdreyer Date: Fri, 18 Oct 2024 15:32:01 -0700 Subject: [PATCH 3/8] Watcher tests --- arcade/arcade/cli/serve.py | 55 +---------- arcade/arcade/cli/watcher.py | 61 +++++++++++++ arcade/tests/cli/test_watcher.py | 152 +++++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 54 deletions(-) create mode 100644 arcade/arcade/cli/watcher.py create mode 100644 arcade/tests/cli/test_watcher.py diff --git a/arcade/arcade/cli/serve.py b/arcade/arcade/cli/serve.py index 8624d363..a7ffb599 100644 --- a/arcade/arcade/cli/serve.py +++ b/arcade/arcade/cli/serve.py @@ -8,7 +8,6 @@ from loguru import logger -from arcade.core.schema import FullyQualifiedName from arcade.core.telemetry import OTELHandler try: @@ -26,6 +25,7 @@ ) from arcade.actor.fastapi.actor import FastAPIActor +from arcade.cli.watcher import ToolkitWatcher from arcade.core.toolkit import Toolkit @@ -82,59 +82,6 @@ async def lifespan(app: fastapi.FastAPI): # type: ignore[no-untyped-def] logger.debug("Lifespan cancelled.") -class ToolkitWatcher: - def __init__( - self, - initial: list[Toolkit], - actor: FastAPIActor, - shutdown_event: threading.Event, - ): - self.current_tools: list[FullyQualifiedName] = self._list_tools(initial) - self.actor = actor - self.shutdown_event = shutdown_event - - async def start(self, interval: int = 1) -> None: - while not self.shutdown_event.is_set(): - try: - new_toolkits = Toolkit.find_all_arcade_toolkits() - new_tools = self._list_tools(new_toolkits) - if new_tools != self.current_tools: - logger.info("Toolkit changes detected. Updating actor's catalog...") - - for tool in new_tools: - if tool not in self.current_tools: - logger.info(f"New tool added: {tool}") - - for tool in self.current_tools: - if tool not in new_tools: - logger.info(f"Toolkit removed: {tool}") - - self.actor.new_catalog() - for toolkit in new_toolkits: - self.actor.register_toolkit(toolkit) - - self.current_tools = new_tools - - logger.info("Actor's catalog has been updated.") - else: - pass - - except Exception: - logger.exception("Error while polling toolkits") - - await asyncio.sleep(interval) - - def _list_tools(self, toolkits: list[Toolkit]) -> list[FullyQualifiedName]: - tools_list = [] - for toolkit in toolkits: - for _, tools in toolkit.tools.items(): - if len(tools) != 0: - tools_list.extend([ - FullyQualifiedName(tool, toolkit.name, toolkit.version) for tool in tools - ]) - return tools_list - - def serve_default_actor( host: str = "127.0.0.1", port: int = 8002, diff --git a/arcade/arcade/cli/watcher.py b/arcade/arcade/cli/watcher.py new file mode 100644 index 00000000..61c701d2 --- /dev/null +++ b/arcade/arcade/cli/watcher.py @@ -0,0 +1,61 @@ +import asyncio +import threading + +from loguru import logger + +from arcade.actor.fastapi.actor import FastAPIActor +from arcade.core.schema import FullyQualifiedName +from arcade.core.toolkit import Toolkit + + +class ToolkitWatcher: + def __init__( + self, + initial: list[Toolkit], + actor: FastAPIActor, + shutdown_event: threading.Event, + ): + self.current_tools: list[FullyQualifiedName] = self._list_tools(initial) + self.actor = actor + self.shutdown_event = shutdown_event + + async def start(self, interval: int = 1) -> None: + while not self.shutdown_event.is_set(): + try: + new_toolkits = Toolkit.find_all_arcade_toolkits() + new_tools = self._list_tools(new_toolkits) + if new_tools != self.current_tools: + logger.info("Toolkit changes detected. Updating actor's catalog...") + + for tool in new_tools: + if tool not in self.current_tools: + logger.info(f"New tool added: {tool}") + + for tool in self.current_tools: + if tool not in new_tools: + logger.info(f"Toolkit removed: {tool}") + + self.actor.new_catalog() + for toolkit in new_toolkits: + self.actor.register_toolkit(toolkit) + + self.current_tools = new_tools + + logger.info("Actor's catalog has been updated.") + else: + pass + + except Exception: + logger.exception("Error while polling toolkits") + + await asyncio.sleep(interval) + + def _list_tools(self, toolkits: list[Toolkit]) -> list[FullyQualifiedName]: + tools_list = [] + for toolkit in toolkits: + for _, tools in toolkit.tools.items(): + if len(tools) != 0: + tools_list.extend([ + FullyQualifiedName(tool, toolkit.name, toolkit.version) for tool in tools + ]) + return tools_list diff --git a/arcade/tests/cli/test_watcher.py b/arcade/tests/cli/test_watcher.py new file mode 100644 index 00000000..74125c6c --- /dev/null +++ b/arcade/tests/cli/test_watcher.py @@ -0,0 +1,152 @@ +import asyncio +import threading +import unittest +from unittest.mock import MagicMock, patch + +from arcade.actor.fastapi.actor import FastAPIActor +from arcade.cli.watcher import FullyQualifiedName, ToolkitWatcher +from arcade.core.toolkit import Toolkit + + +class TestToolkitWatcher(unittest.IsolatedAsyncioTestCase): + async def test_start_detects_toolkit_changes(self): + # Setup initial toolkits + initial_toolkit = MagicMock(spec=Toolkit) + initial_toolkit.name = "Toolkit1" + initial_toolkit.version = "1.0.0" + initial_toolkit.tools = {"group1": ["ToolA"]} + initial = [initial_toolkit] + + # Mock actor + actor = MagicMock(spec=FastAPIActor) + actor.register_toolkit = MagicMock() + actor.new_catalog = MagicMock() + + # Create shutdown event + shutdown_event = threading.Event() + + # Instantiate ToolkitWatcher + watcher = ToolkitWatcher(initial=initial, actor=actor, shutdown_event=shutdown_event) + + # Mock Toolkit.find_all_arcade_toolkits to return updated toolkits on the second call + updated_toolkit = MagicMock(spec=Toolkit) + updated_toolkit.name = "Toolkit1" + updated_toolkit.version = "1.0.1" + updated_toolkit.tools = {"group1": ["ToolA", "ToolB"]} + + with patch( + "arcade.core.toolkit.Toolkit.find_all_arcade_toolkits", + side_effect=[initial, [updated_toolkit], []], + ): + # Run watcher.start() in background + async def run_watcher(): + await watcher.start(interval=0.5) + + watcher_task = asyncio.create_task(run_watcher()) + + # Allow some time for the watcher to process + await asyncio.sleep(0.7) + + # Trigger shutdown to stop the watcher + shutdown_event.set() + + # Await the watcher task to finish + await watcher_task + + # Assertions + actor.new_catalog.assert_called() + actor.register_toolkit.assert_called_with(updated_toolkit) + + # Check that current_tools has been updated + expected_tools = [ + FullyQualifiedName("ToolA", "Toolkit1", "1.0.1"), + FullyQualifiedName("ToolB", "Toolkit1", "1.0.1"), + ] + self.assertEqual(watcher.current_tools, expected_tools) + + async def test_start_no_changes(self): + # Setup initial toolkits + initial_toolkit = MagicMock(spec=Toolkit) + initial_toolkit.name = "Toolkit1" + initial_toolkit.version = "1.0.0" + initial_toolkit.tools = {"group1": ["ToolA"]} + initial = [initial_toolkit] + + # Mock actor + actor = MagicMock(spec=FastAPIActor) + actor.register_toolkit = MagicMock() + actor.new_catalog = MagicMock() + + # Create shutdown event + shutdown_event = threading.Event() + + # Instantiate ToolkitWatcher + watcher = ToolkitWatcher(initial=initial, actor=actor, shutdown_event=shutdown_event) + + # Mock Toolkit.find_all_arcade_toolkits to always return the initial toolkit + with patch("arcade.core.toolkit.Toolkit.find_all_arcade_toolkits", return_value=initial): + # Run watcher.start() in background + async def run_watcher(): + await watcher.start(interval=0.1) + + watcher_task = asyncio.create_task(run_watcher()) + + # Allow some time for the watcher to process + await asyncio.sleep(0.3) + + # Trigger shutdown to stop the watcher + shutdown_event.set() + + # Await the watcher task to finish + await watcher_task + + # Assert that actor.new_catalog was never called since there were no changes + actor.new_catalog.assert_not_called() + actor.register_toolkit.assert_not_called() + + async def test_start_handles_exception(self): + # Setup initial toolkits + initial_toolkit = MagicMock(spec=Toolkit) + initial_toolkit.name = "Toolkit1" + initial_toolkit.version = "1.0.0" + initial_toolkit.tools = {"group1": ["ToolA"]} + initial = [initial_toolkit] + + # Mock actor + actor = MagicMock(spec=FastAPIActor) + actor.register_toolkit = MagicMock() + actor.new_catalog = MagicMock() + + # Create shutdown event + shutdown_event = threading.Event() + + # Instantiate ToolkitWatcher + watcher = ToolkitWatcher(initial=initial, actor=actor, shutdown_event=shutdown_event) + + # Mock Toolkit.find_all_arcade_toolkits to raise an exception + with patch( + "arcade.core.toolkit.Toolkit.find_all_arcade_toolkits", + side_effect=Exception("Test Exception"), + ): + # Run watcher.start() in background + async def run_watcher(): + await watcher.start(interval=0.1) + + watcher_task = asyncio.create_task(run_watcher()) + + # Allow some time for the watcher to process + await asyncio.sleep(0.3) + + # Trigger shutdown to stop the watcher + shutdown_event.set() + + # Await the watcher task to finish + await watcher_task + + # Since the exception is caught and logged, we can check that it does not crash + # Assert that actor.new_catalog was never called + actor.new_catalog.assert_not_called() + + +if __name__ == "__main__": + unittest.main() From a6433f8aef85e6ba60fe0483659655f414738422 Mon Sep 17 00:00:00 2001 From: sdreyer Date: Fri, 18 Oct 2024 15:32:44 -0700 Subject: [PATCH 4/8] Linting --- arcade/arcade/core/catalog.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index a9451a4b..496e8284 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -384,7 +384,7 @@ def create_output_definition(func: Callable) -> ToolOutput: ) if hasattr(return_type, "__metadata__"): - description = return_type.__metadata__[0] if return_type.__metadata__ else None + description = return_type.__metadata__[0] if return_type.__metadata__ else None # type: ignore[assignment] return_type = return_type.__origin__ # Unwrap Optional types @@ -542,7 +542,7 @@ def get_wire_type_info(_type: type) -> WireTypeInfo: # Special case: Enum can be enumerated on the wire elif issubclass(type_to_check, Enum): is_enum = True - enum_values = [e.value for e in type_to_check] + enum_values = [e.value for e in type_to_check] # type: ignore[union-attr] return WireTypeInfo(wire_type, inner_wire_type, enum_values if is_enum else None) From 204d53744e7645767b6cec2bc8cea27a5c667ef3 Mon Sep 17 00:00:00 2001 From: sdreyer Date: Fri, 18 Oct 2024 16:06:44 -0700 Subject: [PATCH 5/8] actor tests --- arcade/tests/actor/core/test_base.py | 67 ++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 arcade/tests/actor/core/test_base.py diff --git a/arcade/tests/actor/core/test_base.py b/arcade/tests/actor/core/test_base.py new file mode 100644 index 00000000..fe1c7ca4 --- /dev/null +++ b/arcade/tests/actor/core/test_base.py @@ -0,0 +1,67 @@ +import asyncio +import os +from unittest import mock + +import pytest + +from arcade.actor.core.base import BaseActor +from arcade.core.schema import ToolStatusRequest, ToolStatusResponse + + +@pytest.mark.asyncio +@mock.patch.dict(os.environ, {"ARCADE_ACTOR_SECRET": "test-secret"}) +class TestBaseActor: + async def test_tool_status_different_uuid(self): + actor = BaseActor() + actor.uuid = "actor-uuid" + request = ToolStatusRequest(uuid="different-uuid") + + response = await actor.tool_status(request) + + assert isinstance(response, ToolStatusResponse) + assert response.uuid == actor.uuid + + async def test_tool_status_timeout(self): + actor = BaseActor() + actor.uuid = "actor-uuid" + request = ToolStatusRequest(uuid="actor-uuid") + + # Mock _wait_for_catalog_update to simulate timeout + original_wait_for_catalog_update = actor._wait_for_catalog_update + + async def mock_wait_for_catalog_update(uuid: str, timeout: float = 0.5): + await asyncio.sleep(timeout) + raise asyncio.TimeoutError() + + actor._wait_for_catalog_update = mock_wait_for_catalog_update + + response = await actor.tool_status(request) + + assert isinstance(response, ToolStatusResponse) + assert response.uuid == actor.uuid + + # Cleanup + actor._wait_for_catalog_update = original_wait_for_catalog_update + + async def test_wait_for_catalog_update_update_happens(self): + actor = BaseActor() + actor.uuid = "actor-uuid" + + wait_task = asyncio.create_task(actor._wait_for_catalog_update("actor-uuid", timeout=1.0)) + + # Simulate catalog update after 0.1 second + await asyncio.sleep(0.1) + actor.uuid = "new-actor-uuid" + + await wait_task + + assert True + + async def test_wait_for_catalog_update_timeout(self): + actor = BaseActor() + actor.uuid = "actor-uuid" + + wait_task = asyncio.create_task(actor._wait_for_catalog_update("actor-uuid", timeout=0.5)) + + with pytest.raises(asyncio.TimeoutError): + await wait_task From e3354fd3443212351e7a546fde306d3f05dcc02c Mon Sep 17 00:00:00 2001 From: sdreyer Date: Wed, 23 Oct 2024 11:02:06 -0700 Subject: [PATCH 6/8] Added schemas --- .../preview/tools_status_request.schema.jsonc | 17 +++++++++++++++++ schemas/tools_status_response.schema.jsonc | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 schemas/preview/tools_status_request.schema.jsonc create mode 100644 schemas/tools_status_response.schema.jsonc diff --git a/schemas/preview/tools_status_request.schema.jsonc b/schemas/preview/tools_status_request.schema.jsonc new file mode 100644 index 00000000..dbe3b038 --- /dev/null +++ b/schemas/preview/tools_status_request.schema.jsonc @@ -0,0 +1,17 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "$schema": { + // Explicitly allow JSON-Schema to be referenced (needed due to additionalProperties: false) + "type": "string", + "format": "uri" + }, + "uuid": { + "type": "string", + "description": "The UUID to compare against the current toolkits UUID" + } + }, + "required": ["uuid"], + "additionalProperties": false +} diff --git a/schemas/tools_status_response.schema.jsonc b/schemas/tools_status_response.schema.jsonc new file mode 100644 index 00000000..56356152 --- /dev/null +++ b/schemas/tools_status_response.schema.jsonc @@ -0,0 +1,17 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "$schema": { + // Explicitly allow JSON-Schema to be referenced (needed due to additionalProperties: false) + "type": "string", + "format": "uri" + }, + "uuid": { + "type": "string", + "description": "The current UUID of the registered toolkits." + } + }, + "required": ["uuid"], + "additionalProperties": false +} From d16b29dec07befdc7ec288385e8bb17c4906976a Mon Sep 17 00:00:00 2001 From: sdreyer Date: Wed, 23 Oct 2024 11:03:21 -0700 Subject: [PATCH 7/8] Fixed file location --- schemas/{ => preview}/tools_status_response.schema.jsonc | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename schemas/{ => preview}/tools_status_response.schema.jsonc (100%) diff --git a/schemas/tools_status_response.schema.jsonc b/schemas/preview/tools_status_response.schema.jsonc similarity index 100% rename from schemas/tools_status_response.schema.jsonc rename to schemas/preview/tools_status_response.schema.jsonc From ecfc44dc9c325ababac96e185ba4ab81d5df3fb1 Mon Sep 17 00:00:00 2001 From: sdreyer Date: Wed, 23 Oct 2024 14:29:38 -0700 Subject: [PATCH 8/8] Update catalog fn name --- arcade/arcade/actor/core/base.py | 2 +- arcade/arcade/cli/watcher.py | 2 +- arcade/tests/cli/test_watcher.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/arcade/arcade/actor/core/base.py b/arcade/arcade/actor/core/base.py index ce11b400..5d2af198 100644 --- a/arcade/arcade/actor/core/base.py +++ b/arcade/arcade/actor/core/base.py @@ -87,7 +87,7 @@ def _set_secret(self, secret: str | None, disable_auth: bool) -> str: "No secret provided for actor. Set the ARCADE_ACTOR_SECRET environment variable." ) - def new_catalog(self) -> None: + def clear_catalog(self) -> None: self.catalog = ToolCatalog() self._update_catalog_uuid() diff --git a/arcade/arcade/cli/watcher.py b/arcade/arcade/cli/watcher.py index 61c701d2..6efa3ebd 100644 --- a/arcade/arcade/cli/watcher.py +++ b/arcade/arcade/cli/watcher.py @@ -35,7 +35,7 @@ async def start(self, interval: int = 1) -> None: if tool not in new_tools: logger.info(f"Toolkit removed: {tool}") - self.actor.new_catalog() + self.actor.clear_catalog() for toolkit in new_toolkits: self.actor.register_toolkit(toolkit) diff --git a/arcade/tests/cli/test_watcher.py b/arcade/tests/cli/test_watcher.py index 74125c6c..c81be6df 100644 --- a/arcade/tests/cli/test_watcher.py +++ b/arcade/tests/cli/test_watcher.py @@ -20,7 +20,7 @@ async def test_start_detects_toolkit_changes(self): # Mock actor actor = MagicMock(spec=FastAPIActor) actor.register_toolkit = MagicMock() - actor.new_catalog = MagicMock() + actor.clear_catalog = MagicMock() # Create shutdown event shutdown_event = threading.Event() @@ -54,7 +54,7 @@ async def run_watcher(): await watcher_task # Assertions - actor.new_catalog.assert_called() + actor.clear_catalog.assert_called() actor.register_toolkit.assert_called_with(updated_toolkit) # Check that current_tools has been updated @@ -75,7 +75,7 @@ async def test_start_no_changes(self): # Mock actor actor = MagicMock(spec=FastAPIActor) actor.register_toolkit = MagicMock() - actor.new_catalog = MagicMock() + actor.clear_catalog = MagicMock() # Create shutdown event shutdown_event = threading.Event() @@ -100,8 +100,8 @@ async def run_watcher(): # Await the watcher task to finish await watcher_task - # Assert that actor.new_catalog was never called since there were no changes - actor.new_catalog.assert_not_called() + # Assert that actor.clear_catalog was never called since there were no changes + actor.clear_catalog.assert_not_called() actor.register_toolkit.assert_not_called() async def test_start_handles_exception(self): @@ -115,7 +115,7 @@ async def test_start_handles_exception(self): # Mock actor actor = MagicMock(spec=FastAPIActor) actor.register_toolkit = MagicMock() - actor.new_catalog = MagicMock() + actor.clear_catalog = MagicMock() # Create shutdown event shutdown_event = threading.Event() @@ -144,8 +144,8 @@ async def run_watcher(): await watcher_task # Since the exception is caught and logged, we can check that it does not crash - # Assert that actor.new_catalog was never called - actor.new_catalog.assert_not_called() + # Assert that actor.clear_catalog was never called + actor.clear_catalog.assert_not_called() if __name__ == "__main__":