Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add fugw interface #83

Merged
merged 22 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
709f59b
feat: Add fugw wrapper to alignment methods
May 6, 2024
7e0a345
feat: Add FugwAlignment to alignment methods tests
May 6, 2024
d32db7d
refactor: Update method parameter to use hyphen instead of underscore
pbarbarant May 6, 2024
b6e0dd1
feat: Add pytest parametrize to fugw_alignment test
pbarbarant May 6, 2024
a834316
chore: Update pyproject.toml with fugw>=0.1.1 dependency
pbarbarant May 6, 2024
dc9dc05
chore: Update pyproject.toml with torch dependency
pbarbarant May 6, 2024
5f65a83
refactor: Remove unused import of nilearn in test_alignment_methods.py
pbarbarant May 6, 2024
9676c1c
Remove unused import of nilearn in alignment_methods.py
pbarbarant May 6, 2024
4fb6903
Fix flake8 linting
pbarbarant May 6, 2024
0c00261
Transpose features for fugw
pbarbarant Jun 4, 2024
915a9e1
refactor: Update FugwAlignment.transform method parameter to default …
pbarbarant Jun 4, 2024
1e57eff
Fix fugw tests
pbarbarant Jun 4, 2024
8283025
refactor: Improve template alignment test robustness
pbarbarant Jun 17, 2024
83ef66e
chore: Update numpy dependency to version <2
pbarbarant Jun 23, 2024
2b0971d
refactor: Change args passing to fugw transform
pbarbarant Jun 23, 2024
572720c
refactor: Update segmentation shape in test_fugw_alignment
pbarbarant Jun 23, 2024
43a3090
FIx trailng whitespace
pbarbarant Jun 24, 2024
a903290
Refactor the fugw api to fit sklearn's standards
pbarbarant Jun 27, 2024
1ce9aaf
Adapt the tests
pbarbarant Jun 27, 2024
9a7e79c
Fix whitespaces
pbarbarant Jun 27, 2024
7f670b5
Fix typo
pbarbarant Jun 27, 2024
83e62da
fix: Remove NaNs when normalizing
pbarbarant Jul 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 271 additions & 2 deletions fmralign/alignment_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

import numpy as np
import torch
import scipy
from joblib import Parallel, delayed
from scipy import linalg
Expand All @@ -18,6 +19,9 @@
from fmralign.hyperalignment.linalg import safe_svd, svd_pca
from fmralign.hyperalignment.piecewise_alignment import PiecewiseAlignment

from fugw.mappings import FUGW, FUGWSparse
from fugw.scripts import coarse_to_fine, lmds


def scaled_procrustes(X, Y, scaling=False, primal=None):
"""
Expand Down Expand Up @@ -92,7 +96,9 @@ def optimal_permutation(X, Y):
dist = pairwise_distances(X.T, Y.T)
u = linear_sum_assignment(dist)
u = np.array(list(zip(*u)))
permutation = scipy.sparse.csr_matrix((np.ones(X.shape[1]), (u[:, 0], u[:, 1]))).T
permutation = scipy.sparse.csr_matrix(
(np.ones(X.shape[1]), (u[:, 0], u[:, 1]))
).T
return permutation


Expand Down Expand Up @@ -583,7 +589,9 @@ def _tuning_estimator(shared_response, target):
return np.linalg.inv(shared_response).dot(target)

@staticmethod
def _stimulus_estimator(full_signal, n_subjects, latent_dim=None, scaling=True):
def _stimulus_estimator(
full_signal, n_subjects, latent_dim=None, scaling=True
):
"""
Estimates the stimulus matrix for the Individualized Neural Tuning model.

Expand Down Expand Up @@ -731,3 +739,264 @@ def transform(self, X, verbose=False):
)

return np.array(reconstructed_signal, dtype=np.float32)


class FugwAlignment:
"""Wrapper for FUGW alignment"""

def __init__(
self,
segmentation,
alpha_coarse=0.5,
alpha_fine=0.5,
rho_coarse=1.0,
rho_fine=1.0,
eps_coarse=1.0,
eps_fine=1.0,
anisotropy=(3, 3, 3),
reg_mode="independent",
divergence="kl",
method="coarse-to-fine",
n_landmarks=1000,
n_samples=100,
radius=5,
id_reg=0.0,
device="auto",
verbose=False,
**kwargs,
) -> None:
"""Initialize FUGW alignment

Parameters
----------
segmentation : ndarray,
Segmentation of the mask
alpha_coarse : float, optional, by default 0.5.
rho_coarse : float, optional, by default 1.
eps_coarse : float, optional, by default 1.
alpha_fine : float, optional, by default 0.5.
rho_fine : float, optional, by default 1.
eps_fine : float, optional, by default 1e-6.
anisotropy : tuple, optional.
Anisotropy of the fmri mask, by default (3, 3, 3)
reg_mode : str, optional
Regularization mode, by default "independent"
divergence : str, optional
Divergence used in the FUGW alignment, by default "kl".
method : str, optional
Method used to compute FUGW alignments, by default "coarse-to-fine".
n_landmarks : int, optional
Number of landmarks used in the embedding, by default 1000.
n_samples : int, optional
Number of samples points passed to
sklearn.cluster.AgglomerativeClustering, by default 100.
radius : int, optional
Radius around the sampled points in mm, by default 5.
id_reg: float, in the [0, 1] interval, defaults to 0
If source/target share the same geometry,
interpolate the transport plan with the identity
using the provided coefficient.
A value of 1 (resp. 0) will rely solely on the identity
(resp. the transport plan).
device : torch.device, optional, by default "auto"
Device on which to perform the computation.
verbose : bool, optional, by default True
**kwargs : dict
Additional parameters passed to the FUGW mapping.fit method.
"""
self.segmentation = segmentation
self.alpha_coarse = alpha_coarse
self.rho_coarse = rho_coarse
self.eps_coarse = eps_coarse
self.alpha_fine = alpha_fine
self.rho_fine = rho_fine
self.eps_fine = eps_fine
self.anisotropy = anisotropy
self.reg_mode = reg_mode
self.divergence = divergence
self.method = method
self.n_landmarks = n_landmarks
self.n_samples = n_samples
self.radius = radius
self.id_reg = id_reg
self.verbose = verbose
self.kwargs = kwargs

self.device = self._get_device(device)
if self.verbose:
print("Computing geometry embedding...")
(
self.geometry_embedding,
self.geometry_embedding_normalized,
self.max_distance,
) = self._prepare_geometry_embedding(
self.segmentation,
self.n_landmarks,
self.anisotropy,
self.verbose,
)
if self.verbose:
print("Geometry embedding computed")

def _get_device(self, device):
"""Set the device on which to perform the computation"""
if device == "auto":
device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
return device

def _normalize(self, X):
"""Normalize the input data"""
return np.nan_to_num((X / np.linalg.norm(X, axis=1).reshape(-1, 1)).T)

def _prepare_geometry_embedding(
self, segmentation, n_landmarks, anisotropy, verbose
):
"""Compute the normalized geometry embedding"""
geometry_embedding = lmds.compute_lmds_volume(
segmentation,
k=12,
n_landmarks=n_landmarks,
anisotropy=anisotropy,
verbose=verbose,
).nan_to_num()

(
geometry_embedding_normalized,
max_distance,
) = coarse_to_fine.random_normalizing(geometry_embedding)

return (
geometry_embedding,
geometry_embedding_normalized,
max_distance,
)

def _sample_geometry(self, segmentation, geometry_embedding, n_samples):
"""Sample the geometry of the mask"""
return coarse_to_fine.sample_volume_uniformly(
segmentation,
embeddings=geometry_embedding,
n_samples=n_samples,
)

def fit(
self,
X,
Y,
):
"""Fit FUGW alignment

Parameters
----------
X : ndarray of shape (n_samples, n_features)
Source features
Y : ndarray of shape (n_samples, n_features)
Target features

Returns
-------
self : FugwAlignment
Fitted FUGW alignment
"""
source_features_normalized = self._normalize(X.T)
target_features_normalized = self._normalize(Y.T)
if self.verbose:
print("Features normalized")

if self.method == "dense":
mapping = FUGW(
alpha=self.alpha_coarse,
rho=self.rho_coarse,
eps=self.eps_coarse,
reg_mode=self.reg_mode,
divergence=self.divergence,
)

mapping.fit(
source_features=source_features_normalized,
target_features=target_features_normalized,
source_geometry=self.geometry_embedding_normalized
@ self.geometry_embedding_normalized.T,
target_geometry=self.geometry_embedding_normalized
@ self.geometry_embedding_normalized.T,
verbose=self.verbose,
**self.kwargs,
)

self.mapping = mapping

elif self.method == "coarse-to-fine":
# Subsample vertices as uniformly as possible on the surface
sampled_geometry = self._sample_geometry(
self.segmentation, self.geometry_embedding, self.n_samples
)

if self.verbose:
print("Samples computed")

coarse_mapping = FUGW(
alpha=self.alpha_coarse,
rho=self.rho_coarse,
eps=self.eps_coarse,
reg_mode=self.reg_mode,
divergence=self.divergence,
)

fine_mapping = FUGWSparse(
alpha=self.alpha_fine,
rho=self.rho_fine,
eps=self.eps_fine,
reg_mode=self.reg_mode,
divergence=self.divergence,
)

coarse_to_fine.fit(
source_features=source_features_normalized,
target_features=target_features_normalized,
source_geometry_embeddings=self.geometry_embedding_normalized,
target_geometry_embeddings=self.geometry_embedding_normalized,
source_sample=sampled_geometry,
target_sample=sampled_geometry,
coarse_mapping=coarse_mapping,
source_selection_radius=(self.radius / self.max_distance),
target_selection_radius=(self.radius / self.max_distance),
fine_mapping=fine_mapping,
device=self.device,
verbose=self.verbose,
**self.kwargs,
)

self.mapping = fine_mapping

return self

def transform(
self,
X,
):
"""Project features using the fitted FUGW alignment

Parameters
----------
X : ndarray of shape (n_samples, n_features)
Source features

Returns
-------
ndarray
Projected features
"""

if self.mapping is None:
raise ValueError(
"FUGW alignment must be fitted before transforming data"
)

# If id_reg is True, interpolate the resulting
# mapping with the identity matrix
transformed_features = self.mapping.transform(
X, id_reg=self.id_reg, device=self.device
)
return transformed_features
33 changes: 28 additions & 5 deletions fmralign/tests/test_alignment_methods.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
from scipy.sparse import csc_matrix
from scipy.linalg import orthogonal_procrustes
Expand All @@ -10,6 +11,7 @@
Hungarian,
Identity,
OptimalTransportAlignment,
FugwAlignment,
POTAlignment,
RidgeAlignment,
ScaledOrthogonalAlignment,
Expand Down Expand Up @@ -79,14 +81,18 @@ def test_scaled_procrustes_on_simple_exact_cases():
assert_array_almost_equal(R_test.T, R)

"""Scaled Matrix"""
X = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 3.0, 4.0, 6.0], [7.0, 8.0, -5.0, -2.0]])
X = np.array(
[[1.0, 2.0, 3.0, 4.0], [5.0, 3.0, 4.0, 6.0], [7.0, 8.0, -5.0, -2.0]]
)

X = X - X.mean(axis=1, keepdims=True)

Y = 2 * X
Y = Y - Y.mean(axis=1, keepdims=True)

assert_array_almost_equal(scaled_procrustes(X.T, Y.T, scaling=True)[0], np.eye(3))
assert_array_almost_equal(
scaled_procrustes(X.T, Y.T, scaling=True)[0], np.eye(3)
)
assert_array_almost_equal(scaled_procrustes(X.T, Y.T, scaling=True)[1], 2)

"""3D Rotation"""
Expand All @@ -106,7 +112,8 @@ def test_scaled_procrustes_on_simple_exact_cases():
R.dot(np.array([0.0, 1.0, 0.0])), np.array([0.0, np.cos(1), np.sin(1)])
)
assert_array_almost_equal(
R.dot(np.array([0.0, 0.0, 1.0])), np.array([0.0, -np.sin(1), np.cos(1)])
R.dot(np.array([0.0, 0.0, 1.0])),
np.array([0.0, -np.sin(1), np.cos(1)]),
)
assert_array_almost_equal(R, R_test.T)

Expand Down Expand Up @@ -189,7 +196,21 @@ def test_all_classes_R_and_pred_shape_and_better_than_identity():
assert algo_score >= identity_baseline_score


# %%
@pytest.mark.parametrize("method", ["dense", "coarse-to-fine"])
def test_fugw_alignment(method):
# Create a fake segmentation
segmentation = np.ones((10, 10, 10))
n_features = 3
n_samples = int(segmentation.sum())
X = np.random.randn(n_samples, n_features).T
Y = np.random.randn(n_samples, n_features).T

fugw_alignment = FugwAlignment(segmentation, method=method)
fugw_alignment.fit(X, Y)
assert fugw_alignment.transform(X).shape == X.shape
assert fugw_alignment.transform(X).shape == Y.shape


def test_ott_backend():
n_samples, n_features = 100, 20
epsilon = 0.1
Expand All @@ -198,7 +219,9 @@ def test_ott_backend():
algo = OptimalTransportAlignment(
reg=epsilon, metric="euclidean", tol=1e-5, max_iter=10000
)
old_implem = POTAlignment(reg=epsilon, metric="euclidean", tol=1e-5, max_iter=10000)
old_implem = POTAlignment(
reg=epsilon, metric="euclidean", tol=1e-5, max_iter=10000
)
algo.fit(X, Y)
old_implem.fit(X, Y)
assert_array_almost_equal(algo.R, old_implem.R, decimal=3)
2 changes: 1 addition & 1 deletion fmralign/tests/test_template_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,5 @@ def test_template_closer_to_target():
)
assert template_mean_distance >= mean_distance_1
assert (
template_mean_distance >= mean_distance_2 - 1.0e-3
template_mean_distance >= mean_distance_2 - 1.0e-2
) # for robustness
Loading
Loading