diff --git a/paperqa/litqa.py b/paperqa/litqa.py index 631b9d04..b1b72741 100644 --- a/paperqa/litqa.py +++ b/paperqa/litqa.py @@ -43,7 +43,8 @@ def make_mc_options( split_distractors = [d.strip("'[ ]\"") for d in distractors.split(",")] options: list[str] = split_distractors else: - options = distractors + # We are going to modify options in-place, so copy it first + options = distractors.copy() if ideal == "null": if not unsure_option: diff --git a/tests/test_litqa.py b/tests/test_litqa.py index 633b58bc..57bdbb29 100644 --- a/tests/test_litqa.py +++ b/tests/test_litqa.py @@ -1,36 +1,86 @@ from typing import cast +import pytest + from paperqa.litqa import LitQAEvaluation, read_litqa_v2_from_hub -def test_creating_litqa_questions() -> None: - """Test making LitQA eval questions after downloading from Hugging Face Hub.""" - _, eval_split = read_litqa_v2_from_hub(seed=42) - assert len(eval_split) > 3 - assert [ - LitQAEvaluation.from_question( - ideal=cast(str, row.ideal), - distractors=cast(list[str], row.distractors), - question=cast(str, row.question), - seed=42, - )[0] - for row in eval_split[:3].itertuples() - ] == [ - ( - "Q: Which of the following mutations in yeast Pbs2 increases its" - " interaction with SH3?\n\nOptions:\nA) S83F\nB) I87W\nC) N92H\nD) K85W\nE)" - " Insufficient information to answer this question\nF) N92S\nG) P97A" - ), - ( - "Q: What percentage of colorectal cancer-associated fibroblasts typically" - " survive at 2 weeks if cultured with the platinum-based chemotherapy" - " oxaliplatin?\n\nOptions:\nA) 80-99%\nB) 1-20%\nC) 20-50%\nD) 50-80%\nE)" - " 0%\nF) Insufficient information to answer this question" - ), - ( - "Q: Which of the following genes shows the greatest difference in gene" - " expression between homologous cell types in mouse and human" - " brain?\n\nOptions:\nA) Htr3a\nB) Htr5a\nC) Htr6\nD) Insufficient" - " information to answer this question\nE) Htr1d" - ), - ] +class TestLitQAEvaluation: + @staticmethod + def _assert_prompt_is_valid( + qa_prompt: str, question: str, ideal: str, distractors: list[str] + ): + for substr in (question, "Insufficient information", ideal, *distractors): + assert qa_prompt.count(substr) == 1 + + @pytest.mark.asyncio + @pytest.mark.vcr + async def test_from_question(self) -> None: + """Tests that we can create a LitQA question and evaluate answers.""" + question = "What is my office's zip code?" + ideal = "94107" + distractors = ["-8", "94106", "cheesecake"] + + qa_prompt, eval_fn = LitQAEvaluation.from_question( + ideal=ideal, distractors=distractors, question=question + ) + self._assert_prompt_is_valid(qa_prompt, question, ideal, distractors) + + for answer, expected in ( + ("the answer is 94107", LitQAEvaluation.CORRECT), + # NOTE: The below case fails this test, because the LM doesn't accept an answer not in the options. + # See https://github.com/Future-House/paper-qa/issues/693 + # ("the answer is 14004", LitQAEvaluation.INCORRECT), + ("the answer is 94106", LitQAEvaluation.INCORRECT), + ("Insufficient information to answer", LitQAEvaluation.UNSURE), + ): + result = await eval_fn(answer) + assert result == expected + + def test_consistent_mc_options(self) -> None: + """Tests that creating multiple evaluations with the same seed results in the same prompt.""" + question = "What is the meaning of life?" + ideal = "42" + distractors = ["-84", "11", "cheesecake"] + + qa_prompt_1, _ = LitQAEvaluation.from_question( + ideal=ideal, distractors=distractors, question=question, seed=0 + ) + self._assert_prompt_is_valid(qa_prompt_1, question, ideal, distractors) + + qa_prompt_2, _ = LitQAEvaluation.from_question( + ideal=ideal, distractors=distractors, question=question, seed=0 + ) + assert qa_prompt_1 == qa_prompt_2 + + def test_creating_litqa_questions(self) -> None: + """Test making LitQA eval questions after downloading from Hugging Face Hub.""" + _, eval_split = read_litqa_v2_from_hub(seed=42) + assert len(eval_split) > 3 + assert [ + LitQAEvaluation.from_question( + ideal=cast(str, row.ideal), + distractors=cast(list[str], row.distractors), + question=cast(str, row.question), + seed=42, + )[0] + for row in eval_split[:3].itertuples() + ] == [ + ( + "Q: Which of the following mutations in yeast Pbs2 increases its" + " interaction with SH3?\n\nOptions:\nA) S83F\nB) I87W\nC) N92H\nD) K85W\nE)" + " Insufficient information to answer this question\nF) N92S\nG) P97A" + ), + ( + "Q: What percentage of colorectal cancer-associated fibroblasts typically" + " survive at 2 weeks if cultured with the platinum-based chemotherapy" + " oxaliplatin?\n\nOptions:\nA) 80-99%\nB) 1-20%\nC) 20-50%\nD) 50-80%\nE)" + " 0%\nF) Insufficient information to answer this question" + ), + ( + "Q: Which of the following genes shows the greatest difference in gene" + " expression between homologous cell types in mouse and human" + " brain?\n\nOptions:\nA) Htr3a\nB) Htr5a\nC) Htr6\nD) Insufficient" + " information to answer this question\nE) Htr1d" + ), + ] diff --git a/uv.lock b/uv.lock index 3f5e6a26..526f6b68 100644 --- a/uv.lock +++ b/uv.lock @@ -674,11 +674,11 @@ wheels = [ [[package]] name = "identify" -version = "2.6.1" +version = "2.6.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/29/bb/25024dbcc93516c492b75919e76f389bac754a3e4248682fba32b250c880/identify-2.6.1.tar.gz", hash = "sha256:91478c5fb7c3aac5ff7bf9b4344f803843dc586832d5f110d672b19aa1984c98", size = 99097 } +sdist = { url = "https://files.pythonhosted.org/packages/02/79/7a520fc5011e02ca3f3285b5f6820eaf80443eb73e3733f73c02fb42ba0b/identify-2.6.2.tar.gz", hash = "sha256:fab5c716c24d7a789775228823797296a2994b075fb6080ac83a102772a98cbd", size = 99113 } wheels = [ - { url = "https://files.pythonhosted.org/packages/7d/0c/4ef72754c050979fdcc06c744715ae70ea37e734816bb6514f79df77a42f/identify-2.6.1-py2.py3-none-any.whl", hash = "sha256:53863bcac7caf8d2ed85bd20312ea5dcfc22226800f6d6881f232d861db5a8f0", size = 98972 }, + { url = "https://files.pythonhosted.org/packages/e0/86/c4395700f3c5475424fb5c41e20c16be28d10c904aee4d005ba3217fc8e7/identify-2.6.2-py2.py3-none-any.whl", hash = "sha256:c097384259f49e372f4ea00a19719d95ae27dd5ff0fd77ad630aa891306b82f3", size = 98982 }, ] [[package]] @@ -752,14 +752,14 @@ wheels = [ [[package]] name = "jedi" -version = "0.19.1" +version = "0.19.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "parso" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d6/99/99b493cec4bf43176b678de30f81ed003fd6a647a301b9c927280c600f0a/jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd", size = 1227821 } +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 } wheels = [ - { url = "https://files.pythonhosted.org/packages/20/9f/bc63f0f0737ad7a60800bfd472a4836661adae21f9c2535f3957b1e54ceb/jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0", size = 1569361 }, + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 }, ] [[package]]