Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor early-exit in data_collector into the WorkerPool. #209

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion compiler_opt/distributed/local/local_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from absl import logging
# pylint: disable=unused-import
from compiler_opt.distributed import worker
from compiler_opt.distributed import buffered_scheduler

from contextlib import AbstractContextManager
from multiprocessing import connection
Expand Down Expand Up @@ -238,6 +239,18 @@ def __dir__(self):
return _Stub()


class LocalWorkerPool(worker.FixedWorkerPool):

def __init__(self, workers: List[Any], worker_concurrency: int):
super().__init__(workers=workers, worker_concurrency=worker_concurrency)

def schedule(self, work: List[Any]) -> List[worker.WorkerFuture]:
return buffered_scheduler.schedule(
work,
workers=self.get_currently_active(),
buffer=self.get_worker_concurrency())


class LocalWorkerPoolManager(AbstractContextManager):
"""A pool of workers hosted on the local machines, each in its own process."""

Expand All @@ -251,7 +264,7 @@ def __init__(self, worker_class: 'type[worker.Worker]', count: Optional[int],
]

def __enter__(self) -> worker.FixedWorkerPool:
return worker.FixedWorkerPool(workers=self._stubs, worker_concurrency=10)
return LocalWorkerPool(workers=self._stubs, worker_concurrency=10)

def __exit__(self, *args):
# first, trigger killing the worker process and exiting of the msg pump,
Expand Down
86 changes: 58 additions & 28 deletions compiler_opt/distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Common abstraction for a worker contract."""

import abc
import concurrent.futures
import sys
from typing import Any, List, Iterable, Optional, Protocol, TypeVar

Expand All @@ -32,34 +33,6 @@ def is_priority_method(cls, method_name: str) -> bool:
T = TypeVar('T')


class WorkerPool(metaclass=abc.ABCMeta):
"""Abstraction of a pool of workers that may be refreshed."""

# Issue #155 would strongly-type the return type.
@abc.abstractmethod
def get_currently_active(self) -> List[Any]:
raise NotImplementedError()

@abc.abstractmethod
def get_worker_concurrency(self) -> int:
raise NotImplementedError()


class FixedWorkerPool(WorkerPool):
"""A WorkerPool built from a fixed list of workers."""

# Issue #155 would strongly-type `workers`
def __init__(self, workers: List[Any], worker_concurrency: int = 2):
self._workers = workers
self._worker_concurrency = worker_concurrency

def get_currently_active(self):
return self._workers

def get_worker_concurrency(self):
return self._worker_concurrency


# Dask's Futures are limited. This captures that.
class WorkerFuture(Protocol[T]):

Expand Down Expand Up @@ -91,6 +64,63 @@ def get_exception(worker_future: WorkerFuture) -> Optional[Exception]:
return e


def lift_futures_through_list(future_list: WorkerFuture,
expected_size: int) -> List[WorkerFuture]:
"""Convert Future[List] to List[Future]."""
flattened = [concurrent.futures.Future() for _ in range(expected_size)]

def _handler(fut):
if e := get_exception(fut):
for f in flattened:
f.set_exception(e)
return

for i, res in enumerate(fut.result()):
assert i < expected_size
if isinstance(res, Exception):
flattened[i].set_exception(res)
else:
flattened[i].set_result(res)
for j in range(i + 1, expected_size):
flattened[j].set_exception(
ValueError(f'No value returned for index {j} in future_list'))

future_list.add_done_callback(_handler)
return flattened


class WorkerPool(metaclass=abc.ABCMeta):
"""Abstraction of a pool of workers that may be refreshed."""

# Issue #155 would strongly-type the return type.
@abc.abstractmethod
def get_currently_active(self) -> List[Any]:
raise NotImplementedError()

@abc.abstractmethod
def get_worker_concurrency(self) -> int:
raise NotImplementedError()

@abc.abstractmethod
def schedule(self, work: List[Any]) -> List[WorkerFuture]:
raise NotImplementedError()


class FixedWorkerPool(WorkerPool):
"""A WorkerPool built from a fixed list of workers."""

# Issue #155 would strongly-type `workers`
def __init__(self, workers: List[Any], worker_concurrency: int = 2):
self._workers = workers
self._worker_concurrency = worker_concurrency

def get_currently_active(self):
return self._workers

def get_worker_concurrency(self):
return self._worker_concurrency


def get_full_worker_args(worker_class: 'type[Worker]', current_kwargs):
"""Get the union of given kwargs and gin config.

Expand Down
60 changes: 60 additions & 0 deletions compiler_opt/distributed/worker_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for worker."""

from absl.testing import absltest
import concurrent.futures

from compiler_opt.distributed import worker


class LiftFuturesThroughListTest(absltest.TestCase):

def test_normal_path(self):
expected_list = [1, True, [2.0, False]]
future_list = concurrent.futures.Future()
list_future = worker.lift_futures_through_list(future_list,
len(expected_list))
future_list.set_result(expected_list)
worker.wait_for(list_future)

self.assertEqual([f.result() for f in list_future], expected_list)

def test_with_exceptions_in_list(self):
expected_list = [1, ValueError('error')]
future_list = concurrent.futures.Future()
list_future = worker.lift_futures_through_list(future_list,
len(expected_list))
future_list.set_result(expected_list)
worker.wait_for(list_future)

self.assertEqual(list_future[0].result(), expected_list[0])
self.assertTrue(
isinstance(worker.get_exception(list_future[1]), ValueError))

def test_list_is_exception(self):
expected_size = 42
future_list = concurrent.futures.Future()
list_future = worker.lift_futures_through_list(future_list, expected_size)
future_list.set_exception(ValueError('error'))

worker.wait_for(list_future)
self.assertEqual(len(list_future), expected_size)
for f in list_future:
self.assertTrue(isinstance(worker.get_exception(f), ValueError))


if __name__ == '__main__':
absltest.main()
88 changes: 87 additions & 1 deletion compiler_opt/rl/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
"""Data collection module."""

import abc
import concurrent.futures
import time
from typing import Dict, Iterator, Tuple, Sequence
from typing import Any, Dict, Iterator, List, Tuple, Sequence

from absl import logging
import numpy as np
from compiler_opt.rl import policy_saver
from tf_agents.trajectories import trajectory

from compiler_opt.distributed import worker

# Deadline for data collection.
DEADLINE_IN_SECONDS = 30

Expand Down Expand Up @@ -138,3 +142,85 @@ def wait(self, get_num_finished_work):

def waited_time(self):
return self._waited_time


class CancelledForEarlyExitException(Exception):
...


def _create_cancelled_future():
f = concurrent.futures.Future()
f.set_exception(CancelledForEarlyExitException())
return f


class EarlyExitWorkerPool(worker.WorkerPool):
"""Worker pool wrapper which performs early-exit checking.

Note that this worker pool wraps another worker pool, and this wrapper only
manages cancelling work from the underlying pool. Also, due to the nature of
"early exit," the futures that this pool's schedule() method returns are all
already .done().
"""

def __init__(self,
worker_pool: worker.WorkerPool,
exit_checker_ctor=EarlyExitChecker):
"""
Args:
worker_pool: the underlying worker pool to schedule work on.
exit_checker_ctor: the exit checker constructor to use.
"""
self._worker_pool = worker_pool
self._reset_workers_pool = concurrent.futures.ThreadPoolExecutor()
self._reset_workers_future: Optional[concurrent.futures.Future] = None
self._exit_checker_ctor = exit_checker_ctor

def get_currently_active(self) -> List[Any]:
return self._worker_pool.get_currently_active()

def get_worker_concurrency(self) -> int:
return self._worker_pool.get_worker_concurrency()

def schedule(self, work: List[Any]) -> List[worker.WorkerFuture]:
"""Schedule the provided work on the underlying worker pool.

After the work is scheduled, this method blocks until the early exit
checker deems it ok to exit early. Work that was cancelled will have a
future with a CancelledForEarlyExitException error.

Args:
work: the work to schedule.

Returns:
a list of futures which all are already .done().
"""

t1 = time.time()
if self._reset_workers_future:
concurrent.futures.wait([self._reset_workers_future])
self._reset_workers_future = None
logging.info('Waiting for pending work took %f', time.time() - t1)

result_futures = self._worker_pool.schedule(work)
early_exit = self._exit_checker_ctor(num_modules=len(work))
early_exit.wait(lambda: sum(res.done() for res in result_futures))

def _wrapup():
workers = self._worker_pool.get_currently_active()
cancel_futures = [wkr.cancel_all_work() for wkr in workers]
worker.wait_for(cancel_futures)
# now that the workers killed pending compilations, make sure the workers
# drained their working queues first - they should all complete quickly
# since the cancellation manager is killing immediately any process starts
worker.wait_for(result_futures)
worker.wait_for([wkr.enable() for wkr in workers])

def _process_future(f):
if f.done():
return f
return _create_cancelled_future()

results = [_process_future(f) for f in result_futures]
self._reset_future = self._reset_workers_pool.submit(_wrapup)
return results
Loading