From 724a3dea988260226277abfcb650a497e160415b Mon Sep 17 00:00:00 2001 From: Florents Tselai Date: Sat, 6 Jul 2024 14:35:08 +0300 Subject: [PATCH] Add ConsoleMixin --- tests/test_tsellm.py | 12 ++--- tsellm/cli.py | 101 ++++++++++++++++++++++++++++++------------- 2 files changed, 78 insertions(+), 35 deletions(-) diff --git a/tests/test_tsellm.py b/tests/test_tsellm.py index 9738533..60ea167 100644 --- a/tests/test_tsellm.py +++ b/tests/test_tsellm.py @@ -1,7 +1,7 @@ import tempfile import duckdb import llm.cli -from tsellm.cli import cli, TsellmConsole, SQLiteConsole, DuckDBConsole +from tsellm.cli import cli, TsellmConsole, SQLiteConsole, DuckDBConsole, TsellmConsoleMixin import unittest from test.support import captured_stdout, captured_stderr from test.support.os_helper import TESTFN, unlink @@ -62,14 +62,14 @@ def expect_failure(self, *args): return err def test_sniff_sqlite(self): - self.assertTrue(TsellmConsole.is_sqlite(new_sqlite_file())) + self.assertTrue(TsellmConsoleMixin().is_sqlite(new_sqlite_file())) def test_sniff_duckdb(self): - self.assertTrue(TsellmConsole.is_duckdb(new_duckdb_file())) + self.assertTrue(TsellmConsoleMixin().is_duckdb(new_duckdb_file())) def test_console_factory_sqlite(self): s = new_sqlite_file() - self.assertTrue(TsellmConsole.is_sqlite(s)) + self.assertTrue(TsellmConsoleMixin().is_sqlite(s)) obj = TsellmConsole.create_console(s) self.assertIsInstance(obj, SQLiteConsole) @@ -93,7 +93,7 @@ def test_choose_db(self): def test_deault_sqlite(self): f = new_tempfile() self.expect_success(str(f), "select 1") - self.assertTrue(TsellmConsole.is_sqlite(f)) + self.assertTrue(TsellmConsoleMixin().is_sqlite(f)) class InMemorySQLiteTest(TsellmConsoleTest): @@ -187,7 +187,7 @@ def setUp(self): def test_embed_default_hazo_leaves_valid_db_behind(self): # This should probably be called for all test cases super().test_embed_default_hazo() - self.assertTrue(TsellmConsole.is_sqlite(self.db_fp)) + self.assertTrue(TsellmConsoleMixin().is_sqlite(self.db_fp)) class InMemoryDuckDBTest(InMemorySQLiteTest): diff --git a/tsellm/cli.py b/tsellm/cli.py index 31c6b16..c3ff0cb 100644 --- a/tsellm/cli.py +++ b/tsellm/cli.py @@ -4,6 +4,8 @@ from argparse import ArgumentParser from code import InteractiveConsole from textwrap import dedent + +from . import __version__ from .core import ( _tsellm_init, _prompt_model, @@ -24,28 +26,8 @@ class DatabaseType(Enum): ERROR = auto() -class TsellmConsole(ABC, InteractiveConsole): - _TSELLM_CONFIG_SQL = """ --- tsellm configuration table --- need to be taken care of accross migrations and versions. - -CREATE TABLE IF NOT EXISTS __tsellm ( -x text -); - -""" - - _functions = [ - ("prompt", 2, _prompt_model, False), - ("prompt", 1, _prompt_model_default, False), - ("embed", 2, _embed_model, False), - ("embed", 1, _embed_model_default, False), - ] - - error_class = None - - @staticmethod - def is_sqlite(path): +class TsellmConsoleMixin(InteractiveConsole): + def is_sqlite(self, path): try: with sqlite3.connect(path) as conn: conn.execute("SELECT 1") @@ -53,8 +35,7 @@ def is_sqlite(path): except: return False - @staticmethod - def is_duckdb(path): + def is_duckdb(self, path): try: con = duckdb.connect(path.__str__()) con.sql("SELECT 1") @@ -62,8 +43,7 @@ def is_duckdb(path): except: return False - @staticmethod - def sniff_db(path): + def sniff_db(self, path): """ Sniffs if the path is a SQLite or DuckDB database. @@ -81,6 +61,49 @@ def sniff_db(path): return DatabaseType.DUCKDB return DatabaseType.UNKNOWN + +class TsellmConsole(ABC, TsellmConsoleMixin): + _TSELLM_CONFIG_SQL = """ +-- tsellm configuration table +-- need to be taken care of accross migrations and versions. + +CREATE TABLE IF NOT EXISTS __tsellm ( +x text +); + +""" + + _functions = [ + ("prompt", 2, _prompt_model, False), + ("prompt", 1, _prompt_model_default, False), + ("embed", 2, _embed_model, False), + ("embed", 1, _embed_model_default, False), + ] + + error_class = None + + @property + def tsellm_version(self) -> str: + return __version__.__version__ + + @property + @abstractmethod + def db_version(self) -> str: + pass + + @property + @abstractmethod + def is_valid_db(self) -> bool: + pass + + @abstractmethod + def complete_statement(self, source) -> str: + pass + + @property + def version(self): + return self.tsellm_version + '\t' + self.db_version + def load(self): self.execute(self._TSELLM_CONFIG_SQL) for func_name, n_args, py_func, deterministic in self._functions: @@ -88,9 +111,9 @@ def load(self): @staticmethod def create_console(path): - if TsellmConsole.is_duckdb(path): + if TsellmConsoleMixin().is_duckdb(path): return DuckDBConsole(path) - if TsellmConsole.is_sqlite(path): + if TsellmConsoleMixin().is_sqlite(path): return SQLiteConsole(path) else: raise ValueError(f"Database type {path} not supported") @@ -109,6 +132,13 @@ def runsource(self, source, filename="", symbol="single"): class SQLiteConsole(TsellmConsole): + def complete_statement(self, source) -> str: + pass + + @property + def is_valid_db(self) -> bool: + pass + error_class = sqlite3.Error def __init__(self, path): @@ -140,6 +170,9 @@ def execute(self, sql, suppress_errors=True): if not suppress_errors: sys.exit(1) + def db_version(self): + return sqlite3.sqlite_version + def runsource(self, source, filename="", symbol="single"): """Override runsource, the core of the InteractiveConsole REPL. @@ -154,13 +187,20 @@ def runsource(self, source, filename="", symbol="single"): case ".quit": sys.exit(0) case _: - if not sqlite3.complete_statement(source): + if not self.complete_statement(source): return True self.execute(source) return False class DuckDBConsole(TsellmConsole): + def complete_statement(self, source) -> str: + pass + + @property + def is_valid_db(self) -> bool: + pass + error_class = sqlite3.Error _functions = [ @@ -180,6 +220,9 @@ def load(self): for func_name, _, py_func, _ in self._functions: self._con.create_function(func_name, py_func) + def db_version(self): + return "DUCKDB VERSION" + def execute(self, sql, suppress_errors=True): """Helper that wraps execution of SQL code.