From d1a7400d00252d911776842bef29eee203dc32db Mon Sep 17 00:00:00 2001 From: xrsrke Date: Fri, 13 Oct 2023 14:07:12 +0700 Subject: [PATCH] refactor job callback's tests --- .../pipeline_parallel_2/job/test_callback.py | 58 ++++++++++++------- .../pipeline_parallel_2/job/test_creator.py | 10 ++-- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/tests/nn/pipeline_parallel_2/job/test_callback.py b/tests/nn/pipeline_parallel_2/job/test_callback.py index 4428643..0f1e8f6 100644 --- a/tests/nn/pipeline_parallel_2/job/test_callback.py +++ b/tests/nn/pipeline_parallel_2/job/test_callback.py @@ -1,6 +1,20 @@ import pytest from pipegoose.nn.pipeline_parallel2._job.callback import Callback +from pipegoose.nn.pipeline_parallel2._job.job import Job + + +class DummyJob(Job): + def run_compute(self): + return self.function(self.input.data) + + +@pytest.fixture +def job(forward_package): + def function(*args, **kwargs): + pass + + return DummyJob(function, forward_package) @pytest.fixture(scope="function") @@ -17,21 +31,21 @@ class Callback3(Callback): return [Callback1(), Callback2, Callback3] -def test_a_callback_access_job_attributes(forward_job): +def test_a_callback_access_job_attributes(job): QUEUE = [] class AccessJobAttributesCallback(Callback): def after_compute(self): QUEUE.append(self.job.key) - forward_job.add_cb(AccessJobAttributesCallback) - forward_job.compute() + job.add_cb(AccessJobAttributesCallback) + job.compute() assert len(QUEUE) == 1 - assert QUEUE == [forward_job.key] + assert QUEUE == [job.key] -def test_run_callbacks_by_order(forward_job): +def test_run_callbacks_by_order(job): QUEUE = [] class Callback1(Callback): @@ -52,13 +66,13 @@ class Callback3(Callback): def after_compute(self): QUEUE.append(3) - forward_job.add_cbs([Callback3, Callback1, Callback2]) - forward_job.compute() + job.add_cbs([Callback3, Callback1, Callback2]) + job.compute() assert QUEUE == [1, 2, 3] -def test_create_and_run_a_callback(forward_job): +def test_create_and_run_a_callback(job): QUEUE = [] class AddToQueueCallback(Callback): @@ -69,31 +83,31 @@ def after_compute(self): assert isinstance(cb.order, int) - forward_job.add_cb(cb) - forward_job.compute() + job.add_cb(cb) + job.compute() assert QUEUE == [69] -def test_add_and_remove_a_callback(forward_job): +def test_add_and_remove_a_callback(job): class ToyCallback(Callback): pass - N_ORIG_CBS = len(forward_job.cbs) + N_ORIG_CBS = len(job.cbs) cb = ToyCallback() - forward_job.add_cb(cb) - assert len(forward_job.cbs) == 1 + N_ORIG_CBS + job.add_cb(cb) + assert len(job.cbs) == 1 + N_ORIG_CBS - forward_job.remove_cb(cb) - assert len(forward_job.cbs) == N_ORIG_CBS + job.remove_cb(cb) + assert len(job.cbs) == N_ORIG_CBS -def test_add_and_remove_a_list_of_callback(forward_job, cbs): - N_ORIG_CBS = len(forward_job.cbs) +def test_add_and_remove_a_list_of_callback(job, cbs): + N_ORIG_CBS = len(job.cbs) - forward_job.add_cbs(cbs) - assert len(forward_job.cbs) == 3 + N_ORIG_CBS + job.add_cbs(cbs) + assert len(job.cbs) == 3 + N_ORIG_CBS - forward_job.remove_cbs(cbs) - assert len(forward_job.cbs) == N_ORIG_CBS + job.remove_cbs(cbs) + assert len(job.cbs) == N_ORIG_CBS diff --git a/tests/nn/pipeline_parallel_2/job/test_creator.py b/tests/nn/pipeline_parallel_2/job/test_creator.py index 7da37a1..c169672 100644 --- a/tests/nn/pipeline_parallel_2/job/test_creator.py +++ b/tests/nn/pipeline_parallel_2/job/test_creator.py @@ -18,13 +18,13 @@ # @pytest.mark.parametrize("package", ["forward_package", "backward_package"]) -# def test_the_job_status_after_executing_a_job(request, package, parallel_context, pipeline_context): -# package = request.getfixturevalue(package) -# job = create_job(function, package, parallel_context, pipeline_context) +def test_backward_job(backward_package, parallel_context, pipeline_context): + # package = request.getfixturevalue(package) + job = create_job(function, backward_package, parallel_context, pipeline_context) -# job.compute() + job.compute() -# assert job.status == JobStatus.EXECUTED + assert job.status == JobStatus.EXECUTED def run_create_a_job_from_package(