forked from priyanshugandhi/ChatBot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
chatbot.py
83 lines (77 loc) · 3.83 KB
/
chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
Python 2.7.13 (v2.7.13:a06454b1afa1, Dec 17 2016, 20:42:59) [MSC v.1500 32 bit (Intel)] on win32
Type "copyright", "credits" or "license()" for more information.
>>> import re
import sqlite3
from collections import Counter
from string import punctuation
from math import sqrt
# initialize the connection to the database
connection = sqlite3.connect('chatbot.sqlite')
cursor = connection.cursor()
# create the tables needed by the program
create_table_request_list = [
'CREATE TABLE words(word TEXT UNIQUE)',
'CREATE TABLE sentences(sentence TEXT UNIQUE, used INT NOT NULL DEFAULT 0)',
'CREATE TABLE associations (word_id INT NOT NULL, sentence_id INT NOT NULL, weight REAL NOT NULL)',
]
for create_table_request in create_table_request_list:
try:
cursor.execute(create_table_request)
except:
pass
def get_id(entityName, text):
"""Retrieve an entity's unique ID from the database, given its associated text.
If the row is not already present, it is inserted.
The entity can either be a sentence or a word."""
tableName = entityName + 's'
columnName = entityName
cursor.execute('SELECT rowid FROM ' + tableName + ' WHERE ' + columnName + ' = ?', (text,))
row = cursor.fetchone()
if row:
return row[0]
else:
cursor.execute('INSERT INTO ' + tableName + ' (' + columnName + ') VALUES (?)', (text,))
return cursor.lastrowid
def get_words(text):
"""Retrieve the words present in a given string of text.
The return value is a list of tuples where the first member is a lowercase word,
and the second member the number of time it is present in the text."""
wordsRegexpString = '(?:\w+|[' + re.escape(punctuation) + ']+)'
wordsRegexp = re.compile(wordsRegexpString)
wordsList = wordsRegexp.findall(text.lower())
return Counter(wordsList).items()
B = 'Hello!'
while True:
# output bot's message
print('B: ' + B)
# ask for user input; if blank line, exit the loop
H = raw_input('H: ').strip()
if H == '':
break
# store the association between the bot's message words and the user's response
words = get_words(B)
words_length = sum([n * len(word) for word, n in words])
sentence_id = get_id('sentence', H)
for word, n in words:
word_id = get_id('word', word)
weight = sqrt(n / float(words_length))
cursor.execute('INSERT INTO associations VALUES (?, ?, ?)', (word_id, sentence_id, weight))
connection.commit()
# retrieve the most likely answer from the database
cursor.execute('CREATE TEMPORARY TABLE results(sentence_id INT, sentence TEXT, weight REAL)')
words = get_words(H)
words_length = sum([n * len(word) for word, n in words])
for word, n in words:
weight = sqrt(n / float(words_length))
cursor.execute('INSERT INTO results SELECT associations.sentence_id, sentences.sentence, ?*associations.weight/(4+sentences.used) FROM words INNER JOIN associations ON associations.word_id=words.rowid INNER JOIN sentences ON sentences.rowid=associations.sentence_id WHERE words.word=?', (weight, word,))
# if matches were found, give the best one
cursor.execute('SELECT sentence_id, sentence, SUM(weight) AS sum_weight FROM results GROUP BY sentence_id ORDER BY sum_weight DESC LIMIT 1')
row = cursor.fetchone()
cursor.execute('DROP TABLE results')
# otherwise, just randomly pick one of the least used sentences
if row is None:
cursor.execute('SELECT rowid, sentence FROM sentences WHERE used = (SELECT MIN(used) FROM sentences) ORDER BY RANDOM() LIMIT 1')
row = cursor.fetchone()
# tell the database the sentence has been used once more, and prepare the sentence
B = row[1]
cursor.execute('UPDATE sentences SET used=used+1 WHERE rowid=?', (row[0],))