Skip to content

Commit

Permalink
Fixed import error
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Feb 6, 2024
1 parent 8c1d4d4 commit a0768d8
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions turftopic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rich.table import Table
from sentence_transformers import SentenceTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.exceptions import NotFittedError

from turftopic.data import TopicData
from turftopic.encoders import ExternalEncoder
Expand Down Expand Up @@ -291,17 +292,21 @@ def prepare_topic_data(
if embeddings is None:
embeddings = self.encode_documents(corpus)
try:
document_topic_matrix = self.transform(corpus, embeddings=embeddings)
document_topic_matrix = self.transform(
corpus, embeddings=embeddings
)
except (AttributeError, NotFittedError):
document_topic_matrix = self.fit_transform(corpus, embeddings=embeddings)
dtm = self.vectorizer.transform(corpus) # type: ignore
document_topic_matrix = self.fit_transform(
corpus, embeddings=embeddings
)
dtm = self.vectorizer.transform(corpus) # type: ignore
res: TopicData = {
"corpus": corpus,
"document_term_matrix": dtm,
"vocab": self.get_vocab(),
"document_topic_matrix": document_topic_matrix,
"document_representation": embeddings,
"topic_term_matrix": self.components_, # type: ignore
"topic_term_matrix": self.components_, # type: ignore
"transform": getattr(self, "transform", None),
}
return res

0 comments on commit a0768d8

Please sign in to comment.