From a0768d8ab640df12c14a0d9658a2b76ea9539f8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Tue, 6 Feb 2024 15:53:27 +0100 Subject: [PATCH] Fixed import error --- turftopic/base.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/turftopic/base.py b/turftopic/base.py index 7d9cdb8..e36daf2 100644 --- a/turftopic/base.py +++ b/turftopic/base.py @@ -6,6 +6,7 @@ from rich.table import Table from sentence_transformers import SentenceTransformer from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.exceptions import NotFittedError from turftopic.data import TopicData from turftopic.encoders import ExternalEncoder @@ -291,17 +292,21 @@ def prepare_topic_data( if embeddings is None: embeddings = self.encode_documents(corpus) try: - document_topic_matrix = self.transform(corpus, embeddings=embeddings) + document_topic_matrix = self.transform( + corpus, embeddings=embeddings + ) except (AttributeError, NotFittedError): - document_topic_matrix = self.fit_transform(corpus, embeddings=embeddings) - dtm = self.vectorizer.transform(corpus) # type: ignore + document_topic_matrix = self.fit_transform( + corpus, embeddings=embeddings + ) + dtm = self.vectorizer.transform(corpus) # type: ignore res: TopicData = { "corpus": corpus, "document_term_matrix": dtm, "vocab": self.get_vocab(), "document_topic_matrix": document_topic_matrix, "document_representation": embeddings, - "topic_term_matrix": self.components_, # type: ignore + "topic_term_matrix": self.components_, # type: ignore "transform": getattr(self, "transform", None), } return res