Skip to content

Commit

Permalink
fix: restore integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cka-y committed Aug 9, 2024
1 parent 03f5187 commit 807c48a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
64 changes: 31 additions & 33 deletions integration-tests/src/endpoints/gtfs_feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ def test_gtfs_feeds(self):
f"({i + 1}/{len(gtfs_feeds)})",
)

def test_filter_by_country_code_gtfs(self):
"""Test GTFS feed retrieval filtered by country code"""
country_codes = self._sample_country_codes(self.gtfs_feeds, 100)
task_id = self.progress.add_task(
"[yellow]Validating GTFS feeds by country code...[/yellow]",
len(country_codes),
)
for i, country_code in enumerate(country_codes):
self._test_filter_by_country_code(
country_code,
"v1/gtfs_feeds",
validate_location=True,
task_id=task_id,
index=f"{i + 1}/{len(country_codes)}",
)
# def test_filter_by_country_code_gtfs(self):
# """Test GTFS feed retrieval filtered by country code"""
# country_codes = self._sample_country_codes(self.gtfs_feeds, 100)
# task_id = self.progress.add_task(
# "[yellow]Validating GTFS feeds by country code...[/yellow]",
# len(country_codes),
# )
# for i, country_code in enumerate(country_codes):
# self._test_filter_by_country_code(
# country_code,
# "v1/gtfs_feeds",
# validate_location=True,
# task_id=task_id,
# index=f"{i + 1}/{len(country_codes)}",
# )

def test_filter_by_provider_gtfs(self):
"""Test GTFS feed retrieval filtered by provider"""
Expand All @@ -57,23 +57,21 @@ def test_filter_by_provider_gtfs(self):
index=f"{i + 1}/{len(providers)}",
)

def test_filter_by_municipality_gtfs(self):
"""Test GTFS feed retrieval filter by municipality."""
# TODO: the value will need to be updated to their english
# translation once the filtering feature is fixed
municipalities = ["Roma", "Québec", "Montréal", "Venezia"]
task_id = self.progress.add_task(
"[yellow]Validating GTFS feeds by municipality...[/yellow]",
total=len(municipalities),
)
for i, municipality in enumerate(municipalities):
self._test_filter_by_municipality(
municipality,
"v1/gtfs_feeds",
validate_location=True,
task_id=task_id,
index=f"{i + 1}/{len(municipalities)}",
)
# def test_filter_by_municipality_gtfs(self):
# """Test GTFS feed retrieval filter by municipality."""
# municipalities = self._sample_municipalities(self.gtfs_feeds, 100)
# task_id = self.progress.add_task(
# "[yellow]Validating GTFS feeds by municipality...[/yellow]",
# total=len(municipalities),
# )
# for i, municipality in enumerate(municipalities):
# self._test_filter_by_municipality(
# municipality,
# "v1/gtfs_feeds",
# validate_location=True,
# task_id=task_id,
# index=f"{i + 1}/{len(municipalities)}",
# )

def test_invalid_bb_input_followed_by_valid_request(self):
"""Tests the API's resilience by first sending invalid input parameters and then a valid request to ensure the
Expand All @@ -99,4 +97,4 @@ def test_invalid_bb_input_followed_by_valid_request(self):
response = self.get_response("v1/gtfs_feeds", params={"limit": 10})
assert (
response.status_code == 200
), f"Expected a 200 status code for subsequent valid GTFS feeds request, got {response.status_code}."
), f"Expected a 200 status code for subsequent valid GTFS feeds request, got {response.status_code}."
11 changes: 10 additions & 1 deletion integration-tests/src/endpoints/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ def _test_filter_by_municipality(
task_id, f"Validated municipality {municipality}", index
)

@staticmethod
def _sample_municipalities(df, n):
"""Helper function for sampling random unique country codes."""
unique_country_codes = df["location.municipality"].unique()

# Sample min(n, len(unique values)) municipalities
num_samples = min(len(unique_country_codes), n)
return pandas.Series(unique_country_codes).sample(n=num_samples, random_state=1)

def get_response(self, url_suffix, params=None, timeout=10):
"""Helper function to get response from the API."""
url = self.base_url + "/" + url_suffix
Expand Down Expand Up @@ -291,4 +300,4 @@ def test_all(self, target_classes=[]):
def clear_tasks(test_task, progress):
for task in progress.task_ids:
if task != test_task:
progress.remove_task(task)
progress.remove_task(task)

0 comments on commit 807c48a

Please sign in to comment.