Skip to content

Commit

Permalink
feat: added translation to rt and search endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
cka-y committed Aug 9, 2024
1 parent affb66a commit 03f5187
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 58 deletions.
115 changes: 62 additions & 53 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Union

from sqlalchemy.orm import joinedload

from sqlalchemy.orm.query import Query
from database.database import Database
from database_gen.sqlacodegen_models import (
Feed,
Expand Down Expand Up @@ -37,7 +37,7 @@
from feeds_gen.models.gtfs_feed import GtfsFeed
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
from utils.date_utils import valid_iso_date
from utils.location_translation import create_location_translation_object
from utils.location_translation import create_location_translation_object, LocationTranslation


class FeedsApiImpl(BaseFeedsApi):
Expand Down Expand Up @@ -93,9 +93,17 @@ def get_gtfs_feed(
id: str,
) -> GtfsFeed:
"""Get the specified gtfs feed from the Mobility Database."""
feed, translations = self._get_gtfs_feed(id)
if feed:
return GtfsFeedImpl.from_orm(feed, translations)
else:
raise_http_error(404, gtfs_feed_not_found.format(id))

@staticmethod
def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
results = (
FeedFilter(
stable_id=id,
stable_id=stable_id,
status=None,
provider__ilike=None,
producer_url__ilike=None,
Expand All @@ -112,9 +120,8 @@ def get_gtfs_feed(
).all()
if len(results) > 0 and results[0].Gtfsfeed:
translations = {result[1]: create_location_translation_object(result) for result in results}
return GtfsFeedImpl.from_orm(results[0].Gtfsfeed, translations)
else:
raise_http_error(404, gtfs_feed_not_found.format(id))
return results[0].Gtfsfeed, translations
return None, {}

def get_gtfs_feed_datasets(
self,
Expand Down Expand Up @@ -174,7 +181,7 @@ def get_gtfs_feeds(
dataset_latitudes: str,
dataset_longitudes: str,
bounding_filter_method: str,
) -> List[GtfsFeed]:
) -> list[BasicFeed | None]:
"""Get some (or all) GTFS feeds from the Mobility Database."""
gtfs_feed_filter = GtfsFeedFilter(
stable_id=None,
Expand Down Expand Up @@ -204,33 +211,36 @@ def get_gtfs_feeds(
gtfs_feed_query = DatasetsApiImpl.apply_bounding_filtering(
gtfs_feed_query, dataset_latitudes, dataset_longitudes, bounding_filter_method
)
if limit is not None:
gtfs_feed_query = gtfs_feed_query.limit(limit)
if offset is not None:
gtfs_feed_query = gtfs_feed_query.offset(offset)
results = gtfs_feed_query.all()
location_translations = {row[1]: create_location_translation_object(row) for row in results}
response = [GtfsFeedImpl.from_orm(gtfs_feed.Gtfsfeed, location_translations) for gtfs_feed in results]
return list({feed.id: feed for feed in response}.values())
return self._get_response(gtfs_feed_query, limit, offset, GtfsFeedImpl)

def get_gtfs_rt_feed(
self,
id: str,
) -> GtfsRTFeed:
"""Get the specified GTFS Realtime feed from the Mobility Database."""
feed = (
GtfsRtFeedFilter(
stable_id=id,
provider__ilike=None,
producer_url__ilike=None,
entity_types=None,
location=None,
)
.filter(Database().get_query_model(Gtfsrealtimefeed))
.first()
gtfs_rt_feed_filter = GtfsRtFeedFilter(
stable_id=id,
provider__ilike=None,
producer_url__ilike=None,
entity_types=None,
location=None,
)
if feed:
return GtfsRTFeedImpl.from_orm(feed)
results = gtfs_rt_feed_filter.filter(
Database()
.get_session()
.query(Gtfsrealtimefeed, t_location_with_translations)
.outerjoin(Location, Gtfsrealtimefeed.locations)
.outerjoin(t_location_with_translations, Location.id == t_location_with_translations.c.location_id)
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*BasicFeedImpl.get_joinedload_options(),
)
).all()

if len(results) > 0 and results[0].Gtfsrealtimefeed:
translations = {result[1]: create_location_translation_object(result) for result in results}
return GtfsRTFeedImpl.from_orm(results[0].Gtfsrealtimefeed, translations)
else:
raise_http_error(404, gtfs_rt_feed_not_found.format(id))

Expand Down Expand Up @@ -269,42 +279,41 @@ def get_gtfs_rt_feeds(
municipality__ilike=municipality,
),
)
gtfs_rt_feed_query = gtfs_rt_feed_filter.filter(Database().get_query_model(Gtfsrealtimefeed)).options(
*BasicFeedImpl.get_joinedload_options()
)
gtfs_rt_feed_query = gtfs_rt_feed_query.outerjoin(Entitytype, Gtfsrealtimefeed.entitytypes).options(
joinedload(Gtfsrealtimefeed.entitytypes)
gtfs_rt_feed_query = gtfs_rt_feed_filter.filter(
Database().get_session().query(Gtfsrealtimefeed, t_location_with_translations)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.outerjoin(Location, Feed.locations).options(
joinedload(Gtfsrealtimefeed.locations)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.outerjoin(Gtfsfeed, Gtfsrealtimefeed.gtfs_feeds).options(
joinedload(Gtfsrealtimefeed.gtfs_feeds)
gtfs_rt_feed_query = (
gtfs_rt_feed_query.outerjoin(Location, Gtfsrealtimefeed.locations)
.outerjoin(t_location_with_translations, Location.id == t_location_with_translations.c.location_id)
.outerjoin(Entitytype, Gtfsrealtimefeed.entitytypes)
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*BasicFeedImpl.get_joinedload_options(),
)
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
return self._get_response(gtfs_rt_feed_query, limit, offset, GtfsRTFeedImpl)

@staticmethod
def _get_response(feed_query: Query, limit: int, offset: int, impl_cls: type[BasicFeedImpl]):
"""Get the response for the feed query."""
if limit is not None:
gtfs_rt_feed_query = gtfs_rt_feed_query.limit(limit)
feed_query = feed_query.limit(limit)
if offset is not None:
gtfs_rt_feed_query = gtfs_rt_feed_query.offset(offset)
results = gtfs_rt_feed_query.all()
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in results]
feed_query = feed_query.offset(offset)
results = feed_query.all()
location_translations = {row[1]: create_location_translation_object(row) for row in results}
response = [impl_cls.from_orm(feed[0], location_translations) for feed in results]
return list({feed.id: feed for feed in response}.values())

def get_gtfs_feed_gtfs_rt_feeds(
self,
id: str,
) -> List[GtfsRTFeed]:
"""Get a list of GTFS Realtime related to a GTFS feed."""
feed = (
FeedFilter(
stable_id=id,
status=None,
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
.first()
)
feed, translations = self._get_gtfs_feed(id)
if feed:
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in feed.gtfs_rt_feeds]
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
else:
raise_http_error(404, gtfs_feed_not_found.format(id))
2 changes: 1 addition & 1 deletion api/src/feeds/impl/models/basic_feed_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Config:
from_attributes = True

@classmethod
def from_orm(cls, feed: Feed | None) -> BasicFeed | None:
def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None:
if not feed:
return None
return cls(
Expand Down
2 changes: 1 addition & 1 deletion api/src/feeds/impl/models/gtfs_feed_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def from_orm(
) -> GtfsFeed | None:
if location_translations is not None:
translate_feed_locations(feed, location_translations)
gtfs_feed = super().from_orm(feed)
gtfs_feed: GtfsFeed = super().from_orm(feed)
if not gtfs_feed:
return None
gtfs_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations]
Expand Down
11 changes: 9 additions & 2 deletions api/src/feeds/impl/models/gtfs_rt_feed_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Dict

from database_gen.sqlacodegen_models import Gtfsrealtimefeed as GtfsRTFeedOrm
from feeds.impl.models.basic_feed_impl import BaseFeedImpl
from feeds.impl.models.location_impl import LocationImpl
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
from utils.location_translation import LocationTranslation, translate_feed_locations


class GtfsRTFeedImpl(BaseFeedImpl, GtfsRTFeed):
Expand All @@ -14,8 +17,12 @@ class Config:
from_attributes = True

@classmethod
def from_orm(cls, feed: GtfsRTFeedOrm | None) -> GtfsRTFeed | None:
gtfs_rt_feed = super().from_orm(feed)
def from_orm(
cls, feed: GtfsRTFeedOrm | None, location_translations: Dict[str, LocationTranslation] = None
) -> GtfsRTFeed | None:
if location_translations is not None:
translate_feed_locations(feed, location_translations)
gtfs_rt_feed: GtfsRTFeed = super().from_orm(feed)
if not gtfs_rt_feed:
return None
gtfs_rt_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations]
Expand Down
27 changes: 27 additions & 0 deletions api/src/feeds/impl/models/search_feed_item_result_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,38 @@ def from_orm_gtfs_rt(cls, feed_search_row):
feed_references=feed_search_row.feed_reference_ids,
)

@classmethod
def _translate_locations(cls, feed_search_row):
"""Translate location information in the feed search row."""
country_translations = cls._create_translation_dict(feed_search_row.country_translations)
subdivision_translations = cls._create_translation_dict(feed_search_row.subdivision_name_translations)
municipality_translations = cls._create_translation_dict(feed_search_row.municipality_translations)

for location in feed_search_row.locations:
location["country"] = country_translations.get(location["country"], location["country"])
location["subdivision_name"] = subdivision_translations.get(
location["subdivision_name"], location["subdivision_name"]
)
location["municipality"] = municipality_translations.get(location["municipality"], location["municipality"])

@staticmethod
def _create_translation_dict(translations):
"""Helper method to create a translation dictionary."""
if translations:
return {
elem.get("key"): elem.get("value") for elem in translations if elem.get("key") and elem.get("value")
}
return {}

@classmethod
def from_orm(cls, feed_search_row):
"""Create a model instance from a SQLAlchemy row object."""
if feed_search_row is None:
return None

# Translate location data
cls._translate_locations(feed_search_row)

match feed_search_row.data_type:
case "gtfs":
return cls.from_orm_gtfs(feed_search_row)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, **kwargs):
feed_reference_ids=[],
entities=["sa"],
locations=[],
country_translations=[],
subdivision_name_translations=[],
municipality_translations=[],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Location,
Translation,
t_feedsearch,
Gtfsfeed,
)
from helpers.database import refresh_materialized_view
from .geocoded_location import GeocodedLocation
Expand Down Expand Up @@ -192,13 +193,27 @@ def update_location(

if len(locations) == 0:
raise Exception("No locations found for the dataset.")
logging.info(f"Updating dataset with stable ID {dataset.stable_id}")
dataset.locations.clear()
dataset.locations = locations

# Update the location of the related feed as well
# Update the location of the related feeds as well
logging.info(f"Updating feed with stable ID {dataset.feed.stable_id}")
dataset.feed.locations.clear()
dataset.feed.locations = locations

gtfs_feed: Gtfsfeed | None = (
session.query(Gtfsfeed)
.filter(Gtfsfeed.stable_id == dataset.feed.stable_id)
.one_or_none()
)

for gtfs_rt_feed in gtfs_feed.gtfs_rt_feeds:
logging.info(f"Updating GTFS-RT feed with stable ID {gtfs_rt_feed.stable_id}")
gtfs_rt_feed.locations.clear()
gtfs_rt_feed.locations = locations
session.add(gtfs_rt_feed)

refresh_materialized_view(session, t_feedsearch.name)
session.add(dataset)
session.commit()
Expand Down

0 comments on commit 03f5187

Please sign in to comment.