Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mudata 0.4 tests #70

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mudata/_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
OPTIONS = {
"display_style": "text",
"display_html_expand": 0b010,
"pull_on_update": None,
"pull_on_update": False,
}

_VALID_OPTIONS = {
Expand Down
2 changes: 2 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_write_read_h5mu_mod_obs_colname(self, mdata, filepath_h5mu):
mdata.obs["mod1:column"] = 1
mdata["mod1"].obs["column"] = 2
mdata.update()
mdata.pull_obs()
mdata.write(filepath_h5mu)
mdata_ = mudata.read(filepath_h5mu)
assert "column" in mdata_.obs.columns
Expand All @@ -51,6 +52,7 @@ def test_write_read_zarr_mod_obs_colname(self, mdata, filepath_zarr):
mdata.obs["mod1:column"] = 1
mdata["mod1"].obs["column"] = 2
mdata.update()
mdata.pull_obs()
mdata.write_zarr(filepath_zarr)
mdata_ = mudata.read_zarr(filepath_zarr)
assert "column" in mdata_.obs.columns
Expand Down
2 changes: 1 addition & 1 deletion tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def mdata():

@pytest.mark.usefixtures("filepath_h5mu", "filepath_zarr")
class TestMuData:
def test_merge(self, mdata, filepath_h5mu):
def test_merge(self, mdata):
mdata1, mdata2 = mdata[:N1, :].copy(), mdata[N1:, :].copy()
mdata_ = mudata.concat([mdata1, mdata2])
assert list(mdata_.mod.keys()) == ["mod1", "mod2"]
Expand Down
1 change: 1 addition & 0 deletions tests/test_nullable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def mdata():

class TestMuData:
def test_mdata_bool_boolean(self, mdata):
mdata.pull_var()
assert mdata.var["assert-bool"].dtype == bool
assert isinstance(mdata.var["mod1:assert-boolean-1"].dtype, pd.BooleanDtype)
assert isinstance(mdata.var["mod2:assert-boolean-2"].dtype, pd.BooleanDtype)
10 changes: 7 additions & 3 deletions tests/test_obs_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,26 @@ def test_obs_global_columns(self, mdata, filepath_h5mu):
mod.obs["demo"] = m
mdata.obs["demo"] = "global"
mdata.update()
assert list(mdata.obs.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["demo"]
mdata.pull_obs()
assert list(mdata.obs.columns.values) == ["demo"] + [f"{m}:demo" for m in mdata.mod.keys()]
mdata.write(filepath_h5mu)
mdata_ = mudata.read(filepath_h5mu)
assert list(mdata_.obs.columns.values) == [f"{m}:demo" for m in mdata_.mod.keys()] + [
"demo"
assert list(mdata_.obs.columns.values) == ["demo"] + [
f"{m}:demo" for m in mdata_.mod.keys()
]

def test_var_global_columns(self, mdata, filepath_h5mu):
for m, mod in mdata.mod.items():
mod.var["demo"] = m
mdata.update()
mdata.pull_var()
mdata.var["global"] = "global_var"
mdata.update()
mdata.pull_var()
assert list(mdata.var.columns.values) == ["demo", "global"]
del mdata.var["global"]
mdata.update()
mdata.pull_var()
assert list(mdata.var.columns.values) == ["demo"]
mdata.write(filepath_h5mu)
mdata_ = mudata.read(filepath_h5mu)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def test_update_simple(self, modalities):
mod.var_names = [f"{m}_var{j}" for j in range(mod.n_vars)]
mdata = MuData(modalities)
mdata.update()
mdata.pull_obs()
mdata.pull_var()

# Variables are different across modalities
assert "mod" in mdata.var.columns
Expand Down Expand Up @@ -124,6 +126,8 @@ def test_update_duplicates(self, modalities):
mod.var_names = [f"{m}_var{j // 2}" for j in range(mod.n_vars)]
mdata = MuData(modalities)
mdata.update()
mdata.pull_obs()
mdata.pull_var()

# Variables are different across modalities
assert "mod" in mdata.var.columns
Expand Down Expand Up @@ -152,6 +156,12 @@ def test_update_intersecting(self, modalities):
mod.var_names = [f"{m}_var{j}" if j != 0 else f"var_{j}" for j in range(mod.n_vars)]
mdata = MuData(modalities)
mdata.update()
mdata.pull_obs()
# New behaviour since v0.4:
# - Will add a single column 'mod' with the correct labels even with intersecting var_names
mdata.pull_var()
# - Will add the columns with modality prefixes
mdata.pull_var(join_common=False)

for m, mod in modalities.items():
# Observations are the same across modalities
Expand All @@ -177,6 +187,7 @@ def test_update_after_filter_obs_adata(self, mdata):
# mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2))
mdata.mod["mod1"] = mdata["mod1"][mdata["mod1"].obs["min_count"] < -2].copy()
mdata.update()
mdata.pull_obs()
assert mdata.obs["batch"].isna().sum() == 0

@pytest.mark.parametrize("obs_mod", ["unique"])
Expand Down Expand Up @@ -207,6 +218,33 @@ def test_update_after_obs_reordered(self, mdata):
[all(true_obsm_values[i] == test_obsm_values[i]) for i in range(len(true_obsm_values))]
)

@pytest.mark.parametrize("obs_mod", ["unique"])
@pytest.mark.parametrize("obs_across", ["intersecting"])
@pytest.mark.parametrize("obs_n", ["joint", "disjoint"])
def test_update_intersecting_var_names_after_filtering(self, mdata):
orig_shape = mdata.shape
mdata.mod["mod1"].var_names = [str(i) for i in range(mdata["mod1"].n_vars)]
mdata.mod["mod2"].var_names = [str(i) for i in range(mdata["mod2"].n_vars)]
mdata.update()
mdata.mod["mod1"] = mdata["mod1"][:, :5].copy()
mdata["mod1"].var["true"] = True
mdata["mod2"].var["false"] = False
assert mdata["mod1"].n_vars == 5
mdata.update()
mdata.pull_var(prefix_unique=False)
assert mdata.n_obs == orig_shape[0]
assert mdata.n_vars == mdata["mod1"].n_vars + mdata["mod2"].n_vars
assert mdata.var["true"].sum() == 5
assert (~mdata.var["false"]).sum() == (~mdata["mod2"].var["false"]).sum()

@pytest.mark.parametrize("obs_mod", ["unique"])
@pytest.mark.parametrize("obs_across", ["intersecting"])
@pytest.mark.parametrize("obs_n", ["joint", "disjoint"])
def test_update_to_new_names(self, mdata):
mdata["mod1"].var_names = [f"_mod1_var{i}" for i in range(1, mdata["mod1"].n_vars + 1)]
mdata["mod2"].var_names = [f"_mod2_var{i}" for i in range(1, mdata["mod2"].n_vars + 1)]
mdata.update()


# @pytest.mark.usefixtures("filepath_h5mu")
# class TestMuDataSameVars:
Expand Down
12 changes: 11 additions & 1 deletion tests/test_update_axis_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ def test_update_simple(self, datasets):
datasets[d].obs_names = [f"{d}_obs{j}" for j in range(dset.n_obs)]
mdata = MuData(datasets, axis=1)
mdata.update()
mdata.pull_obs()
mdata.pull_var()

# Variables are different across datasets
assert "dataset" in mdata.obs.columns
for d, dset in datasets.items():
# Veriables are the same across datasets
# Variables are the same across datasets
# hence /mod/mod1/var/dataset -> /var/mod1:dataset
assert f"{d}:dataset" in mdata.var.columns
# Columns are intact in individual datasets
Expand All @@ -124,6 +126,8 @@ def test_update_duplicates(self, datasets):
dset.obs_names = [f"{d}_obs{j // 2}" for j in range(dset.n_obs)]
mdata = MuData(datasets, axis=1)
mdata.update()
mdata.pull_obs()
mdata.pull_var()

# Observations are different across datasets
assert "dataset" in mdata.obs.columns
Expand Down Expand Up @@ -151,6 +155,12 @@ def test_update_intersecting(self, datasets):
dset.obs_names = [f"{d}_obs{j}" if j != 0 else f"obs_{j}" for j in range(dset.n_obs)]
mdata = MuData(datasets, axis=1)
mdata.update()
# New behaviour since v0.4:
# - Will add a single column 'mod' with the correct labels even with intersecting obs_names
mdata.pull_obs()
# - Will add the columns with modality prefixes
mdata.pull_obs(join_common=False)
mdata.pull_var()

for d, dset in datasets.items():
# Veriables are the same across datasets
Expand Down
Loading