Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add surface support #106

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions fmralign/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,7 @@ def _make_parcellation(
)
err.args += (errmsg,)
raise err
labels = apply_mask_fmri(
parcellation.labels_img_, masker.mask_img_
).astype(int)
labels = masker.transform(parcellation.labels_img_)[0].astype(int)

if verbose > 0:
unique_labels, counts = np.unique(labels, return_counts=True)
Expand Down
65 changes: 39 additions & 26 deletions fmralign/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from nibabel.nifti1 import Nifti1Image
from nilearn._utils.masker_validation import check_embedded_masker
from nilearn.image import concat_imgs
from nilearn.maskers._utils import concatenate_surface_images
from nilearn.surface import SurfaceImage
from sklearn.base import BaseEstimator, TransformerMixin

from fmralign._utils import (
Expand Down Expand Up @@ -118,10 +120,7 @@ def __init__(

def _fit_masker(self, imgs):
"""Fit the masker on a single or multiple images."""
self.masker_ = check_embedded_masker(self)
self.masker_.n_jobs = self.n_jobs

if isinstance(imgs, Nifti1Image):
if isinstance(imgs, (Nifti1Image, SurfaceImage)):
imgs = [imgs]
# If images are 3D, add a fourth dimension
for i, img in enumerate(imgs):
Expand All @@ -136,32 +135,46 @@ def _fit_masker(self, imgs):
raise NotImplementedError(
"fmralign does not support images of different shapes."
)
if self.masker_.mask_img is None:
self.masker_.fit(imgs)
else:
self.masker_.fit()

if isinstance(self.clustering, Nifti1Image) or os.path.isfile(
self.clustering
):
# check that clustering provided fills the mask, if not, reduce the mask
if 0 in self.masker_.transform(self.clustering):
reduced_mask = _intersect_clustering_mask(
self.clustering, self.masker_.mask_img
)
self.mask = reduced_mask
self.masker_ = check_embedded_masker(self)
self.masker_.n_jobs = self.n_jobs

masker_type = (
"surface" if isinstance(imgs[0], SurfaceImage) else "multi_nii"
)
self.masker_ = check_embedded_masker(self, masker_type=masker_type)
self.masker_.n_jobs = self.n_jobs

# Fit the masker for volume data
if masker_type == "multi_nii":
if self.masker_.mask_img is None:
self.masker_.fit(imgs)
else:
self.masker_.fit()
warnings.warn(
"Mask used was bigger than clustering provided. "
+ "Its intersection with the clustering was used instead."
)

if isinstance(self.clustering, Nifti1Image) or os.path.isfile(
self.clustering
):
# check that clustering provided fills the mask, if not, reduce the mask
if 0 in self.masker_.transform(self.clustering):
reduced_mask = _intersect_clustering_mask(
self.clustering, self.masker_.mask_img
)
self.mask = reduced_mask
self.masker_ = check_embedded_masker(self)
self.masker_.n_jobs = self.n_jobs
self.masker_.fit()
warnings.warn(
"Mask used was bigger than clustering provided. "
+ "Its intersection with the clustering was used instead."
)
else:
self.masker_.fit(imgs)

def _one_parcellation(self, imgs):
"""Compute one parcellation for all images."""
if isinstance(imgs, list):
imgs = concat_imgs(imgs)
if isinstance(imgs[0], (Nifti1Image)):
imgs = concat_imgs(imgs)
else:
imgs = concatenate_surface_images(imgs)
self.labels = _make_parcellation(
imgs,
self.clustering,
Expand Down Expand Up @@ -234,7 +247,7 @@ def transform(self, imgs):
List of ParceledData objects containing the data and parcelation
information for each image.
"""
if isinstance(imgs, Nifti1Image):
if isinstance(imgs, (Nifti1Image, SurfaceImage)):
imgs = [imgs]

parceled_data = Parallel(n_jobs=self.n_jobs)(
Expand Down
25 changes: 25 additions & 0 deletions fmralign/tests/test_pairwise_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from nibabel import Nifti1Image
from nilearn.image import new_img_like
from nilearn.maskers import NiftiMasker
from nilearn.surface import SurfaceImage

from fmralign.pairwise_alignment import PairwiseAlignment
from fmralign.tests.utils import (
assert_algo_transform_almost_exactly,
random_niimg,
surf_img,
zero_mean_coefficient_determination,
)

Expand Down Expand Up @@ -130,3 +132,26 @@ def test_parcellation_before_fit():
AttributeError, match="Parcellation has not been computed yet"
):
alignment.get_parcellation()


def test_surface_alignment():
"""Test compatibility with `SurfaceImage`"""
alignment = PairwiseAlignment()
n_pieces = 3
img1 = surf_img(20)
img2 = surf_img(20)
alignment = PairwiseAlignment(n_pieces=n_pieces)

# Test fitting
alignment.fit(img1, img2)

# Test transformation
img_transformed = alignment.transform(img1)
assert img_transformed.shape == img1.shape
assert isinstance(img_transformed, SurfaceImage)

# Test parcellation retrieval
labels, parcellation_image = alignment.get_parcellation()
assert isinstance(labels, np.ndarray)
assert len(np.unique(labels)) == n_pieces
assert isinstance(parcellation_image, SurfaceImage)
36 changes: 33 additions & 3 deletions fmralign/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fmralign._utils import ParceledData
from fmralign.preprocessing import ParcellationMasker
from fmralign.tests.utils import random_niimg
from fmralign.tests.utils import random_niimg, surf_img


def test_init_default_params():
Expand Down Expand Up @@ -181,13 +181,43 @@ def test_standardization():
assert np.abs(np.std(data_array) - 1.0) < 1e-5


def test_one_surface_image():
"""Test that ParcellationMasker can handle surface images"""
img = surf_img(20)
n_pieces = 2
n_vertices_total = img.shape[0]
parcel_masker = ParcellationMasker(n_pieces=n_pieces)
fitted_parcel_masker = parcel_masker.fit(img)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fitted_parcel_masker = parcel_masker.fit(img)
parcel_masker = parcel_masker.fit(img)


assert hasattr(fitted_parcel_masker, "masker_")
assert fitted_parcel_masker.labels is not None
assert isinstance(fitted_parcel_masker.labels, np.ndarray)
assert len(np.unique(fitted_parcel_masker.labels)) == n_pieces
assert len(fitted_parcel_masker.labels) == n_vertices_total


def test_multiple_surface_images():
"""Test that ParcellationMasker can handle multiple surface images"""
imgs = [surf_img(20)] * 3
n_pieces = 2
n_vertices_total = imgs[0].shape[0]
parcel_masker = ParcellationMasker(n_pieces=n_pieces)
fitted_parcel_masker = parcel_masker.fit(imgs)

assert hasattr(fitted_parcel_masker, "masker_")
assert fitted_parcel_masker.labels is not None
assert isinstance(fitted_parcel_masker.labels, np.ndarray)
assert len(np.unique(fitted_parcel_masker.labels)) == n_pieces
assert len(fitted_parcel_masker.labels) == n_vertices_total


def test_one_contrast():
"""Test that ParcellationMasker handles both 3D and\n
4D images in the case of one contrast"""
img1, _ = random_niimg((8, 7, 6))
img2, _ = random_niimg((8, 7, 6, 1))
pmasker = ParcellationMasker()
pmasker.fit([img1, img2])
parcel_masker = ParcellationMasker()
parcel_masker.fit([img1, img2])


def test_get_parcellation_img():
Expand Down
55 changes: 50 additions & 5 deletions fmralign/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import nibabel
import numpy as np
from nilearn.maskers import NiftiMasker
from nilearn.surface import InMemoryMesh, PolyMesh, SurfaceImage
from numpy.random import default_rng
from numpy.testing import assert_array_almost_equal

Expand Down Expand Up @@ -43,9 +44,7 @@ def zero_mean_coefficient_determination(
nonzero_numerator = numerator != 0
valid_score = nonzero_denominator & nonzero_numerator
output_scores = np.ones([y_true.shape[1]])
output_scores[valid_score] = 1 - (
numerator[valid_score] / denominator[valid_score]
)
output_scores[valid_score] = 1 - (numerator[valid_score] / denominator[valid_score])
output_scores[nonzero_numerator & ~nonzero_denominator] = 0

if multioutput == "raw_values":
Expand All @@ -56,8 +55,7 @@ def zero_mean_coefficient_determination(
avg_weights = None
elif multioutput == "variance_weighted":
avg_weights = (
weight
* (y_true - np.average(y_true, axis=0, weights=sample_weight)) ** 2
weight * (y_true - np.average(y_true, axis=0, weights=sample_weight)) ** 2
).sum(axis=0, dtype=np.float64)
# avoid fail on constant y or one-element arrays
if not np.any(nonzero_denominator):
Expand Down Expand Up @@ -134,7 +132,54 @@ def sample_parceled_data(n_pieces=1):
return data, masker, labels


def _make_mesh():
"""Create a sample mesh with two parts: left and right, and total of
9 vertices and 10 faces.

The left part is a tetrahedron with four vertices and four faces.
The right part is a pyramid with five vertices and six faces.
"""
left_coords = np.asarray([[0.0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
left_faces = np.asarray([[1, 0, 2], [0, 1, 3], [0, 3, 2], [1, 2, 3]])
right_coords = (
np.asarray([[0.0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1]]) + 2.0
)
right_faces = np.asarray(
[
[0, 1, 4],
[0, 3, 1],
[1, 3, 2],
[1, 2, 4],
[2, 3, 4],
[0, 4, 3],
]
)
return PolyMesh(
left=InMemoryMesh(left_coords, left_faces),
right=InMemoryMesh(right_coords, right_faces),
)


def surf_img(n_samples=1):
"""Create a sample surface image using the sample mesh. # noqa: D202
This will add some random data to the vertices of the mesh.
The shape of the data will be (n_vertices, n_samples).
n_samples by default is 1.
"""

mesh = _make_mesh()
data = {}
for i, (key, val) in enumerate(mesh.parts.items()):
data_shape = (val.n_vertices, n_samples)
data_part = (
np.arange(np.prod(data_shape)).reshape(data_shape[::-1]) + 1.0
) * 10**i
data[key] = data_part.T
return SurfaceImage(mesh, data)


def sample_subjects_data(n_subjects=3):
"""Sample data in one parcel for n_subjects"""
subjects_data = [np.random.rand(10, 20) for _ in range(n_subjects)]
return subjects_data

Loading