diff --git a/src/pycram/failure_handling.py b/src/pycram/failure_handling.py index 8fb266282..3e409eb6b 100644 --- a/src/pycram/failure_handling.py +++ b/src/pycram/failure_handling.py @@ -1,8 +1,12 @@ +from .datastructures.enums import State from .designator import DesignatorDescription from .plan_failures import PlanFailure +from threading import Lock +from typing_extensions import Union, Tuple, Any, List +from .language import Language, Monitor -class FailureHandling: +class FailureHandling(Language): """ Base class for failure handling mechanisms in automated systems or workflows. @@ -11,11 +15,12 @@ class FailureHandling: to be extended by subclasses that implement specific failure handling behaviors. """ - def __init__(self, designator_description: DesignatorDescription): + def __init__(self, designator_description: Union[DesignatorDescription, Monitor]): """ Initializes a new instance of the FailureHandling class. - :param designator_description: The description or context of the task or process for which the failure handling is being set up. + :param Union[DesignatorDescription, Monitor] designator_description: The description or context of the task + or process for which the failure handling is being set up. """ self.designator_description = designator_description @@ -37,15 +42,10 @@ class Retry(FailureHandling): This class represents a specific failure handling strategy where the system attempts to retry a failed action a certain number of times before giving up. - - Attributes: - max_tries (int): The maximum number of attempts to retry the action. - - Inherits: - All attributes and methods from the FailureHandling class. - - Overrides: - perform(): Implements the retry logic. + """ + max_tries: int + """ + The maximum number of attempts to retry the action. """ def __init__(self, designator_description: DesignatorDescription, max_tries: int = 3): @@ -58,7 +58,7 @@ def __init__(self, designator_description: DesignatorDescription, max_tries: int super().__init__(designator_description) self.max_tries = max_tries - def perform(self): + def perform(self) -> Tuple[State, List[Any]]: """ Implementation of the retry mechanism. @@ -79,5 +79,93 @@ def perform(self): raise e +class RetryMonitor(FailureHandling): + """ + A subclass of FailureHandling that implements a retry mechanism that works with a Monitor. + This class represents a specific failure handling strategy that allows us to retry a demo that is + being monitored, in case that monitoring condition is triggered. + """ + max_tries: int + """ + The maximum number of attempts to retry the action. + """ + recovery: dict + """ + A dictionary that maps exception types to recovery actions + """ + def __init__(self, designator_description: Monitor, max_tries: int = 3, recovery: dict = None): + """ + Initializes a new instance of the RetryMonitor class. + :param Monitor designator_description: The Monitor instance to be used. + :param int max_tries: The maximum number of attempts to retry. Defaults to 3. + :param dict recovery: A dictionary that maps exception types to recovery actions. Defaults to None. + """ + super().__init__(designator_description) + self.max_tries = max_tries + self.lock = Lock() + if recovery is None: + self.recovery = {} + else: + if not isinstance(recovery, dict): + raise ValueError( + "Recovery must be a dictionary with exception types as keys and Language instances as values.") + for key, value in recovery.items(): + if not issubclass(key, BaseException): + raise TypeError("Keys in the recovery dictionary must be exception types.") + if not isinstance(value, Language): + raise TypeError("Values in the recovery dictionary must be instances of the Language class.") + self.recovery = recovery + + def perform(self) -> Tuple[State, List[Any]]: + """ + This method attempts to perform the Monitor + plan specified in the designator_description. If the action + fails, it is retried up to max_tries times. If all attempts fail, the last exception is raised. In every + loop, we need to clear the kill_event, and set all relevant 'interrupted' variables to False, to make sure + the Monitor and plan are executed properly again. + + :raises PlanFailure: If all retry attempts fail. + + :return: The state of the execution performed, as well as a flattened list of the + results, in the correct order + """ + + def reset_interrupted(child): + child.interrupted = False + try: + for sub_child in child.children: + reset_interrupted(sub_child) + except AttributeError: + pass + + def flatten(result): + flattened_list = [] + if result: + for item in result: + if isinstance(item, list): + flattened_list.extend(item) + else: + flattened_list.append(item) + return flattened_list + return None + + status, res = None, None + with self.lock: + tries = 0 + while True: + self.designator_description.kill_event.clear() + self.designator_description.interrupted = False + for child in self.designator_description.children: + reset_interrupted(child) + try: + status, res = self.designator_description.perform() + break + except PlanFailure as e: + tries += 1 + if tries >= self.max_tries: + raise e + exception_type = type(e) + if exception_type in self.recovery: + self.recovery[exception_type].perform() + return status, flatten(res) diff --git a/src/pycram/language.py b/src/pycram/language.py index 81612a78b..f63b56a34 100644 --- a/src/pycram/language.py +++ b/src/pycram/language.py @@ -1,8 +1,9 @@ # used for delayed evaluation of typing until python 3.11 becomes mainstream from __future__ import annotations -import time -from typing_extensions import Iterable, Optional, Callable, Dict, Any, List, Union +from queue import Queue +import rospy +from typing_extensions import Iterable, Optional, Callable, Dict, Any, List, Union, Tuple from anytree import NodeMixin, Node, PreOrderIter from pycram.datastructures.enums import State @@ -260,6 +261,7 @@ def __init__(self, condition: Union[Callable, Fluent] = None): """ super().__init__(None, None) self.kill_event = threading.Event() + self.exception_queue = Queue() if callable(condition): self.condition = Fluent(condition) elif isinstance(condition, Fluent): @@ -267,27 +269,43 @@ def __init__(self, condition: Union[Callable, Fluent] = None): else: raise AttributeError("The condition of a Monitor has to be a Callable or a Fluent") - def perform(self): + def perform(self) -> Tuple[State, Any]: """ Behavior of the Monitor, starts a new Thread which checks the condition and then performs the attached language expression - :return: The result of the attached language expression + :return: The state of the attached language expression, as well as a list of the results of the children """ def check_condition(): - while not self.condition.get_value() and not self.kill_event.is_set(): - time.sleep(0.1) - if self.kill_event.is_set(): - return - for child in self.children: - child.interrupt() + while not self.kill_event.is_set(): + try: + cond = self.condition.get_value() + if cond: + for child in self.children: + try: + child.interrupt() + except NotImplementedError: + pass + if isinstance(cond, type) and issubclass(cond, Exception): + self.exception_queue.put(cond) + else: + self.exception_queue.put(PlanFailure("Condition met in Monitor")) + return + except Exception as e: + self.exception_queue.put(e) + return + rospy.sleep(0.1) t = threading.Thread(target=check_condition) t.start() - res = self.children[0].perform() - self.kill_event.set() - t.join() - return res + try: + state, result = self.children[0].perform() + if not self.exception_queue.empty(): + raise self.exception_queue.get() + finally: + self.kill_event.set() + t.join() + return state, result def interrupt(self) -> None: """ @@ -303,28 +321,35 @@ class Sequential(Language): Instead, the exception is saved to a list of all exceptions thrown during execution and returned. Behaviour: - Return the state :py:attr:`~State.SUCCEEDED` *iff* all children are executed without exception. - In any other case the State :py:attr:`~State.FAILED` will be returned. + Returns a tuple containing the final state of execution (SUCCEEDED, FAILED) and a list of results from each + child's perform() method. The state is :py:attr:`~State.SUCCEEDED` *iff* all children are executed without + exception. In any other case the State :py:attr:`~State.FAILED` will be returned. """ - def perform(self) -> State: + def perform(self) -> Tuple[State, List[Any]]: """ Behaviour of Sequential, calls perform() on each child sequentially - :return: The state according to the behaviour described in :func:`Sequential` + :return: The state and list of results according to the behaviour described in :func:`Sequential` """ + children_return_values = [None] * len(self.children) try: - for child in self.children: + for index, child in enumerate(self.children): if self.interrupted: if threading.get_ident() in self.block_list: self.block_list.remove(threading.get_ident()) - return + return State.FAILED, children_return_values self.root.executing_thread[child] = threading.get_ident() - child.resolve().perform() + ret_val = child.resolve().perform() + if isinstance(ret_val, tuple): + child_state, child_result = ret_val + children_return_values[index] = child_result + else: + children_return_values[index] = ret_val except PlanFailure as e: self.root.exceptions[self] = e - return State.FAILED - return State.SUCCEEDED + return State.FAILED, children_return_values + return State.SUCCEEDED, children_return_values def interrupt(self) -> None: """ @@ -343,33 +368,40 @@ class TryInOrder(Language): Instead, the exception is saved to a list of all exceptions thrown during execution and returned. Behaviour: - Returns the State :py:attr:`~State.SUCCEEDED` if one or more children are executed without + Returns a tuple containing the final state of execution (SUCCEEDED, FAILED) and a list of results from each + child's perform() method. The state is :py:attr:`~State.SUCCEEDED` if one or more children are executed without exception. In the case that all children could not be executed the State :py:attr:`~State.FAILED` will be returned. """ - def perform(self) -> State: + def perform(self) -> Tuple[State, List[Any]]: """ Behaviour of TryInOrder, calls perform() on each child sequentially and catches raised exceptions. - :return: The state according to the behaviour described in :func:`TryInOrder` + :return: The state and list of results according to the behaviour described in :func:`TryInOrder` """ failure_list = [] - for child in self.children: + children_return_values = [None] * len(self.children) + for index, child in enumerate(self.children): if self.interrupted: if threading.get_ident() in self.block_list: self.block_list.remove(threading.get_ident()) - return + return State.INTERRUPTED, children_return_values try: - child.resolve().perform() + ret_val = child.resolve().perform() + if isinstance(ret_val, tuple): + child_state, child_result = ret_val + children_return_values[index] = child_result + else: + children_return_values[index] = ret_val except PlanFailure as e: failure_list.append(e) if len(failure_list) > 0: self.root.exceptions[self] = failure_list if len(failure_list) == len(self.children): self.root.exceptions[self] = failure_list - return State.FAILED + return State.FAILED, children_return_values else: - return State.SUCCEEDED + return State.SUCCEEDED, children_return_values def interrupt(self) -> None: """ @@ -388,19 +420,27 @@ class Parallel(Language): exceptions during execution will be caught, saved to a list and returned upon end. Behaviour: - Returns the State :py:attr:`~State.SUCCEEDED` *iff* all children could be executed without an exception. In any - other case the State :py:attr:`~State.FAILED` will be returned. + Returns a tuple containing the final state of execution (SUCCEEDED, FAILED) and a list of results from + each child's perform() method. The state is :py:attr:`~State.SUCCEEDED` *iff* all children could be executed without + an exception. In any other case the State :py:attr:`~State.FAILED` will be returned. + """ - def perform(self) -> State: + def perform(self) -> Tuple[State, List[Any]]: """ Behaviour of Parallel, creates a new thread for each child and calls perform() of the child in the respective thread. - :return: The state according to the behaviour described in :func:`Parallel` + :return: The state and list of results according to the behaviour described in :func:`Parallel` + """ + results = [None] * len(self.children) + self.threads: List[threading.Thread] = [] + state = State.SUCCEEDED + results_lock = threading.Lock() - def lang_call(child_node): + def lang_call(child_node, index): + nonlocal state if ("DesignatorDescription" in [cls.__name__ for cls in child_node.__class__.__mro__] and self.__class__.__name__ not in self.do_not_use_giskard): if self not in giskard.par_threads.keys(): @@ -409,26 +449,39 @@ def lang_call(child_node): giskard.par_threads[self].append(threading.get_ident()) try: self.root.executing_thread[child] = threading.get_ident() - child_node.resolve().perform() + result = child_node.resolve().perform() + if isinstance(result, tuple): + child_state, child_result = result + with results_lock: + results[index] = child_result + else: + with results_lock: + results[index] = result except PlanFailure as e: + nonlocal state + with results_lock: + state = State.FAILED if self in self.root.exceptions.keys(): self.root.exceptions[self].append(e) else: self.root.exceptions[self] = [e] - for child in self.children: + for index, child in enumerate(self.children): if self.interrupted: + state = State.FAILED break - t = threading.Thread(target=lambda: lang_call(child)) + t = threading.Thread(target=lambda: lang_call(child, index)) t.start() self.threads.append(t) for thread in self.threads: thread.join() - if thread.ident in self.block_list: - self.block_list.remove(thread.ident) + with results_lock: + for thread in self.threads: + if thread.ident in self.block_list: + self.block_list.remove(thread.ident) if self in self.root.exceptions.keys() and len(self.root.exceptions[self]) != 0: - return State.FAILED - return State.SUCCEEDED + state = State.FAILED + return state, results def interrupt(self) -> None: """ @@ -448,20 +501,24 @@ class TryAll(Language): exceptions during execution will be caught, saved to a list and returned upon end. Behaviour: - Returns the State :py:attr:`~State.SUCCEEDED` if one or more children could be executed without raising an - exception. If all children fail the State :py:attr:`~State.FAILED` will be returned. + Returns a tuple containing the final state of execution (SUCCEEDED, FAILED) and a list of results from each + child's perform() method. The state is :py:attr:`~State.SUCCEEDED` if one or more children could be executed + without raising an exception. If all children fail the State :py:attr:`~State.FAILED` will be returned. """ - def perform(self) -> State: + def perform(self) -> Tuple[State, List[Any]]: """ Behaviour of TryAll, creates a new thread for each child and executes all children in their respective threads. - :return: The state according to the behaviour described in :func:`TryAll` + :return: The state and list of results according to the behaviour described in :func:`TryAll` """ + results = [None] * len(self.children) + results_lock = threading.Lock() + state = State.SUCCEEDED self.threads: List[threading.Thread] = [] failure_list = [] - def lang_call(child_node): + def lang_call(child_node, index): if ("DesignatorDescription" in [cls.__name__ for cls in child_node.__class__.__mro__] and self.__class__.__name__ not in self.do_not_use_giskard): if self not in giskard.par_threads.keys(): @@ -469,27 +526,37 @@ def lang_call(child_node): else: giskard.par_threads[self].append(threading.get_ident()) try: - child_node.resolve().perform() + result = child_node.resolve().perform() + if isinstance(result, tuple): + child_state, child_result = result + with results_lock: + results[index] = child_result + else: + with results_lock: + results[index] = result except PlanFailure as e: failure_list.append(e) if self in self.root.exceptions.keys(): self.root.exceptions[self].append(e) else: self.root.exceptions[self] = [e] - - for child in self.children: - t = threading.Thread(target=lambda: lang_call(child)) - t.start() + for index, child in enumerate(self.children): + if self.interrupted: + state = State.FAILED + break + t = threading.Thread(target=lambda: lang_call(child, index)) self.threads.append(t) + t.start() for thread in self.threads: thread.join() - if thread.ident in self.block_list: - self.block_list.remove(thread.ident) + with results_lock: + for thread in self.threads: + if thread.ident in self.block_list: + self.block_list.remove(thread.ident) if len(self.children) == len(failure_list): self.root.exceptions[self] = failure_list - return State.FAILED - else: - return State.SUCCEEDED + state = State.FAILED + return state, results def interrupt(self) -> None: """ @@ -529,9 +596,16 @@ def execute(self) -> Any: """ Execute the code with its arguments - :returns: Anything that the function associated with this object will return. + :returns: State.SUCCEEDED, and anything that the function associated with this object will return. """ - return self.function(**self.kwargs) + child_state = State.SUCCEEDED + ret_val = self.function(**self.kwargs) + if isinstance(ret_val, tuple): + child_state, child_result = ret_val + else: + child_result = ret_val + + return child_state, child_result def interrupt(self) -> None: raise NotImplementedError diff --git a/test/test_language.py b/test/test_language.py index bb3fec509..bfa41d647 100644 --- a/test/test_language.py +++ b/test/test_language.py @@ -4,8 +4,9 @@ from pycram.designators.action_designator import * from pycram.designators.object_designator import BelieveObject from pycram.datastructures.enums import ObjectType, State +from pycram.failure_handling import RetryMonitor from pycram.fluent import Fluent -from pycram.plan_failures import PlanFailure +from pycram.plan_failures import PlanFailure, NotALanguageExpression from pycram.datastructures.pose import Pose from pycram.language import Sequential, Language, Parallel, TryAll, TryInOrder, Monitor, Code from pycram.process_module import simulated_robot @@ -115,6 +116,80 @@ def monitor_func(): self.assertRaises(AttributeError, lambda: Monitor(monitor_func) >> Monitor(monitor_func)) + def test_retry_monitor_construction(self): + act = ParkArmsAction([Arms.BOTH]) + act2 = MoveTorsoAction([0.3]) + + def monitor_func(): + time.sleep(1) + return True + + def recovery1(): + return + + recover1 = Code(lambda: recovery1()) + recovery = {NotALanguageExpression: recover1} + + subplan = act + act2 >> Monitor(monitor_func) + plan = RetryMonitor(subplan, max_tries=6, recovery=recovery) + self.assertEqual(len(plan.recovery), 1) + self.assertIsInstance(plan.designator_description, Monitor) + + def test_retry_monitor_tries(self): + act = ParkArmsAction([Arms.BOTH]) + act2 = MoveTorsoAction([0.3]) + tries_counter = 0 + + def monitor_func(): + nonlocal tries_counter + tries_counter += 1 + return True + + subplan = act + act2 >> Monitor(monitor_func) + plan = RetryMonitor(subplan, max_tries=6) + try: + plan.perform() + except PlanFailure as e: + pass + self.assertEqual(tries_counter, 6) + + def test_retry_monitor_recovery(self): + recovery1_counter = 0 + recovery2_counter = 0 + + def monitor_func(): + if not hasattr(monitor_func, 'tries_counter'): + monitor_func.tries_counter = 0 + if monitor_func.tries_counter % 2: + monitor_func.tries_counter += 1 + return NotALanguageExpression + monitor_func.tries_counter += 1 + return PlanFailure + + def recovery1(): + nonlocal recovery1_counter + recovery1_counter += 1 + + def recovery2(): + nonlocal recovery2_counter + recovery2_counter += 1 + + recover1 = Code(lambda: recovery1()) + recover2 = Code(lambda: recovery2()) + recovery = {NotALanguageExpression: recover1, + PlanFailure: recover2} + + act = ParkArmsAction([Arms.BOTH]) + act2 = MoveTorsoAction([0.3]) + subplan = act + act2 >> Monitor(monitor_func) + plan = RetryMonitor(subplan, max_tries=6, recovery=recovery) + try: + plan.perform() + except PlanFailure as e: + pass + self.assertEqual(recovery1_counter, 2) + self.assertEqual(recovery2_counter, 3) + def test_repeat_construction(self): act = ParkArmsAction([Arms.BOTH]) act2 = MoveTorsoAction([0.3]) @@ -196,7 +271,7 @@ def raise_except(): plan = act + code with simulated_robot: - state = plan.perform() + state, _ = plan.perform() self.assertIsInstance(plan.exceptions[plan], PlanFailure) self.assertEqual(len(plan.exceptions.keys()), 1) self.assertEqual(state, State.FAILED) @@ -209,7 +284,7 @@ def raise_except(): plan = act - code with simulated_robot: - state = plan.perform() + state, _ = plan.perform() self.assertIsInstance(plan.exceptions[plan], list) self.assertIsInstance(plan.exceptions[plan][0], PlanFailure) self.assertEqual(len(plan.exceptions.keys()), 1) @@ -223,7 +298,7 @@ def raise_except(): plan = act | code with simulated_robot: - state = plan.perform() + state, _ = plan.perform() self.assertIsInstance(plan.exceptions[plan], list) self.assertIsInstance(plan.exceptions[plan][0], PlanFailure) self.assertEqual(len(plan.exceptions.keys()), 1) @@ -237,7 +312,7 @@ def raise_except(): plan = act ^ code with simulated_robot: - state = plan.perform() + state, _ = plan.perform() self.assertIsInstance(plan.exceptions[plan], list) self.assertIsInstance(plan.exceptions[plan][0], PlanFailure) self.assertEqual(len(plan.exceptions.keys()), 1)