Skip to content

Commit

Permalink
refactor job callback's tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 13, 2023
1 parent 1b177cd commit d1a7400
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
58 changes: 36 additions & 22 deletions tests/nn/pipeline_parallel_2/job/test_callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
import pytest

from pipegoose.nn.pipeline_parallel2._job.callback import Callback
from pipegoose.nn.pipeline_parallel2._job.job import Job


class DummyJob(Job):
def run_compute(self):
return self.function(self.input.data)


@pytest.fixture
def job(forward_package):
def function(*args, **kwargs):
pass

return DummyJob(function, forward_package)


@pytest.fixture(scope="function")
Expand All @@ -17,21 +31,21 @@ class Callback3(Callback):
return [Callback1(), Callback2, Callback3]


def test_a_callback_access_job_attributes(forward_job):
def test_a_callback_access_job_attributes(job):
QUEUE = []

class AccessJobAttributesCallback(Callback):
def after_compute(self):
QUEUE.append(self.job.key)

forward_job.add_cb(AccessJobAttributesCallback)
forward_job.compute()
job.add_cb(AccessJobAttributesCallback)
job.compute()

assert len(QUEUE) == 1
assert QUEUE == [forward_job.key]
assert QUEUE == [job.key]


def test_run_callbacks_by_order(forward_job):
def test_run_callbacks_by_order(job):
QUEUE = []

class Callback1(Callback):
Expand All @@ -52,13 +66,13 @@ class Callback3(Callback):
def after_compute(self):
QUEUE.append(3)

forward_job.add_cbs([Callback3, Callback1, Callback2])
forward_job.compute()
job.add_cbs([Callback3, Callback1, Callback2])
job.compute()

assert QUEUE == [1, 2, 3]


def test_create_and_run_a_callback(forward_job):
def test_create_and_run_a_callback(job):
QUEUE = []

class AddToQueueCallback(Callback):
Expand All @@ -69,31 +83,31 @@ def after_compute(self):

assert isinstance(cb.order, int)

forward_job.add_cb(cb)
forward_job.compute()
job.add_cb(cb)
job.compute()

assert QUEUE == [69]


def test_add_and_remove_a_callback(forward_job):
def test_add_and_remove_a_callback(job):
class ToyCallback(Callback):
pass

N_ORIG_CBS = len(forward_job.cbs)
N_ORIG_CBS = len(job.cbs)
cb = ToyCallback()

forward_job.add_cb(cb)
assert len(forward_job.cbs) == 1 + N_ORIG_CBS
job.add_cb(cb)
assert len(job.cbs) == 1 + N_ORIG_CBS

forward_job.remove_cb(cb)
assert len(forward_job.cbs) == N_ORIG_CBS
job.remove_cb(cb)
assert len(job.cbs) == N_ORIG_CBS


def test_add_and_remove_a_list_of_callback(forward_job, cbs):
N_ORIG_CBS = len(forward_job.cbs)
def test_add_and_remove_a_list_of_callback(job, cbs):
N_ORIG_CBS = len(job.cbs)

forward_job.add_cbs(cbs)
assert len(forward_job.cbs) == 3 + N_ORIG_CBS
job.add_cbs(cbs)
assert len(job.cbs) == 3 + N_ORIG_CBS

forward_job.remove_cbs(cbs)
assert len(forward_job.cbs) == N_ORIG_CBS
job.remove_cbs(cbs)
assert len(job.cbs) == N_ORIG_CBS
10 changes: 5 additions & 5 deletions tests/nn/pipeline_parallel_2/job/test_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@


# @pytest.mark.parametrize("package", ["forward_package", "backward_package"])
# def test_the_job_status_after_executing_a_job(request, package, parallel_context, pipeline_context):
# package = request.getfixturevalue(package)
# job = create_job(function, package, parallel_context, pipeline_context)
def test_backward_job(backward_package, parallel_context, pipeline_context):
# package = request.getfixturevalue(package)
job = create_job(function, backward_package, parallel_context, pipeline_context)

# job.compute()
job.compute()

# assert job.status == JobStatus.EXECUTED
assert job.status == JobStatus.EXECUTED


def run_create_a_job_from_package(
Expand Down

0 comments on commit d1a7400

Please sign in to comment.