diff --git a/functions-python/dataset_service/main.py b/functions-python/dataset_service/main.py index b8f514f5a..53c761217 100644 --- a/functions-python/dataset_service/main.py +++ b/functions-python/dataset_service/main.py @@ -22,6 +22,7 @@ from google.cloud import datastore from google.cloud.datastore import Client + # This files contains the dataset trace and batch execution models and services. # The dataset trace is used to store the trace of a dataset and the batch execution # One batch execution can have multiple dataset traces. @@ -32,8 +33,16 @@ # Status of the dataset trace class Status(Enum): FAILED = "FAILED" + SUCCESS = "SUCCESS" PUBLISHED = "PUBLISHED" NOT_PUBLISHED = "NOT_PUBLISHED" + PROCESSING = "PROCESSING" + + +# Stage of the pipeline +class PipelineStage(Enum): + DATASET_PROCESSING = "DATASET_PROCESSING" + LOCATION_EXTRACTION = "LOCATION_EXTRACTION" # Dataset trace class to store the trace of a dataset @@ -42,10 +51,12 @@ class DatasetTrace: stable_id: str status: Status timestamp: datetime + dataset_id: Optional[str] = None trace_id: Optional[str] = None execution_id: Optional[str] = None file_sha256_hash: Optional[str] = None hosted_url: Optional[str] = None + pipeline_stage: PipelineStage = PipelineStage.DATASET_PROCESSING error_message: Optional[str] = None @@ -98,7 +109,9 @@ def get_by_execution_and_stable_ids( # Transform the dataset trace to entity def _dataset_trace_to_entity(self, dataset_trace: DatasetTrace) -> datastore.Entity: - trace_id = str(uuid.uuid4()) + trace_id = ( + str(uuid.uuid4()) if not dataset_trace.trace_id else dataset_trace.trace_id + ) key = self.client.key(dataset_trace_collection, trace_id) entity = datastore.Entity(key=key) diff --git a/functions-python/extract_bb/.coveragerc b/functions-python/extract_bb/.coveragerc index d3ef5cbc8..ae792ac20 100644 --- a/functions-python/extract_bb/.coveragerc +++ b/functions-python/extract_bb/.coveragerc @@ -3,6 +3,7 @@ omit = */test*/* */helpers/* */database_gen/* + */dataset_service/* [report] exclude_lines = diff --git a/functions-python/extract_bb/function_config.json b/functions-python/extract_bb/function_config.json index f81c07ae1..c82c23e16 100644 --- a/functions-python/extract_bb/function_config.json +++ b/functions-python/extract_bb/function_config.json @@ -5,7 +5,7 @@ "timeout": 540, "memory": "8Gi", "trigger_http": false, - "include_folders": ["database_gen", "helpers"], + "include_folders": ["database_gen", "helpers", "dataset_service"], "environment_variables": [], "secret_environment_variables": [ { diff --git a/functions-python/extract_bb/src/main.py b/functions-python/extract_bb/src/main.py index 13fb308e3..b18de2984 100644 --- a/functions-python/extract_bb/src/main.py +++ b/functions-python/extract_bb/src/main.py @@ -2,6 +2,8 @@ import json import logging import os +import uuid +from datetime import datetime import functions_framework import gtfs_kit @@ -13,18 +15,16 @@ from database_gen.sqlacodegen_models import Gtfsdataset from helpers.database import start_db_session from helpers.logger import Logger +from dataset_service.main import ( + DatasetTraceService, + DatasetTrace, + Status, + PipelineStage, +) logging.basicConfig(level=logging.INFO) -class Request: - def __init__(self, json): - self.json = json - - def get_json(self): - return self.json - - def parse_resource_data(data: dict) -> tuple: """ Parse the cloud event data to extract resource information. @@ -101,6 +101,10 @@ def extract_bounding_box_pubsub(cloud_event: CloudEvent): @param cloud_event: The CloudEvent containing the Pub/Sub message. """ Logger.init_logger() + try: + maximum_executions = int(os.getenv("MAXIMUM_EXECUTIONS", 1)) + except ValueError: + maximum_executions = 1 data = cloud_event.data logging.info(f"Function triggered with Pub/Sub event data: {data}") @@ -110,7 +114,7 @@ def extract_bounding_box_pubsub(cloud_event: CloudEvent): message_json = json.loads(base64.b64decode(message_data).decode("utf-8")) except Exception as e: logging.error(f"Error parsing message data: {e}") - return "Invalid Pub/Sub message data.", 400 + return "Invalid Pub/Sub message data." logging.info(f"Parsed message data: {message_json}") @@ -121,37 +125,73 @@ def extract_bounding_box_pubsub(cloud_event: CloudEvent): or "url" not in message_json ): logging.error("Invalid message data.") - return ( - "Invalid message data. Expected 'stable_id', 'dataset_id', and 'url' in the message.", - 400, - ) + return "Invalid message data. Expected 'stable_id', 'dataset_id', and 'url' in the message." stable_id = message_json["stable_id"] dataset_id = message_json["dataset_id"] url = message_json["url"] - - logging.info(f"[{dataset_id}] accessing url: {url}") - try: - bounds = get_gtfs_feed_bounds(url, dataset_id) - except Exception as e: - return f"Error processing GTFS feed: {e}", 500 - logging.info(f"[{dataset_id}] extracted bounding box = {bounds}") - - geometry_polygon = create_polygon_wkt_element(bounds) - - session = None + execution_id = message_json.get("execution_id", None) + if execution_id is None: + logging.warning(f"[{dataset_id}] No execution ID found in message") + execution_id = str(uuid.uuid4()) + logging.info(f"[{dataset_id}] Generated execution ID: {execution_id}") + trace_service = DatasetTraceService() + trace = trace_service.get_by_execution_and_stable_ids(execution_id, stable_id) + logging.info(f"[{dataset_id}] Trace: {trace}") + executions = len(trace) if trace else 0 + print(f"[{dataset_id}] Executions: {executions}") + print(trace_service.get_by_execution_and_stable_ids(execution_id, stable_id)) + logging.info(f"[{dataset_id}] Executions: {executions}") + if executions > 0 and executions >= maximum_executions: + logging.warning( + f"[{dataset_id}] Maximum executions reached. Skipping processing." + ) + return f"Maximum executions reached for {dataset_id}." + trace_id = str(uuid.uuid4()) + error = None + # Saving trace before starting in case we run into memory problems or uncatchable errors + trace = DatasetTrace( + trace_id=trace_id, + stable_id=stable_id, + execution_id=execution_id, + status=Status.PROCESSING, + timestamp=datetime.now(), + hosted_url=url, + dataset_id=dataset_id, + pipeline_stage=PipelineStage.LOCATION_EXTRACTION, + ) + trace_service.save(trace) try: - session = start_db_session(os.getenv("FEEDS_DATABASE_URL")) - update_dataset_bounding_box(session, dataset_id, geometry_polygon) - except Exception as e: - logging.error(f"[{dataset_id}] Error while processing: {e}") - if session is not None: - session.rollback() - raise e + logging.info(f"[{dataset_id}] accessing url: {url}") + try: + bounds = get_gtfs_feed_bounds(url, dataset_id) + except Exception as e: + error = f"Error processing GTFS feed: {e}" + raise e + logging.info(f"[{dataset_id}] extracted bounding box = {bounds}") + + geometry_polygon = create_polygon_wkt_element(bounds) + + session = None + try: + session = start_db_session(os.getenv("FEEDS_DATABASE_URL")) + update_dataset_bounding_box(session, dataset_id, geometry_polygon) + except Exception as e: + error = f"Error updating bounding box in database: {e}" + logging.error(f"[{dataset_id}] Error while processing: {e}") + if session is not None: + session.rollback() + raise e + finally: + if session is not None: + session.close() + logging.info(f"[{stable_id} - {dataset_id}] Bounding box updated successfully.") + except Exception: + pass finally: - if session is not None: - session.close() - logging.info(f"[{stable_id} - {dataset_id}] Bounding box updated successfully.") + trace.status = Status.FAILED if error is not None else Status.SUCCESS + trace.error_message = error + trace_service.save(trace) @functions_framework.cloud_event diff --git a/functions-python/extract_bb/tests/test_extract_bb.py b/functions-python/extract_bb/tests/test_extract_bb.py index 25df1bc87..a3871284b 100644 --- a/functions-python/extract_bb/tests/test_extract_bb.py +++ b/functions-python/extract_bb/tests/test_extract_bb.py @@ -72,11 +72,8 @@ def test_get_gtfs_feed_bounds(self, mock_gtfs_kit): @patch("extract_bb.src.main.Logger") def test_extract_bb_exception(self, _): - data = { - "stable_id": faker.pystr(), - "dataset_id": faker.pystr(), - "url": faker.url(), - } + # Data with missing url + data = {"stable_id": faker.pystr(), "dataset_id": faker.pystr()} message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( "utf-8" ) @@ -97,6 +94,18 @@ def test_extract_bb_exception(self, _): self.assertTrue(False) except Exception: self.assertTrue(True) + data = {} # empty data + message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( + "utf-8" + ) + cloud_event = CloudEvent( + attributes=attributes, data={"message": {"data": message_data}} + ) + try: + extract_bounding_box_pubsub(cloud_event) + self.assertTrue(False) + except Exception: + self.assertTrue(True) @mock.patch.dict( os.environ, @@ -107,10 +116,15 @@ def test_extract_bb_exception(self, _): @patch("extract_bb.src.main.get_gtfs_feed_bounds") @patch("extract_bb.src.main.update_dataset_bounding_box") @patch("extract_bb.src.main.Logger") - def test_extract_bb(self, _, update_bb_mock, get_gtfs_feed_bounds_mock): + @patch("extract_bb.src.main.DatasetTraceService") + def test_extract_bb( + self, __, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock + ): get_gtfs_feed_bounds_mock.return_value = np.array( [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] ) + mock_dataset_trace.save.return_value = None + mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0 data = { "stable_id": faker.pystr(), @@ -134,6 +148,47 @@ def test_extract_bb(self, _, update_bb_mock, get_gtfs_feed_bounds_mock): extract_bounding_box_pubsub(cloud_event) update_bb_mock.assert_called_once() + @mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + "MAXIMUM_EXECUTIONS": "1", + }, + ) + @patch("extract_bb.src.main.get_gtfs_feed_bounds") + @patch("extract_bb.src.main.update_dataset_bounding_box") + @patch("extract_bb.src.main.DatasetTraceService.get_by_execution_and_stable_ids") + @patch("extract_bb.src.main.Logger") + def test_extract_bb_max_executions( + self, _, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock + ): + get_gtfs_feed_bounds_mock.return_value = np.array( + [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] + ) + mock_dataset_trace.return_value = [1, 2, 3] + + data = { + "stable_id": faker.pystr(), + "dataset_id": faker.pystr(), + "url": faker.url(), + } + message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( + "utf-8" + ) + + # Creating attributes for CloudEvent, including required fields + attributes = { + "type": "com.example.someevent", + "source": "https://example.com/event-source", + } + + # Constructing the CloudEvent object + cloud_event = CloudEvent( + attributes=attributes, data={"message": {"data": message_data}} + ) + extract_bounding_box_pubsub(cloud_event) + update_bb_mock.assert_not_called() + @mock.patch.dict( os.environ, { @@ -142,11 +197,16 @@ def test_extract_bb(self, _, update_bb_mock, get_gtfs_feed_bounds_mock): ) @patch("extract_bb.src.main.get_gtfs_feed_bounds") @patch("extract_bb.src.main.update_dataset_bounding_box") + @patch("extract_bb.src.main.DatasetTraceService") @patch("extract_bb.src.main.Logger") - def test_extract_bb_cloud_event(self, _, update_bb_mock, get_gtfs_feed_bounds_mock): + def test_extract_bb_cloud_event( + self, _, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock + ): get_gtfs_feed_bounds_mock.return_value = np.array( [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] ) + mock_dataset_trace.save.return_value = None + mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0 file_name = faker.file_name() resource_name = ( diff --git a/infra/batch/main.tf b/infra/batch/main.tf index 93d19915a..0808987eb 100644 --- a/infra/batch/main.tf +++ b/infra/batch/main.tf @@ -200,32 +200,6 @@ resource "google_cloud_run_service_iam_member" "batch_datasets_cloud_run_invoker } -resource "google_datastore_index" "dataset_processing_index_execution_id_stable_id_status" { - project = var.project_id - kind = "historical_dataset_batch" - properties { - name = "execution_id" - direction = "ASCENDING" - } - properties { - name = "stable_id" - direction = "ASCENDING" - } -} - -resource "google_datastore_index" "dataset_processing_index_execution_id_timestamp" { - project = var.project_id - kind = "historical_dataset_batch" - properties { - name = "execution_id" - direction = "ASCENDING" - } - properties { - name = "timestamp" - direction = "ASCENDING" - } -} - resource "google_datastore_index" "batch_execution_index_execution_id_timestamp" { project = var.project_id kind = "batch_execution" diff --git a/infra/functions-python/main.tf b/infra/functions-python/main.tf index 03a36f02e..69a253fc7 100644 --- a/infra/functions-python/main.tf +++ b/infra/functions-python/main.tf @@ -484,4 +484,11 @@ resource "google_pubsub_topic_iam_binding" "functions_subscriber" { role = "roles/pubsub.subscriber" topic = google_pubsub_topic.dataset_updates.name members = ["serviceAccount:${google_service_account.functions_service_account.email}"] +} + +# Grant permissions to the service account to write/read in datastore +resource "google_project_iam_member" "datastore_owner" { + project = var.project_id + role = "roles/datastore.owner" + member = "serviceAccount:${google_service_account.functions_service_account.email}" } \ No newline at end of file