Skip to content

Commit

Permalink
Refactored naming code and switched to SmolLM-1.7b
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Nov 4, 2024
1 parent 72511f2 commit 75e5302
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 111 deletions.
51 changes: 14 additions & 37 deletions turftopic/namers/base.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
from abc import ABC, abstractmethod
from typing import Optional

DEFAULT_POSITIVE_PROMPT = """
You will be tasked with naming a topic.
The topic is described by the following set of keywords: {positive}.
Based on the keywords, create a short label that best summarizes the topics.
The topic name should be at maximum three words long.
Only respond with a short topic name and nothing else.
"""
from rich.progress import track

DEFAULT_NEGATIVE_PROMPT = """
DEFAULT_PROMPT = """
You will be tasked with naming a topic.
The topic is described with most relevant positive and negative terms.
Make sure to consider the negative terms as well when naming the topic.
An example of a topic name like this would be: Oriental vs. European Cuisine
Positive terms: {positive}
Negative terms: {negative}
Based on the keywords, create a short label that best summarizes the topics.
Only respond with a short, human readable topic name and nothing else.
Based on the keywords, create a short label (5 words maximum) that best summarizes the topics.
Only respond with the topic name and nothing else.
The topic is described by the following set of keywords: {keywords}.
"""


DEFAULT_SYSTEM_PROMPT = """
You are a topic namer. When the user gives you a set of keywords, you respond with a name for the topic they describe.
You only repond briefly with the name of the topic, and nothing else.
Expand All @@ -34,18 +21,14 @@ class TopicNamer(ABC):
@abstractmethod
def name_topic(
self,
positive: list[str],
negative: Optional[list[str]] = None,
keywords: list[str],
) -> str:
"""Names one topics based on top descriptive terms.
Parameters
----------
positive: list[str]
keywords: list[str]
Top K highest ranking terms on the topic.
negative: list[str], default None
Top K lowest ranking terms on the topic.
(this is only relevant in the context of $S^3$)
Returns
-------
Expand All @@ -56,27 +39,21 @@ def name_topic(

def name_topics(
self,
positive: list[list[str]],
negative: Optional[list[list[str]]] = None,
keywords: list[list[str]],
) -> list[str]:
"""Names all topics based on top descriptive terms.
Parameters
----------
positive: list[list[str]]
keywords: list[list[str]]
Top K highest ranking terms on the topics.
negative: list[list[str]], default None
Top K lowest ranking terms on the topics
(this is only relevant in the context of $S^3$)
Returns
-------
list[str]
Topic names returned by the namer.
"""
if negative is not None:
return [
self.name_topic(pos, neg)
for pos, neg in zip(positive, negative)
]
return [self.name_topic(pos) for pos in positive]
names = []
for keys in track(keywords, description="Naming topics..."):
names.append(self.name_topic(keys))
return names
84 changes: 10 additions & 74 deletions turftopic/namers/hf_transformers.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,18 @@
from typing import Optional
from transformers import pipeline

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from turftopic.namers.base import (DEFAULT_PROMPT, DEFAULT_SYSTEM_PROMPT,
TopicNamer)

from turftopic.namers.base import (
DEFAULT_NEGATIVE_PROMPT,
DEFAULT_POSITIVE_PROMPT,
DEFAULT_SYSTEM_PROMPT,
TopicNamer,
)


class Text2TextTopicNamer(TopicNamer):
"""Name topics with a Text2Text model (e.g. Google's T5).
Parameters
----------
model_name: str, default 'google/flan-t5-large'
Model to load from :hugs: Hub.
prompt_template: str
Prompt template to use when no negative terms are specified.
axis_prompt_template: str
Prompt template to use when negative terms are also specified.
device: str, default 'cpu'
Device to run the model on.
"""

def __init__(
self,
model_name: str = "google/flan-t5-large",
prompt_template: str = DEFAULT_POSITIVE_PROMPT,
axis_prompt_template: str = DEFAULT_NEGATIVE_PROMPT,
device: str = "cpu",
):
self.model_name = model_name
self.prompt_template = prompt_template
self.axis_prompt_template = axis_prompt_template
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(
self.device
)

def name_topic(
self,
positive: list[list[str]],
negative: Optional[list[list[str]]] = None,
) -> str:
if negative is not None:
prompt = self.axis_prompt_template.format(
positive=", ".join(positive), negative=", ".join(negative)
)
else:
prompt = self.prompt_template.format(positive=", ".join(positive))
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
output = self.model.generate(**inputs, max_new_tokens=24)
label = self.tokenizer.decode(output[0], skip_special_tokens=True)
return label


class ChatTopicNamer(TopicNamer):
"""Name topics with a Chat model, e.g. Zephyr-7b-beta
class LLMTopicNamer(TopicNamer):
"""Name topics with an instruction-finetuned LLM, e.g. Zephyr-7b-beta
Parameters
----------
model_name: str, default 'HuggingFaceH4/zephyr-7b-beta'
model_name: str, default 'HuggingFaceTB/SmolLM2-1.7B-Instruct'
Model to load from :hugs: Hub.
prompt_template: str
Prompt template to use when no negative terms are specified.
axis_prompt_template: str
Prompt template to use when negative terms are also specified.
system_prompt: str
System prompt to use for the language model.
device: str, default 'cpu'
Expand All @@ -77,15 +21,13 @@ class ChatTopicNamer(TopicNamer):

def __init__(
self,
model_name: str = "HuggingFaceH4/zephyr-7b-beta",
prompt_template: str = DEFAULT_POSITIVE_PROMPT,
axis_prompt_template: str = DEFAULT_NEGATIVE_PROMPT,
model_name: str = "HuggingFaceTB/SmolLM2-1.7B-Instruct",
prompt_template: str = DEFAULT_PROMPT,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
device: str = "cpu",
):
self.model_name = model_name
self.prompt_template = prompt_template
self.axis_prompt_template = axis_prompt_template
self.system_prompt = system_prompt
self.device = device
self.pipe = pipeline(
Expand All @@ -94,15 +36,9 @@ def __init__(

def name_topic(
self,
positive: list[list[str]],
negative: Optional[list[list[str]]] = None,
keywords: list[list[str]],
) -> str:
if negative is not None:
prompt = self.axis_prompt_template.format(
positive=", ".join(positive), negative=", ".join(negative)
)
else:
prompt = self.prompt_template.format(positive=", ".join(positive))
prompt = self.prompt_template.format(keywords=", ".join(keywords))
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
Expand Down

0 comments on commit 75e5302

Please sign in to comment.