diff --git a/scripts/data/client.py b/scripts/data/client.py index 1d9d8f3c..1c786b28 100755 --- a/scripts/data/client.py +++ b/scripts/data/client.py @@ -15,7 +15,6 @@ import signal from generate_data import generate_data from format_args import format_args -import logging from logging.handlers import TimedRotatingFileHandler logger = logging.getLogger(__name__) @@ -31,7 +30,17 @@ current_weight = 0 weight_lock = threading.Condition() job_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) -shutdown_event = threading.Event() # New event for coordinating shutdown + + +class CancellationToken: + def __init__(self): + self._is_cancelled = threading.Event() + + def cancel(self): + self._is_cancelled.set() + + def is_cancelled(self): + return self._is_cancelled.is_set() class ShutdownRequested(Exception): @@ -40,11 +49,12 @@ class ShutdownRequested(Exception): pass -def run(cmd, timeout=None): +def run(cmd, timeout=None, cancellation_token=None): + """ + Run a subprocess with proper cancellation handling """ - Run a subprocess with proper shutdown handling""" - if shutdown_event.is_set(): - raise ShutdownRequested("Shutdown requested before process start") + if cancellation_token and cancellation_token.is_cancelled(): + raise ShutdownRequested("Cancellation requested before process start") process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True @@ -53,14 +63,14 @@ def run(cmd, timeout=None): try: stdout, stderr = process.communicate(timeout=timeout) - if shutdown_event.is_set(): + if cancellation_token and cancellation_token.is_cancelled(): process.terminate() try: process.wait(timeout=5) except subprocess.TimeoutExpired: process.kill() process.wait() - raise ShutdownRequested("Shutdown requested during process execution") + raise ShutdownRequested("Cancellation requested during process execution") return stdout, stderr, process.returncode @@ -100,7 +110,9 @@ def __str__(self): # Generator function to create jobs -def job_generator(start, blocks, step, mode, strategy, execute_scripts): +def job_generator( + start, blocks, step, mode, strategy, execute_scripts, cancellation_token=None +): BASE_DIR.mkdir(exist_ok=True) end = start + blocks @@ -111,7 +123,7 @@ def job_generator(start, blocks, step, mode, strategy, execute_scripts): ) for height in height_range: - if shutdown_event.is_set(): + if cancellation_token and cancellation_token.is_cancelled(): break try: batch_file = BASE_DIR / f"{mode}_{height}_{step}.json" @@ -130,10 +142,7 @@ def job_generator(start, blocks, step, mode, strategy, execute_scripts): logger.error(f"Error while generating data for: {height}:\n{e}") -def process_batch(job): - if shutdown_event.is_set(): - return - +def process_batch(job, cancellation_token=None): arguments_file = job.batch_file.as_posix().replace(".json", "-arguments.json") with open(arguments_file, "w") as af: @@ -151,7 +160,8 @@ def process_batch(job): "main", "--arguments-file", str(arguments_file), - ] + ], + cancellation_token=cancellation_token, ) if ( @@ -196,7 +206,7 @@ def process_batch(job): logger.warning(f"{job}: no gas info found") except ShutdownRequested: - logger.debug(f"Shutdown requested while processing {job}") + logger.debug(f"Cancellation requested while processing {job}") return except subprocess.TimeoutExpired: logger.warning(f"Timeout while terminating subprocess for {job}") @@ -205,12 +215,12 @@ def process_batch(job): # Producer function: Generates data and adds jobs to the queue -def job_producer(job_gen): +def job_producer(job_gen, cancellation_token=None): global current_weight try: for job, weight in job_gen: - if shutdown_event.is_set(): + if cancellation_token and cancellation_token.is_cancelled(): break # Wait until there is enough weight capacity to add the new block @@ -218,17 +228,17 @@ def job_producer(job_gen): logger.debug( f"Adding job: {job}, current total weight: {current_weight}..." ) - while not shutdown_event.is_set() and ( + while not ( + cancellation_token and cancellation_token.is_cancelled() + ) and ( (current_weight + weight > MAX_WEIGHT_LIMIT) and current_weight != 0 or job_queue.full() ): logger.debug("Producer is waiting for weight to be released.") - weight_lock.wait( - timeout=1.0 - ) # Wait with timeout to check shutdown_event + weight_lock.wait(timeout=1.0) - if shutdown_event.is_set(): + if cancellation_token and cancellation_token.is_cancelled(): break if (current_weight + weight > MAX_WEIGHT_LIMIT) and current_weight == 0: @@ -256,15 +266,15 @@ def job_producer(job_gen): # Consumer function: Processes blocks from the queue -def job_consumer(process_job): +def job_consumer(process_job, cancellation_token=None): global current_weight - while not shutdown_event.is_set(): + while not (cancellation_token and cancellation_token.is_cancelled()): try: logger.debug( f"Consumer is waiting for a job. Queue length: {job_queue.qsize()}" ) - # Get a job from the queue with timeout to check shutdown_event + # Get a job from the queue with timeout to check cancellation try: work_to_do = job_queue.get(timeout=1.0) except queue.Empty: @@ -277,7 +287,7 @@ def job_consumer(process_job): (job, weight) = work_to_do - if shutdown_event.is_set(): + if cancellation_token and cancellation_token.is_cancelled(): with weight_lock: current_weight -= weight weight_lock.notify_all() @@ -287,7 +297,7 @@ def job_consumer(process_job): # Process the block try: logger.debug(f"Executing job: {job}...") - process_job(job) + process_job(job, cancellation_token) except Exception as e: logger.error(f"Error while processing job: {job}:\n{e}") @@ -302,24 +312,21 @@ def job_consumer(process_job): job_queue.task_done() except Exception as e: - if not shutdown_event.is_set(): + if not (cancellation_token and cancellation_token.is_cancelled()): logger.error("Error in the consumer: %s", e) break -def signal_handler(signum, frame): - """Handle shutdown signals""" - signal_name = signal.Signals(signum).name - logger.info(f"Received signal {signal_name}. Initiating graceful shutdown...") - shutdown_event.set() - - # Wake up any waiting threads - with weight_lock: - weight_lock.notify_all() +def main(start, blocks, step, mode, strategy, execute_scripts): + # Create a centralized cancellation mechanism + cancellation_token = CancellationToken() + # Set up signal handlers to use the cancellation token + def signal_handler(signum, frame): + signal_name = signal.Signals(signum).name + logger.info(f"Received signal {signal_name}. Initiating graceful shutdown...") + cancellation_token.cancel() -def main(start, blocks, step, mode, strategy, execute_scripts): - # Set up signal handlers signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) @@ -340,31 +347,38 @@ def main(start, blocks, step, mode, strategy, execute_scripts): ) # Create the job generator - job_gen = job_generator(start, blocks, step, mode, strategy, execute_scripts) + job_gen = job_generator( + start, blocks, step, mode, strategy, execute_scripts, cancellation_token + ) # Start the job producer thread - producer_thread = threading.Thread(target=job_producer, args=(job_gen,)) + producer_thread = threading.Thread( + target=job_producer, args=(job_gen, cancellation_token) + ) producer_thread.start() # Start the consumer threads using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=THREAD_POOL_SIZE) as executor: futures = [ - executor.submit(job_consumer, process_batch) + executor.submit(job_consumer, process_batch, cancellation_token) for _ in range(THREAD_POOL_SIZE) ] - # Wait for producer to finish or shutdown signal + # Wait for producer to finish or cancellation producer_thread.join() - # Wait for all items in the queue to be processed or shutdown signal - while not shutdown_event.is_set() and not job_queue.empty(): + # Wait for all items in the queue to be processed or cancellation + while ( + not (cancellation_token and cancellation_token.is_cancelled()) + and not job_queue.empty() + ): try: job_queue.join() break except KeyboardInterrupt: - shutdown_event.set() + cancellation_token.cancel() - if shutdown_event.is_set(): + if cancellation_token.is_cancelled(): logger.info("Shutdown complete.") else: logger.info("All jobs have been processed.") @@ -406,7 +420,7 @@ def main(start, blocks, step, mode, strategy, execute_scripts): MAX_WEIGHT_LIMIT = args.maxweight - # file_handler = logging.FileHandler("client.log") + # Logging setup file_handler = TimedRotatingFileHandler( filename="client.log", when="midnight",