Skip to content

Commit

Permalink
provided session using with statement
Browse files Browse the repository at this point in the history
  • Loading branch information
qcdyx committed Dec 18, 2024
1 parent d982f3d commit cf2a50a
Showing 1 changed file with 76 additions and 73 deletions.
149 changes: 76 additions & 73 deletions api/src/scripts/populate_db_gbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def deprecate_feeds(self, deprecated_feeds):
if gbfs_feed:
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
gbfs_feed.status = "deprecated"
session.flush()
# session.flush()

def populate_db(self):
"""Populate the database with the GBFS feeds"""
Expand All @@ -57,81 +57,84 @@ def populate_db(self):
# Compare the database to the CSV file
df_from_db = generate_system_csv_from_db(self.df, session)
added_or_updated_feeds, deprecated_feeds = compare_db_to_csv(df_from_db, self.df, self.logger)
except Exception as e:
self.logger.error(f"Failed to compare the database to the CSV file. Error: {e}")
return

self.deprecate_feeds(deprecated_feeds)
if added_or_updated_feeds is None:
added_or_updated_feeds = self.df
for index, row in added_or_updated_feeds.iterrows():
self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}")
stable_id = self.get_stable_id(row)
gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
fetched_data = fetch_data(
row["Auto-Discovery URL"], self.logger, ["system_information", "gbfs_versions"], ["version"]
)
# If the feed already exists, update it. Otherwise, create a new feed.
if gbfs_feed:
feed_id = gbfs_feed.id
self.logger.info(f"Updating feed {stable_id} - {row['Name']}")
else:
feed_id = generate_unique_id()
self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}")
gbfs_feed = Gbfsfeed(
id=feed_id,
data_type="gbfs",
stable_id=stable_id,
created_at=datetime.now(pytz.utc),
)
gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])]
self.db.session.add(gbfs_feed)

system_information_content = get_data_content(fetched_data.get("system_information"), self.logger)
gbfs_feed.license_url = get_license_url(system_information_content, self.logger)
gbfs_feed.feed_contact_email = (
system_information_content.get("feed_contact_email") if system_information_content else None
)
gbfs_feed.operator = row["Name"]
gbfs_feed.operator_url = row["URL"]
gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
gbfs_feed.updated_at = datetime.now(pytz.utc)

country_code = self.get_safe_value(row, "Country Code", "")
municipality = self.get_safe_value(row, "Location", "")
location_id = self.get_location_id(country_code, None, municipality)
country = pycountry.countries.get(alpha_2=country_code) if country_code else None
location = self.db.session.get(Location, location_id) or Location(
id=location_id,
country_code=country_code,
country=country.name if country else None,
municipality=municipality,
)
gbfs_feed.locations.clear()
gbfs_feed.locations = [location]

# Add the GBFS versions
versions = get_gbfs_versions(
fetched_data.get("gbfs_versions"), row["Auto-Discovery URL"], fetched_data.get("version"), self.logger
)
existing_versions = [version.version for version in gbfs_feed.gbfsversions]
for version in versions:
version_value = version.get("version")
if version_value.upper() in OFFICIAL_VERSIONS and version_value not in existing_versions:
gbfs_feed.gbfsversions.append(
Gbfsversion(
feed_id=feed_id,
url=version.get("url"),
version=version_value,
)
self.deprecate_feeds(deprecated_feeds)
if added_or_updated_feeds is None:
added_or_updated_feeds = self.df
for index, row in added_or_updated_feeds.iterrows():
self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}")
stable_id = self.get_stable_id(row)
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
fetched_data = fetch_data(
row["Auto-Discovery URL"], self.logger, ["system_information", "gbfs_versions"], ["version"]
)
# If the feed already exists, update it. Otherwise, create a new feed.
if gbfs_feed:
feed_id = gbfs_feed.id
self.logger.info(f"Updating feed {stable_id} - {row['Name']}")
else:
feed_id = generate_unique_id()
self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}")
gbfs_feed = Gbfsfeed(
id=feed_id,
data_type="gbfs",
stable_id=stable_id,
created_at=datetime.now(pytz.utc),
)
gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])]
session.add(gbfs_feed)

self.db.session.flush()
self.logger.info(80 * "-")

self.db.session.commit()
end_time = datetime.now()
self.logger.info(f"Time taken: {end_time - start_time} seconds")
system_information_content = get_data_content(fetched_data.get("system_information"), self.logger)
gbfs_feed.license_url = get_license_url(system_information_content, self.logger)
gbfs_feed.feed_contact_email = (
system_information_content.get("feed_contact_email") if system_information_content else None
)
gbfs_feed.operator = row["Name"]
gbfs_feed.operator_url = row["URL"]
gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
gbfs_feed.updated_at = datetime.now(pytz.utc)

country_code = self.get_safe_value(row, "Country Code", "")
municipality = self.get_safe_value(row, "Location", "")
location_id = self.get_location_id(country_code, None, municipality)
country = pycountry.countries.get(alpha_2=country_code) if country_code else None
location = session.get(Location, location_id) or Location(
id=location_id,
country_code=country_code,
country=country.name if country else None,
municipality=municipality,
)
gbfs_feed.locations.clear()
gbfs_feed.locations = [location]

# Add the GBFS versions
versions = get_gbfs_versions(
fetched_data.get("gbfs_versions"),
row["Auto-Discovery URL"],
fetched_data.get("version"),
self.logger,
)
existing_versions = [version.version for version in gbfs_feed.gbfsversions]
for version in versions:
version_value = version.get("version")
if version_value.upper() in OFFICIAL_VERSIONS and version_value not in existing_versions:
gbfs_feed.gbfsversions.append(
Gbfsversion(
feed_id=feed_id,
url=version.get("url"),
version=version_value,
)
)

# self.db.session.flush()
self.logger.info(80 * "-")

# self.db.session.commit()
end_time = datetime.now()
self.logger.info(f"Time taken: {end_time - start_time} seconds")
except Exception as e:
self.logger.error(f"Error populating the database: {e}")
raise e


if __name__ == "__main__":
Expand Down

0 comments on commit cf2a50a

Please sign in to comment.