Skip to content

Commit 2dad12b

Browse files
authored
Add files via upload
1 parent f3d43e8 commit 2dad12b

File tree

5 files changed

+512
-0
lines changed

5 files changed

+512
-0
lines changed

neural_network/chatbot/chatbot.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import datetime
2+
from typing import List, Dict, Any
3+
4+
5+
class Chatbot:
6+
"""
7+
A Chatbot class to manage chat conversations using an LLM service and a database to store chat data.
8+
9+
Methods:
10+
- start_chat: Starts a new conversation, logs the start time.
11+
- handle_user_message: Processes user input and stores user message & bot response in DB.
12+
- end_chat: Ends the conversation and logs the end time.
13+
- continue_chat: Retains only the last few messages if the conversation exceeds 1000 messages.
14+
"""
15+
16+
def __init__(self, db: Any, llm_service: Any) -> None:
17+
"""
18+
Initialize the Chatbot with a database and an LLM service.
19+
20+
Parameters:
21+
- db: The database instance used for storing chat data.
22+
- llm_service: The language model service for generating responses.
23+
"""
24+
self.db = db
25+
self.llm_service = llm_service
26+
self.conversation_history: List[Dict[str, str]] = []
27+
self.chat_id_pk: int = None
28+
29+
def start_chat(self) -> None:
30+
"""
31+
Start a new chat session and insert chat history to the database.
32+
"""
33+
start_time = datetime.datetime.now()
34+
is_stream = 1 # Start new conversation
35+
self.db.insert_chat_history(start_time, is_stream)
36+
self.chat_id_pk = self.db.get_latest_chat_id()
37+
38+
def handle_user_message(self, user_input: str) -> str:
39+
"""
40+
Handle user input and generate a bot response.
41+
If the user sends '/stop', the conversation is terminated.
42+
43+
Parameters:
44+
- user_input: The input provided by the user.
45+
46+
Returns:
47+
- bot_response: The response generated by the bot.
48+
49+
Raises:
50+
- ValueError: If user input is not a string or if no chat_id is available.
51+
52+
Doctest:
53+
>>> class MockDatabase:
54+
... def __init__(self):
55+
... self.data = []
56+
... def insert_chat_data(self, *args, **kwargs):
57+
... pass
58+
... def insert_chat_history(self, *args, **kwargs):
59+
... pass
60+
... def get_latest_chat_id(self):
61+
... return 1
62+
...
63+
>>> class MockLLM:
64+
... def generate_response(self, conversation_history):
65+
... if conversation_history[-1]["content"] == "/stop":
66+
... return "conversation-terminated"
67+
... return "Mock response"
68+
>>> db_mock = MockDatabase()
69+
>>> llm_mock = MockLLM()
70+
>>> bot = Chatbot(db_mock, llm_mock)
71+
>>> bot.start_chat()
72+
>>> bot.handle_user_message("/stop")
73+
'conversation-terminated'
74+
>>> bot.handle_user_message("Hello!")
75+
'Mock response'
76+
"""
77+
if not isinstance(user_input, str):
78+
raise ValueError("User input must be a string.")
79+
80+
if self.chat_id_pk is None:
81+
raise ValueError("Chat has not been started. Call start_chat() first.")
82+
83+
self.conversation_history.append({"role": "user", "content": user_input})
84+
85+
if user_input == "/stop":
86+
self.end_chat()
87+
return "conversation-terminated"
88+
else:
89+
bot_response = self.llm_service.generate_response(self.conversation_history)
90+
print(f"Bot : ",bot_response)
91+
self.conversation_history.append(
92+
{"role": "assistant", "content": bot_response}
93+
)
94+
self._store_message_in_db(user_input, bot_response)
95+
96+
return bot_response
97+
98+
def _store_message_in_db(self, user_input: str, bot_response: str) -> None:
99+
"""
100+
Store user input and bot response in the database.
101+
102+
Parameters:
103+
- user_input: The message from the user.
104+
- bot_response: The response generated by the bot.
105+
106+
Raises:
107+
- ValueError: If insertion into the database fails.
108+
"""
109+
try:
110+
self.db.insert_chat_data(self.chat_id_pk, user_input, bot_response)
111+
except Exception as e:
112+
raise ValueError(f"Failed to insert chat data: {e}")
113+
114+
def end_chat(self) -> None:
115+
"""
116+
End the chat session and update the chat history in the database.
117+
"""
118+
current_time = datetime.datetime.now()
119+
is_stream = 2 # End of conversation
120+
try:
121+
user_input = "/stop"
122+
bot_response = "conversation-terminated"
123+
print(f"Bot : ",bot_response)
124+
self.db.insert_chat_data(self.chat_id_pk, user_input, bot_response)
125+
self.db.insert_chat_history(current_time, is_stream)
126+
except Exception as e:
127+
raise ValueError(f"Failed to update chat history: {e}")
128+
129+
def continue_chat(self) -> None:
130+
"""
131+
Retain only the last few entries if the conversation exceeds 1000 messages.
132+
"""
133+
if len(self.conversation_history) > 1000:
134+
self.conversation_history = self.conversation_history[-3:]

neural_network/chatbot/db.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import os
2+
from dotenv import load_dotenv
3+
import mysql.connector
4+
from mysql.connector import MySQLConnection
5+
6+
load_dotenv()
7+
8+
9+
class Database:
10+
"""
11+
A class to manage the connection to the MySQL database using configuration from environment variables.
12+
13+
Attributes:
14+
-----------
15+
config : dict
16+
The database connection parameters like user, password, host, and database name.
17+
"""
18+
19+
def __init__(self) -> None:
20+
self.config = {
21+
"user": os.environ.get("DB_USER"),
22+
"password": os.environ.get("DB_PASSWORD"),
23+
"host": os.environ.get("DB_HOST"),
24+
"database": os.environ.get("DB_NAME"),
25+
}
26+
27+
def connect(self) -> MySQLConnection:
28+
"""
29+
Establish a connection to the MySQL database.
30+
31+
Returns:
32+
--------
33+
MySQLConnection
34+
A connection object for interacting with the MySQL database.
35+
36+
Raises:
37+
-------
38+
mysql.connector.Error
39+
If the connection to the database fails.
40+
"""
41+
return mysql.connector.connect(**self.config)
42+
43+
44+
class ChatDatabase:
45+
"""
46+
A class to manage chat-related database operations, such as creating tables,
47+
inserting chat history, and retrieving chat data.
48+
49+
Attributes:
50+
-----------
51+
db : Database
52+
An instance of the `Database` class for establishing connections to the MySQL database.
53+
"""
54+
55+
def __init__(self, db: Database) -> None:
56+
self.db = db
57+
58+
def create_tables(self) -> None:
59+
"""
60+
Create the necessary tables for chat history and chat data in the database.
61+
If the tables already exist, they will not be created again.
62+
63+
Raises:
64+
-------
65+
mysql.connector.Error
66+
If there is any error executing the SQL statements.
67+
"""
68+
conn = self.db.connect()
69+
cursor = conn.cursor()
70+
71+
cursor.execute(
72+
"""
73+
CREATE TABLE IF NOT EXISTS ChatDB.Chat_history (
74+
chat_id INT AUTO_INCREMENT PRIMARY KEY,
75+
start_time DATETIME,
76+
is_stream INT
77+
)
78+
"""
79+
)
80+
81+
cursor.execute(
82+
"""
83+
CREATE TABLE IF NOT EXISTS ChatDB.Chat_data (
84+
id INT AUTO_INCREMENT PRIMARY KEY,
85+
chat_id INT,
86+
user TEXT,
87+
assistant TEXT,
88+
FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id)
89+
)
90+
"""
91+
)
92+
93+
cursor.execute("DROP TRIGGER IF EXISTS update_is_stream")
94+
95+
cursor.execute(
96+
"""
97+
CREATE TRIGGER update_is_stream
98+
AFTER UPDATE ON ChatDB.Chat_history
99+
FOR EACH ROW
100+
BEGIN
101+
UPDATE ChatDB.Chat_data
102+
SET is_stream = NEW.is_stream
103+
WHERE chat_id = NEW.chat_id;
104+
END;
105+
"""
106+
)
107+
108+
conn.commit()
109+
cursor.close()
110+
conn.close()
111+
112+
def insert_chat_history(self, start_time: str, is_stream: int) -> None:
113+
"""
114+
Insert a new chat history record into the database.
115+
116+
Parameters:
117+
-----------
118+
start_time : str
119+
The starting time of the chat session.
120+
is_stream : int
121+
An integer indicating whether the chat is in progress (1) or ended (2).
122+
123+
Raises:
124+
-------
125+
mysql.connector.Error
126+
If there is any error executing the SQL statements.
127+
"""
128+
conn = self.db.connect()
129+
cursor = conn.cursor()
130+
cursor.execute(
131+
"""
132+
INSERT INTO ChatDB.Chat_history (start_time, is_stream)
133+
VALUES (%s, %s)
134+
""",
135+
(start_time, is_stream),
136+
)
137+
conn.commit()
138+
cursor.close()
139+
conn.close()
140+
141+
def get_latest_chat_id(self) -> int:
142+
"""
143+
Retrieve the chat ID of the most recent chat session from the database.
144+
145+
Returns:
146+
--------
147+
int
148+
The ID of the latest chat session.
149+
150+
Raises:
151+
-------
152+
mysql.connector.Error
153+
If there is any error executing the SQL statements.
154+
"""
155+
conn = self.db.connect()
156+
cursor = conn.cursor()
157+
cursor.execute(
158+
"""
159+
SELECT chat_id FROM ChatDB.Chat_history WHERE
160+
chat_id=(SELECT MAX(chat_id) FROM ChatDB.Chat_history)
161+
"""
162+
)
163+
chat_id_pk = cursor.fetchone()[0]
164+
cursor.close()
165+
conn.close()
166+
return chat_id_pk
167+
168+
def insert_chat_data(
169+
self, chat_id: int, user_message: str, assistant_message: str
170+
) -> None:
171+
"""
172+
Insert a new chat data record into the database.
173+
174+
Parameters:
175+
-----------
176+
chat_id : int
177+
The ID of the chat session to which this data belongs.
178+
user_message : str
179+
The message provided by the user in the chat session.
180+
assistant_message : str
181+
The response from the assistant in the chat session.
182+
183+
Raises:
184+
-------
185+
mysql.connector.Error
186+
If there is any error executing the SQL statements.
187+
"""
188+
conn = self.db.connect()
189+
cursor = conn.cursor()
190+
cursor.execute(
191+
"""
192+
INSERT INTO ChatDB.Chat_data (chat_id, user, assistant)
193+
VALUES (%s, %s, %s)
194+
""",
195+
(chat_id, user_message, assistant_message),
196+
)
197+
conn.commit()
198+
cursor.close()
199+
conn.close()

0 commit comments

Comments
 (0)