Skip to content

Commit

Permalink
feat: added trace to bb extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
cka-y committed Jul 25, 2024
1 parent af1e388 commit 2f3b06b
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 69 deletions.
15 changes: 14 additions & 1 deletion functions-python/dataset_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.cloud import datastore
from google.cloud.datastore import Client


# This files contains the dataset trace and batch execution models and services.
# The dataset trace is used to store the trace of a dataset and the batch execution
# One batch execution can have multiple dataset traces.
Expand All @@ -32,8 +33,16 @@
# Status of the dataset trace
class Status(Enum):
FAILED = "FAILED"
SUCCESS = "SUCCESS"
PUBLISHED = "PUBLISHED"
NOT_PUBLISHED = "NOT_PUBLISHED"
PROCESSING = "PROCESSING"


# Stage of the pipeline
class PipelineStage(Enum):
DATASET_PROCESSING = "DATASET_PROCESSING"
LOCATION_EXTRACTION = "LOCATION_EXTRACTION"


# Dataset trace class to store the trace of a dataset
Expand All @@ -42,10 +51,12 @@ class DatasetTrace:
stable_id: str
status: Status
timestamp: datetime
dataset_id: Optional[str] = None
trace_id: Optional[str] = None
execution_id: Optional[str] = None
file_sha256_hash: Optional[str] = None
hosted_url: Optional[str] = None
pipeline_stage: PipelineStage = PipelineStage.DATASET_PROCESSING
error_message: Optional[str] = None


Expand Down Expand Up @@ -98,7 +109,9 @@ def get_by_execution_and_stable_ids(

# Transform the dataset trace to entity
def _dataset_trace_to_entity(self, dataset_trace: DatasetTrace) -> datastore.Entity:
trace_id = str(uuid.uuid4())
trace_id = (
str(uuid.uuid4()) if not dataset_trace.trace_id else dataset_trace.trace_id
)
key = self.client.key(dataset_trace_collection, trace_id)
entity = datastore.Entity(key=key)

Expand Down
1 change: 1 addition & 0 deletions functions-python/extract_bb/.coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ omit =
*/test*/*
*/helpers/*
*/database_gen/*
*/dataset_service/*

[report]
exclude_lines =
Expand Down
2 changes: 1 addition & 1 deletion functions-python/extract_bb/function_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"timeout": 540,
"memory": "8Gi",
"trigger_http": false,
"include_folders": ["database_gen", "helpers"],
"include_folders": ["database_gen", "helpers", "dataset_service"],
"environment_variables": [],
"secret_environment_variables": [
{
Expand Down
108 changes: 74 additions & 34 deletions functions-python/extract_bb/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
import logging
import os
import uuid
from datetime import datetime

import functions_framework
import gtfs_kit
Expand All @@ -13,18 +15,16 @@
from database_gen.sqlacodegen_models import Gtfsdataset
from helpers.database import start_db_session
from helpers.logger import Logger
from dataset_service.main import (
DatasetTraceService,
DatasetTrace,
Status,
PipelineStage,
)

logging.basicConfig(level=logging.INFO)


class Request:
def __init__(self, json):
self.json = json

def get_json(self):
return self.json


def parse_resource_data(data: dict) -> tuple:
"""
Parse the cloud event data to extract resource information.
Expand Down Expand Up @@ -101,6 +101,10 @@ def extract_bounding_box_pubsub(cloud_event: CloudEvent):
@param cloud_event: The CloudEvent containing the Pub/Sub message.
"""
Logger.init_logger()
try:
maximum_executions = int(os.getenv("MAXIMUM_EXECUTIONS", 1))
except ValueError:
maximum_executions = 1
data = cloud_event.data
logging.info(f"Function triggered with Pub/Sub event data: {data}")

Expand All @@ -110,7 +114,7 @@ def extract_bounding_box_pubsub(cloud_event: CloudEvent):
message_json = json.loads(base64.b64decode(message_data).decode("utf-8"))
except Exception as e:
logging.error(f"Error parsing message data: {e}")
return "Invalid Pub/Sub message data.", 400
return "Invalid Pub/Sub message data."

logging.info(f"Parsed message data: {message_json}")

Expand All @@ -121,37 +125,73 @@ def extract_bounding_box_pubsub(cloud_event: CloudEvent):
or "url" not in message_json
):
logging.error("Invalid message data.")
return (
"Invalid message data. Expected 'stable_id', 'dataset_id', and 'url' in the message.",
400,
)
return "Invalid message data. Expected 'stable_id', 'dataset_id', and 'url' in the message."

stable_id = message_json["stable_id"]
dataset_id = message_json["dataset_id"]
url = message_json["url"]

logging.info(f"[{dataset_id}] accessing url: {url}")
try:
bounds = get_gtfs_feed_bounds(url, dataset_id)
except Exception as e:
return f"Error processing GTFS feed: {e}", 500
logging.info(f"[{dataset_id}] extracted bounding box = {bounds}")

geometry_polygon = create_polygon_wkt_element(bounds)

session = None
execution_id = message_json.get("execution_id", None)
if execution_id is None:
logging.warning(f"[{dataset_id}] No execution ID found in message")
execution_id = str(uuid.uuid4())
logging.info(f"[{dataset_id}] Generated execution ID: {execution_id}")
trace_service = DatasetTraceService()
trace = trace_service.get_by_execution_and_stable_ids(execution_id, stable_id)
logging.info(f"[{dataset_id}] Trace: {trace}")
executions = len(trace) if trace else 0
print(f"[{dataset_id}] Executions: {executions}")
print(trace_service.get_by_execution_and_stable_ids(execution_id, stable_id))
logging.info(f"[{dataset_id}] Executions: {executions}")
if executions > 0 and executions >= maximum_executions:
logging.warning(
f"[{dataset_id}] Maximum executions reached. Skipping processing."
)
return f"Maximum executions reached for {dataset_id}."
trace_id = str(uuid.uuid4())
error = None
# Saving trace before starting in case we run into memory problems or uncatchable errors
trace = DatasetTrace(
trace_id=trace_id,
stable_id=stable_id,
execution_id=execution_id,
status=Status.PROCESSING,
timestamp=datetime.now(),
hosted_url=url,
dataset_id=dataset_id,
pipeline_stage=PipelineStage.LOCATION_EXTRACTION,
)
trace_service.save(trace)
try:
session = start_db_session(os.getenv("FEEDS_DATABASE_URL"))
update_dataset_bounding_box(session, dataset_id, geometry_polygon)
except Exception as e:
logging.error(f"[{dataset_id}] Error while processing: {e}")
if session is not None:
session.rollback()
raise e
logging.info(f"[{dataset_id}] accessing url: {url}")
try:
bounds = get_gtfs_feed_bounds(url, dataset_id)
except Exception as e:
error = f"Error processing GTFS feed: {e}"
raise e
logging.info(f"[{dataset_id}] extracted bounding box = {bounds}")

geometry_polygon = create_polygon_wkt_element(bounds)

session = None
try:
session = start_db_session(os.getenv("FEEDS_DATABASE_URL"))
update_dataset_bounding_box(session, dataset_id, geometry_polygon)
except Exception as e:
error = f"Error updating bounding box in database: {e}"
logging.error(f"[{dataset_id}] Error while processing: {e}")
if session is not None:
session.rollback()
raise e
finally:
if session is not None:
session.close()
logging.info(f"[{stable_id} - {dataset_id}] Bounding box updated successfully.")
except Exception:
pass
finally:
if session is not None:
session.close()
logging.info(f"[{stable_id} - {dataset_id}] Bounding box updated successfully.")
trace.status = Status.FAILED if error is not None else Status.SUCCESS
trace.error_message = error
trace_service.save(trace)


@functions_framework.cloud_event
Expand Down
74 changes: 67 additions & 7 deletions functions-python/extract_bb/tests/test_extract_bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,8 @@ def test_get_gtfs_feed_bounds(self, mock_gtfs_kit):

@patch("extract_bb.src.main.Logger")
def test_extract_bb_exception(self, _):
data = {
"stable_id": faker.pystr(),
"dataset_id": faker.pystr(),
"url": faker.url(),
}
# Data with missing url
data = {"stable_id": faker.pystr(), "dataset_id": faker.pystr()}
message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode(
"utf-8"
)
Expand All @@ -97,6 +94,18 @@ def test_extract_bb_exception(self, _):
self.assertTrue(False)
except Exception:
self.assertTrue(True)
data = {} # empty data
message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode(
"utf-8"
)
cloud_event = CloudEvent(
attributes=attributes, data={"message": {"data": message_data}}
)
try:
extract_bounding_box_pubsub(cloud_event)
self.assertTrue(False)
except Exception:
self.assertTrue(True)

@mock.patch.dict(
os.environ,
Expand All @@ -107,10 +116,15 @@ def test_extract_bb_exception(self, _):
@patch("extract_bb.src.main.get_gtfs_feed_bounds")
@patch("extract_bb.src.main.update_dataset_bounding_box")
@patch("extract_bb.src.main.Logger")
def test_extract_bb(self, _, update_bb_mock, get_gtfs_feed_bounds_mock):
@patch("extract_bb.src.main.DatasetTraceService")
def test_extract_bb(
self, __, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock
):
get_gtfs_feed_bounds_mock.return_value = np.array(
[faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()]
)
mock_dataset_trace.save.return_value = None
mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0

data = {
"stable_id": faker.pystr(),
Expand All @@ -134,6 +148,47 @@ def test_extract_bb(self, _, update_bb_mock, get_gtfs_feed_bounds_mock):
extract_bounding_box_pubsub(cloud_event)
update_bb_mock.assert_called_once()

@mock.patch.dict(
os.environ,
{
"FEEDS_DATABASE_URL": default_db_url,
"MAXIMUM_EXECUTIONS": "1",
},
)
@patch("extract_bb.src.main.get_gtfs_feed_bounds")
@patch("extract_bb.src.main.update_dataset_bounding_box")
@patch("extract_bb.src.main.DatasetTraceService.get_by_execution_and_stable_ids")
@patch("extract_bb.src.main.Logger")
def test_extract_bb_max_executions(
self, _, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock
):
get_gtfs_feed_bounds_mock.return_value = np.array(
[faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()]
)
mock_dataset_trace.return_value = [1, 2, 3]

data = {
"stable_id": faker.pystr(),
"dataset_id": faker.pystr(),
"url": faker.url(),
}
message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode(
"utf-8"
)

# Creating attributes for CloudEvent, including required fields
attributes = {
"type": "com.example.someevent",
"source": "https://example.com/event-source",
}

# Constructing the CloudEvent object
cloud_event = CloudEvent(
attributes=attributes, data={"message": {"data": message_data}}
)
extract_bounding_box_pubsub(cloud_event)
update_bb_mock.assert_not_called()

@mock.patch.dict(
os.environ,
{
Expand All @@ -142,11 +197,16 @@ def test_extract_bb(self, _, update_bb_mock, get_gtfs_feed_bounds_mock):
)
@patch("extract_bb.src.main.get_gtfs_feed_bounds")
@patch("extract_bb.src.main.update_dataset_bounding_box")
@patch("extract_bb.src.main.DatasetTraceService")
@patch("extract_bb.src.main.Logger")
def test_extract_bb_cloud_event(self, _, update_bb_mock, get_gtfs_feed_bounds_mock):
def test_extract_bb_cloud_event(
self, _, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock
):
get_gtfs_feed_bounds_mock.return_value = np.array(
[faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()]
)
mock_dataset_trace.save.return_value = None
mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0

file_name = faker.file_name()
resource_name = (
Expand Down
26 changes: 0 additions & 26 deletions infra/batch/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -200,32 +200,6 @@ resource "google_cloud_run_service_iam_member" "batch_datasets_cloud_run_invoker
}


resource "google_datastore_index" "dataset_processing_index_execution_id_stable_id_status" {
project = var.project_id
kind = "historical_dataset_batch"
properties {
name = "execution_id"
direction = "ASCENDING"
}
properties {
name = "stable_id"
direction = "ASCENDING"
}
}

resource "google_datastore_index" "dataset_processing_index_execution_id_timestamp" {
project = var.project_id
kind = "historical_dataset_batch"
properties {
name = "execution_id"
direction = "ASCENDING"
}
properties {
name = "timestamp"
direction = "ASCENDING"
}
}

resource "google_datastore_index" "batch_execution_index_execution_id_timestamp" {
project = var.project_id
kind = "batch_execution"
Expand Down
7 changes: 7 additions & 0 deletions infra/functions-python/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -484,4 +484,11 @@ resource "google_pubsub_topic_iam_binding" "functions_subscriber" {
role = "roles/pubsub.subscriber"
topic = google_pubsub_topic.dataset_updates.name
members = ["serviceAccount:${google_service_account.functions_service_account.email}"]
}

# Grant permissions to the service account to write/read in datastore
resource "google_project_iam_member" "datastore_owner" {
project = var.project_id
role = "roles/datastore.owner"
member = "serviceAccount:${google_service_account.functions_service_account.email}"
}

0 comments on commit 2f3b06b

Please sign in to comment.