Skip to content

Commit

Permalink
Merge pull request #837 from MobilityData/293-Psycopg
Browse files Browse the repository at this point in the history
feat: Use psycopg2 for Connection Pooling and Implement Global Engine with Context Manager for Session Management
  • Loading branch information
qcdyx authored Dec 17, 2024
2 parents b0d5f89 + b95e99b commit 2aec616
Show file tree
Hide file tree
Showing 42 changed files with 706 additions and 875 deletions.
301 changes: 77 additions & 224 deletions api/src/database/database.py

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions api/src/feeds/impl/datasets_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from geoalchemy2 import WKTElement
from sqlalchemy import or_
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, Session

from database.database import Database
from database.database import Database, with_db_session
from database_gen.sqlacodegen_models import (
Gtfsdataset,
Feed,
Expand Down Expand Up @@ -93,9 +93,10 @@ def apply_bounding_filtering(
raise_http_validation_error(invalid_bounding_method.format(bounding_filter_method))

@staticmethod
def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> List[GtfsDataset]:
def get_datasets_gtfs(query: Query, session: Session, limit: int = None, offset: int = None) -> List[GtfsDataset]:
# Results are sorted by stable_id because Database.select(group_by=) requires it so
dataset_groups = Database().select(
session=session,
query=query.order_by(Gtfsdataset.stable_id),
limit=limit,
offset=offset,
Expand All @@ -109,15 +110,13 @@ def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> Li
gtfs_datasets.append(GtfsDatasetImpl.from_orm(dataset_objects[0]))
return gtfs_datasets

def get_dataset_gtfs(
self,
id: str,
) -> GtfsDataset:
@with_db_session
def get_dataset_gtfs(self, id: str, db_session: Session) -> GtfsDataset:
"""Get the specified dataset from the Mobility Database."""

query = DatasetsApiImpl.create_dataset_query().filter(Gtfsdataset.stable_id == id)

if (ret := DatasetsApiImpl.get_datasets_gtfs(query)) and len(ret) == 1:
if (ret := DatasetsApiImpl.get_datasets_gtfs(query, db_session)) and len(ret) == 1:
return ret[0]
else:
raise_http_error(404, dataset_not_found.format(id))
84 changes: 39 additions & 45 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, Session
from sqlalchemy.orm.query import Query

from database.database import Database
from database.database import Database, with_db_session
from database_gen.sqlacodegen_models import (
Feed,
Gtfsdataset,
Expand Down Expand Up @@ -63,17 +63,15 @@ class FeedsApiImpl(BaseFeedsApi):
def __init__(self) -> None:
self.logger = Logger("FeedsApiImpl").get_logger()

def get_feed(
self,
id: str,
) -> BasicFeed:
@with_db_session
def get_feed(self, id: str, db_session: Session) -> BasicFeed:
"""Get the specified feed from the Mobility Database."""
is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")

feed = (
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
.filter(Database().get_query_model(Feed))
.filter(Database().get_query_model(db_session, Feed))
.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
.filter(
or_(
Expand All @@ -89,6 +87,7 @@ def get_feed(
else:
raise_http_error(404, feed_not_found.format(id))

@with_db_session
def get_feeds(
self,
limit: int,
Expand All @@ -97,14 +96,15 @@ def get_feeds(
provider: str,
producer_url: str,
is_official: bool,
db_session: Session,
) -> List[BasicFeed]:
"""Get some (or all) feeds from the Mobility Database."""
is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")
feed_filter = FeedFilter(
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
)
feed_query = feed_filter.filter(Database().get_query_model(Feed))
feed_query = feed_filter.filter(Database().get_query_model(db_session, Feed))
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
Expand All @@ -126,27 +126,25 @@ def get_feeds(
results = feed_query.all()
return [BasicFeedImpl.from_orm(feed) for feed in results]

def get_gtfs_feed(
self,
id: str,
) -> GtfsFeed:
@with_db_session
def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed:
"""Get the specified gtfs feed from the Mobility Database."""
feed, translations = self._get_gtfs_feed(id)
feed, translations = self._get_gtfs_feed(id, db_session)
if feed:
return GtfsFeedImpl.from_orm(feed, translations)
else:
raise_http_error(404, gtfs_feed_not_found.format(id))

@staticmethod
def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
results = (
FeedFilter(
stable_id=stable_id,
status=None,
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
.filter(db_session.query(Gtfsfeed, t_location_with_translations_en))
.filter(
or_(
Gtfsfeed.operational_status == None, # noqa: E711
Expand All @@ -168,6 +166,7 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
return results[0].Gtfsfeed, translations
return None, {}

@with_db_session
def get_gtfs_feed_datasets(
self,
gtfs_feed_id: str,
Expand All @@ -176,6 +175,7 @@ def get_gtfs_feed_datasets(
offset: int,
downloaded_after: str,
downloaded_before: str,
db_session: Session,
) -> List[GtfsDataset]:
"""Get a list of datasets related to a feed."""
if downloaded_before and not valid_iso_date(downloaded_before):
Expand All @@ -191,7 +191,7 @@ def get_gtfs_feed_datasets(
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
.filter(Database().get_query_model(db_session, Gtfsfeed))
.filter(
or_(
Feed.operational_status == None, # noqa: E711
Expand All @@ -208,19 +208,20 @@ def get_gtfs_feed_datasets(
# Replace Z with +00:00 to make the datetime object timezone aware
# Due to https://github.com/python/cpython/issues/80010, once migrate to Python 3.11, we can use fromisoformat
query = GtfsDatasetFilter(
downloaded_at__lte=datetime.fromisoformat(downloaded_before.replace("Z", "+00:00"))
if downloaded_before
else None,
downloaded_at__gte=datetime.fromisoformat(downloaded_after.replace("Z", "+00:00"))
if downloaded_after
else None,
downloaded_at__lte=(
datetime.fromisoformat(downloaded_before.replace("Z", "+00:00")) if downloaded_before else None
),
downloaded_at__gte=(
datetime.fromisoformat(downloaded_after.replace("Z", "+00:00")) if downloaded_after else None
),
).filter(DatasetsApiImpl.create_dataset_query().filter(Feed.stable_id == gtfs_feed_id))

if latest:
query = query.filter(Gtfsdataset.latest)

return DatasetsApiImpl.get_datasets_gtfs(query, limit=limit, offset=offset)
return DatasetsApiImpl.get_datasets_gtfs(query, session=db_session, limit=limit, offset=offset)

@with_db_session
def get_gtfs_feeds(
self,
limit: int,
Expand All @@ -234,6 +235,7 @@ def get_gtfs_feeds(
dataset_longitudes: str,
bounding_filter_method: str,
is_official: bool,
db_session: Session,
) -> List[GtfsFeed]:
"""Get some (or all) GTFS feeds from the Mobility Database."""
gtfs_feed_filter = GtfsFeedFilter(
Expand All @@ -255,9 +257,7 @@ def get_gtfs_feeds(
is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")
feed_query = (
Database()
.get_session()
.query(Gtfsfeed)
db_session.query(Gtfsfeed)
.filter(Gtfsfeed.id.in_(subquery))
.filter(
or_(
Expand All @@ -277,12 +277,10 @@ def get_gtfs_feeds(
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.limit(limit).offset(offset)
return self._get_response(feed_query, GtfsFeedImpl)
return self._get_response(feed_query, GtfsFeedImpl, db_session)

def get_gtfs_rt_feed(
self,
id: str,
) -> GtfsRTFeed:
@with_db_session
def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
"""Get the specified GTFS Realtime feed from the Mobility Database."""
gtfs_rt_feed_filter = GtfsRtFeedFilter(
stable_id=id,
Expand All @@ -292,9 +290,7 @@ def get_gtfs_rt_feed(
location=None,
)
results = gtfs_rt_feed_filter.filter(
Database()
.get_session()
.query(Gtfsrealtimefeed, t_location_with_translations_en)
db_session.query(Gtfsrealtimefeed, t_location_with_translations_en)
.filter(
or_(
Gtfsrealtimefeed.operational_status == None, # noqa: E711
Expand All @@ -317,6 +313,7 @@ def get_gtfs_rt_feed(
else:
raise_http_error(404, gtfs_rt_feed_not_found.format(id))

@with_db_session
def get_gtfs_rt_feeds(
self,
limit: int,
Expand All @@ -328,6 +325,7 @@ def get_gtfs_rt_feeds(
subdivision_name: str,
municipality: str,
is_official: bool,
db_session: Session,
) -> List[GtfsRTFeed]:
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
entity_types_list = entity_types.split(",") if entity_types else None
Expand Down Expand Up @@ -359,9 +357,7 @@ def get_gtfs_rt_feeds(
.join(Entitytype, Gtfsrealtimefeed.entitytypes)
).subquery()
feed_query = (
Database()
.get_session()
.query(Gtfsrealtimefeed)
db_session.query(Gtfsrealtimefeed)
.filter(Gtfsrealtimefeed.id.in_(subquery))
.filter(
or_(
Expand All @@ -380,22 +376,20 @@ def get_gtfs_rt_feeds(
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.limit(limit).offset(offset)
return self._get_response(feed_query, GtfsRTFeedImpl)
return self._get_response(feed_query, GtfsRTFeedImpl, db_session)

@staticmethod
def _get_response(feed_query: Query, impl_cls: type[T]) -> List[T]:
def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") -> List[T]:
"""Get the response for the feed query."""
results = feed_query.all()
location_translations = get_feeds_location_translations(results)
location_translations = get_feeds_location_translations(results, db_session)
response = [impl_cls.from_orm(feed, location_translations) for feed in results]
return list({feed.id: feed for feed in response}.values())

def get_gtfs_feed_gtfs_rt_feeds(
self,
id: str,
) -> List[GtfsRTFeed]:
@with_db_session
def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]:
"""Get a list of GTFS Realtime related to a GTFS feed."""
feed, translations = self._get_gtfs_feed(id)
feed, translations = self._get_gtfs_feed(id, db_session)
if feed:
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
else:
Expand Down
8 changes: 6 additions & 2 deletions api/src/feeds/impl/search_api_impl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List

from sqlalchemy import func, select
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, Session

from database.database import Database
from database.database import Database, with_db_session
from database.sql_functions.unaccent import unaccent
from database_gen.sqlacodegen_models import t_feedsearch
from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl
Expand Down Expand Up @@ -93,6 +93,7 @@ def create_search_query(
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status, is_official)
return query.order_by(rank_expression.desc())

@with_db_session
def search_feeds(
self,
limit: int,
Expand All @@ -102,15 +103,18 @@ def search_feeds(
data_type: str,
is_official: bool,
search_query: str,
db_session: "Session",
) -> SearchFeeds200Response:
"""Search feeds using full-text search on feed, location and provider's information."""
query = self.create_search_query(status, feed_id, data_type, is_official, search_query)
feed_rows = Database().select(
session=db_session,
query=query,
limit=limit,
offset=offset,
)
feed_total_count = Database().select(
session=db_session,
query=self.create_count_search_query(status, feed_id, data_type, is_official, search_query),
)
if feed_rows is None or feed_total_count is None:
Expand Down
11 changes: 8 additions & 3 deletions api/src/scripts/populate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Type
from typing import Type, TYPE_CHECKING

import pandas
from dotenv import load_dotenv
Expand All @@ -11,6 +11,9 @@
from database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed
from utils.logger import Logger

if TYPE_CHECKING:
from sqlalchemy.orm import Session

logging.basicConfig()
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR)

Expand Down Expand Up @@ -56,12 +59,14 @@ def __init__(self, filepaths):

self.filter_data()

def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None:
def query_feed_by_stable_id(
self, session: "Session", stable_id: str, data_type: str | None
) -> Gtfsrealtimefeed | Gtfsfeed | None:
"""
Query the feed by stable id
"""
model = self.get_model(data_type)
return self.db.session.query(model).filter(model.stable_id == stable_id).first()
return session.query(model).filter(model.stable_id == stable_id).first()

@staticmethod
def get_model(data_type: str | None) -> Type[Feed]:
Expand Down
16 changes: 9 additions & 7 deletions api/src/scripts/populate_db_gbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ def deprecate_feeds(self, deprecated_feeds):
if deprecated_feeds is None or deprecated_feeds.empty:
self.logger.info("No feeds to deprecate.")
return

self.logger.info(f"Deprecating {len(deprecated_feeds)} feed(s).")
for index, row in deprecated_feeds.iterrows():
stable_id = self.get_stable_id(row)
gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
if gbfs_feed:
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
gbfs_feed.status = "deprecated"
self.db.session.flush()
with self.db.start_db_session() as session:
for index, row in deprecated_feeds.iterrows():
stable_id = self.get_stable_id(row)
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
if gbfs_feed:
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
gbfs_feed.status = "deprecated"
session.flush()

def populate_db(self):
"""Populate the database with the GBFS feeds"""
Expand Down
Loading

0 comments on commit 2aec616

Please sign in to comment.