Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: automate locations - modified API responses #661

Merged
merged 13 commits into from
Aug 13, 2024
3 changes: 2 additions & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ cloud-sql-python-connector[pg8000]
fastapi-filter[sqlalchemy]==1.0.0
PyJWT
shapely
google-cloud-pubsub
google-cloud-pubsub
pycountry
153 changes: 91 additions & 62 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import datetime
from typing import List, Union
from typing import List, Union, TypeVar

from sqlalchemy.orm import joinedload

from sqlalchemy.orm.query import Query
from database.database import Database
from database_gen.sqlacodegen_models import (
Feed,
Expand All @@ -12,6 +12,7 @@
Location,
Validationreport,
Entitytype,
t_location_with_translations_en,
)
from feeds.filters.feed_filter import FeedFilter
from feeds.filters.gtfs_dataset_filter import GtfsDatasetFilter
Expand All @@ -36,6 +37,9 @@
from feeds_gen.models.gtfs_feed import GtfsFeed
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
from utils.date_utils import valid_iso_date
from utils.location_translation import create_location_translation_object, LocationTranslation

T = TypeVar("T", bound="BasicFeed")


class FeedsApiImpl(BaseFeedsApi):
Expand Down Expand Up @@ -91,20 +95,35 @@ def get_gtfs_feed(
id: str,
) -> GtfsFeed:
"""Get the specified gtfs feed from the Mobility Database."""
feed = (
feed, translations = self._get_gtfs_feed(id)
if feed:
return GtfsFeedImpl.from_orm(feed, translations)
else:
raise_http_error(404, gtfs_feed_not_found.format(id))

@staticmethod
def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
results = (
FeedFilter(
stable_id=id,
stable_id=stable_id,
status=None,
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
.first()
)
if feed:
return GtfsFeedImpl.from_orm(feed)
else:
raise_http_error(404, gtfs_feed_not_found.format(id))
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
.outerjoin(Location, Feed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.notices),
*BasicFeedImpl.get_joinedload_options(),
)
).all()
if len(results) > 0 and results[0].Gtfsfeed:
translations = {result[1]: create_location_translation_object(result) for result in results}
return results[0].Gtfsfeed, translations
return None, {}

def get_gtfs_feed_datasets(
self,
Expand Down Expand Up @@ -176,43 +195,54 @@ def get_gtfs_feeds(
municipality__ilike=municipality,
),
)
gtfs_feed_query = gtfs_feed_filter.filter(Database().get_query_model(Gtfsfeed))

gtfs_feed_query = gtfs_feed_query.outerjoin(Location, Feed.locations).options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.notices),
*BasicFeedImpl.get_joinedload_options(),
gtfs_feed_query = gtfs_feed_filter.filter(
Database().get_session().query(Gtfsfeed, t_location_with_translations_en)
)
gtfs_feed_query = (
gtfs_feed_query.outerjoin(Location, Feed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.notices),
*BasicFeedImpl.get_joinedload_options(),
)
.order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
)
gtfs_feed_query = gtfs_feed_query.order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
gtfs_feed_query = DatasetsApiImpl.apply_bounding_filtering(
gtfs_feed_query, dataset_latitudes, dataset_longitudes, bounding_filter_method
)
if limit is not None:
gtfs_feed_query = gtfs_feed_query.limit(limit)
if offset is not None:
gtfs_feed_query = gtfs_feed_query.offset(offset)
results = gtfs_feed_query.all()
return [GtfsFeedImpl.from_orm(gtfs_feed) for gtfs_feed in results]
return self._get_response(gtfs_feed_query, limit, offset, GtfsFeedImpl)

def get_gtfs_rt_feed(
self,
id: str,
) -> GtfsRTFeed:
"""Get the specified GTFS Realtime feed from the Mobility Database."""
feed = (
GtfsRtFeedFilter(
stable_id=id,
provider__ilike=None,
producer_url__ilike=None,
entity_types=None,
location=None,
)
.filter(Database().get_query_model(Gtfsrealtimefeed))
.first()
gtfs_rt_feed_filter = GtfsRtFeedFilter(
stable_id=id,
provider__ilike=None,
producer_url__ilike=None,
entity_types=None,
location=None,
)
if feed:
return GtfsRTFeedImpl.from_orm(feed)
results = gtfs_rt_feed_filter.filter(
Database()
.get_session()
.query(Gtfsrealtimefeed, t_location_with_translations_en)
.outerjoin(Location, Gtfsrealtimefeed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*BasicFeedImpl.get_joinedload_options(),
)
).all()

if len(results) > 0 and results[0].Gtfsrealtimefeed:
translations = {result[1]: create_location_translation_object(result) for result in results}
return GtfsRTFeedImpl.from_orm(results[0].Gtfsrealtimefeed, translations)
else:
raise_http_error(404, gtfs_rt_feed_not_found.format(id))

Expand Down Expand Up @@ -251,42 +281,41 @@ def get_gtfs_rt_feeds(
municipality__ilike=municipality,
),
)
gtfs_rt_feed_query = gtfs_rt_feed_filter.filter(Database().get_query_model(Gtfsrealtimefeed)).options(
*BasicFeedImpl.get_joinedload_options()
gtfs_rt_feed_query = gtfs_rt_feed_filter.filter(
Database().get_session().query(Gtfsrealtimefeed, t_location_with_translations_en)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.outerjoin(Entitytype, Gtfsrealtimefeed.entitytypes).options(
joinedload(Gtfsrealtimefeed.entitytypes)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.outerjoin(Location, Feed.locations).options(
joinedload(Gtfsrealtimefeed.locations)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.outerjoin(Gtfsfeed, Gtfsrealtimefeed.gtfs_feeds).options(
joinedload(Gtfsrealtimefeed.gtfs_feeds)
gtfs_rt_feed_query = (
gtfs_rt_feed_query.outerjoin(Location, Gtfsrealtimefeed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.outerjoin(Entitytype, Gtfsrealtimefeed.entitytypes)
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*BasicFeedImpl.get_joinedload_options(),
)
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
return self._get_response(gtfs_rt_feed_query, limit, offset, GtfsRTFeedImpl)

@staticmethod
def _get_response(feed_query: Query, limit: int, offset: int, impl_cls: type[T]) -> List[T]:
"""Get the response for the feed query."""
if limit is not None:
gtfs_rt_feed_query = gtfs_rt_feed_query.limit(limit)
feed_query = feed_query.limit(limit)
if offset is not None:
gtfs_rt_feed_query = gtfs_rt_feed_query.offset(offset)
results = gtfs_rt_feed_query.all()
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in results]
feed_query = feed_query.offset(offset)
results = feed_query.all()
location_translations = {row[1]: create_location_translation_object(row) for row in results}
response = [impl_cls.from_orm(feed[0], location_translations) for feed in results]
return list({feed.id: feed for feed in response}.values())

def get_gtfs_feed_gtfs_rt_feeds(
self,
id: str,
) -> List[GtfsRTFeed]:
"""Get a list of GTFS Realtime related to a GTFS feed."""
feed = (
FeedFilter(
stable_id=id,
status=None,
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
.first()
)
feed, translations = self._get_gtfs_feed(id)
if feed:
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in feed.gtfs_rt_feeds]
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
else:
raise_http_error(404, gtfs_feed_not_found.format(id))
2 changes: 1 addition & 1 deletion api/src/feeds/impl/models/basic_feed_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Config:
from_attributes = True

@classmethod
def from_orm(cls, feed: Feed | None) -> BasicFeed | None:
def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None:
if not feed:
return None
return cls(
Expand Down
12 changes: 9 additions & 3 deletions api/src/feeds/impl/models/gtfs_feed_impl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Dict

from database_gen.sqlacodegen_models import Gtfsfeed as GtfsfeedOrm
from feeds.impl.models.basic_feed_impl import BaseFeedImpl
from feeds.impl.models.latest_dataset_impl import LatestDatasetImpl
from feeds.impl.models.location_impl import LocationImpl
from feeds_gen.models.gtfs_feed import GtfsFeed
from utils.location_translation import LocationTranslation, translate_feed_locations


class GtfsFeedImpl(BaseFeedImpl, GtfsFeed):
Expand All @@ -17,12 +20,15 @@ class Config:
from_attributes = True

@classmethod
def from_orm(cls, feed: GtfsfeedOrm | None) -> GtfsFeed | None:
gtfs_feed = super().from_orm(feed)
def from_orm(
cls, feed: GtfsfeedOrm | None, location_translations: Dict[str, LocationTranslation] = None
) -> GtfsFeed | None:
if location_translations is not None:
translate_feed_locations(feed, location_translations)
gtfs_feed: GtfsFeed = super().from_orm(feed)
if not gtfs_feed:
return None
gtfs_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations]

latest_dataset = next(
(dataset for dataset in feed.gtfsdatasets if dataset is not None and dataset.latest), None
)
Expand Down
11 changes: 9 additions & 2 deletions api/src/feeds/impl/models/gtfs_rt_feed_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Dict

from database_gen.sqlacodegen_models import Gtfsrealtimefeed as GtfsRTFeedOrm
from feeds.impl.models.basic_feed_impl import BaseFeedImpl
from feeds.impl.models.location_impl import LocationImpl
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
from utils.location_translation import LocationTranslation, translate_feed_locations


class GtfsRTFeedImpl(BaseFeedImpl, GtfsRTFeed):
Expand All @@ -14,8 +17,12 @@ class Config:
from_attributes = True

@classmethod
def from_orm(cls, feed: GtfsRTFeedOrm | None) -> GtfsRTFeed | None:
gtfs_rt_feed = super().from_orm(feed)
def from_orm(
cls, feed: GtfsRTFeedOrm | None, location_translations: Dict[str, LocationTranslation] = None
) -> GtfsRTFeed | None:
if location_translations is not None:
translate_feed_locations(feed, location_translations)
gtfs_rt_feed: GtfsRTFeed = super().from_orm(feed)
if not gtfs_rt_feed:
return None
gtfs_rt_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations]
Expand Down
1 change: 1 addition & 0 deletions api/src/feeds/impl/models/location_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def from_orm(cls, location: LocationOrm | None) -> Location | None:
return None
return cls(
country_code=location.country_code,
country=location.country,
subdivision_name=location.subdivision_name,
municipality=location.municipality,
)
30 changes: 30 additions & 0 deletions api/src/feeds/impl/models/search_feed_item_result_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from feeds_gen.models.latest_dataset import LatestDataset
from feeds_gen.models.search_feed_item_result import SearchFeedItemResult
from feeds_gen.models.source_info import SourceInfo
import pycountry


class SearchFeedItemResultImpl(SearchFeedItemResult):
Expand Down Expand Up @@ -74,11 +75,40 @@ def from_orm_gtfs_rt(cls, feed_search_row):
feed_references=feed_search_row.feed_reference_ids,
)

@classmethod
def _translate_locations(cls, feed_search_row):
"""Translate location information in the feed search row."""
country_translations = cls._create_translation_dict(feed_search_row.country_translations)
subdivision_translations = cls._create_translation_dict(feed_search_row.subdivision_name_translations)
municipality_translations = cls._create_translation_dict(feed_search_row.municipality_translations)

for location in feed_search_row.locations:
location["country"] = country_translations.get(location["country"], location["country"])
if location["country"] is None:
location["country"] = pycountry.countries.get(alpha_2=location["country_code"]).name
location["subdivision_name"] = subdivision_translations.get(
location["subdivision_name"], location["subdivision_name"]
)
location["municipality"] = municipality_translations.get(location["municipality"], location["municipality"])

@staticmethod
def _create_translation_dict(translations):
"""Helper method to create a translation dictionary."""
if translations:
return {
elem.get("key"): elem.get("value") for elem in translations if elem.get("key") and elem.get("value")
}
return {}

@classmethod
def from_orm(cls, feed_search_row):
"""Create a model instance from a SQLAlchemy row object."""
if feed_search_row is None:
return None

# Translate location data
cls._translate_locations(feed_search_row)

match feed_search_row.data_type:
case "gtfs":
return cls.from_orm_gtfs(feed_search_row)
Expand Down
Loading
Loading