From 652a933a1b91f9c0dab6295d2c8a2c0de53fad1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Thu, 1 Feb 2024 12:12:41 +0100 Subject: [PATCH] Fixed issues with default vectorizer --- turftopic/models/decomp.py | 2 +- turftopic/models/gmm.py | 7 +++---- turftopic/models/keynmf.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/turftopic/models/decomp.py b/turftopic/models/decomp.py index e5a3310..ffca417 100644 --- a/turftopic/models/decomp.py +++ b/turftopic/models/decomp.py @@ -46,7 +46,7 @@ def __init__( else: self.encoder_ = encoder if vectorizer is None: - self.vectorizer = default_vectorizer + self.vectorizer = default_vectorizer() else: self.vectorizer = vectorizer self.objective = objective diff --git a/turftopic/models/gmm.py b/turftopic/models/gmm.py index af759df..d777d2a 100644 --- a/turftopic/models/gmm.py +++ b/turftopic/models/gmm.py @@ -47,20 +47,19 @@ def __init__( Encoder, str ] = "sentence-transformers/all-MiniLM-L6-v2", vectorizer: Optional[CountVectorizer] = None, - weight_prior: Literal[ - "dirichlet", "dirichlet_process", None - ] = "dirichlet", + weight_prior: Literal["dirichlet", "dirichlet_process", None] = None, gamma: Optional[float] = None, ): self.n_components = n_components self.encoder = encoder self.weight_prior = weight_prior + self.gamma = gamma if isinstance(encoder, str): self.encoder_ = SentenceTransformer(encoder) else: self.encoder_ = encoder if vectorizer is None: - self.vectorizer = default_vectorizer + self.vectorizer = default_vectorizer() else: self.vectorizer = vectorizer if self.weight_prior is not None: diff --git a/turftopic/models/keynmf.py b/turftopic/models/keynmf.py index 9595875..9c906f2 100644 --- a/turftopic/models/keynmf.py +++ b/turftopic/models/keynmf.py @@ -82,7 +82,7 @@ def __init__( else: self.encoder_ = encoder if vectorizer is None: - self.vectorizer = default_vectorizer + self.vectorizer = default_vectorizer() else: self.vectorizer = vectorizer self.dict_vectorizer_ = DictVectorizer()