From fcd1ee06528033a1df65f625b93a61577dcfa2c5 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Wed, 16 Oct 2024 19:11:42 +0800 Subject: [PATCH] :sparkles: all tasks put into refs --- nonebot/adapters/satori/adapter.py | 10 ++++++++-- tests/test_connection.py | 4 +++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/nonebot/adapters/satori/adapter.py b/nonebot/adapters/satori/adapter.py index a6be1b4..d2beb10 100644 --- a/nonebot/adapters/satori/adapter.py +++ b/nonebot/adapters/satori/adapter.py @@ -48,7 +48,7 @@ def __init__(self, driver: Driver, **kwargs: Any): super().__init__(driver, **kwargs) # 读取适配器所需的配置项 self.satori_config: Config = get_plugin_config(Config) - self.tasks: list[asyncio.Task] = [] # 存储 ws 任务 + self.tasks: set[asyncio.Task] = set() # 存储 ws 任务等 self.sequences: dict[str, int] = {} # 存储 连接序列号 self._bots: defaultdict[str, set[str]] = defaultdict(set) # 存储 identity 和 bot_id 的映射 self.setup() @@ -80,7 +80,9 @@ def setup(self) -> None: async def startup(self) -> None: """定义启动时的操作,例如和平台建立连接""" for client in self.satori_config.satori_clients: - self.tasks.append(asyncio.create_task(self.ws(client))) + t = asyncio.create_task(self.ws(client)) + self.tasks.add(t) + t.add_done_callback(self.tasks.discard) async def shutdown(self) -> None: for task in self.tasks: @@ -184,6 +186,8 @@ async def ws(self, info: ClientInfo) -> None: await asyncio.sleep(3) continue heartbeat_task = asyncio.create_task(self._heartbeat(info, ws)) + self.tasks.add(heartbeat_task) + heartbeat_task.add_done_callback(self.tasks.discard) await self._loop(info, ws) except WebSocketClosed as e: log( @@ -266,6 +270,8 @@ async def _loop(self, info: ClientInfo, ws: WebSocket): if isinstance(event, (MessageEvent, InteractionEvent)): event = event.convert() _t = asyncio.create_task(bot.handle_event(event)) + self.tasks.add(_t) + _t.add_done_callback(self.tasks.discard) elif isinstance(payload, PongPayload): log("TRACE", "Pong") continue diff --git a/tests/test_connection.py b/tests/test_connection.py index 9167f15..05d010e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -39,7 +39,9 @@ def _ping(json: dict) -> dict: return {"op": 2} for client in adapter.satori_config.satori_clients: - adapter.tasks.append(asyncio.create_task(adapter.ws(client))) + task = asyncio.create_task(adapter.ws(client)) + adapter.tasks.add(task) + task.add_done_callback(adapter.tasks.discard) await asyncio.sleep(5) bots = nonebot.get_bots()