Skip to content

Commit

Permalink
Run black
Browse files Browse the repository at this point in the history
  • Loading branch information
maximusunc committed Jun 20, 2024
1 parent 0b42ba2 commit 913b846
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 26 deletions.
24 changes: 21 additions & 3 deletions test_harness/result_collector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The Collector of Results."""

from typing import Union
from translator_testing_model.datamodel.pydanticmodel import TestAsset, TestCase

Expand All @@ -10,7 +11,15 @@ class ResultCollector:

def __init__(self):
"""Initialize the Collector."""
self.agents = ["ars", "aragorn", "arax", "biothings-explorer", "improving-agent", "unsecret-agent", "cqs"]
self.agents = [
"ars",
"aragorn",
"arax",
"biothings-explorer",
"improving-agent",
"unsecret-agent",
"cqs",
]
query_types = ["TopAnswer", "Acceptable", "BadButForgivable", "NeverShow"]
self.result_types = {
"PASSED": "PASSED",
Expand All @@ -30,7 +39,14 @@ def __init__(self):
header = ",".join(self.columns)
self.csv = f"{header}\n"

def collect_result(self, test: TestCase, asset: TestAsset, report: dict, parent_pk: Union[str, None], url: str):
def collect_result(
self,
test: TestCase,
asset: TestAsset,
report: dict,
parent_pk: Union[str, None],
url: str,
):
"""Add a single report to the total output."""
# add result to stats
for agent in self.agents:
Expand All @@ -45,7 +61,9 @@ def collect_result(self, test: TestCase, asset: TestAsset, report: dict, parent_
agent_results = ",".join(
get_tag(report[agent]) for agent in self.agents if agent in report
)
pk_url = f"https://arax.ncats.io/?r={parent_pk}" if parent_pk is not None else ""
pk_url = (
f"https://arax.ncats.io/?r={parent_pk}" if parent_pk is not None else ""
)
self.csv += (
f"""{asset.name},{url},{pk_url},{test.id},{asset.id},{agent_results}\n"""
)
30 changes: 19 additions & 11 deletions test_harness/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Dict

from ARS_Test_Runner.semantic_test import pass_fail_analysis

# from benchmarks_runner import run_benchmarks

from translator_testing_model.datamodel.pydanticmodel import TestCase
Expand Down Expand Up @@ -65,7 +66,7 @@ async def run_tests(
test_ids.append(test_id)
except Exception:
logger.error(f"Failed to create test: {test.id}")

test_asset_hash = hash_test_asset(asset)
test_query = query_responses.get(test_asset_hash)
if test_query is not None:
Expand All @@ -90,16 +91,24 @@ async def run_tests(
if response["status_code"] == "598":
agent_report["message"] = "Timed out"
else:
agent_report["message"] = f"Status code: {response['status_code']}"
agent_report["message"] = (
f"Status code: {response['status_code']}"
)
elif (
response["response"]["message"]["results"] is None or
len(response["response"]["message"]["results"]) == 0
response["response"]["message"]["results"] is None
or len(response["response"]["message"]["results"]) == 0
):
agent_report["status"] = "DONE"
agent_report["message"] = "No results"
else:
await pass_fail_analysis(report["result"], agent, response["response"]["message"]["results"], query_runner.normalized_curies[asset.output_id], asset.expected_output)

await pass_fail_analysis(
report["result"],
agent,
response["response"]["message"]["results"],
query_runner.normalized_curies[asset.output_id],
asset.expected_output,
)

status = "PASSED"
# grab only ars result if it exists, otherwise default to failed
ars_status = report["result"].get("ars", {}).get("status")
Expand All @@ -122,18 +131,17 @@ async def run_tests(
"key": ara,
"value": get_tag(report["result"][ara]),
}
for ara in collector.agents if ara in report["result"]
for ara in collector.agents
if ara in report["result"]
]
await reporter.upload_labels(test_id, labels)
except Exception as e:
logger.warning(f"[{test.id}] failed to upload labels: {e}")
try:
await reporter.upload_log(
test_id, json.dumps(report, indent=4)
)
await reporter.upload_log(test_id, json.dumps(report, indent=4))
except Exception:
logger.error(f"[{test.id}] failed to upload logs.")

try:
await reporter.finish_test(test_id, status)
except Exception:
Expand Down
8 changes: 6 additions & 2 deletions test_harness/runner/generate_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def generate_query(test_asset: TestAsset) -> dict:
raise Exception("Unsupported input category for MVP1")
# add knowledge_type
if "inferred" in test_asset.test_runner_settings:
query["message"]["query_graph"]["edges"]["t_edge"]["knowledge_type"] = "inferred"
query["message"]["query_graph"]["edges"]["t_edge"][
"knowledge_type"
] = "inferred"
elif test_asset.predicate_id == "biolink:affects":
# MVP2
query = MVP2
Expand All @@ -96,6 +98,8 @@ def generate_query(test_asset: TestAsset) -> dict:
][1]["qualifier_value"] = direction_qualifier
# add knowledge_type
if "inferred" in test_asset.test_runner_settings:
query["message"]["query_graph"]["edges"]["t_edge"]["knowledge_type"] = "inferred"
query["message"]["query_graph"]["edges"]["t_edge"][
"knowledge_type"
] = "inferred"

return query
38 changes: 28 additions & 10 deletions test_harness/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ async def run_queries(
# component = "ara"
# loop over all specified components, i.e. ars, ara, kp, utilities
semaphore = asyncio.Semaphore(concurrency)
self.logger.info(f"Sending queries to {self.registry[env_map[test_case.test_env]][component]}")
self.logger.info(
f"Sending queries to {self.registry[env_map[test_case.test_env]][component]}"
)
tasks = [
asyncio.create_task(
self.run_query(
Expand All @@ -133,7 +135,9 @@ async def run_queries(

return queries

async def get_ars_responses(self, parent_pk: str, base_url: str) -> Tuple[Dict[str, dict], Dict[str, str]]:
async def get_ars_responses(
self, parent_pk: str, base_url: str
) -> Tuple[Dict[str, dict], Dict[str, str]]:
"""Given a parent pk, get responses for all ARS things."""
responses = {}
pks = {
Expand Down Expand Up @@ -175,7 +179,9 @@ async def get_ars_responses(self, parent_pk: str, base_url: str) -> Tuple[Dict[s
current_time = time.time()
await asyncio.sleep(5)
else:
self.logger.warning(f"Timed out getting ARS child messages after {MAX_QUERY_TIME / 60} minutes.")
self.logger.warning(
f"Timed out getting ARS child messages after {MAX_QUERY_TIME / 60} minutes."
)

# add response to output
if response is not None:
Expand All @@ -189,14 +195,18 @@ async def get_ars_responses(self, parent_pk: str, base_url: str) -> Tuple[Dict[s
current_time = time.time()
while current_time - start_time <= MAX_QUERY_TIME:
async with httpx.AsyncClient(timeout=30) as client:
res = await client.get(f"{base_url}/ars/api/messages/{parent_pk}?trace=y")
res = await client.get(
f"{base_url}/ars/api/messages/{parent_pk}?trace=y"
)
res.raise_for_status()
response = res.json()
status = response.get("status")
if status == "Done" or status == "Error":
merged_pk = response.get("merged_version")
if merged_pk is None:
self.logger.error(f"Failed to get the ARS merged message from pk: {parent_pk}.")
self.logger.error(
f"Failed to get the ARS merged message from pk: {parent_pk}."
)
pks["ars"] = "None"
responses["ars"] = {
"response": {"message": {"results": []}},
Expand All @@ -206,12 +216,18 @@ async def get_ars_responses(self, parent_pk: str, base_url: str) -> Tuple[Dict[s
# add final ars pk
pks["ars"] = merged_pk
# get full merged pk
res = await client.get(f"{base_url}/ars/api/messages/{merged_pk}")
res = await client.get(
f"{base_url}/ars/api/messages/{merged_pk}"
)
res.raise_for_status()
merged_message = res.json()
responses["ars"] = {
"response": merged_message.get("fields", {}).get("data", {}),
"status_code": merged_message.get("fields", {}).get("code", 410),
"response": merged_message.get("fields", {}).get(
"data", {}
),
"status_code": merged_message.get("fields", {}).get(
"code", 410
),
}
self.logger.info("Got ARS merged message!")
break
Expand All @@ -220,11 +236,13 @@ async def get_ars_responses(self, parent_pk: str, base_url: str) -> Tuple[Dict[s
current_time = time.time()
await asyncio.sleep(5)
else:
self.logger.warning(f"ARS merging took greater than {MAX_QUERY_TIME / 60} minutes.")
self.logger.warning(
f"ARS merging took greater than {MAX_QUERY_TIME / 60} minutes."
)
pks["ars"] = "None"
responses["ars"] = {
"response": {"message": {"results": []}},
"status_code": 410,
}

return responses, pks
return responses, pks

0 comments on commit 913b846

Please sign in to comment.