diff --git a/src/mudata/_core/config.py b/src/mudata/_core/config.py index 93631cf..b1bbec3 100644 --- a/src/mudata/_core/config.py +++ b/src/mudata/_core/config.py @@ -3,7 +3,7 @@ OPTIONS = { "display_style": "text", "display_html_expand": 0b010, - "pull_on_update": None, + "pull_on_update": False, } _VALID_OPTIONS = { diff --git a/tests/test_io.py b/tests/test_io.py index 273b0ab..08f2ea3 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -39,6 +39,7 @@ def test_write_read_h5mu_mod_obs_colname(self, mdata, filepath_h5mu): mdata.obs["mod1:column"] = 1 mdata["mod1"].obs["column"] = 2 mdata.update() + mdata.pull_obs() mdata.write(filepath_h5mu) mdata_ = mudata.read(filepath_h5mu) assert "column" in mdata_.obs.columns @@ -51,6 +52,7 @@ def test_write_read_zarr_mod_obs_colname(self, mdata, filepath_zarr): mdata.obs["mod1:column"] = 1 mdata["mod1"].obs["column"] = 2 mdata.update() + mdata.pull_obs() mdata.write_zarr(filepath_zarr) mdata_ = mudata.read_zarr(filepath_zarr) assert "column" in mdata_.obs.columns diff --git a/tests/test_merge.py b/tests/test_merge.py index 69e8301..aae3db1 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -29,7 +29,7 @@ def mdata(): @pytest.mark.usefixtures("filepath_h5mu", "filepath_zarr") class TestMuData: - def test_merge(self, mdata, filepath_h5mu): + def test_merge(self, mdata): mdata1, mdata2 = mdata[:N1, :].copy(), mdata[N1:, :].copy() mdata_ = mudata.concat([mdata1, mdata2]) assert list(mdata_.mod.keys()) == ["mod1", "mod2"] diff --git a/tests/test_nullable.py b/tests/test_nullable.py index e5cc6e2..5b53e4b 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -35,6 +35,7 @@ def mdata(): class TestMuData: def test_mdata_bool_boolean(self, mdata): + mdata.pull_var() assert mdata.var["assert-bool"].dtype == bool assert isinstance(mdata.var["mod1:assert-boolean-1"].dtype, pd.BooleanDtype) assert isinstance(mdata.var["mod2:assert-boolean-2"].dtype, pd.BooleanDtype) diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 164efa5..a67cd3a 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -27,22 +27,26 @@ def test_obs_global_columns(self, mdata, filepath_h5mu): mod.obs["demo"] = m mdata.obs["demo"] = "global" mdata.update() - assert list(mdata.obs.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["demo"] + mdata.pull_obs() + assert list(mdata.obs.columns.values) == ["demo"] + [f"{m}:demo" for m in mdata.mod.keys()] mdata.write(filepath_h5mu) mdata_ = mudata.read(filepath_h5mu) - assert list(mdata_.obs.columns.values) == [f"{m}:demo" for m in mdata_.mod.keys()] + [ - "demo" + assert list(mdata_.obs.columns.values) == ["demo"] + [ + f"{m}:demo" for m in mdata_.mod.keys() ] def test_var_global_columns(self, mdata, filepath_h5mu): for m, mod in mdata.mod.items(): mod.var["demo"] = m mdata.update() + mdata.pull_var() mdata.var["global"] = "global_var" mdata.update() + mdata.pull_var() assert list(mdata.var.columns.values) == ["demo", "global"] del mdata.var["global"] mdata.update() + mdata.pull_var() assert list(mdata.var.columns.values) == ["demo"] mdata.write(filepath_h5mu) mdata_ = mudata.read(filepath_h5mu) diff --git a/tests/test_update.py b/tests/test_update.py index 2a4e3ee..e6a37f4 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -97,6 +97,8 @@ def test_update_simple(self, modalities): mod.var_names = [f"{m}_var{j}" for j in range(mod.n_vars)] mdata = MuData(modalities) mdata.update() + mdata.pull_obs() + mdata.pull_var() # Variables are different across modalities assert "mod" in mdata.var.columns @@ -124,6 +126,8 @@ def test_update_duplicates(self, modalities): mod.var_names = [f"{m}_var{j // 2}" for j in range(mod.n_vars)] mdata = MuData(modalities) mdata.update() + mdata.pull_obs() + mdata.pull_var() # Variables are different across modalities assert "mod" in mdata.var.columns @@ -152,6 +156,12 @@ def test_update_intersecting(self, modalities): mod.var_names = [f"{m}_var{j}" if j != 0 else f"var_{j}" for j in range(mod.n_vars)] mdata = MuData(modalities) mdata.update() + mdata.pull_obs() + # New behaviour since v0.4: + # - Will add a single column 'mod' with the correct labels even with intersecting var_names + mdata.pull_var() + # - Will add the columns with modality prefixes + mdata.pull_var(join_common=False) for m, mod in modalities.items(): # Observations are the same across modalities @@ -177,6 +187,7 @@ def test_update_after_filter_obs_adata(self, mdata): # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) mdata.mod["mod1"] = mdata["mod1"][mdata["mod1"].obs["min_count"] < -2].copy() mdata.update() + mdata.pull_obs() assert mdata.obs["batch"].isna().sum() == 0 @pytest.mark.parametrize("obs_mod", ["unique"]) @@ -207,6 +218,33 @@ def test_update_after_obs_reordered(self, mdata): [all(true_obsm_values[i] == test_obsm_values[i]) for i in range(len(true_obsm_values))] ) + @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_across", ["intersecting"]) + @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) + def test_update_intersecting_var_names_after_filtering(self, mdata): + orig_shape = mdata.shape + mdata.mod["mod1"].var_names = [str(i) for i in range(mdata["mod1"].n_vars)] + mdata.mod["mod2"].var_names = [str(i) for i in range(mdata["mod2"].n_vars)] + mdata.update() + mdata.mod["mod1"] = mdata["mod1"][:, :5].copy() + mdata["mod1"].var["true"] = True + mdata["mod2"].var["false"] = False + assert mdata["mod1"].n_vars == 5 + mdata.update() + mdata.pull_var(prefix_unique=False) + assert mdata.n_obs == orig_shape[0] + assert mdata.n_vars == mdata["mod1"].n_vars + mdata["mod2"].n_vars + assert mdata.var["true"].sum() == 5 + assert (~mdata.var["false"]).sum() == (~mdata["mod2"].var["false"]).sum() + + @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_across", ["intersecting"]) + @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) + def test_update_to_new_names(self, mdata): + mdata["mod1"].var_names = [f"_mod1_var{i}" for i in range(1, mdata["mod1"].n_vars + 1)] + mdata["mod2"].var_names = [f"_mod2_var{i}" for i in range(1, mdata["mod2"].n_vars + 1)] + mdata.update() + # @pytest.mark.usefixtures("filepath_h5mu") # class TestMuDataSameVars: diff --git a/tests/test_update_axis_1.py b/tests/test_update_axis_1.py index 5192a37..e0b527d 100644 --- a/tests/test_update_axis_1.py +++ b/tests/test_update_axis_1.py @@ -97,11 +97,13 @@ def test_update_simple(self, datasets): datasets[d].obs_names = [f"{d}_obs{j}" for j in range(dset.n_obs)] mdata = MuData(datasets, axis=1) mdata.update() + mdata.pull_obs() + mdata.pull_var() # Variables are different across datasets assert "dataset" in mdata.obs.columns for d, dset in datasets.items(): - # Veriables are the same across datasets + # Variables are the same across datasets # hence /mod/mod1/var/dataset -> /var/mod1:dataset assert f"{d}:dataset" in mdata.var.columns # Columns are intact in individual datasets @@ -124,6 +126,8 @@ def test_update_duplicates(self, datasets): dset.obs_names = [f"{d}_obs{j // 2}" for j in range(dset.n_obs)] mdata = MuData(datasets, axis=1) mdata.update() + mdata.pull_obs() + mdata.pull_var() # Observations are different across datasets assert "dataset" in mdata.obs.columns @@ -151,6 +155,12 @@ def test_update_intersecting(self, datasets): dset.obs_names = [f"{d}_obs{j}" if j != 0 else f"obs_{j}" for j in range(dset.n_obs)] mdata = MuData(datasets, axis=1) mdata.update() + # New behaviour since v0.4: + # - Will add a single column 'mod' with the correct labels even with intersecting obs_names + mdata.pull_obs() + # - Will add the columns with modality prefixes + mdata.pull_obs(join_common=False) + mdata.pull_var() for d, dset in datasets.items(): # Veriables are the same across datasets