Skip to content

Commit

Permalink
Sniff if a file is duckdb or sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai committed Jul 4, 2024
1 parent 9273f3e commit 67f0b72
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_long_description():
license="BSD License",
version=VERSION,
packages=["tsellm"],
install_requires=["llm", "setuptools", "pip"],
install_requires=["llm", "setuptools", "pip", "duckdb"],
extras_require={
"test": [
"pytest",
Expand Down
28 changes: 27 additions & 1 deletion tests/test_tsellm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import tempfile
import duckdb
import llm.cli
from sqlite_utils import Database
from tsellm.cli import cli
from tsellm.cli import cli, TsellmConsole
import unittest
from test.support import captured_stdout, captured_stderr, captured_stdin, os_helper
from test.support.os_helper import TESTFN, unlink
from llm import models
import sqlite3
from llm import cli as llm_cli
from tempfile import tempdir
from pathlib import Path
import sqlite3


class CommandLineInterface(unittest.TestCase):
Expand All @@ -16,6 +21,10 @@ def setUp(self):
llm_cli.set_default_model("markov")
llm_cli.set_default_embedding_model("hazo")

@staticmethod
def tempfile():
return Path(tempfile.mkdtemp()) / 'test.db'

def _do_test(self, *args, expect_success=True):
with (
captured_stdout() as out,
Expand Down Expand Up @@ -43,6 +52,23 @@ def expect_failure(self, *args):
self.assertEqual(out, "")
return err

def test_sniff_sqlite(self):
f = self.tempfile()
self.assertTrue(f.__str__().endswith("db"))
with sqlite3.connect(f) as db:
db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)")

self.assertTrue(TsellmConsole.is_sqlite(f))

def test_sniff_duckdb(self):
f = self.tempfile()
print(f)
self.assertTrue(f.__str__().endswith("db"))
con = duckdb.connect(f.__str__())
con.sql("CREATE TABLE test (id INTEGER PRIMARY KEY)")

self.assertTrue(TsellmConsole.is_duckdb(f))

def test_cli_help(self):
out = self.expect_success("-h")
self.assertIn("usage: python -m tsellm", out)
Expand Down
83 changes: 80 additions & 3 deletions tsellm/cli.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,93 @@
import sqlite3
import sys

import duckdb
from argparse import ArgumentParser
from code import InteractiveConsole
from textwrap import dedent
from .core import _tsellm_init
from .core import _tsellm_init, _prompt_model, _prompt_model_default, _embed_model, _embed_model_default
from abc import ABC, abstractmethod, abstractproperty

from enum import Enum, auto


class DatabaseType(Enum):
SQLITE = auto()
DUCKDB = auto()
UNKNOWN = auto()
FILE_NOT_FOUND = auto()
ERROR = auto()


class TsellmConsole(ABC, InteractiveConsole):
_TSELLM_CONFIG_SQL = """
-- tsellm configuration table
-- need to be taken care of accross migrations and versions.
CREATE TABLE IF NOT EXISTS __tsellm (
x text
);
"""

_functions = []

error_class = None

def __init__(self, path):
super().__init__()
self._con = sqlite3.connect(path, isolation_level=None)
self._cur = self._con.cursor()

_tsellm_init(self._con)
self.load()

@staticmethod
def is_sqlite(path):
try:
with open(path, 'rb') as f:
header = f.read(16)
if header.startswith(b'SQLite format 3'):
return DatabaseType.SQLITE
else:
return DatabaseType.UNKNOWN
except FileNotFoundError:
return DatabaseType.FILE_NOT_FOUND
except Exception as e:
return DatabaseType.ERROR

@staticmethod
def is_duckdb(path):
try:
con = duckdb.connect(path.__str__())
con.sql("SELECT 1")
return True
except FileNotFoundError:
return DatabaseType.FILE_NOT_FOUND
except Exception as e:
return DatabaseType.ERROR

@staticmethod
def sniff_db(path):
"""
Sniffs if the path is a SQLite or DuckDB database.
Args:
path (str): The file path to check.
Returns:
DatabaseType: The type of database (DatabaseType.SQLITE, DatabaseType.DUCKDB,
DatabaseType.UNKNOWN, DatabaseType.FILE_NOT_FOUND, DatabaseType.ERROR).
"""

if TsellmConsole.is_sqlite(path):
return DatabaseType.SQLITE
if TsellmConsole.is_duckdb(path):
return DatabaseType.DUCKDB
return DatabaseType.UNKNOWN

def load(self):
self.execute(self._TSELLM_CONFIG_SQL)
for func_name, n_args, py_func, deterministic in self._functions:
self._con.create_function(func_name, n_args, py_func)

@property
def connection(self):
Expand All @@ -33,6 +104,12 @@ def runsource(self, source, filename="<input>", symbol="single"):

class SQLiteConsole(TsellmConsole):
error_class = sqlite3.Error
_functions = [
("prompt", 2, _prompt_model, False),
("prompt", 1, _prompt_model_default, False),
("embed", 2, _embed_model, False),
("embed", 1, _embed_model_default, False)
]

def execute(self, sql, suppress_errors=True):
"""Helper that wraps execution of SQL code.
Expand Down

0 comments on commit 67f0b72

Please sign in to comment.