Skip to content

Commit

Permalink
add given biases
Browse files Browse the repository at this point in the history
CorrNMF and MultimodalCorrNMF now allow to fix the signature or sample biases during model training.
Also fixed the COSMIC catalog format from a previous commit.
  • Loading branch information
BeGeiger committed Oct 31, 2023
1 parent ec6bca1 commit b11550d
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 224 deletions.
97 changes: 97 additions & 0 deletions data/COSMIC_v3.3.1_SBS_GRCh38.csv

Large diffs are not rendered by default.

97 changes: 0 additions & 97 deletions data/COSMIC_v3.3.1_SBS_GRCh38.txt

This file was deleted.

117 changes: 82 additions & 35 deletions src/salamander/nmf_framework/corrnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,27 +338,56 @@ def _update_u(self, index, aux_col, outer_prods_L):
shape: (n_signatures, dim_embeddings, dim_embeddings)
"""

def _check_given_signature_embeddings(self, given_signature_embeddings: np.ndarray):
type_checker("signature embeddings", given_signature_embeddings, np.ndarray)
shape_checker(
"given_signature_embeddings",
given_signature_embeddings,
(self.dim_embeddings, self.n_signatures),
)
def _check_given_biases(self, given_biases, expected_n_biases, name):
type_checker(name, given_biases, np.ndarray)
shape_checker(name, given_biases, (expected_n_biases,))

def _check_given_sample_embeddings(self, given_sample_embeddings: np.ndarray):
type_checker("sample embeddings", given_sample_embeddings, np.ndarray)
def _check_given_embeddings(self, given_embeddings, expected_n_embeddings, name):
type_checker(name, given_embeddings, np.ndarray)
shape_checker(
"given_sample_embeddings",
given_sample_embeddings,
(self.dim_embeddings, self.n_samples),
name, given_embeddings, (self.dim_embeddings, expected_n_embeddings)
)

def _check_given_parameters(
self,
given_signatures,
given_signature_biases,
given_signature_embeddings,
given_sample_biases,
given_sample_embeddings,
):
if given_signatures is not None:
self._check_given_signatures(given_signatures)

if given_signature_biases is not None:
self._check_given_biases(
given_signature_biases, self.n_signatures, "given_signature_biases"
)

if given_signature_embeddings is not None:
self._check_given_embeddings(
given_signature_embeddings,
self.n_signatures,
"given_signature_embeddings",
)

if given_sample_biases is not None:
self._check_given_biases(
given_sample_biases, self.n_samples, "given_sample_biases"
)

if given_sample_embeddings is not None:
self._check_given_embeddings(
given_sample_embeddings, self.n_samples, "given_sample_embeddings"
)

def _initialize(
self,
given_signatures=None,
given_signature_biases=None,
given_signature_embeddings=None,
given_sample_embeddings=True,
given_sample_biases=None,
given_sample_embeddings=None,
init_kwargs=None,
):
"""
Expand All @@ -374,9 +403,17 @@ def _initialize(
algorithm instance, and the mutation type names have to match
the mutation types of the count data.
given_signature_biases : np.ndarray, default=None
Known signature biases of shape (n_signatures,) that will be fixed
during model fitting.
given_signature_embeddings : np.ndarray, default=None
A priori known signature embeddings of shape (dim_embeddings, n_signatures).
given_sample_biases : np.ndarray, default=None
Known sample biases of shape (n_samples,) that will be fixed
during model fitting.
given_sample_embeddings : np.ndarray, default=None
A priori known sample embeddings of shape (dim_embeddings, n_samples).
Expand All @@ -385,18 +422,19 @@ def _initialize(
This includes, for example, a possible 'seed' keyword argument
for all stochastic methods.
"""
self._check_given_parameters(
given_signatures,
given_signature_biases,
given_signature_embeddings,
given_sample_biases,
given_sample_embeddings,
)

if given_signatures is not None:
self._check_given_signatures(given_signatures)
self.n_given_signatures = len(given_signatures.columns)
else:
self.n_given_signatures = 0

if given_signature_embeddings is not None:
self._check_given_signature_embeddings(given_signature_embeddings)

if given_sample_embeddings is not None:
self._check_given_sample_embeddings(given_sample_embeddings)

init_kwargs = {} if init_kwargs is None else init_kwargs.copy()

if self.init_method == "custom":
Expand Down Expand Up @@ -426,32 +464,41 @@ def _initialize(
self.signature_names = np.concatenate(
[given_signatures_names, new_signatures_names]
)

else:
self.signature_names = np.array(
[f"Sig{k+1}" for k in range(self.n_signatures)], dtype="<U20"
)

self.W /= np.sum(self.W, axis=0)
self.W = self.W.clip(EPSILON)
self.alpha = np.zeros(self.n_samples)
self.beta = np.zeros(self.n_signatures)
self.sigma_sq = 1.0
self.L = np.random.multivariate_normal(
np.zeros(self.dim_embeddings),
np.identity(self.dim_embeddings),
size=self.n_signatures,
).T
self.U = np.random.multivariate_normal(
np.zeros(self.dim_embeddings),
np.identity(self.dim_embeddings),
size=self.n_samples,
).T

if given_signature_embeddings is not None:
if given_signature_biases is None:
self.beta = np.zeros(self.n_signatures)
else:
self.beta = given_signature_biases

if given_signature_embeddings is None:
self.L = np.random.multivariate_normal(
np.zeros(self.dim_embeddings),
np.identity(self.dim_embeddings),
size=self.n_signatures,
).T
else:
self.L = given_signature_embeddings

if given_sample_embeddings is not None:
if given_sample_biases is None:
self.alpha = np.zeros(self.n_samples)
else:
self.alpha = given_sample_biases

if given_sample_embeddings is None:
self.U = np.random.multivariate_normal(
np.zeros(self.dim_embeddings),
np.identity(self.dim_embeddings),
size=self.n_samples,
).T
else:
self.U = given_sample_embeddings

@property
Expand Down
40 changes: 29 additions & 11 deletions src/salamander/nmf_framework/corrnmf_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@ class CorrNMFDet(CorrNMF):
for given signatures and mutation count data
"""

def _update_alpha(self):
self.alpha = _utils_corrnmf.update_alpha(self.X, self.beta, self.L, self.U)

def _update_beta(self, p):
self.beta = _utils_corrnmf.update_beta(self.X, p, self.alpha, self.L, self.U)
def _update_alpha(self, given_sample_biases=None):
if given_sample_biases is None:
self.alpha = _utils_corrnmf.update_alpha(self.X, self.beta, self.L, self.U)

def _update_beta(self, p, given_signature_biases=None):
if given_signature_biases is None:
self.beta = _utils_corrnmf.update_beta(
self.X, p, self.alpha, self.L, self.U
)

def _update_sigma_sq(self):
embeddings = np.concatenate([self.L, self.U], axis=1)
Expand Down Expand Up @@ -170,7 +174,9 @@ def _update_U(self, aux):
for d, aux_col in enumerate(aux.T):
self._update_u(d, aux_col, outer_prods_L)

def _update_LU(self, p, given_signature_embeddings, given_sample_embeddings):
def _update_LU(
self, p, given_signature_embeddings=None, given_sample_embeddings=None
):
aux = np.einsum("vd,vkd->kd", self.X, p)

if given_signature_embeddings is None:
Expand All @@ -183,7 +189,9 @@ def fit(
self,
data: pd.DataFrame,
given_signatures=None,
given_signature_biases=None,
given_signature_embeddings=None,
given_sample_biases=None,
given_sample_embeddings=None,
init_kwargs=None,
history=False,
Expand All @@ -198,13 +206,21 @@ def fit(
The mutation count data
given_signatures: pd.DataFrame, default=None
Known signatures which will be fixed during model fitting.
Known signatures that will be fixed during model fitting.
given_signature_biases : np.ndarray, default=None
Known signature biases of shape (n_signatures,) that will be fixed
during model fitting.
given_signature_embeddings: np.ndarray, default=None
Known signature embeddings which will be fixed during model fitting.
Known signature embeddings that will be fixed during model fitting.
given_sample_biases : np.ndarray, default=None
Known sample biases of shape (n_samples,) that will be fixed
during model fitting.
given_sample_embeddings: np.ndarray, default=None
Known sample embeddings which will be fixed during model fitting.
Known sample embeddings that will be fixed during model fitting.
init_kwargs: dict
Any further keywords arguments to be passed to the initialization method.
Expand All @@ -226,7 +242,9 @@ def fit(
self._setup_data_parameters(data)
self._initialize(
given_signatures=given_signatures,
given_signature_biases=given_signature_biases,
given_signature_embeddings=given_signature_embeddings,
given_sample_biases=given_sample_biases,
given_sample_embeddings=given_sample_embeddings,
init_kwargs=init_kwargs,
)
Expand All @@ -240,9 +258,9 @@ def fit(
if verbose and n_iteration % 100 == 0:
print(f"iteration: {n_iteration}; objective: {of_values[-1]:.2f}")

self._update_alpha()
self._update_alpha(given_sample_biases)
p = self._update_p()
self._update_beta(p)
self._update_beta(p, given_signature_biases)
self._update_LU(p, given_signature_embeddings, given_sample_embeddings)
self._update_sigma_sq()

Expand Down
41 changes: 32 additions & 9 deletions src/salamander/nmf_framework/multimodal_corrnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,13 @@ def _n_parameters(self) -> int:
def bic(self) -> float:
return self._n_parameters * np.log(self.n_samples) - 2 * self.loglikelihood()

def _update_alphas(self):
for model in self.models:
model._update_alpha()
def _update_alphas(self, given_sample_biases):
for model, given_sam_biases in zip(self.models, given_sample_biases):
model._update_alpha(given_sam_biases)

def _update_betas(self, ps):
for model, p in zip(self.models, ps):
model._update_beta(p)
def _update_betas(self, ps, given_signature_biases):
for model, p, given_sig_biases in zip(self.models, ps, given_signature_biases):
model._update_beta(p, given_sig_biases)

def _update_sigma_sq(self):
Ls = np.concatenate([model.L for model in self.models], axis=1)
Expand Down Expand Up @@ -345,7 +345,9 @@ def _setup_data_parameters(self, data: list):
def _initialize(
self,
given_signatures=None,
given_signature_biases=None,
given_signature_embeddings=None,
given_sample_biases=None,
given_sample_embeddings=None,
init_kwargs=None,
):
Expand All @@ -358,14 +360,25 @@ def _initialize(
else:
U = given_sample_embeddings

for model, modality_name, given_sigs, given_sig_embs in zip(
for (
model,
modality_name,
given_sigs,
given_sig_biases,
given_sig_embs,
given_sam_biases,
) in zip(
self.models,
self.modality_names,
given_signatures,
given_signature_biases,
given_signature_embeddings,
given_sample_biases,
):
model._initialize(
given_signatures=given_sigs,
given_signature_biases=given_sig_biases,
given_sample_biases=given_sam_biases,
given_signature_embeddings=given_sig_embs,
given_sample_embeddings=U,
init_kwargs=init_kwargs,
Expand All @@ -378,7 +391,9 @@ def fit(
self,
data: list,
given_signatures=None,
given_signature_biases=None,
given_signature_embeddings=None,
given_sample_biases=None,
given_sample_embeddings=None,
init_kwargs=None,
history=False,
Expand All @@ -387,13 +402,21 @@ def fit(
if given_signatures is None:
given_signatures = [None for _ in range(self.n_modalities)]

if given_signature_biases is None:
given_signature_biases = [None for _ in range(self.n_modalities)]

if given_signature_embeddings is None:
given_signature_embeddings = [None for _ in range(self.n_modalities)]

if given_sample_biases is None:
given_sample_biases = [None for _ in range(self.n_modalities)]

self._setup_data_parameters(data)
self._initialize(
given_signatures=given_signatures,
given_signature_biases=given_signature_biases,
given_signature_embeddings=given_signature_embeddings,
given_sample_biases=given_sample_biases,
given_sample_embeddings=given_sample_embeddings,
init_kwargs=init_kwargs,
)
Expand All @@ -407,9 +430,9 @@ def fit(
if verbose and n_iteration % 100 == 0:
print(f"iteration: {n_iteration}; objective: {of_values[-1]:.2f}")

self._update_alphas()
self._update_alphas(given_sample_biases)
ps = self._update_ps()
self._update_betas(ps)
self._update_betas(ps, given_signature_biases)
self._update_LsU(ps, given_signature_embeddings, given_sample_embeddings)
self._update_sigma_sq()
self._update_Ws()
Expand Down
Loading

0 comments on commit b11550d

Please sign in to comment.