From f9d1bd14e4435fd634ba49c5e2036816263977d3 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Wed, 27 Nov 2024 19:01:49 +0100 Subject: [PATCH 01/13] Refactor _make_parcellation function for improved readability and add _concat_surf_imgs utility function --- fmralign/_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/fmralign/_utils.py b/fmralign/_utils.py index 55170f6..9b6e4da 100644 --- a/fmralign/_utils.py +++ b/fmralign/_utils.py @@ -8,6 +8,7 @@ from nilearn.image import new_img_like, smooth_img from nilearn.masking import apply_mask_fmri, intersect_masks from nilearn.regions.parcellations import Parcellations +from nilearn.surface import SurfaceImage class ParceledData: @@ -258,3 +259,15 @@ def _make_parcellation(imgs, clustering, n_pieces, masker, smoothing_fwhm=5, ver _check_labels(labels) return labels + + +def _concat_surf_imgs(imgs): + mesh = imgs[0].mesh + data_concat = {} + for key, val in mesh.parts.items(): + for img in imgs: + if key not in data_concat: + data_concat[key] = img.data.parts[key] + else: + data_concat[key] = np.hstack((data_concat[key], img.data.parts[key])) + return SurfaceImage(mesh, data_concat) From 1968f551281cdffe6484437163dde4a8554f1675 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Wed, 27 Nov 2024 19:02:09 +0100 Subject: [PATCH 02/13] Enhance ParcellationMasker to support SurfaceImage and refactor _fit_masker method for clarity --- fmralign/preprocessing.py | 59 ++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/fmralign/preprocessing.py b/fmralign/preprocessing.py index 7c9812e..c078829 100644 --- a/fmralign/preprocessing.py +++ b/fmralign/preprocessing.py @@ -8,9 +8,11 @@ from nibabel.nifti1 import Nifti1Image from nilearn._utils.masker_validation import check_embedded_masker from nilearn.image import concat_imgs +from nilearn.surface import SurfaceImage from sklearn.base import BaseEstimator, TransformerMixin from fmralign._utils import ( + _concat_surf_imgs, _img_to_parceled_data, _intersect_clustering_mask, _make_parcellation, @@ -118,40 +120,51 @@ def __init__( def _fit_masker(self, imgs): """Fit the masker on a single or multiple images.""" - self.masker_ = check_embedded_masker(self) - self.masker_.n_jobs = self.n_jobs - - if isinstance(imgs, Nifti1Image): + if isinstance(imgs, (Nifti1Image, SurfaceImage)): imgs = [imgs] # Assert that all images have the same shape if len(set([img.shape for img in imgs])) > 1: raise NotImplementedError( "fmralign does not support images of different shapes." ) - if self.masker_.mask_img is None: - self.masker_.fit(imgs) - else: - self.masker_.fit() - - if isinstance(self.clustering, Nifti1Image) or os.path.isfile(self.clustering): - # check that clustering provided fills the mask, if not, reduce the mask - if 0 in self.masker_.transform(self.clustering): - reduced_mask = _intersect_clustering_mask( - self.clustering, self.masker_.mask_img - ) - self.mask = reduced_mask - self.masker_ = check_embedded_masker(self) - self.masker_.n_jobs = self.n_jobs + + masker_type = "surface" if isinstance(imgs[0], SurfaceImage) else "multi_nii" + self.masker_ = check_embedded_masker(self, masker_type=masker_type) + self.masker_.n_jobs = self.n_jobs + + # Fit the masker for volume data + if masker_type == "multi_nii": + if self.masker_.mask_img is None: + self.masker_.fit(imgs) + else: self.masker_.fit() - warnings.warn( - "Mask used was bigger than clustering provided. " - + "Its intersection with the clustering was used instead." - ) + + if isinstance(self.clustering, Nifti1Image) or os.path.isfile( + self.clustering + ): + # check that clustering provided fills the mask, if not, reduce the mask + if 0 in self.masker_.transform(self.clustering): + reduced_mask = _intersect_clustering_mask( + self.clustering, self.masker_.mask_img + ) + self.mask = reduced_mask + self.masker_ = check_embedded_masker(self) + self.masker_.n_jobs = self.n_jobs + self.masker_.fit() + warnings.warn( + "Mask used was bigger than clustering provided. " + + "Its intersection with the clustering was used instead." + ) + else: + self.masker_.fit(imgs) def _one_parcellation(self, imgs): """Compute one parcellation for all images.""" if isinstance(imgs, list): - imgs = concat_imgs(imgs) + if isinstance(imgs[0], (Nifti1Image)): + imgs = concat_imgs(imgs) + else: + imgs = _concat_surf_imgs(imgs) self.labels = _make_parcellation( imgs, self.clustering, From 8567664523c73cf84d454278919fe5397d7ce216 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Wed, 27 Nov 2024 19:02:29 +0100 Subject: [PATCH 03/13] Add test for ParcellationMasker to handle surface images and improve existing test readability --- fmralign/tests/test_preprocessing.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/fmralign/tests/test_preprocessing.py b/fmralign/tests/test_preprocessing.py index 9b81e2c..1968877 100644 --- a/fmralign/tests/test_preprocessing.py +++ b/fmralign/tests/test_preprocessing.py @@ -5,7 +5,7 @@ from fmralign._utils import ParceledData from fmralign.preprocessing import ParcellationMasker -from fmralign.tests.utils import random_niimg +from fmralign.tests.utils import random_niimg, surf_img def test_init_default_params(): @@ -176,3 +176,15 @@ def test_standardization(): data_array = transformed_data[0].data assert np.abs(np.mean(data_array)) < 1e-5 assert np.abs(np.std(data_array) - 1.0) < 1e-5 + + +def test_one_surface_image(): + """Test that ParcellationMasker can handle surface images""" + img = surf_img(20) + pmasker = ParcellationMasker(n_pieces=2) + fitted_pmasker = pmasker.fit([img, img]) + + assert hasattr(fitted_pmasker, "masker_") + assert fitted_pmasker.labels is not None + assert isinstance(fitted_pmasker.labels, np.ndarray) + assert len(np.unique(fitted_pmasker.labels)) == 2 # n_pieces=2 From 181fa358bb7a6aae689c91385f95b36f8f4738e0 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Wed, 27 Nov 2024 19:02:37 +0100 Subject: [PATCH 04/13] Remove redundant main block from test_non_contiguous_labels function in test_utils.py --- fmralign/tests/test_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fmralign/tests/test_utils.py b/fmralign/tests/test_utils.py index c15f00c..8e2a944 100644 --- a/fmralign/tests/test_utils.py +++ b/fmralign/tests/test_utils.py @@ -189,7 +189,3 @@ def test_non_contiguous_labels(): # Test accessing by label same_parcel = parceled.get_parcel(1) assert_array_equal(same_parcel, expected) - - -if __name__ == "__main__": - test_non_contiguous_labels() From 8d25cc2d7a59d367e39120653c31fd227b25a644 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Wed, 27 Nov 2024 19:04:31 +0100 Subject: [PATCH 05/13] Add functions to create sample mesh and surface images for testing (from nilearn) --- fmralign/tests/utils.py | 54 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/fmralign/tests/utils.py b/fmralign/tests/utils.py index 35133c6..eb3c0bc 100644 --- a/fmralign/tests/utils.py +++ b/fmralign/tests/utils.py @@ -1,6 +1,7 @@ import nibabel import numpy as np from nilearn.maskers import NiftiMasker +from nilearn.surface import InMemoryMesh, PolyMesh, SurfaceImage from numpy.random import default_rng from numpy.testing import assert_array_almost_equal @@ -43,9 +44,7 @@ def zero_mean_coefficient_determination( nonzero_numerator = numerator != 0 valid_score = nonzero_denominator & nonzero_numerator output_scores = np.ones([y_true.shape[1]]) - output_scores[valid_score] = 1 - ( - numerator[valid_score] / denominator[valid_score] - ) + output_scores[valid_score] = 1 - (numerator[valid_score] / denominator[valid_score]) output_scores[nonzero_numerator & ~nonzero_denominator] = 0 if multioutput == "raw_values": @@ -56,8 +55,7 @@ def zero_mean_coefficient_determination( avg_weights = None elif multioutput == "variance_weighted": avg_weights = ( - weight - * (y_true - np.average(y_true, axis=0, weights=sample_weight)) ** 2 + weight * (y_true - np.average(y_true, axis=0, weights=sample_weight)) ** 2 ).sum(axis=0, dtype=np.float64) # avoid fail on constant y or one-element arrays if not np.any(nonzero_denominator): @@ -132,3 +130,49 @@ def sample_parceled_data(n_pieces=1): data = masker.fit_transform(img) labels = _make_parcellation(img, "kmeans", n_pieces, masker) return data, masker, labels + + +def _make_mesh(): + """Create a sample mesh with two parts: left and right, and total of + 9 vertices and 10 faces. + + The left part is a tetrahedron with four vertices and four faces. + The right part is a pyramid with five vertices and six faces. + """ + left_coords = np.asarray([[0.0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) + left_faces = np.asarray([[1, 0, 2], [0, 1, 3], [0, 3, 2], [1, 2, 3]]) + right_coords = ( + np.asarray([[0.0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1]]) + 2.0 + ) + right_faces = np.asarray( + [ + [0, 1, 4], + [0, 3, 1], + [1, 3, 2], + [1, 2, 4], + [2, 3, 4], + [0, 4, 3], + ] + ) + return PolyMesh( + left=InMemoryMesh(left_coords, left_faces), + right=InMemoryMesh(right_coords, right_faces), + ) + + +def surf_img(n_samples=1): + """Create a sample surface image using the sample mesh. # noqa: D202 + This will add some random data to the vertices of the mesh. + The shape of the data will be (n_vertices, n_samples). + n_samples by default is 1. + """ + + mesh = _make_mesh() + data = {} + for i, (key, val) in enumerate(mesh.parts.items()): + data_shape = (val.n_vertices, n_samples) + data_part = ( + np.arange(np.prod(data_shape)).reshape(data_shape[::-1]) + 1.0 + ) * 10**i + data[key] = data_part.T + return SurfaceImage(mesh, data) From 6a5aa67da3c3ff07888feff8788e6cdbffeb620a Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 28 Nov 2024 10:05:23 +0100 Subject: [PATCH 06/13] Refactor _make_parcellation function for improved readability and format code; remove unused _concat_surf_imgs function --- fmralign/_utils.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/fmralign/_utils.py b/fmralign/_utils.py index 9b6e4da..55170f6 100644 --- a/fmralign/_utils.py +++ b/fmralign/_utils.py @@ -8,7 +8,6 @@ from nilearn.image import new_img_like, smooth_img from nilearn.masking import apply_mask_fmri, intersect_masks from nilearn.regions.parcellations import Parcellations -from nilearn.surface import SurfaceImage class ParceledData: @@ -259,15 +258,3 @@ def _make_parcellation(imgs, clustering, n_pieces, masker, smoothing_fwhm=5, ver _check_labels(labels) return labels - - -def _concat_surf_imgs(imgs): - mesh = imgs[0].mesh - data_concat = {} - for key, val in mesh.parts.items(): - for img in imgs: - if key not in data_concat: - data_concat[key] = img.data.parts[key] - else: - data_concat[key] = np.hstack((data_concat[key], img.data.parts[key])) - return SurfaceImage(mesh, data_concat) From 90390316e100232d1e92e5408074bfc8f5c400a3 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 28 Nov 2024 10:05:57 +0100 Subject: [PATCH 07/13] Refactor ParcellationMasker to use concatenate_surface_images function and improve code readability --- fmralign/preprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fmralign/preprocessing.py b/fmralign/preprocessing.py index c078829..f7ea5bc 100644 --- a/fmralign/preprocessing.py +++ b/fmralign/preprocessing.py @@ -8,11 +8,11 @@ from nibabel.nifti1 import Nifti1Image from nilearn._utils.masker_validation import check_embedded_masker from nilearn.image import concat_imgs +from nilearn.maskers._utils import concatenate_surface_images from nilearn.surface import SurfaceImage from sklearn.base import BaseEstimator, TransformerMixin from fmralign._utils import ( - _concat_surf_imgs, _img_to_parceled_data, _intersect_clustering_mask, _make_parcellation, @@ -164,7 +164,7 @@ def _one_parcellation(self, imgs): if isinstance(imgs[0], (Nifti1Image)): imgs = concat_imgs(imgs) else: - imgs = _concat_surf_imgs(imgs) + imgs = concatenate_surface_images(imgs) self.labels = _make_parcellation( imgs, self.clustering, From acc297f4afc512eade02de2e6ecc7641dcb60e53 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 19 Dec 2024 18:50:23 +0100 Subject: [PATCH 08/13] Refactor label extraction in parcellation function to use masker.transform --- fmralign/_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fmralign/_utils.py b/fmralign/_utils.py index 7f8e4ac..938fce6 100644 --- a/fmralign/_utils.py +++ b/fmralign/_utils.py @@ -256,9 +256,7 @@ def _make_parcellation( ) err.args += (errmsg,) raise err - labels = apply_mask_fmri( - parcellation.labels_img_, masker.mask_img_ - ).astype(int) + labels = masker.transform(parcellation.labels_img_)[0].astype(int) if verbose > 0: unique_labels, counts = np.unique(labels, return_counts=True) From 8831be5138228d469fa58e4ccb113e8df5d44d56 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 19 Dec 2024 18:50:43 +0100 Subject: [PATCH 09/13] Format masker_type assignment for improved readability --- fmralign/preprocessing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fmralign/preprocessing.py b/fmralign/preprocessing.py index d2be5e1..1ef05f1 100644 --- a/fmralign/preprocessing.py +++ b/fmralign/preprocessing.py @@ -136,7 +136,9 @@ def _fit_masker(self, imgs): "fmralign does not support images of different shapes." ) - masker_type = "surface" if isinstance(imgs[0], SurfaceImage) else "multi_nii" + masker_type = ( + "surface" if isinstance(imgs[0], SurfaceImage) else "multi_nii" + ) self.masker_ = check_embedded_masker(self, masker_type=masker_type) self.masker_.n_jobs = self.n_jobs From a4d1a361abfabc88d4dcb0825b880688372c3c7b Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 19 Dec 2024 18:50:57 +0100 Subject: [PATCH 10/13] Enhance tests for ParcellationMasker to support multiple surface images and improve assertions --- fmralign/tests/test_preprocessing.py | 35 +++++++++++++++++++++------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/fmralign/tests/test_preprocessing.py b/fmralign/tests/test_preprocessing.py index 4540456..802c524 100644 --- a/fmralign/tests/test_preprocessing.py +++ b/fmralign/tests/test_preprocessing.py @@ -184,13 +184,31 @@ def test_standardization(): def test_one_surface_image(): """Test that ParcellationMasker can handle surface images""" img = surf_img(20) - pmasker = ParcellationMasker(n_pieces=2) - fitted_pmasker = pmasker.fit([img, img]) + n_pieces = 2 + n_vertices_total = img.shape[0] + parcel_masker = ParcellationMasker(n_pieces=n_pieces) + fitted_parcel_masker = parcel_masker.fit(img) - assert hasattr(fitted_pmasker, "masker_") - assert fitted_pmasker.labels is not None - assert isinstance(fitted_pmasker.labels, np.ndarray) - assert len(np.unique(fitted_pmasker.labels)) == 2 # n_pieces=2 + assert hasattr(fitted_parcel_masker, "masker_") + assert fitted_parcel_masker.labels is not None + assert isinstance(fitted_parcel_masker.labels, np.ndarray) + assert len(np.unique(fitted_parcel_masker.labels)) == n_pieces + assert len(fitted_parcel_masker.labels) == n_vertices_total + + +def test_multiple_surface_images(): + """Test that ParcellationMasker can handle multiple surface images""" + imgs = [surf_img(20)] * 3 + n_pieces = 2 + n_vertices_total = imgs[0].shape[0] + parcel_masker = ParcellationMasker(n_pieces=n_pieces) + fitted_parcel_masker = parcel_masker.fit(imgs) + + assert hasattr(fitted_parcel_masker, "masker_") + assert fitted_parcel_masker.labels is not None + assert isinstance(fitted_parcel_masker.labels, np.ndarray) + assert len(np.unique(fitted_parcel_masker.labels)) == n_pieces + assert len(fitted_parcel_masker.labels) == n_vertices_total def test_one_contrast(): @@ -198,8 +216,8 @@ def test_one_contrast(): 4D images in the case of one contrast""" img1, _ = random_niimg((8, 7, 6)) img2, _ = random_niimg((8, 7, 6, 1)) - pmasker = ParcellationMasker() - pmasker.fit([img1, img2]) + parcel_masker = ParcellationMasker() + parcel_masker.fit([img1, img2]) def test_get_parcellation_img(): @@ -219,4 +237,3 @@ def test_get_parcellation_img(): assert np.allclose(data, labels) assert len(np.unique(data)) == n_pieces - From aecec66aa3f4ef4f1b95508ce202a770111655f9 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 19 Dec 2024 19:02:50 +0100 Subject: [PATCH 11/13] Add test for surface alignment compatibility in PairwiseAlignment --- fmralign/tests/test_pairwise_alignment.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/fmralign/tests/test_pairwise_alignment.py b/fmralign/tests/test_pairwise_alignment.py index cd4a375..301fea0 100644 --- a/fmralign/tests/test_pairwise_alignment.py +++ b/fmralign/tests/test_pairwise_alignment.py @@ -11,6 +11,7 @@ from fmralign.tests.utils import ( assert_algo_transform_almost_exactly, random_niimg, + surf_img, zero_mean_coefficient_determination, ) @@ -130,3 +131,13 @@ def test_parcellation_before_fit(): AttributeError, match="Parcellation has not been computed yet" ): alignment.get_parcellation() + + +def test_surface_alignment(): + """Test compatibility with `SurfaceImage`""" + alignment = PairwiseAlignment() + n_pieces = 3 + img1 = surf_img(20) + img2 = surf_img(20) + alignment = PairwiseAlignment(n_pieces=n_pieces) + alignment.fit(img1, img2) From f3ae8476ee2222eab3fece699634d71b2f9a04b7 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 19 Dec 2024 19:08:16 +0100 Subject: [PATCH 12/13] Support SurfaceImage type in ParcellationMasker for improved image handling --- fmralign/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fmralign/preprocessing.py b/fmralign/preprocessing.py index 1ef05f1..369be67 100644 --- a/fmralign/preprocessing.py +++ b/fmralign/preprocessing.py @@ -247,7 +247,7 @@ def transform(self, imgs): List of ParceledData objects containing the data and parcelation information for each image. """ - if isinstance(imgs, Nifti1Image): + if isinstance(imgs, (Nifti1Image, SurfaceImage)): imgs = [imgs] parceled_data = Parallel(n_jobs=self.n_jobs)( From 298749edeac3515c8a9ac6f613116ac87e95fa92 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 19 Dec 2024 19:08:34 +0100 Subject: [PATCH 13/13] Add tests for SurfaceImage compatibility in PairwiseAlignment --- fmralign/tests/test_pairwise_alignment.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/fmralign/tests/test_pairwise_alignment.py b/fmralign/tests/test_pairwise_alignment.py index 301fea0..f7ed206 100644 --- a/fmralign/tests/test_pairwise_alignment.py +++ b/fmralign/tests/test_pairwise_alignment.py @@ -6,6 +6,7 @@ from nibabel import Nifti1Image from nilearn.image import new_img_like from nilearn.maskers import NiftiMasker +from nilearn.surface import SurfaceImage from fmralign.pairwise_alignment import PairwiseAlignment from fmralign.tests.utils import ( @@ -140,4 +141,17 @@ def test_surface_alignment(): img1 = surf_img(20) img2 = surf_img(20) alignment = PairwiseAlignment(n_pieces=n_pieces) + + # Test fitting alignment.fit(img1, img2) + + # Test transformation + img_transformed = alignment.transform(img1) + assert img_transformed.shape == img1.shape + assert isinstance(img_transformed, SurfaceImage) + + # Test parcellation retrieval + labels, parcellation_image = alignment.get_parcellation() + assert isinstance(labels, np.ndarray) + assert len(np.unique(labels)) == n_pieces + assert isinstance(parcellation_image, SurfaceImage)