diff --git a/turftopic/base.py b/turftopic/base.py index 7d9cdb8..e36daf2 100644 --- a/turftopic/base.py +++ b/turftopic/base.py @@ -6,6 +6,7 @@ from rich.table import Table from sentence_transformers import SentenceTransformer from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.exceptions import NotFittedError from turftopic.data import TopicData from turftopic.encoders import ExternalEncoder @@ -291,17 +292,21 @@ def prepare_topic_data( if embeddings is None: embeddings = self.encode_documents(corpus) try: - document_topic_matrix = self.transform(corpus, embeddings=embeddings) + document_topic_matrix = self.transform( + corpus, embeddings=embeddings + ) except (AttributeError, NotFittedError): - document_topic_matrix = self.fit_transform(corpus, embeddings=embeddings) - dtm = self.vectorizer.transform(corpus) # type: ignore + document_topic_matrix = self.fit_transform( + corpus, embeddings=embeddings + ) + dtm = self.vectorizer.transform(corpus) # type: ignore res: TopicData = { "corpus": corpus, "document_term_matrix": dtm, "vocab": self.get_vocab(), "document_topic_matrix": document_topic_matrix, "document_representation": embeddings, - "topic_term_matrix": self.components_, # type: ignore + "topic_term_matrix": self.components_, # type: ignore "transform": getattr(self, "transform", None), } return res