diff --git a/dataset_builders/pie/drugprot/drugprot.py b/dataset_builders/pie/drugprot/drugprot.py index 7f2dd46b..f07756ac 100644 --- a/dataset_builders/pie/drugprot/drugprot.py +++ b/dataset_builders/pie/drugprot/drugprot.py @@ -107,7 +107,8 @@ class Drugprot(GeneratorBasedBuilder): } BASE_DATASET_PATH = "bigbio/drugprot" - BASE_DATASET_REVISION = "38ff03d68347aaf694e598c50cb164191f50f61c" + # This revision includes the "test_background" split (see https://github.com/bigscience-workshop/biomedical/pull/928) + BASE_DATASET_REVISION = "0cc98b3d292242e69adcfd2c3e5eea94baaca8ea" BUILDER_CONFIGS = [ datasets.BuilderConfig( diff --git a/tests/dataset_builders/pie/drugprot/test_drugprot.py b/tests/dataset_builders/pie/drugprot/test_drugprot.py index 4cf58b8b..63f82d83 100644 --- a/tests/dataset_builders/pie/drugprot/test_drugprot.py +++ b/tests/dataset_builders/pie/drugprot/test_drugprot.py @@ -24,8 +24,9 @@ DATASET_NAME = "drugprot" PIE_DATASET_PATH = PIE_BASE_PATH / DATASET_NAME HF_DATASET_PATH = Drugprot.BASE_DATASET_PATH -SPLIT_NAMES = {"train", "validation"} -SPLIT_SIZES = {"train": 3500, "validation": 750} +HF_DATASET_REVISION = Drugprot.BASE_DATASET_REVISION +SPLIT_NAMES = {"train", "validation", "test_background"} +SPLIT_SIZES = {"train": 3500, "validation": 750, "test_background": 10750} @pytest.fixture(params=[config.name for config in Drugprot.BUILDER_CONFIGS], scope="module") @@ -35,7 +36,9 @@ def dataset_variant(request) -> str: @pytest.fixture(scope="module") def hf_dataset(dataset_variant) -> datasets.DatasetDict: - return datasets.load_dataset(HF_DATASET_PATH, name=dataset_variant) + return datasets.load_dataset( + HF_DATASET_PATH, revision=HF_DATASET_REVISION, name=dataset_variant + ) def test_hf_dataset(hf_dataset): @@ -44,6 +47,11 @@ def test_hf_dataset(hf_dataset): assert split_sizes == SPLIT_SIZES +@pytest.fixture(scope="module", params=list(SPLIT_NAMES)) +def split(request) -> str: + return request.param + + @pytest.fixture(scope="module") def hf_example(hf_dataset) -> Dict[str, Any]: return hf_dataset["train"][0] @@ -276,6 +284,39 @@ def test_hf_example(hf_example, dataset_variant): raise ValueError(f"Unknown dataset variant: {dataset_variant}") +def test_hf_example_for_every_split(hf_dataset, dataset_variant, split): + # covers both dataset variants + example = hf_dataset[split][0] + if split == "train": + assert example["document_id"] == "17512723" + assert len(example["entities"]) == 13 + assert len(example["relations"]) == 1 + elif split == "validation": + assert example["document_id"] == "17651117" + assert len(example["entities"]) == 18 + assert len(example["relations"]) == 0 + elif split == "test_background": + assert example["document_id"] == "32733640" + assert len(example["entities"]) == 37 + assert len(example["relations"]) == 0 + else: + raise ValueError(f"Unknown dataset split: {split}") + + +def test_hf_dataset_all(hf_dataset, split): + # covers both dataset variants + for example in hf_dataset[split]: + assert example["document_id"] is not None + assert len(example["entities"]) > 0 + + # The split "test_background" does not contain any relations + if split == "test_background": + assert len(example["relations"]) == 0 + # The splits "train" and "validation" sometimes contain no relation + else: + assert len(example["relations"]) >= 0 + + @pytest.fixture(scope="module") def builder(dataset_variant) -> Drugprot: return Drugprot(config_name=dataset_variant)