diff --git a/tests/unit/test_environment.py b/tests/unit/test_environment.py index 3bdd7b37b..83feb08ee 100644 --- a/tests/unit/test_environment.py +++ b/tests/unit/test_environment.py @@ -269,7 +269,9 @@ def test_executed_jp_action(self): with self.subTest(task_cls=task_cls): task = self.get_task( task_cls, JointPosition(True)) - demos = task.get_demos(5, live_demos=True) + num_episodes = 20 + demos = task.get_demos(num_episodes, live_demos=True) + total_reward = 0.0 # Check if executed joint position action is stored for demo in demos: jp_action = [] @@ -287,6 +289,9 @@ def test_executed_jp_action(self): obs, reward, term = task.step(action) if term: break - self.assertEqual(reward, 1.0) + total_reward += reward + + success_rate = total_reward / num_episodes + self.assertTrue(success_rate >= 0.9) self.env.shutdown() \ No newline at end of file