Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor template alignments to keep a fixed parcellation #116

Merged
merged 24 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1749a3a
Add index_by_parcel
Dec 7, 2024
00751ab
Modify fit method
Dec 7, 2024
916fdfc
Replace _create_template by _fit_local_template
Dec 7, 2024
9f9fe29
Add template reconstruction + history
Dec 7, 2024
7c1cc25
Rename get_parcellation method to get_parcellation_img for clarity
Dec 7, 2024
22d2b6e
Rename transform parameter from X to img for clarity and update relat…
Dec 7, 2024
931511c
Rename test_get_parcellation to test_get_parcellation_img
Dec 7, 2024
5612e75
Refactor transform method in TemplateAlignment for clarity and functi…
Dec 7, 2024
169fd45
Remove unused _map_template_to_image function to streamline codebase
Dec 7, 2024
3bf5571
Remove unused _predict_from_template_and_mapping function
Dec 7, 2024
a88606c
Add tests for parcellation retrieval in TemplateAlignment
Dec 7, 2024
218fa6c
Refactor variable names in TemplateAlignment for clarity in parcel da…
Dec 8, 2024
bf33bbd
Rename index_by_parcel function to _index_by_parcel for clarity and u…
Dec 8, 2024
1d32ad7
Set default values for parameters in _fit_local_template function for…
Dec 8, 2024
885a05b
Add unit tests for template alignment functions and parcellation proc…
Dec 8, 2024
5883e48
Add function to sample subjects data for testing purposes
Dec 8, 2024
9411aee
Update documentation for image processing functions to clarify parame…
Dec 8, 2024
e55d982
Enhance tutorial on template-based prediction by improving clarity an…
Dec 9, 2024
a413c75
Merge branch 'main' into feat/one-parcellation-template
pbarbarant Dec 10, 2024
e22e7f8
Add test for ParcellationMasker to handle 3D and 4D images with one c…
Dec 3, 2024
2379b81
Rebase with main
Dec 10, 2024
78ff3b3
Fix shape assertion in parcellation retrieval test
Dec 10, 2024
25b39ee
Update examples/plot_template_alignment.py
pbarbarant Dec 17, 2024
2ba236d
Update examples/plot_template_alignment.py
pbarbarant Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 29 additions & 51 deletions examples/plot_template_alignment.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-

"""
Template-based prediction.
==========================

In this tutorial, we show how to better predict new contrasts for a target
subject using many source subjects corresponding contrasts. For this purpose,
we create a template to which we align the target subject, using shared information.
We then predict new images for the target and compare them to a baseline.
In this tutorial, we show how to improve inter-subject similarity using a template
computed across multiple source subjects. For this purpose, we create a template
using Procrustes alignment (hyperalignment) to which we align the target subject,
using shared information. We then compare the voxelwise similarity between the
target subject and the template to the similarity between the target subject and
the anatomical Euclidean average of the source subjects.

We mostly rely on Python common packages and on nilearn to handle
functional data in a clean fashion.
Expand Down Expand Up @@ -36,7 +37,7 @@
)

###############################################################################
# Definine a masker
# Define a masker
# -----------------
# We define a nilearn masker that will be used to handle relevant data.
# For more information, visit :
Expand Down Expand Up @@ -64,22 +65,17 @@
template_train = []
for i in range(5):
template_train.append(concat_imgs(imgs[i]))
target_train = df[df.subject == "sub-07"][df.acquisition == "ap"].path.values

# For subject sub-07, we split it in two folds:
# - target train: sub-07 AP contrasts, used to learn alignment to template
# - target test: sub-07 PA contrasts, used as a ground truth to score predictions
# We make a single 4D Niimg from our list of 3D filenames
# sub-07 (that is 5th in the list) will be our left-out subject.
# We make a single 4D Niimg from our list of 3D filenames.

target_train = concat_imgs(target_train)
target_test = df[df.subject == "sub-07"][df.acquisition == "pa"].path.values
left_out_subject = concat_imgs(imgs[5])

###############################################################################
# Compute a baseline (average of subjects)
# ----------------------------------------
# We create an image with as many contrasts as any subject representing for
# each contrast the average of all train subjects maps.
#

import numpy as np

Expand All @@ -92,70 +88,53 @@
# ---------------------------------------------
# We define an estimator using the class TemplateAlignment:
# * We align the whole brain through 'multiple' local alignments.
# * These alignments are calculated on a parcellation of the brain in 150 pieces,
# * These alignments are calculated on a parcellation of the brain in 50 pieces,
# this parcellation creates group of functionnally similar voxels.
# * The template is created iteratively, aligning all subjects data into a
# common space, from which the template is inferred and aligning again to this
# new template space.
#

from nilearn.image import index_img

from fmralign.template_alignment import TemplateAlignment

# We use Procrustes/scaled orthogonal alignment method
template_estim = TemplateAlignment(
n_pieces=150, alignment_method="ridge_cv", mask=masker
n_pieces=50,
alignment_method="scaled_orthogonal",
mask=masker,
)
template_estim.fit(template_train)
procrustes_template = template_estim.template

###############################################################################
# Predict new data for left-out subject
# -------------------------------------
# We use target_train data to fit the transform, indicating it corresponds to
# the contrasts indexed by train_index and predict from this learnt alignment
# contrasts corresponding to template test_index numbers.
# For each train subject and for the template, the AP contrasts are sorted from
# 0, to 53, and then the PA contrasts from 53 to 106.
#

train_index = range(53)
test_index = range(53, 106)

# We input the mapping image target_train in a list, we could have input more
# than one subject for which we'd want to predict : [train_1, train_2 ...]
# We predict the contrasts of the left-out subject using the template we just
# created. We use the transform method of the estimator. This method takes the
# left-out subject as input, computes a pairwise alignment with the template
# and returns the aligned data.

prediction_from_template = template_estim.transform(
[target_train], train_index, test_index
)

# As a baseline prediction, let's just take the average of activations across subjects.

prediction_from_average = index_img(average_subject, test_index)
predictions_from_template = template_estim.transform(left_out_subject)

###############################################################################
# Score the baseline and the prediction
# -------------------------------------
# We use a utility scoring function to measure the voxelwise correlation
# between the prediction and the ground truth. That is, for each voxel, we
# measure the correlation between its profile of activation without and with
# alignment, to see if alignment was able to predict a signal more alike the ground truth.
#
# between the images. That is, for each voxel, we measure the correlation between
# its profile of activation without and with alignment, to see if template-based
# alignment was able to improve inter-subject similarity.

from fmralign.metrics import score_voxelwise

# Now we use this scoring function to compare the correlation of predictions
# made from group average and from template with the real PA contrasts of sub-07

average_score = masker.inverse_transform(
score_voxelwise(target_test, prediction_from_average, masker, loss="corr")
score_voxelwise(left_out_subject, average_subject, masker, loss="corr")
)
template_score = masker.inverse_transform(
score_voxelwise(
target_test, prediction_from_template[0], masker, loss="corr"
predictions_from_template, procrustes_template, masker, loss="corr"
)
)


###############################################################################
# Plotting the measures
# ---------------------
Expand All @@ -167,13 +146,12 @@
baseline_display = plotting.plot_stat_map(
average_score, display_mode="z", vmax=1, cut_coords=[-15, -5]
)
baseline_display.title("Group average correlation wt ground truth")
baseline_display.title("Left-out subject correlation with group average")
display = plotting.plot_stat_map(
template_score, display_mode="z", cut_coords=[-15, -5], vmax=1
)
display.title("Template-based prediction correlation wt ground truth")
display.title("Aligned subject correlation with Procrustes template")

###############################################################################
# We observe that creating a template and aligning a new subject to it yields
# a prediction that is better correlated with the ground truth than just using
# the average activations of subjects.
# better inter-subject similarity than regular euclidean averaging.
9 changes: 4 additions & 5 deletions fmralign/pairwise_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def fit(self, X, Y):

return self

def transform(self, X):
def transform(self, img):
"""Predict data from X.

Parameters
----------
X: Niimg-like object
img: Niimg-like object
Source data

Returns
Expand All @@ -255,7 +255,7 @@ def transform(self, X):
"This instance has not been fitted yet. "
"Please call 'fit' before 'transform'."
)
parceled_data_list = self.parcel_masker.transform(X)
parceled_data_list = self.parcel_masker.transform(img)
transformed_img = Parallel(
self.n_jobs, prefer="threads", verbose=self.verbose
)(
Expand All @@ -266,7 +266,6 @@ def transform(self, X):
return transformed_img[0]
else:
return transformed_img
return transformed_img

# Make inherited function harmless
def fit_transform(self):
Expand All @@ -291,7 +290,7 @@ def get_parcellation(self):
if hasattr(self, "parcel_masker"):
check_is_fitted(self)
labels = self.parcel_masker.get_labels()
parcellation_img = self.parcel_masker.get_parcellation()
parcellation_img = self.parcel_masker.get_parcellation_img()
return labels, parcellation_img
else:
raise AttributeError(
Expand Down
2 changes: 1 addition & 1 deletion fmralign/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def get_labels(self):
)
return self.labels

def get_parcellation(self):
def get_parcellation_img(self):
"""Return the parcellation image.

Returns
Expand Down
Loading
Loading