diff --git a/fmralign/preprocessing.py b/fmralign/preprocessing.py index da3f798..d271b51 100644 --- a/fmralign/preprocessing.py +++ b/fmralign/preprocessing.py @@ -123,6 +123,14 @@ def _fit_masker(self, imgs): if isinstance(imgs, Nifti1Image): imgs = [imgs] + # If images are 3D, add a fourth dimension + for i, img in enumerate(imgs): + if len(img.shape) == 3: + imgs[i] = Nifti1Image( + np.expand_dims(img.get_fdata(), axis=-1), + img.affine, + img.header, + ) # Assert that all images have the same shape if len(set([img.shape for img in imgs])) > 1: raise NotImplementedError( diff --git a/fmralign/tests/test_preprocessing.py b/fmralign/tests/test_preprocessing.py index 5ff180e..570abeb 100644 --- a/fmralign/tests/test_preprocessing.py +++ b/fmralign/tests/test_preprocessing.py @@ -181,7 +181,16 @@ def test_standardization(): assert np.abs(np.std(data_array) - 1.0) < 1e-5 -def test_get_parcellatio_img(): +def test_one_contrast(): + """Test that ParcellationMasker handles both 3D and\n + 4D images in the case of one contrast""" + img1, _ = random_niimg((8, 7, 6)) + img2, _ = random_niimg((8, 7, 6, 1)) + pmasker = ParcellationMasker() + pmasker.fit([img1, img2]) + + +def test_get_parcellation_img(): """Test that ParcellationMasker returns the parcellation mask""" n_pieces = 2 img, _ = random_niimg((8, 7, 6)) @@ -198,3 +207,4 @@ def test_get_parcellatio_img(): assert np.allclose(data, labels) assert len(np.unique(data)) == n_pieces +