Skip to content

Commit

Permalink
Worker checks Task should_run_task to decide if it should execute run…
Browse files Browse the repository at this point in the history
…_task or republish with delay (#60)

Worker now allows Tasks to override should_run_task. The default behavior is to return True and execute run_task immediately after should_run_task. If a Task returns False, Worker will not execute run_task and republish the message with a delay. This allows API users to cancel the execution of a task without forcing them to raise an exception.
  • Loading branch information
dickinsonm authored Apr 28, 2020
1 parent 8a3e2be commit 4bbcd55
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 6 deletions.
16 changes: 11 additions & 5 deletions kale/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,21 +164,24 @@ def handle_failure(cls, message, raised_exception, increment_failure_num=True):
PERMANENT_FAILURE_RETRIES_EXCEEDED, False)
return False

failure_count = message.task_failure_num
if increment_failure_num:
failure_count = failure_count + 1
cls.republish(message, failure_count)
return True

@classmethod
def republish(cls, message, failure_count):
payload = {
'args': message.task_args,
'kwargs': message.task_kwargs,
'app_data': message.task_app_data}

retry_count = message.task_retry_num + 1
failure_count = message.task_failure_num
if increment_failure_num:
failure_count += 1
delay_sec = cls._get_delay_sec_for_retry(message.task_retry_num)
pub = cls._get_publisher()
pub.publish(
cls, message.task_id, payload,
current_retry_num=retry_count, current_failure_num=failure_count, delay_sec=delay_sec)
return True

def run(self, *args, **kwargs):
"""Wrap the run_task method of tasks.
Expand Down Expand Up @@ -220,6 +223,9 @@ def run_task(self, *args, **kwargs):
"""Run the task, this must be implemented by subclasses."""
raise NotImplementedError()

def should_run_task(self, *args, **kwargs):
return True

def _check_blacklist(self, *args, **kwargs):
"""Raises an exception if a task should not run.
Expand Down
13 changes: 13 additions & 0 deletions kale/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def run_task(self, *args, **kwargs):
raise exceptions.TaskException('Task failed.')


class ShouldNotRunTask(task.Task):

@classmethod
def _get_task_id(cls, *args, **kwargs):
return "should_not_run_task"

def should_run_task(self, *args, **kwargs):
return False

def run_task(self, *args, **kwargs):
pass


class TimeoutTask(task.Task):

@classmethod
Expand Down
29 changes: 29 additions & 0 deletions kale/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,35 @@ def testRunBatchTaskException(self):
self.assertEqual(0, len(worker_inst._permanent_failures))
self.assertEqual(num_messages, len(worker_inst._failed_messages))

def testRunBatchTaskShouldNotRun(self):
"""Test batch with a task which should not run."""
mock_consumer = self._create_patch('kale.consumer.Consumer')
get_time = self._create_patch('time.time')
mock_republish = self._create_patch(
'kale.test_utils.ShouldNotRunTask.republish')

worker_inst = worker.Worker()
worker_inst._batch_queue = worker_inst._queue_selector.get_queue()
mock_consumer.assert_called_once_with()

worker_inst._batch_stop_time = 100
# _batch_stop_time - (get_time + task.time_limit) > 0
# (100 - (10 + 60)) > 0)
get_time.return_value = 10

message = test_utils.new_mock_message(task_class=test_utils.ShouldNotRunTask)
message_batch = [message]

worker_inst._run_batch(message_batch)

message_republished, failure_count = mock_republish.call_args[0]
self.assertEqual(message, message_republished)
self.assertEqual(0, failure_count)
self.assertEqual(0, len(worker_inst._incomplete_messages))
self.assertEqual(0, len(worker_inst._successful_messages))
self.assertEqual(0, len(worker_inst._permanent_failures))
self.assertEqual(0, len(worker_inst._failed_messages))

def testRunBatchTaskExceptionPermanentFailure(self):
"""Test batch with a task exception."""
mock_consumer = self._create_patch('kale.consumer.Consumer')
Expand Down
2 changes: 1 addition & 1 deletion kale/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '2.1.0' # http://semver.org/
__version__ = '2.2.0' # http://semver.org/
6 changes: 6 additions & 0 deletions kale/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ def _run_single_message(self, message, time_remaining_sec):
task_inst = message.task_inst
try:
with timeout.time_limit(task_inst.time_limit):
if not self.should_run_task(message):
task_inst.__class__.republish(message, message.task_failure_num)
return
self.run_task(message)
except Exception as err:
# Re-publish failed tasks.
Expand Down Expand Up @@ -391,6 +394,9 @@ def remove_message_or_exit(self, message):
# Cleanup happened due to the signal handler - make sure we exit immediately.
sys.exit(0)

def should_run_task(self, message):
return message.task_inst.should_run_task(*message.task_args, **message.task_kwargs)

def run_task(self, message):
"""Run the task contained in the message.
:param message: message.KaleMessage containing the task and arguments to run.
Expand Down

0 comments on commit 4bbcd55

Please sign in to comment.