diff --git a/broadcast_service/_core.py b/broadcast_service/_core.py index bf6749e..9b88c48 100644 --- a/broadcast_service/_core.py +++ b/broadcast_service/_core.py @@ -39,9 +39,11 @@ def _invoke_callback( ) -> Any: if enable_async: future_result = thread_pool.submit(callback, *args, **kwargs) - if future_result.result() is not None: - logger.debug(f"[broadcast-service invoke_callback result] {future_result.result()}") - return future_result.result() + def handle_future_result(future): + result = future.result() + if result is not None: + logger.debug(f"[broadcast-service invoke_callback result] {result}") + future_result.add_done_callback(handle_future_result) else: return callback(*args, **kwargs) @@ -161,7 +163,8 @@ def _invoke_broadcast_topic(self, topic_name: str, *args, **kwargs): def _final_invoke_listen_callback(self, callback: Callable, *args, **kwargs) -> Any: self.logger.debug(f"[broadcast-service] {callback.__name__} is called") - return _invoke_callback(callback, self.thread_pool, self.enable_async, *args, **kwargs) + _invoke_callback(callback, self.thread_pool, self.enable_async, *args, **kwargs) + return None def stop_listen(self, topic_name: str, callback: Callable): if topic_name not in self.pubsub_channels.keys(): @@ -375,12 +378,11 @@ def _invoke_broadcast_topic(self, topic_name: str, *args, **kwargs): f"[broadcast-service] start_publisher_callback_or_not: {self.cur_publisher_dispatch_config.start_publisher_callback_or_not}") if self.enable_config and self.cur_publisher_dispatch_config.start_publisher_callback_or_not: self._invoke_finish_callback() + return None def _final_invoke_listen_callback(self, callback: Callable, *args, **kwargs): - result = super()._final_invoke_listen_callback(callback, *args, **kwargs) - - if result: - self.cur_publisher_dispatch_config.append_sub_callback_results(result) + super()._final_invoke_listen_callback(callback, *args, **kwargs) + return None broadcast_service = BroadcastService() diff --git a/tests/test_invoke_callback.py b/tests/test_invoke_callback.py new file mode 100644 index 0000000..d46f4bf --- /dev/null +++ b/tests/test_invoke_callback.py @@ -0,0 +1,33 @@ +import unittest +from concurrent.futures import ThreadPoolExecutor + +from broadcast_service import BroadcastService + + +class TestInvokeCallback(unittest.TestCase): + def setUp(self): + self.broadcast_service = BroadcastService() + self.thread_pool = ThreadPoolExecutor(max_workers=5) + self.callback1 = lambda x: x * 2 + self.callback2 = lambda x: x + 2 + + def test_invoke_callback_async(self): + result = self.broadcast_service._invoke_callback(self.callback1, self.thread_pool, True, 5) + self.assertEqual(result, 10) + + def test_invoke_callback_sync(self): + result = self.broadcast_service._invoke_callback(self.callback1, self.thread_pool, False, 5) + self.assertEqual(result, 10) + + def test_multiple_callbacks(self): + self.broadcast_service.listen('Test', self.callback1) + self.broadcast_service.listen('Test', self.callback2) + self.broadcast_service.broadcast('Test', 5) + self.assertEqual(self.broadcast_service.pubsub_channels['Test'][0](5), 10) + self.assertEqual(self.broadcast_service.pubsub_channels['Test'][1](5), 7) + + def tearDown(self): + self.thread_pool.shutdown() + +if __name__ == '__main__': + unittest.main()