Skip to content

Commit

Permalink
Merge pull request #2611 from cta-observatory/uniqueness_check
Browse files Browse the repository at this point in the history
Check uniqueness of input paths and obs_ids in merge tool
  • Loading branch information
kosack authored Sep 11, 2024
2 parents b83225c + 00dddbd commit b30984e
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 6 deletions.
6 changes: 6 additions & 0 deletions docs/changes/2611.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
The ``ctapipe-merge`` tool now checks for duplicated input files and
raises an error in that case.

The ``HDF5Merger`` class, and thus also the ``ctapipe-merge`` tool,
now checks for duplicated obs_ids during merging, to prevent
invalid output files.
8 changes: 8 additions & 0 deletions src/ctapipe/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,14 @@ def proton_train_clf(model_tmp_path, energy_regressor_path):
],
raises=True,
)

# modify obs_ids by adding a constant, this enables merging gamma and proton files
# which is used in the merge tool tests.
with tables.open_file(outpath, mode="r+") as f:
for table in f.walk_nodes("/", "Table"):
if "obs_id" in table.colnames:
obs_id = table.col("obs_id")
table.modify_column(colname="obs_id", column=obs_id + 1_000_000_000)
return outpath


Expand Down
30 changes: 30 additions & 0 deletions src/ctapipe/io/hdf5merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def __init__(self, output_path=None, **kwargs):
self.data_model_version = None
self.subarray = None
self.meta = None
self._merged_obs_ids = set()

# output file existed, so read subarray and data model version to make sure
# any file given matches what we already have
if appending:
Expand All @@ -202,6 +204,9 @@ def __init__(self, output_path=None, **kwargs):
)
self.required_nodes = _get_required_nodes(self.h5file)

# this will update _merged_obs_ids from existing input file
self._check_obs_ids(self.h5file)

def __call__(self, other: str | Path | tables.File):
"""
Append file ``other`` to the output file
Expand Down Expand Up @@ -267,7 +272,32 @@ def _check_can_merge(self, other):
f"Required node {node_path} not found in {other.filename}"
)

def _check_obs_ids(self, other):
keys = [
"/configuration/observation/observation_block",
"/dl1/event/subarray/trigger",
]

for key in keys:
if key in other.root:
obs_ids = other.root[key].col("obs_id")
break
else:
raise CannotMerge(
f"Input file {other.filename} is missing keys required to"
f" check for duplicated obs_ids. Tried: {keys}"
)

duplicated = self._merged_obs_ids.intersection(obs_ids)
if len(duplicated) > 0:
msg = f"Input file {other.filename} contains obs_ids already included in output file: {duplicated}"
raise CannotMerge(msg)

self._merged_obs_ids.update(obs_ids)

def _append(self, other):
self._check_obs_ids(other)

# Configuration
self._append_subarray(other)

Expand Down
27 changes: 26 additions & 1 deletion src/ctapipe/io/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_simple(tmp_path, gamma_train_clf, proton_train_clf):
merger(proton_train_clf)

subarray = SubarrayDescription.from_hdf(gamma_train_clf)
assert subarray == SubarrayDescription.from_hdf(output), "Subarays do not match"
assert subarray == SubarrayDescription.from_hdf(output), "Subarrays do not match"

tel_groups = [
"/dl1/event/telescope/parameters",
Expand Down Expand Up @@ -164,3 +164,28 @@ def test_muon(tmp_path, dl1_muon_output_file):
n_input = len(input_table)
assert len(table) == n_input
assert_table_equal(table, input_table)


def test_duplicated_obs_ids(tmp_path, dl2_shower_geometry_file):
from ctapipe.io.hdf5merger import CannotMerge, HDF5Merger

output = tmp_path / "invalid.dl1.h5"

# check for fresh file
with HDF5Merger(output) as merger:
merger(dl2_shower_geometry_file)

with pytest.raises(
CannotMerge, match="Input file .* contains obs_ids already included"
):
merger(dl2_shower_geometry_file)

# check for appending
with HDF5Merger(output, overwrite=True) as merger:
merger(dl2_shower_geometry_file)

with HDF5Merger(output, append=True) as merger:
with pytest.raises(
CannotMerge, match="Input file .* contains obs_ids already included"
):
merger(dl2_shower_geometry_file)
8 changes: 8 additions & 0 deletions src/ctapipe/tools/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import sys
from argparse import ArgumentParser
from collections import Counter
from pathlib import Path

from tqdm.auto import tqdm
Expand Down Expand Up @@ -161,6 +162,13 @@ def setup(self):
)
sys.exit(1)

counts = Counter(self.input_files)
duplicated = [p for p, c in counts.items() if c > 1]
if len(duplicated) > 0:
raise ToolConfigurationError(
f"Same file given multiple times. Duplicated files are: {duplicated}"
)

self.merger = self.enter_context(HDF5Merger(parent=self))
if self.merger.output_path in self.input_files:
raise ToolConfigurationError(
Expand Down
28 changes: 23 additions & 5 deletions src/ctapipe/tools/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from pathlib import Path

import numpy as np
import pytest
import tables
from astropy.table import vstack
from astropy.utils.diff import report_diff_values

from ctapipe.core import run_tool
from ctapipe.core import ToolConfigurationError, run_tool
from ctapipe.io import TableLoader
from ctapipe.io.astropy_helpers import read_table
from ctapipe.io.tests.test_astropy_helpers import assert_table_equal
Expand Down Expand Up @@ -176,7 +177,6 @@ def test_muon(tmp_path, dl1_muon_output_file):
argv=[
f"--output={output}",
str(dl1_muon_output_file),
str(dl1_muon_output_file),
],
raises=True,
)
Expand All @@ -185,6 +185,24 @@ def test_muon(tmp_path, dl1_muon_output_file):
input_table = read_table(dl1_muon_output_file, "/dl1/event/telescope/muon/tel_001")

n_input = len(input_table)
assert len(table) == 2 * n_input
assert_table_equal(table[:n_input], input_table)
assert_table_equal(table[n_input:], input_table)
assert len(table) == n_input
assert_table_equal(table, input_table)


def test_duplicated(tmp_path, dl1_file, dl1_proton_file):
from ctapipe.tools.merge import MergeTool

output = tmp_path / "invalid.dl1.h5"
with pytest.raises(ToolConfigurationError, match="Same file given multiple times"):
run_tool(
MergeTool(),
argv=[
str(dl1_file),
str(dl1_proton_file),
str(dl1_file),
f"--output={output}",
"--overwrite",
],
cwd=tmp_path,
raises=True,
)

0 comments on commit b30984e

Please sign in to comment.