Skip to content

Commit

Permalink
Add ConsoleMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai committed Jul 6, 2024
1 parent 0742558 commit 724a3de
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 35 deletions.
12 changes: 6 additions & 6 deletions tests/test_tsellm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tempfile
import duckdb
import llm.cli
from tsellm.cli import cli, TsellmConsole, SQLiteConsole, DuckDBConsole
from tsellm.cli import cli, TsellmConsole, SQLiteConsole, DuckDBConsole, TsellmConsoleMixin
import unittest
from test.support import captured_stdout, captured_stderr
from test.support.os_helper import TESTFN, unlink
Expand Down Expand Up @@ -62,14 +62,14 @@ def expect_failure(self, *args):
return err

def test_sniff_sqlite(self):
self.assertTrue(TsellmConsole.is_sqlite(new_sqlite_file()))
self.assertTrue(TsellmConsoleMixin().is_sqlite(new_sqlite_file()))

def test_sniff_duckdb(self):
self.assertTrue(TsellmConsole.is_duckdb(new_duckdb_file()))
self.assertTrue(TsellmConsoleMixin().is_duckdb(new_duckdb_file()))

def test_console_factory_sqlite(self):
s = new_sqlite_file()
self.assertTrue(TsellmConsole.is_sqlite(s))
self.assertTrue(TsellmConsoleMixin().is_sqlite(s))
obj = TsellmConsole.create_console(s)
self.assertIsInstance(obj, SQLiteConsole)

Expand All @@ -93,7 +93,7 @@ def test_choose_db(self):
def test_deault_sqlite(self):
f = new_tempfile()
self.expect_success(str(f), "select 1")
self.assertTrue(TsellmConsole.is_sqlite(f))
self.assertTrue(TsellmConsoleMixin().is_sqlite(f))


class InMemorySQLiteTest(TsellmConsoleTest):
Expand Down Expand Up @@ -187,7 +187,7 @@ def setUp(self):
def test_embed_default_hazo_leaves_valid_db_behind(self):
# This should probably be called for all test cases
super().test_embed_default_hazo()
self.assertTrue(TsellmConsole.is_sqlite(self.db_fp))
self.assertTrue(TsellmConsoleMixin().is_sqlite(self.db_fp))


class InMemoryDuckDBTest(InMemorySQLiteTest):
Expand Down
101 changes: 72 additions & 29 deletions tsellm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from argparse import ArgumentParser
from code import InteractiveConsole
from textwrap import dedent

from . import __version__
from .core import (
_tsellm_init,
_prompt_model,
Expand All @@ -24,46 +26,24 @@ class DatabaseType(Enum):
ERROR = auto()


class TsellmConsole(ABC, InteractiveConsole):
_TSELLM_CONFIG_SQL = """
-- tsellm configuration table
-- need to be taken care of accross migrations and versions.
CREATE TABLE IF NOT EXISTS __tsellm (
x text
);
"""

_functions = [
("prompt", 2, _prompt_model, False),
("prompt", 1, _prompt_model_default, False),
("embed", 2, _embed_model, False),
("embed", 1, _embed_model_default, False),
]

error_class = None

@staticmethod
def is_sqlite(path):
class TsellmConsoleMixin(InteractiveConsole):
def is_sqlite(self, path):
try:
with sqlite3.connect(path) as conn:
conn.execute("SELECT 1")
return True
except:
return False

@staticmethod
def is_duckdb(path):
def is_duckdb(self, path):
try:
con = duckdb.connect(path.__str__())
con.sql("SELECT 1")
return True
except:
return False

@staticmethod
def sniff_db(path):
def sniff_db(self, path):
"""
Sniffs if the path is a SQLite or DuckDB database.
Expand All @@ -81,16 +61,59 @@ def sniff_db(path):
return DatabaseType.DUCKDB
return DatabaseType.UNKNOWN


class TsellmConsole(ABC, TsellmConsoleMixin):
_TSELLM_CONFIG_SQL = """
-- tsellm configuration table
-- need to be taken care of accross migrations and versions.
CREATE TABLE IF NOT EXISTS __tsellm (
x text
);
"""

_functions = [
("prompt", 2, _prompt_model, False),
("prompt", 1, _prompt_model_default, False),
("embed", 2, _embed_model, False),
("embed", 1, _embed_model_default, False),
]

error_class = None

@property
def tsellm_version(self) -> str:
return __version__.__version__

@property
@abstractmethod
def db_version(self) -> str:
pass

@property
@abstractmethod
def is_valid_db(self) -> bool:
pass

@abstractmethod
def complete_statement(self, source) -> str:
pass

@property
def version(self):
return self.tsellm_version + '\t' + self.db_version

def load(self):
self.execute(self._TSELLM_CONFIG_SQL)
for func_name, n_args, py_func, deterministic in self._functions:
self._con.create_function(func_name, n_args, py_func)

@staticmethod
def create_console(path):
if TsellmConsole.is_duckdb(path):
if TsellmConsoleMixin().is_duckdb(path):
return DuckDBConsole(path)
if TsellmConsole.is_sqlite(path):
if TsellmConsoleMixin().is_sqlite(path):
return SQLiteConsole(path)
else:
raise ValueError(f"Database type {path} not supported")
Expand All @@ -109,6 +132,13 @@ def runsource(self, source, filename="<input>", symbol="single"):


class SQLiteConsole(TsellmConsole):
def complete_statement(self, source) -> str:
pass

@property
def is_valid_db(self) -> bool:
pass

error_class = sqlite3.Error

def __init__(self, path):
Expand Down Expand Up @@ -140,6 +170,9 @@ def execute(self, sql, suppress_errors=True):
if not suppress_errors:
sys.exit(1)

def db_version(self):
return sqlite3.sqlite_version

def runsource(self, source, filename="<input>", symbol="single"):
"""Override runsource, the core of the InteractiveConsole REPL.
Expand All @@ -154,13 +187,20 @@ def runsource(self, source, filename="<input>", symbol="single"):
case ".quit":
sys.exit(0)
case _:
if not sqlite3.complete_statement(source):
if not self.complete_statement(source):
return True
self.execute(source)
return False


class DuckDBConsole(TsellmConsole):
def complete_statement(self, source) -> str:
pass

@property
def is_valid_db(self) -> bool:
pass

error_class = sqlite3.Error

_functions = [
Expand All @@ -180,6 +220,9 @@ def load(self):
for func_name, _, py_func, _ in self._functions:
self._con.create_function(func_name, py_func)

def db_version(self):
return "DUCKDB VERSION"

def execute(self, sql, suppress_errors=True):
"""Helper that wraps execution of SQL code.
Expand Down

0 comments on commit 724a3de

Please sign in to comment.