Skip to content

Commit

Permalink
Add prompts.sql
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai committed Jun 23, 2024
1 parent acc8f1f commit 1297231
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 122 deletions.
10 changes: 10 additions & 0 deletions prompts.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
PRAGMA foreign_keys=OFF;
BEGIN TRANSACTION;
CREATE TABLE [prompts] (
[prompt] TEXT
);
INSERT INTO prompts VALUES('hello world!');
INSERT INTO prompts VALUES('how are you?');
INSERT INTO prompts VALUES('is this real life?');
INSERT INTO prompts VALUES('1+1=?');
COMMIT;
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ def get_long_description():
"CI": "https://github.com/Florents-Tselai/tsellm/actions",
"Changelog": "https://github.com/Florents-Tselai/tsellm/releases",
},
license="MIT License",
license="BSD License",
version=VERSION,
packages=["tsellm"],
install_requires=["llm", "setuptools", "pip"],
extras_require={"test": ["pytest", "pytest-cov", "black", "ruff", "sqlite_utils", "llm-markov"]},
extras_require={
"test": ["pytest", "pytest-cov", "black", "ruff", "sqlite_utils", "llm-markov"]
},
python_requires=">=3.7",
)
123 changes: 4 additions & 119 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,119 +16,6 @@ def pytest_configure(config):
sys._called_from_test = True


@pytest.fixture
def user_path(tmpdir):
dir = tmpdir / "tsellm"
dir.mkdir()
return dir


@pytest.fixture
def logs_db(user_path):
return sqlite_utils.Database(str(user_path / "logs.db"))


@pytest.fixture
def user_path_with_embeddings(user_path):
path = str(user_path / "embeddings.db")
db = sqlite_utils.Database(path)
collection = llm.Collection("demo", db, model_id="embed-demo")
collection.embed("1", "hello world")
collection.embed("2", "goodbye world")


class MockModel(llm.Model):
model_id = "mock-echo"

class Options(llm.Options):
max_tokens: Optional[int] = Field(
description="Maximum number of tokens to generate.", default=None
)

def __init__(self):
self.history = []
self._queue = []

def enqueue(self, messages):
assert isinstance(messages, list)
self._queue.append(messages)

def execute(self, prompt, stream, response, conversation):
self.history.append((prompt, stream, response, conversation))
while True:
try:
messages = self._queue.pop(0)
yield from messages
break
except IndexError:
break


class EmbedDemo(llm.EmbeddingModel):
model_id = "embed-demo"
batch_size = 10
supports_binary = True

def __init__(self):
self.embedded_content = []

def embed_batch(self, texts):
if not hasattr(self, "batch_count"):
self.batch_count = 0
self.batch_count += 1
for text in texts:
self.embedded_content.append(text)
words = text.split()[:16]
embedding = [len(word) for word in words]
# Pad with 0 up to 16 words
embedding += [0] * (16 - len(embedding))
yield embedding


class EmbedBinaryOnly(EmbedDemo):
model_id = "embed-binary-only"
supports_text = False
supports_binary = True


class EmbedTextOnly(EmbedDemo):
model_id = "embed-text-only"
supports_text = True
supports_binary = False


@pytest.fixture
def embed_demo():
return EmbedDemo()


@pytest.fixture
def mock_model():
return MockModel()


@pytest.fixture(autouse=True)
def register_embed_demo_model(embed_demo, mock_model):
class MockModelsPlugin:
__name__ = "MockModelsPlugin"

@llm.hookimpl
def register_embedding_models(self, register):
register(embed_demo)
register(EmbedBinaryOnly())
register(EmbedTextOnly())

@llm.hookimpl
def register_models(self, register):
register(mock_model)

pm.register(MockModelsPlugin(), name="undo-mock-models-plugin")
try:
yield
finally:
pm.unregister(name="undo-mock-models-plugin")


@pytest.fixture
def db_path(tmpdir):
path = str(tmpdir / "test.db")
Expand All @@ -139,20 +26,18 @@ def db_path(tmpdir):
def fresh_db_path(db_path):
return db_path


@pytest.fixture
def existing_db_path(fresh_db_path):
db = Database(fresh_db_path)
table = db.create_table(
"prompts",
{
"prompt": str,
"generated": str,
"model": str,
"embedding": dict,
},
{"prompt": str},
)

table.insert({"prompt": "hello world!"})
table.insert({"prompt": "how are you?"})
table.insert({"prompt": "is this real life?"})
table.insert({"prompt": "1+1=?"})

return fresh_db_path
6 changes: 5 additions & 1 deletion tests/test_tsellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ def test_cli_prompt_mock(existing_db_path):
("hello world!",),
("how are you?",),
("is this real life?",),
("1+1=?",),
]

cli([existing_db_path, "UPDATE prompts SET generated=prompt(prompt, 'markov')"])

for prompt, generated in db.execute("select prompt, generated from prompts").fetchall():
for prompt, generated in db.execute(
"select prompt, generated from prompts"
).fetchall():
words = generated.strip().split()
# Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20)
prompt_words = prompt.split()
for word in words:
assert word in prompt_words
assert existing_db_path == 'gdfg'
2 changes: 2 additions & 0 deletions tsellm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
"""


def _prompt_model(prompt, model):
return llm.get_model(model).prompt(prompt).text()


def _tsellm_init(con):
"""Entry-point for tsellm initialization."""
con.execute(TSELLM_CONFIG_SQL)
Expand Down

0 comments on commit 1297231

Please sign in to comment.