diff --git a/examples/plot_template_alignment.py b/examples/plot_template_alignment.py index 21225ec..b1e6154 100644 --- a/examples/plot_template_alignment.py +++ b/examples/plot_template_alignment.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- - """ Template-based prediction. ========================== -In this tutorial, we show how to better predict new contrasts for a target -subject using many source subjects corresponding contrasts. For this purpose, -we create a template to which we align the target subject, using shared information. -We then predict new images for the target and compare them to a baseline. +In this tutorial, we show how to improve inter-subject similarity using a template +computed across multiple source subjects. For this purpose, we create a template +using Procrustes alignment (hyperalignment) to which we align the target subject, +using shared information. We then compare the voxelwise similarity between the +target subject and the template to the similarity between the target subject and +the anatomical Euclidean average of the source subjects. We mostly rely on Python common packages and on nilearn to handle functional data in a clean fashion. @@ -36,7 +37,7 @@ ) ############################################################################### -# Definine a masker +# Define a masker # ----------------- # We define a nilearn masker that will be used to handle relevant data. # For more information, visit : @@ -64,22 +65,17 @@ template_train = [] for i in range(5): template_train.append(concat_imgs(imgs[i])) -target_train = df[df.subject == "sub-07"][df.acquisition == "ap"].path.values -# For subject sub-07, we split it in two folds: -# - target train: sub-07 AP contrasts, used to learn alignment to template -# - target test: sub-07 PA contrasts, used as a ground truth to score predictions -# We make a single 4D Niimg from our list of 3D filenames +# sub-07 (that is 5th in the list) will be our left-out subject. +# We make a single 4D Niimg from our list of 3D filenames. -target_train = concat_imgs(target_train) -target_test = df[df.subject == "sub-07"][df.acquisition == "pa"].path.values +left_out_subject = concat_imgs(imgs[5]) ############################################################################### # Compute a baseline (average of subjects) # ---------------------------------------- # We create an image with as many contrasts as any subject representing for # each contrast the average of all train subjects maps. -# import numpy as np @@ -92,70 +88,53 @@ # --------------------------------------------- # We define an estimator using the class TemplateAlignment: # * We align the whole brain through 'multiple' local alignments. -# * These alignments are calculated on a parcellation of the brain in 150 pieces, +# * These alignments are calculated on a parcellation of the brain in 50 pieces, # this parcellation creates group of functionnally similar voxels. # * The template is created iteratively, aligning all subjects data into a # common space, from which the template is inferred and aligning again to this # new template space. # -from nilearn.image import index_img - from fmralign.template_alignment import TemplateAlignment +# We use Procrustes/scaled orthogonal alignment method template_estim = TemplateAlignment( - n_pieces=150, alignment_method="ridge_cv", mask=masker + n_pieces=50, + alignment_method="scaled_orthogonal", + mask=masker, ) template_estim.fit(template_train) +procrustes_template = template_estim.template ############################################################################### # Predict new data for left-out subject # ------------------------------------- -# We use target_train data to fit the transform, indicating it corresponds to -# the contrasts indexed by train_index and predict from this learnt alignment -# contrasts corresponding to template test_index numbers. -# For each train subject and for the template, the AP contrasts are sorted from -# 0, to 53, and then the PA contrasts from 53 to 106. -# - -train_index = range(53) -test_index = range(53, 106) - -# We input the mapping image target_train in a list, we could have input more -# than one subject for which we'd want to predict : [train_1, train_2 ...] +# We predict the contrasts of the left-out subject using the template we just +# created. We use the transform method of the estimator. This method takes the +# left-out subject as input, computes a pairwise alignment with the template +# and returns the aligned data. -prediction_from_template = template_estim.transform( - [target_train], train_index, test_index -) - -# As a baseline prediction, let's just take the average of activations across subjects. - -prediction_from_average = index_img(average_subject, test_index) +predictions_from_template = template_estim.transform(left_out_subject) ############################################################################### # Score the baseline and the prediction # ------------------------------------- # We use a utility scoring function to measure the voxelwise correlation -# between the prediction and the ground truth. That is, for each voxel, we -# measure the correlation between its profile of activation without and with -# alignment, to see if alignment was able to predict a signal more alike the ground truth. -# +# between the images. That is, for each voxel, we measure the correlation between +# its profile of activation without and with alignment, to see if template-based +# alignment was able to improve inter-subject similarity. from fmralign.metrics import score_voxelwise -# Now we use this scoring function to compare the correlation of predictions -# made from group average and from template with the real PA contrasts of sub-07 - average_score = masker.inverse_transform( - score_voxelwise(target_test, prediction_from_average, masker, loss="corr") + score_voxelwise(left_out_subject, average_subject, masker, loss="corr") ) template_score = masker.inverse_transform( score_voxelwise( - target_test, prediction_from_template[0], masker, loss="corr" + predictions_from_template, procrustes_template, masker, loss="corr" ) ) - ############################################################################### # Plotting the measures # --------------------- @@ -167,13 +146,12 @@ baseline_display = plotting.plot_stat_map( average_score, display_mode="z", vmax=1, cut_coords=[-15, -5] ) -baseline_display.title("Group average correlation wt ground truth") +baseline_display.title("Left-out subject correlation with group average") display = plotting.plot_stat_map( template_score, display_mode="z", cut_coords=[-15, -5], vmax=1 ) -display.title("Template-based prediction correlation wt ground truth") +display.title("Aligned subject correlation with Procrustes template") ############################################################################### # We observe that creating a template and aligning a new subject to it yields -# a prediction that is better correlated with the ground truth than just using -# the average activations of subjects. +# better inter-subject similarity than regular euclidean averaging. diff --git a/fmralign/pairwise_alignment.py b/fmralign/pairwise_alignment.py index d05c803..4a7495c 100644 --- a/fmralign/pairwise_alignment.py +++ b/fmralign/pairwise_alignment.py @@ -237,12 +237,12 @@ def fit(self, X, Y): return self - def transform(self, X): + def transform(self, img): """Predict data from X. Parameters ---------- - X: Niimg-like object + img: Niimg-like object Source data Returns @@ -255,7 +255,7 @@ def transform(self, X): "This instance has not been fitted yet. " "Please call 'fit' before 'transform'." ) - parceled_data_list = self.parcel_masker.transform(X) + parceled_data_list = self.parcel_masker.transform(img) transformed_img = Parallel( self.n_jobs, prefer="threads", verbose=self.verbose )( @@ -266,7 +266,6 @@ def transform(self, X): return transformed_img[0] else: return transformed_img - return transformed_img # Make inherited function harmless def fit_transform(self): @@ -291,7 +290,7 @@ def get_parcellation(self): if hasattr(self, "parcel_masker"): check_is_fitted(self) labels = self.parcel_masker.get_labels() - parcellation_img = self.parcel_masker.get_parcellation() + parcellation_img = self.parcel_masker.get_parcellation_img() return labels, parcellation_img else: raise AttributeError( diff --git a/fmralign/preprocessing.py b/fmralign/preprocessing.py index 0cc4a0a..d271b51 100644 --- a/fmralign/preprocessing.py +++ b/fmralign/preprocessing.py @@ -194,7 +194,7 @@ def get_labels(self): ) return self.labels - def get_parcellation(self): + def get_parcellation_img(self): """Return the parcellation image. Returns diff --git a/fmralign/template_alignment.py b/fmralign/template_alignment.py index f07779e..779a789 100644 --- a/fmralign/template_alignment.py +++ b/fmralign/template_alignment.py @@ -8,221 +8,196 @@ import numpy as np from joblib import Memory, Parallel, delayed -from nilearn._utils.masker_validation import check_embedded_masker -from nilearn.image import concat_imgs, index_img, load_img from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils.validation import check_is_fitted -from fmralign.pairwise_alignment import PairwiseAlignment +from fmralign._utils import _parcels_to_array, _transform_one_img +from fmralign.pairwise_alignment import PairwiseAlignment, fit_one_piece +from fmralign.preprocessing import ParcellationMasker -def _rescaled_euclidean_mean(imgs, masker, scale_average=False): +def _rescaled_euclidean_mean(subjects_data, scale_average=False): """ - Make the Euclidian average of images. + Make the Euclidian average of `numpy.ndarray`. Parameters ---------- - imgs: list of Niimgs - Each img is 3D by default, but can also be 4D. - masker: instance of NiftiMasker or MultiNiftiMasker - Masker to be used on the data. + subjects_data: `list` of `numpy.ndarray` + Each element of the list is the data for one subject. scale_average: boolean - If true, the returned average is scaled to have the average norm of imgs - If false, it will usually have a smaller norm than initial average - because noise will cancel across images + If true, average is rescaled so that it keeps the same norm as the + average of training images. Returns ------- - average_img: Niimg + average_data: ndarray Average of imgs, with same shape as one img """ - masked_imgs = [masker.transform(img) for img in imgs] - average_img = np.mean(masked_imgs, axis=0) + average_data = np.mean(subjects_data, axis=0) scale = 1 if scale_average: X_norm = 0 - for img in masked_imgs: - X_norm += np.linalg.norm(img) - X_norm /= len(masked_imgs) - scale = X_norm / np.linalg.norm(average_img) - average_img *= scale + for data in subjects_data: + X_norm += np.linalg.norm(data) + X_norm /= len(subjects_data) + scale = X_norm / np.linalg.norm(average_data) + average_data *= scale + + return average_data + + +def _reconstruct_template(fit, labels, masker): + """ + Reconstruct template from fit output. + + Parameters + ---------- + fit: list of list of numpy.ndarray + Each element of the list is the list of parcels data for one subject. + labels: numpy.ndarray + Labels of the parcels. + masker: instance of NiftiMasker or MultiNiftiMasker + Masker to be used on the data. - return masker.inverse_transform(average_img) + Returns + ------- + template_img: 4D Niimg object + Models the barycenter of input imgs + template_history: list of 4D Niimgs + List of the intermediate templates computed at the end of each iteration + """ + template_parcels = [fit_i["template_data"] for fit_i in fit] + template_data = _parcels_to_array(template_parcels, labels) + template_img = masker.inverse_transform(template_data) + + n_iter = len(fit[0]["template_history"]) + template_history = [] + for i in range(n_iter): + template_parcels = [fit_j["template_history"][i] for fit_j in fit] + template_data = _parcels_to_array(template_parcels, labels) + template_history.append(masker.inverse_transform(template_data)) + + return template_img, template_history def _align_images_to_template( - imgs, + subjects_data, template, alignment_method, - n_pieces, - clustering, - masker, - memory, - memory_level, - n_jobs, - verbose, ): """ Convenience function. - For a list of images, return the list of estimators (PairwiseAlignment instances) + For a list of ndarrays, return the list of alignment estimators aligning each of them to a common target, the template. - All arguments are used in PairwiseAlignment. + + Parameters + ---------- + subjects_data: `list` of `numpy.ndarray` + Each element of the list is the data for one subject. + template: `numpy.ndarray` + The target data. + alignment_method: string + Algorithm used to perform alignment between sources and template. + + Returns + ------- + aligned_data: `list` of `numpy.ndarray` + List of aligned data. + piecewise_estimators: `list` of `PairwiseAlignment` + List of `Alignment` estimators. """ - aligned_imgs = [] - for img in imgs: - piecewise_estimator = PairwiseAlignment( - n_pieces=n_pieces, - alignment_method=alignment_method, - clustering=clustering, - mask=masker, - memory=memory, - memory_level=memory_level, - n_jobs=n_jobs, - verbose=verbose, + aligned_data = [] + piecewise_estimators = [] + for subject_data in subjects_data: + piecewise_estimator = fit_one_piece( + subject_data, + template, + alignment_method, ) - piecewise_estimator.fit(img, template) - aligned_imgs.append(piecewise_estimator.transform(img)) - return aligned_imgs + piecewise_estimator.fit(subject_data, template) + piecewise_estimators.append(piecewise_estimator) + aligned_data.append(piecewise_estimator.transform(subject_data)) + return aligned_data, piecewise_estimators -def _create_template( - imgs, - n_iter, - scale_template, - alignment_method, - n_pieces, - clustering, - masker, - memory, - memory_level, - n_jobs, - verbose, +def _fit_local_template( + subjects_data, + n_iter=2, + scale_template=False, + alignment_method="identity", ): """ Create template through alternate minimization. Compute iteratively : - * T minimizing sum(||R_i X_i-T||) which is the mean of aligned images (RX_i) + * T minimizing sum(||R X-T||) which is the mean of aligned images (RX) * align initial images to new template T - (find transform R_i minimizing ||R_i X_i-T|| for each img X_i) + (find transform R minimizing ||R X-T|| for each img X) Parameters ---------- - imgs: List of Niimg-like objects - See http://nilearn.github.io/manipulating_images/input_output.html - source data. Every img must have the same length (n_sample) - scale_template: boolean - If true, template is rescaled after each inference so that it keeps - the same norm as the average of training images. + imgs: `list` of `numpy.ndarray` + Each element of the list is the data for one subject. n_iter: int Number of iterations in the alternate minimization. Each image is aligned n_iter times to the evolving template. If n_iter = 0, the template is simply the mean of the input images. - All other arguments are the same are passed to PairwiseAlignment + scale_template: boolean + If true, template is rescaled after each inference so that it keeps + the same norm as the average of training images. + alignment_method: string + Algorithm used to perform alignment between sources and template. Returns ------- - template: list of 3D Niimgs of length (n_sample) - Models the barycenter of input imgs - template_history: list of list of 3D Niimgs - List of the intermediate templates computed at the end of each iteration + template_data: `numpy.ndarray` + Template data. + template_history: `list` of `numpy.ndarray` + List of the intermediate templates computed at the end of each iteration. + piecewise_estimators: `list` of `PairwiseAlignment` + List of `Alignment` estimators. """ - aligned_imgs = imgs + aligned_data = subjects_data template_history = [] for iter in range(n_iter): - template = _rescaled_euclidean_mean( - aligned_imgs, masker, scale_template - ) + template = _rescaled_euclidean_mean(aligned_data, scale_template) if 0 < iter < n_iter - 1: template_history.append(template) - aligned_imgs = _align_images_to_template( - imgs, + aligned_data, subjects_estimators = _align_images_to_template( + subjects_data, template, alignment_method, - n_pieces, - clustering, - masker, - memory, - memory_level, - n_jobs, - verbose, ) - return template, template_history + return { + "template_data": template, + "template_history": template_history, + "estimators": subjects_estimators, + } -def _map_template_to_image( - imgs, - train_index, - template, - alignment_method, - n_pieces, - clustering, - masker, - memory, - memory_level, - n_jobs, - verbose, -): +def _index_by_parcel(subjects_parcels): """ - Learn alignment operator from the template toward new images. + Index data by parcel. Parameters ---------- - imgs: list of 3D Niimgs - Target images to learn mapping from the template to a new subject - train_index: list of int - Matching index between imgs and the corresponding template images to use - to learn alignment. len(train_index) must be equal to len(imgs) - template: list of 3D Niimgs - Learnt in a first step now used as source image - All other arguments are the same are passed to PairwiseAlignment - + subjects_parcels: list of list of numpy.ndarray + Each element of the list is the list of parcels + data for one subject. Returns ------- - mapping: instance of PairwiseAlignment class - Alignment estimator fitted to align the template with the input images + list of list of numpy.ndarray + Each element of the list is the list of subjects + data for one parcel. """ - - mapping_image = index_img(template, train_index) - mapping = PairwiseAlignment( - n_pieces=n_pieces, - alignment_method=alignment_method, - clustering=clustering, - mask=masker, - memory=memory, - memory_level=memory_level, - n_jobs=n_jobs, - verbose=verbose, - ) - mapping.fit(mapping_image, imgs) - return mapping - - -def _predict_from_template_and_mapping(template, test_index, mapping): - """ - From a template and an alignment estimator, predict new contrasts. - - Parameters - ---------- - template: list of 3D Niimgs - Learnt in a first step now used to predict some new data - test_index: - Index of the images not used to learn the alignment mapping and so - predictable without overfitting - mapping: instance of PairwiseAlignment class - Alignment estimator that must have been fitted already - - Returns - ------- - transformed_image: list of Niimgs - Prediction corresponding to each template image with index in test_index - once realigned to the new subjects - """ - image_to_transform = index_img(template, test_index) - transformed_image = mapping.transform(image_to_transform) - return transformed_image + n_pieces = subjects_parcels[0].n_pieces + return [ + [subject_parcels[i] for subject_parcels in subjects_parcels] + for i in range(n_pieces) + ] class TemplateAlignment(BaseEstimator, TransformerMixin): @@ -369,130 +344,113 @@ def fit(self, imgs): """ - # Check if the input is a list, if list of lists, concatenate each subjects - # data into one unique image. - if not isinstance(imgs, (list, np.ndarray)) or len(imgs) < 2: - raise ValueError( - "The method TemplateAlignment.fit() need a list input. " - "Each element of the list (Niimg-like or list of Niimgs) " - "is the data for one subject." - ) - else: - if isinstance(imgs[0], (list, np.ndarray)): - imgs = [concat_imgs(img) for img in imgs] + self.parcel_masker = ParcellationMasker( + n_pieces=self.n_pieces, + clustering=self.clustering, + mask=self.mask, + smoothing_fwhm=self.smoothing_fwhm, + standardize=self.standardize, + detrend=self.detrend, + low_pass=self.low_pass, + high_pass=self.high_pass, + t_r=self.t_r, + target_affine=self.target_affine, + target_shape=self.target_shape, + memory=self.memory, + memory_level=self.memory_level, + n_jobs=self.n_jobs, + verbose=self.verbose, + ) - self.masker_ = check_embedded_masker(self) - self.masker_.n_jobs = self.n_jobs # self.n_jobs + subjects_parcels = self.parcel_masker.fit_transform(imgs) + parcels_data = _index_by_parcel(subjects_parcels) + self.masker = self.parcel_masker.masker_ + self.mask = self.parcel_masker.masker_.mask_img_ + self.labels_ = self.parcel_masker.labels + self.n_pieces = self.parcel_masker.n_pieces - # if masker_ has been provided a mask_img - if self.masker_.mask_img is None: - self.masker_.fit(imgs) - else: - self.masker_.fit() - - self.template, self.template_history = _create_template( - imgs, - self.n_iter, - self.scale_template, - self.alignment_method, - self.n_pieces, - self.clustering, - self.masker_, - self.memory, - self.memory_level, - self.n_jobs, - self.verbose, + self.fit_ = Parallel( + self.n_jobs, prefer="threads", verbose=self.verbose + )( + delayed(_fit_local_template)( + parcel_i, + self.n_iter, + self.scale_template, + self.alignment_method, + ) + for parcel_i in parcels_data + ) + + self.template, self.template_history = _reconstruct_template( + self.fit_, self.labels_, self.masker ) if self.save_template is not None: self.template.to_filename(self.save_template) - def transform(self, imgs, train_index, test_index): + def transform(self, img, subject_index=None): """ - Learn alignment between new subject and template calculated during fit, - then predict other conditions for this new subject. - Alignment is learnt between imgs and conditions in the template indexed by train_index. - Prediction correspond to conditions in the template index by test_index. + Transform a (new) subject image into the template space. Parameters ---------- - imgs: List of 3D Niimg-like objects - Target subjects known data. - Every img must have length (number of sample) train_index. - train_index: list of ints - Indexes of the 3D samples used to map each img to the template. - Every index should be smaller than the number of images in the template. - test_index: list of ints - Indexes of the 3D samples to predict from the template and the mapping. - Every index should be smaller than the number of images in the template. + img: 4D Niimg-like object + Subject image. + subject_index: int, optional (default = None) + Index of the subject to be transformed. It should + correspond to the index of the subject in the list of + subjects used to fit the template. If None, a new + `PairwiseAlignment` object is fitted between the new + subject and the template before transforming. Returns ------- - predicted_imgs: List of 3D Niimg-like objects - Target subjects predicted data. - Each Niimg has the same length as the list test_index + predicted_imgs: 4D Niimg object + Transformed data. """ - - if not isinstance(imgs, (list, np.ndarray)): - raise ValueError( - "The method TemplateAlignment.transform() need a list input. " - "Each element of the list (Niimg-like or list of Niimgs) " - "is the data used to align one new subject with images " - "indexed by train_index." - ) - else: - if isinstance(imgs[0], (list, np.ndarray)) and len(imgs[0]) != len( - train_index - ): - raise ValueError( - "Each element of imgs (Niimg-like or list of Niimgs) " - "should have the same length as the length of train_index." - ) - elif load_img(imgs[0]).shape[-1] != len(train_index): - raise ValueError( - "Each element of imgs (Niimg-like or list of Niimgs) " - "should have the same length as the length of train_index." - ) - - template_length = self.template.shape[-1] - if not ( - all(i < template_length for i in test_index) - and all(i < template_length for i in train_index) - ): + if not hasattr(self, "fit_"): raise ValueError( - f"Template has {template_length} images but you provided a " - "greater index in train_index or test_index." + "This instance has not been fitted yet. " + "Please call 'fit' before 'transform'." ) - fitted_mappings = Parallel( - self.n_jobs, prefer="threads", verbose=self.verbose - )( - delayed(_map_template_to_image)( - img, - train_index, - self.template, - self.alignment_method, - self.n_pieces, - self.clustering, - self.masker_, - self.memory, - self.memory_level, - self.n_jobs, - self.verbose, + if subject_index is None: + alignment_estimator = PairwiseAlignment( + n_pieces=self.n_pieces, + alignment_method=self.alignment_method, + clustering=self.parcel_masker.get_parcellation_img(), + mask=self.masker, + smoothing_fwhm=self.smoothing_fwhm, + standardize=self.standardize, + detrend=self.detrend, + target_affine=self.target_affine, + target_shape=self.target_shape, + low_pass=self.low_pass, + high_pass=self.high_pass, + t_r=self.t_r, + memory=self.memory, + memory_level=self.memory_level, + n_jobs=self.n_jobs, + verbose=self.verbose, ) - for img in imgs - ) - - predicted_imgs = Parallel( - self.n_jobs, prefer="threads", verbose=self.verbose - )( - delayed(_predict_from_template_and_mapping)( - self.template, test_index, mapping + alignment_estimator.fit(img, self.template) + return alignment_estimator.transform(img) + else: + parceled_data_list = self.parcel_masker.transform(img) + subject_estimators = [ + fit_i["estimators"][subject_index] for fit_i in self.fit_ + ] + transformed_img = Parallel( + self.n_jobs, prefer="threads", verbose=self.verbose + )( + delayed(_transform_one_img)(parceled_data, subject_estimators) + for parceled_data in parceled_data_list ) - for mapping in fitted_mappings - ) - return predicted_imgs + if len(transformed_img) == 1: + return transformed_img[0] + else: + return transformed_img # Make inherited function harmless def fit_transform(self): @@ -500,3 +458,26 @@ def fit_transform(self): raise AttributeError( "type object 'PairwiseAlignment' has no attribute 'fit_transform'" ) + + def get_parcellation(self): + """Get the parcellation masker used for alignment. + + Returns + ------- + labels: `list` of `int` + Labels of the parcellation masker. + parcellation_img: Niimg-like object + Parcellation image. + """ + if hasattr(self, "parcel_masker"): + check_is_fitted(self) + labels = self.parcel_masker.get_labels() + parcellation_img = self.parcel_masker.get_parcellation_img() + return labels, parcellation_img + else: + raise AttributeError( + ( + "Parcellation has not been computed yet," + "please fit the alignment estimator first." + ) + ) diff --git a/fmralign/tests/test_preprocessing.py b/fmralign/tests/test_preprocessing.py index 970c008..0efa96f 100644 --- a/fmralign/tests/test_preprocessing.py +++ b/fmralign/tests/test_preprocessing.py @@ -190,13 +190,13 @@ def test_one_contrast(): pmasker.fit([img1, img2]) -def test_get_parcellation(): +def test_get_parcellation_img(): """Test that ParcellationMasker returns the parcellation mask""" n_pieces = 2 img, _ = random_niimg((8, 7, 6)) parcel_masker = ParcellationMasker(n_pieces=n_pieces) parcel_masker.fit(img) - parcellation_img = parcel_masker.get_parcellation() + parcellation_img = parcel_masker.get_parcellation_img() labels = parcel_masker.get_labels() assert isinstance(parcellation_img, Nifti1Image) @@ -207,4 +207,3 @@ def test_get_parcellation(): assert np.allclose(data, labels) assert len(np.unique(data)) == n_pieces - diff --git a/fmralign/tests/test_template_alignment.py b/fmralign/tests/test_template_alignment.py index 245eddd..7b914ec 100644 --- a/fmralign/tests/test_template_alignment.py +++ b/fmralign/tests/test_template_alignment.py @@ -1,19 +1,107 @@ -# -*- coding: utf-8 -*- +import numpy as np import pytest -from nilearn.image import concat_imgs, index_img, math_img +from nibabel import Nifti1Image +from nilearn.image import concat_imgs, math_img from nilearn.maskers import NiftiMasker from numpy.testing import assert_array_almost_equal +from fmralign._utils import ParceledData +from fmralign.preprocessing import ParcellationMasker from fmralign.template_alignment import ( TemplateAlignment, + _align_images_to_template, + _fit_local_template, + _index_by_parcel, + _reconstruct_template, _rescaled_euclidean_mean, ) from fmralign.tests.utils import ( random_niimg, + sample_parceled_data, + sample_subjects_data, zero_mean_coefficient_determination, ) +@pytest.mark.parametrize("scale_average", [True, False]) +def test_rescaled_euclidean_mean(scale_average): + subjects_data = sample_subjects_data() + average_data = _rescaled_euclidean_mean(subjects_data) + assert average_data.shape == subjects_data[0].shape + assert average_data.dtype == subjects_data[0].dtype + + if scale_average is False: + assert np.allclose(average_data, np.mean(subjects_data, axis=0)) + + +def test_reconstruct_template(): + n_subjects = 3 + n_iter = 3 + n_pieces = 2 + imgs = [random_niimg((8, 7, 6, 20))[0]] * n_subjects + parcel_masker = ParcellationMasker(n_pieces=n_pieces) + subjects_parcels = parcel_masker.fit_transform(imgs) + parcels_subjects = _index_by_parcel(subjects_parcels) + masker = parcel_masker.masker_ + labels = parcel_masker.labels + + fit = [ + _fit_local_template(parcel_i, n_iter=n_iter) + for parcel_i in parcels_subjects + ] + template, template_history = _reconstruct_template(fit, labels, masker) + + assert template.shape == imgs[0].shape + assert len(template_history) == n_iter - 2 + for template_i in template_history: + assert template_i.shape == imgs[0].shape + + +def test_align_images_to_template(): + subjects_data = sample_subjects_data() + template = _rescaled_euclidean_mean(subjects_data) + aligned_data, subjects_estimators = _align_images_to_template( + subjects_data, + template, + alignment_method="identity", + ) + assert len(aligned_data) == len(subjects_data) + assert len(subjects_estimators) == len(subjects_data) + assert aligned_data[0].shape == subjects_data[0].shape + + +def test_fit_local_template(): + n_subjects = 3 + n_iter = 3 + subjects_data = sample_subjects_data(n_subjects=n_subjects) + fit = _fit_local_template( + subjects_data, + n_iter=n_iter, + alignment_method="identity", + scale_template=False, + ) + template_data = fit["template_data"] + template_history = fit["template_history"] + estimators = fit["estimators"] + + assert template_data.shape == subjects_data[0].shape + assert len(template_history) == n_iter - 2 + assert len(estimators) == n_subjects + + +def test_index_by_parcel(): + n_subjects = 3 + n_pieces = 2 + subjects_parcels = [ + ParceledData(*sample_parceled_data(n_pieces)) + for _ in range(n_subjects) + ] + parcels_subjects = _index_by_parcel(subjects_parcels) + assert len(parcels_subjects) == n_pieces + assert len(parcels_subjects[0]) == n_subjects + assert parcels_subjects[0][0].shape == subjects_parcels[0][0].shape + + def test_template_identity(): n = 10 im, mask_img = random_niimg((6, 5, 3)) @@ -28,15 +116,9 @@ def test_template_identity(): subs = [sub_1, sub_2, sub_3] - # test euclidian mean function - euclidian_template = _rescaled_euclidean_mean(subs, masker) - assert_array_almost_equal( - ref_template.get_fdata(), euclidian_template.get_fdata() - ) - # test different fit() accept list of list of 3D Niimgs as input. algo = TemplateAlignment(alignment_method="identity", mask=masker) - algo.fit([n * [im]] * 3) + algo.fit([concat_imgs(n * [im])] * 3) # test template assert_array_almost_equal(sub_1.get_fdata(), algo.template.get_fdata()) @@ -54,51 +136,53 @@ def test_template_identity(): for args in args_list: algo = TemplateAlignment(**args) - # Learning a template which is algo.fit(subs) # test template assert_array_almost_equal( - ref_template.get_fdata(), algo.template.get_fdata() + ref_template.get_fdata(), + algo.template.get_fdata(), ) - predicted_imgs = algo.transform( - [index_img(sub_1, range(8))], - train_index=range(8), - test_index=range(8, 10), + predicted_imgs = algo.transform(ref_template) + assert_array_almost_equal( + predicted_imgs.get_fdata(), + ref_template.get_fdata(), ) - ground_truth = index_img(ref_template, range(8, 10)) + predicted_imgs = algo.transform(ref_template, subject_index=1) assert_array_almost_equal( - ground_truth.get_fdata(), predicted_imgs[0].get_fdata() + predicted_imgs.get_fdata(), + ref_template.get_fdata(), ) - # test transform() with wrong indexes length or content (on previous fitted algo) - train_inds, test_inds = ( - [[0, 1], [1, 10], [4, 11], [0, 1, 2]], - [ - [6, 8, 29], - [4, 6], - [4, 11], - [4, 5], - ], + +def test_template_diagonal(): + n = 10 + im, mask_img = random_niimg((6, 5, 3)) + + sub_1 = concat_imgs(n * [im]) + sub_2 = math_img("2 * img", img=sub_1) + sub_3 = math_img("3 * img", img=sub_1) + + ref_template = sub_2 + masker = NiftiMasker(mask_img=mask_img) + masker.fit() + + subs = [sub_1, sub_2, sub_3] + + # Test without subject_index + algo = TemplateAlignment(alignment_method="diagonal", mask=masker) + algo.fit(subs) + predicted_imgs = algo.transform(sub_1, subject_index=None) + assert_array_almost_equal( + ref_template.get_fdata(), + predicted_imgs.get_fdata(), ) - for train_ind, test_ind in zip(train_inds, test_inds): - with pytest.raises(Exception): - assert algo.transform( - [index_img(sub_1, range(2))], - train_index=train_ind, - test_index=test_ind, - ) - - # test wrong images input in fit() and transform method - with pytest.raises(Exception): - assert algo.transform( - [n * [im]] * 2, - train_index=train_inds[-1], - test_index=test_inds[-1], - ) - assert algo.fit([im]) - assert algo.transform( - [im], train_index=train_inds[-1], test_index=test_inds[-1] + # Test with subject_index + for i, sub in enumerate(subs): + predicted_imgs = algo.transform(sub, subject_index=i) + assert_array_almost_equal( + predicted_imgs.get_fdata(), + ref_template.get_fdata(), ) @@ -116,8 +200,7 @@ def test_template_closer_to_target(): sub_1 = masker.transform(subject_1) sub_2 = masker.transform(subject_2) subs = [subject_1, subject_2] - average_img = _rescaled_euclidean_mean(subs, masker) - avg_data = masker.transform(average_img) + avg_data = np.mean([sub_1, sub_2], axis=0) mean_distance_1 = zero_mean_coefficient_determination(sub_1, avg_data) mean_distance_2 = zero_mean_coefficient_determination(sub_2, avg_data) @@ -129,16 +212,51 @@ def test_template_closer_to_target(): "diagonal", ]: algo = TemplateAlignment( - alignment_method=alignment_method, n_pieces=3, mask=masker + alignment_method=alignment_method, + n_pieces=3, + mask=masker, ) # Learn template algo.fit(subs) # Assess template is closer to mean than both images template_data = masker.transform(algo.template) template_mean_distance = zero_mean_coefficient_determination( - avg_data, template_data + avg_data, + template_data, ) assert template_mean_distance >= mean_distance_1 assert ( template_mean_distance >= mean_distance_2 - 1.0e-2 ) # for robustness + + +def test_parcellation_retrieval(): + """Test that TemplateAlignment returns both the\n + labels and the parcellation image + """ + n_pieces = 3 + imgs = [random_niimg((8, 7, 6))[0]] * 3 + alignment = TemplateAlignment(n_pieces=n_pieces) + alignment.fit(imgs) + + labels, parcellation_image = alignment.get_parcellation() + assert isinstance(labels, np.ndarray) + assert len(np.unique(labels)) == n_pieces + assert isinstance(parcellation_image, Nifti1Image) + assert parcellation_image.shape == imgs[0].shape[:-1] + + +def test_parcellation_before_fit(): + """Test that TemplateAlignment raises an error if\n + the parcellation is retrieved before fitting + """ + alignment = TemplateAlignment() + with pytest.raises( + AttributeError, + match="Parcellation has not been computed yet", + ): + alignment.get_parcellation() + + +if __name__ == "__main__": + test_parcellation_retrieval() diff --git a/fmralign/tests/utils.py b/fmralign/tests/utils.py index 35133c6..61af947 100644 --- a/fmralign/tests/utils.py +++ b/fmralign/tests/utils.py @@ -132,3 +132,9 @@ def sample_parceled_data(n_pieces=1): data = masker.fit_transform(img) labels = _make_parcellation(img, "kmeans", n_pieces, masker) return data, masker, labels + + +def sample_subjects_data(n_subjects=3): + """Sample data in one parcel for n_subjects""" + subjects_data = [np.random.rand(10, 20) for _ in range(n_subjects)] + return subjects_data