Skip to content
This repository has been archived by the owner on Apr 2, 2020. It is now read-only.

First level residuals #410

Merged
merged 39 commits into from
Jan 13, 2020
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
2b7df92
Allow for storing residuals in First Level Models
Gilles86 Dec 7, 2018
0663e7a
import missing modules and update calls to functions
jdkent Nov 20, 2019
1b58ebb
fix flake8 errors
jdkent Nov 20, 2019
4909feb
remove @set_attr_on_read
jdkent Nov 20, 2019
acac09b
modify tests to reflect updated code base
jdkent Nov 20, 2019
4e7c897
fix typo and simplifiy loop
jdkent Nov 20, 2019
0b15ec6
respond to review comments:
jdkent Dec 2, 2019
5f260c9
change rsq to r_square
jdkent Dec 2, 2019
adb504c
change rsq to r_square in tests
jdkent Dec 2, 2019
9b434bc
fix function calls
jdkent Dec 2, 2019
d416d0b
Example of how to use for
Gilles86 Dec 18, 2019
65010cf
Made ValueError for storing model attributes more verbose.
Gilles86 Dec 18, 2019
0b5a170
Also include R-squared
Gilles86 Dec 18, 2019
740f3bb
Merge pull request #1 from Gilles86/first_level_residuals
jdkent Dec 18, 2019
3d0ed81
Merge branch 'master' of https://github.com/nistats/nistats into firs…
jdkent Dec 18, 2019
5b61be4
fix heading underlines in example
jdkent Dec 18, 2019
bf319c1
Merge branch 'master' of https://github.com/nistats/nistats into firs…
jdkent Dec 19, 2019
08acd3a
Merge branch 'first_level_residuals' of https://github.com/jdkent/nis…
jdkent Dec 19, 2019
fcf8320
fix grammar
jdkent Dec 19, 2019
a864919
fix code formatting and do not standardize
jdkent Dec 19, 2019
953d15b
Merge branch 'first_level_residuals' of https://github.com/jdkent/nis…
jdkent Dec 19, 2019
806f11d
change parameter timeseries to result_as_time_series
jdkent Dec 19, 2019
51002ea
Merge branch 'master' of https://github.com/nistats/nistats into firs…
kchawla-pi Jan 8, 2020
20e2201
Merge pull request #2 from kchawla-pi/first_level_residuals
jdkent Jan 8, 2020
a47586f
attempt to address @bthirion comments
jdkent Jan 8, 2020
cdf2161
split imports statements
jdkent Jan 8, 2020
b8c483d
always return list get_voxelwise_model_attribute_
jdkent Jan 8, 2020
72aec1d
change docstrings for output to always be a list
jdkent Jan 9, 2020
a245635
modify tests to treat output as list
jdkent Jan 9, 2020
afef0a9
make _get_voxelwise_model_attribute private and improve documentation
jdkent Jan 9, 2020
1cbd072
fix formatting of function call
jdkent Jan 9, 2020
1a8e0c5
add empty line back in
jdkent Jan 9, 2020
e801f8e
revert regression.py to master
jdkent Jan 9, 2020
bf67066
make result_as_time_series mandatory
jdkent Jan 9, 2020
b55c13b
add newlines to docs
jdkent Jan 9, 2020
999caf9
add newline to end of file
jdkent Jan 9, 2020
4b5d485
fix missing newline
jdkent Jan 9, 2020
78995d0
add James Kent to .mailmap
jdkent Jan 11, 2020
44b0e85
add entry for the new attributes to FirstLevelModel
jdkent Jan 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions examples/02_first_level_models/plot_predictions_residuals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Predicted time series and residuals
===================================

Here we fit a First Level GLM with the `minimize_memory`-argument set to `False`.
By doing so, the `FirstLevelModel`-object stores the residuals, which we can then inspect.
Also, the predicted time series can be extracted, which is useful to assess the quality of the model fit.
"""


#########################################################################
# Import modules
# --------------
from nistats.datasets import fetch_spm_auditory
from nilearn import input_data, image
import matplotlib.pyplot as plt
from nilearn import plotting, masking
from nistats.reporting import get_clusters_table
from nistats.first_level_model import FirstLevelModel
import pandas as pd

# load fMRI data
subject_data = fetch_spm_auditory()
fmri_img = image.concat_imgs(subject_data.func)

# Make an average
mean_img = image.mean_img(fmri_img)
mask = masking.compute_epi_mask(mean_img)

# Clean and smooth data
fmri_img = image.clean_img(fmri_img, standardize=False)
fmri_img = image.smooth_img(fmri_img, 5.)

# load events
events = pd.read_table(subject_data['events'])


#########################################################################
# Fit model
# ---------
# Note that `minimize_memory` is set to `False` so that `FirstLevelModel`
# stores the residuals.
# `signal_scaling` is set to False, so we keep the same scaling as the
# original data in `fmri_img`.
kchawla-pi marked this conversation as resolved.
Show resolved Hide resolved

fmri_glm = FirstLevelModel(t_r=7,
drift_model='cosine',
signal_scaling=False,
mask_img=mask,
minimize_memory=False)

fmri_glm = fmri_glm.fit(fmri_img, events)


#########################################################################
# Calculate and plot contrast
# ---------------------------
z_map = fmri_glm.compute_contrast('active - rest')

plotting.plot_stat_map(z_map, bg_img=mean_img, threshold=3.1)

#########################################################################
# Extract the largest clusters
# ----------------------------

table = get_clusters_table(z_map, stat_threshold=3.1,
cluster_threshold=20).set_index('Cluster ID', drop=True)
table.head()

masker = input_data.NiftiSpheresMasker(table.loc[range(1, 7), ['X', 'Y', 'Z']].values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the corresponding locations be plotted ? Otherwise, this looks a bit cryptic
Also, you probably want to explain what you're doing here.


real_timeseries = masker.fit_transform(fmri_img)
predicted_timeseries = masker.fit_transform(fmri_glm.predicted)


#########################################################################
# Plot predicted and actual time series for 6 most significant clusters
# ---------------------------------------------------------------------
for i in range(1, 7):
plt.subplot(2, 3, i)
plt.title('Cluster peak {}\n'.format(table.loc[i, ['X', 'Y', 'Z']].tolist()))
plt.plot(real_timeseries[:, i-1], c='k', lw=2)
plt.plot(predicted_timeseries[:, i-1], c='r', ls='--', lw=2)
plt.xlabel('Time')
plt.ylabel('Signal intensity')

plt.gcf().set_size_inches(12, 7)
plt.tight_layout()

#########################################################################
# Get residuals
# -------------
resid = masker.fit_transform(fmri_glm.residuals)


#########################################################################
# Plot distribution of residuals
# ------------------------------
# Note that residuals are not really distributed normally.


for i in range(1, 7):
plt.subplot(2, 3, i)
plt.title('Cluster peak {}\n'.format(table.loc[i, ['X', 'Y', 'Z']].tolist()))
plt.hist(resid[:, i-1])
print('Mean residuals: {}'.format(resid[:, i-1].mean()))

plt.gcf().set_size_inches(12, 7)
plt.tight_layout()


#########################################################################
# Plot R-squared
# --------------
# Because we stored the residuals, we can plot the R-squared: the proportion
# of explained variance of the GLM as a whole. Note that the R-squared is markedly
# lower deep down the brain, where there is more physiological noise and we
# are further away from the receive coils.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of R squared in this context: as it is calculated in-sample, it has no mechanism to rpevent overfit, hence a bigger model always has a higher Rsquared.
Besides, this one does not make the distinction between effects of interests (conditions) and effects of no interest (drifts, motion), so the interpretation is not straightforward.
This probably deserves a comment, and possibly a plot of the "effects of interest" F test.
Does this make sense ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bthirion has this been resolved to your satisfaction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more text to illustrate the drawbacks of r-squared and added a section for calculating/plotting an f-test.

plotting.plot_stat_map(fmri_glm.r_square,
bg_img=mean_img, threshold=.1, display_mode='z', cut_coords=7)
98 changes: 94 additions & 4 deletions nistats/first_level_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import pandas as pd
from nibabel import Nifti1Image
from nibabel.onetime import setattr_on_read

from sklearn.base import (BaseEstimator,
clone,
Expand Down Expand Up @@ -117,14 +118,14 @@ def run_glm(Y, X, noise_model='ar1', bins=100, n_jobs=1, verbose=0):
acceptable_noise_models = ['ar1', 'ols']
if noise_model not in acceptable_noise_models:
raise ValueError(
"Acceptable noise models are {0}. You provided 'noise_model={1}'".\
format(acceptable_noise_models, noise_model))
"Acceptable noise models are {0}. You provided 'noise_model={1}'".
format(acceptable_noise_models, noise_model))

if Y.shape[0] != X.shape[0]:
raise ValueError(
'The number of rows of Y should match the number of rows of X.'
' You provided X with shape {0} and Y with shape {1}'.\
format(X.shape, Y.shape))
' You provided X with shape {0} and Y with shape {1}'.
format(X.shape, Y.shape))

# Create the model
ols_result = OLSModel(X).fit(Y)
Expand Down Expand Up @@ -309,6 +310,7 @@ def __init__(self, t_r=None, slice_time_ref=0., hrf_model='glover',
else:
raise ValueError('signal_scaling must be "False", "0", "1"'
' or "(0, 1)"')

self.noise_model = noise_model
self.verbose = verbose
self.n_jobs = n_jobs
Expand Down Expand Up @@ -583,6 +585,94 @@ def compute_contrast(self, contrast_def, stat_type=None,

return outputs if output_type == 'all' else output

def get_voxelwise_model_attribute_(self, attribute, result_as_time_series=True):
"""Transform RegressionResults instances within a dictionary
(whose keys represent the autoregressive coefficient under the 'ar1'
noise model or only 0.0 under 'ols' noise_model and values are the
RegressionResults instances) into input nifti space.

jdkent marked this conversation as resolved.
Show resolved Hide resolved
jdkent marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
attribute : str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attribute is defined within a fixed dictionary of values I guess ? This should be provided if the function is meant to be used publicly.

Copy link
Contributor Author

@jdkent jdkent Jan 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should really be a private method (I've changed it to be a private method), but I added additional documentation and an internal check to be explicit anyway.

This comment was marked as outdated.

This comment was marked as resolved.

an attribute of a RegressionResults instance
result_as_time_series : bool, optional
whether the RegressionResult attribute has a value
per timepoint of the input nifti image.

Returns
-------
output : list or Nifti1Image
a list of Nifti1Images if FirstLevelModel is fit with
a list of Nifti1Images, or a single Nifti1Image otherwise.
"""
if self.minimize_memory:
raise ValueError('To access voxelwise attributes like R-squared, residuals, '
'and predictions, the `FirstLevelModel`-object needs to store '
'there attributes. To do so, set `minimize_memory` to `False` '
'when initializing the `FirstLevelModel`-object.')

if self.labels_ is None or self.results_ is None:
raise ValueError('The model has not been fit yet')

output = []

for design_matrix, labels, results in zip(self.design_matrices_, self.labels_, self.results_):

if result_as_time_series:
voxelwise_attribute = np.zeros((design_matrix.shape[0], len(labels)))
else:
voxelwise_attribute = np.zeros((1, len(labels)))

for label_ in results:
label_mask = labels == label_
voxelwise_attribute[:, label_mask] = getattr(results[label_], attribute)

output.append(self.masker_.inverse_transform(voxelwise_attribute))

if len(output) == 1:
return output[0]
else:
return output
kchawla-pi marked this conversation as resolved.
Show resolved Hide resolved

@setattr_on_read
def residuals(self):
jdkent marked this conversation as resolved.
Show resolved Hide resolved
kchawla-pi marked this conversation as resolved.
Show resolved Hide resolved
"""Transform voxelwise residuals to the same shape
as the input Nifti1Image(s)

Returns
-------
output : list or Nifti1Image
a list of Nifti1Images if FirstLevelModel is fit with
a list of Nifti1Images, or a single Nifti1Image otherwise.
"""
return self.get_voxelwise_model_attribute_('resid')

@setattr_on_read
def predicted(self):
jdkent marked this conversation as resolved.
Show resolved Hide resolved
"""Transform voxelwise predicted values to the same shape
as the input Nifti1Image(s)

Returns
-------
output : list or Nifti1Image
a list of Nifti1Images if FirstLevelModel is fit with
a list of Nifti1Images, or a single Nifti1Image otherwise.
"""
return self.get_voxelwise_model_attribute_('predicted')

@setattr_on_read
def r_square(self):
"""Transform voxelwise r-squared values to the same shape
as the input Nifti1Image(s)

Returns
-------
output : list or Nifti1Image
a list of Nifti1Images if FirstLevelModel is fit with
a list of Nifti1Images, or a single Nifti1Image otherwise.
"""
return self.get_voxelwise_model_attribute_('r_square', result_as_time_series=False)


@replace_parameters({'mask': 'mask_img'}, end_version='next')
def first_level_models_from_bids(
Expand Down
49 changes: 17 additions & 32 deletions nistats/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class OLSModel(object):
df_resid : scalar
Degrees of freedom of the residuals. Number of observations less the
rank of the design.

df_model : scalar
Degrees of freedome of the model. The rank of the design.
"""
Expand Down Expand Up @@ -276,6 +275,7 @@ def __init__(self, theta, Y, model, wY, wresid, cov=None, dispersion=1.,
dispersion, nuisance)
self.wY = wY
self.wresid = wresid
self.wdesign = model.wdesign

@setattr_on_read
def resid(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bthirion We should deprecate this and replace it with residuals in a separate PR, and others like it.
Or I should ask JD Kent to revert the changes I asked him to make and rename the resid mehotd in this PR back to resid so it stay consistent.
I like the former idea, renaming all uses of resid to residuals.
Does Nilearn have resid or residuals?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's do it in a sperate PR.
I'm not aware of resid in nilear? We should go for residuals

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think Nilearn has any resid method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we will keep the method as resid for this pull request, I will not change this

Expand Down Expand Up @@ -310,7 +310,7 @@ def predicted(self):
"""
beta = self.theta
# the LikelihoodModelResults has parameters named 'theta'
X = self.model.design
X = self.wdesign
return np.dot(X, beta)

@setattr_on_read
Expand All @@ -319,13 +319,20 @@ def SSE(self):
"""
return (self.wresid ** 2).sum(0)

@setattr_on_read
def r_square(self):
"""Proportion of explained variance.
If not from an OLS model this is "pseudo"-R2.
"""
return np.var(self.predicted, 0) / np.var(self.wY, 0)

@setattr_on_read
def MSE(self):
""" Mean square (error) """
return self.SSE / self.df_resid


class SimpleRegressionResults(LikelihoodModelResults):
class SimpleRegressionResults(RegressionResults):
"""This class contains only information of the model fit necessary
for contast computation.

Expand All @@ -347,41 +354,19 @@ def __init__(self, results):
# put this as a parameter of LikelihoodModel
self.df_resid = self.df_total - self.df_model

self.wdesign = results.model.wdesign

def logL(self, Y):
"""
The maximized log-likelihood
"""
raise ValueError('can not use this method for simple results')
raise ValueError('minimize_memory should be set to False to make residuals or predictions.')
jdkent marked this conversation as resolved.
Show resolved Hide resolved

def resid(self, Y):
def resid(self):
"""
Residuals from the fit.
"""
return Y - self.predicted

def norm_resid(self, Y):
"""
Residuals, normalized to have unit length.

Notes
-----
Is this supposed to return "stanardized residuals,"
residuals standardized
to have mean zero and approximately unit variance?
raise ValueError('minimize_memory should be set to False to make residuals or predictions.')
Copy link
Collaborator

@kchawla-pi kchawla-pi Jan 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bthirion @jdkent @Gilles86
I'm still confused about the suitability of this, and other such methods. It feels like logL(), resid() and norm_resid() are in effect, traps; waiting for unsuspecting users to call them and watch their programs fail.

If these methods do not have anything to say, then we should simply remove all of them, or convert this into a message rather than an error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the motivation for this was to provide a hint to the user on how to get residuals from their model instead of an attribute error, but I see this check is handled here. I'm fine with removing those methods.


d_i = e_i / sqrt(MS_E)

Where MS_E = SSE / (n - k)

See: Montgomery and Peck 3.2.1 p. 68
Davidson and MacKinnon 15.2 p 662
"""
return self.resid(Y) * positive_reciprocal(np.sqrt(self.dispersion))

def predicted(self):
""" Return linear predictor values from a design matrix.
"""
beta = self.theta
# the LikelihoodModelResults has parameters named 'theta'
X = self.model.design
return np.dot(X, beta)
def norm_resid(self):
raise ValueError('minimize_memory should be set to False to make residuals or predictions.')
Loading