Skip to content

Commit

Permalink
fix: seperated shutdown subtleties from busines logic
Browse files Browse the repository at this point in the history
  • Loading branch information
raizo07 committed Nov 27, 2024
1 parent acb909b commit c15d94e
Showing 1 changed file with 63 additions and 49 deletions.
112 changes: 63 additions & 49 deletions scripts/data/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import signal
from generate_data import generate_data
from format_args import format_args
import logging
from logging.handlers import TimedRotatingFileHandler

logger = logging.getLogger(__name__)
Expand All @@ -31,7 +30,17 @@
current_weight = 0
weight_lock = threading.Condition()
job_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
shutdown_event = threading.Event() # New event for coordinating shutdown


class CancellationToken:
def __init__(self):
self._is_cancelled = threading.Event()

def cancel(self):
self._is_cancelled.set()

def is_cancelled(self):
return self._is_cancelled.is_set()


class ShutdownRequested(Exception):
Expand All @@ -40,11 +49,12 @@ class ShutdownRequested(Exception):
pass


def run(cmd, timeout=None):
def run(cmd, timeout=None, cancellation_token=None):
"""
Run a subprocess with proper cancellation handling
"""
Run a subprocess with proper shutdown handling"""
if shutdown_event.is_set():
raise ShutdownRequested("Shutdown requested before process start")
if cancellation_token and cancellation_token.is_cancelled():
raise ShutdownRequested("Cancellation requested before process start")

process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
Expand All @@ -53,14 +63,14 @@ def run(cmd, timeout=None):
try:
stdout, stderr = process.communicate(timeout=timeout)

if shutdown_event.is_set():
if cancellation_token and cancellation_token.is_cancelled():
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
raise ShutdownRequested("Shutdown requested during process execution")
raise ShutdownRequested("Cancellation requested during process execution")

return stdout, stderr, process.returncode

Expand Down Expand Up @@ -100,7 +110,9 @@ def __str__(self):


# Generator function to create jobs
def job_generator(start, blocks, step, mode, strategy, execute_scripts):
def job_generator(
start, blocks, step, mode, strategy, execute_scripts, cancellation_token=None
):
BASE_DIR.mkdir(exist_ok=True)
end = start + blocks

Expand All @@ -111,7 +123,7 @@ def job_generator(start, blocks, step, mode, strategy, execute_scripts):
)

for height in height_range:
if shutdown_event.is_set():
if cancellation_token and cancellation_token.is_cancelled():
break
try:
batch_file = BASE_DIR / f"{mode}_{height}_{step}.json"
Expand All @@ -130,10 +142,7 @@ def job_generator(start, blocks, step, mode, strategy, execute_scripts):
logger.error(f"Error while generating data for: {height}:\n{e}")


def process_batch(job):
if shutdown_event.is_set():
return

def process_batch(job, cancellation_token=None):
arguments_file = job.batch_file.as_posix().replace(".json", "-arguments.json")

with open(arguments_file, "w") as af:
Expand All @@ -151,7 +160,8 @@ def process_batch(job):
"main",
"--arguments-file",
str(arguments_file),
]
],
cancellation_token=cancellation_token,
)

if (
Expand Down Expand Up @@ -196,7 +206,7 @@ def process_batch(job):
logger.warning(f"{job}: no gas info found")

except ShutdownRequested:
logger.debug(f"Shutdown requested while processing {job}")
logger.debug(f"Cancellation requested while processing {job}")
return
except subprocess.TimeoutExpired:
logger.warning(f"Timeout while terminating subprocess for {job}")
Expand All @@ -205,30 +215,30 @@ def process_batch(job):


# Producer function: Generates data and adds jobs to the queue
def job_producer(job_gen):
def job_producer(job_gen, cancellation_token=None):
global current_weight

try:
for job, weight in job_gen:
if shutdown_event.is_set():
if cancellation_token and cancellation_token.is_cancelled():
break

# Wait until there is enough weight capacity to add the new block
with weight_lock:
logger.debug(
f"Adding job: {job}, current total weight: {current_weight}..."
)
while not shutdown_event.is_set() and (
while not (
cancellation_token and cancellation_token.is_cancelled()
) and (
(current_weight + weight > MAX_WEIGHT_LIMIT)
and current_weight != 0
or job_queue.full()
):
logger.debug("Producer is waiting for weight to be released.")
weight_lock.wait(
timeout=1.0
) # Wait with timeout to check shutdown_event
weight_lock.wait(timeout=1.0)

if shutdown_event.is_set():
if cancellation_token and cancellation_token.is_cancelled():
break

if (current_weight + weight > MAX_WEIGHT_LIMIT) and current_weight == 0:
Expand Down Expand Up @@ -256,15 +266,15 @@ def job_producer(job_gen):


# Consumer function: Processes blocks from the queue
def job_consumer(process_job):
def job_consumer(process_job, cancellation_token=None):
global current_weight

while not shutdown_event.is_set():
while not (cancellation_token and cancellation_token.is_cancelled()):
try:
logger.debug(
f"Consumer is waiting for a job. Queue length: {job_queue.qsize()}"
)
# Get a job from the queue with timeout to check shutdown_event
# Get a job from the queue with timeout to check cancellation
try:
work_to_do = job_queue.get(timeout=1.0)
except queue.Empty:
Expand All @@ -277,7 +287,7 @@ def job_consumer(process_job):

(job, weight) = work_to_do

if shutdown_event.is_set():
if cancellation_token and cancellation_token.is_cancelled():
with weight_lock:
current_weight -= weight
weight_lock.notify_all()
Expand All @@ -287,7 +297,7 @@ def job_consumer(process_job):
# Process the block
try:
logger.debug(f"Executing job: {job}...")
process_job(job)
process_job(job, cancellation_token)
except Exception as e:
logger.error(f"Error while processing job: {job}:\n{e}")

Expand All @@ -302,24 +312,21 @@ def job_consumer(process_job):
job_queue.task_done()

except Exception as e:
if not shutdown_event.is_set():
if not (cancellation_token and cancellation_token.is_cancelled()):
logger.error("Error in the consumer: %s", e)
break


def signal_handler(signum, frame):
"""Handle shutdown signals"""
signal_name = signal.Signals(signum).name
logger.info(f"Received signal {signal_name}. Initiating graceful shutdown...")
shutdown_event.set()

# Wake up any waiting threads
with weight_lock:
weight_lock.notify_all()
def main(start, blocks, step, mode, strategy, execute_scripts):
# Create a centralized cancellation mechanism
cancellation_token = CancellationToken()

# Set up signal handlers to use the cancellation token
def signal_handler(signum, frame):
signal_name = signal.Signals(signum).name
logger.info(f"Received signal {signal_name}. Initiating graceful shutdown...")
cancellation_token.cancel()

def main(start, blocks, step, mode, strategy, execute_scripts):
# Set up signal handlers
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)

Expand All @@ -340,31 +347,38 @@ def main(start, blocks, step, mode, strategy, execute_scripts):
)

# Create the job generator
job_gen = job_generator(start, blocks, step, mode, strategy, execute_scripts)
job_gen = job_generator(
start, blocks, step, mode, strategy, execute_scripts, cancellation_token
)

# Start the job producer thread
producer_thread = threading.Thread(target=job_producer, args=(job_gen,))
producer_thread = threading.Thread(
target=job_producer, args=(job_gen, cancellation_token)
)
producer_thread.start()

# Start the consumer threads using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=THREAD_POOL_SIZE) as executor:
futures = [
executor.submit(job_consumer, process_batch)
executor.submit(job_consumer, process_batch, cancellation_token)
for _ in range(THREAD_POOL_SIZE)
]

# Wait for producer to finish or shutdown signal
# Wait for producer to finish or cancellation
producer_thread.join()

# Wait for all items in the queue to be processed or shutdown signal
while not shutdown_event.is_set() and not job_queue.empty():
# Wait for all items in the queue to be processed or cancellation
while (
not (cancellation_token and cancellation_token.is_cancelled())
and not job_queue.empty()
):
try:
job_queue.join()
break
except KeyboardInterrupt:
shutdown_event.set()
cancellation_token.cancel()

if shutdown_event.is_set():
if cancellation_token.is_cancelled():
logger.info("Shutdown complete.")
else:
logger.info("All jobs have been processed.")
Expand Down Expand Up @@ -406,7 +420,7 @@ def main(start, blocks, step, mode, strategy, execute_scripts):

MAX_WEIGHT_LIMIT = args.maxweight

# file_handler = logging.FileHandler("client.log")
# Logging setup
file_handler = TimedRotatingFileHandler(
filename="client.log",
when="midnight",
Expand Down

0 comments on commit c15d94e

Please sign in to comment.