diff --git a/neural_network/chatbot/README.md b/neural_network/chatbot/README.md new file mode 100644 index 000000000000..acddd4c8f671 --- /dev/null +++ b/neural_network/chatbot/README.md @@ -0,0 +1,83 @@ + + +# Chatbot with LLM Integration and Database Storage + +This chatbot application integrates LLM (Large Language Model) API services, **Together** and **Groq**(you can use any one of them), to generate AI-driven responses. It stores conversation history in a MySQL database and manages chat sessions with triggers that update the status of conversations automatically. + +## Features +- Supports LLM response generation using **Together** and **Groq** APIs. +- Stores chat sessions and message exchanges in MySQL database tables. +- Automatically updates chat session status using database triggers. +- Manages conversation history with user-assistant interaction. + +## Requirements + +Before running the application, ensure the following dependencies are installed: + +- Python 3.13+ +- MySQL Server +- The following Python libraries: + ```bash + pip3 install -r requirements.txt + ``` + +## Setup Instructions + +### Step 1: Set Up Environment Variables + +Create a `.env` file in the root directory of your project and add the following entries for your database credentials and API keys: + +``` +# Together API key +TOGETHER_API_KEY="YOUR_API_KEY" + +# Groq API key +GROQ_API_KEY = "YOUR_API_KEY" + +# MySQL connectionDB (if you're running locally) +DB_USER = "" +DB_PASSWORD = "" +DB_HOST = "127.0.0.1" +DB_NAME = "ChatDB" +PORT = "3306" + +# API service to you(or use "Together") +API_SERVICE = "Groq" +``` + +### Step 2: Create MySQL Tables and Trigger + +The `create_tables()` function in the script automatically creates the necessary tables and a trigger for updating chat session statuses. To ensure the database is set up correctly, the function is called at the beginning of the script. + +Ensure that your MySQL server is running and accessible before running the code. + +### Step 3: Run the Application + +To start the chatbot: + +1. Ensure your MySQL server is running. +2. Open a terminal and run the Python script: + +```bash +python3 chat_db.py +``` + +The chatbot will initialize, and you can interact with it by typing your inputs. Type `/stop` to end the conversation. + +### Step 4: Test and Validate Code + +This project uses doctests to ensure that the functions work as expected. To run the doctests: + +```bash +python3 -m doctest -v chatbot.py +``` + +Make sure to add doctests to all your functions where applicable, to validate both valid and erroneous inputs. + +### Key Functions + +- **create_tables()**: Sets up the MySQL tables (`Chat_history` and `Chat_data`) and the `update_is_stream` trigger. +- **insert_chat_history()**: Inserts a new chat session into the `Chat_history` table. +- **insert_chat_data()**: Inserts user-assistant message pairs into the `Chat_data` table. +- **generate_llm_response()**: Generates a response from the selected LLM API service, either **Together** or **Groq**. + diff --git a/neural_network/chatbot/chat_db.py b/neural_network/chatbot/chat_db.py new file mode 100644 index 000000000000..5f27512942a0 --- /dev/null +++ b/neural_network/chatbot/chat_db.py @@ -0,0 +1,246 @@ +""" +credits : https://medium.com/google-developer-experts/beyond-live-sessions-building-persistent-memory-chatbots-with-langchain-gemini-pro-and-firebase-19d6f84e21d3 + +""" + +import os +import datetime +from dotenv import load_dotenv +import mysql.connector +from together import Together +from groq import Groq + +load_dotenv() + +# Database configuration +db_config = { + "user": os.environ.get("DB_USER"), + "password": os.environ.get("DB_PASSWORD"), + "host": os.environ.get("DB_HOST"), + "database": os.environ.get("DB_NAME"), +} + +api_service = os.environ.get("API_SERVICE") + + +def create_tables() -> None: + """ + Create the ChatDB.Chat_history and ChatDB.Chat_data tables + if they do not exist.Also, create a trigger to update is_stream + in Chat_data when Chat_history.is_stream is updated. + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ChatDB.Chat_history ( + chat_id INT AUTO_INCREMENT PRIMARY KEY, + start_time DATETIME, + is_stream INT + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ChatDB.Chat_data ( + id INT AUTO_INCREMENT PRIMARY KEY, + chat_id INT, + user TEXT, + assistant TEXT, + FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id) + ) + """ + ) + + cursor.execute("DROP TRIGGER IF EXISTS update_is_stream;") + + cursor.execute( + """ + CREATE TRIGGER update_is_stream + AFTER UPDATE ON ChatDB.Chat_history + FOR EACH ROW + BEGIN + UPDATE ChatDB.Chat_data + SET is_stream = NEW.is_stream + WHERE chat_id = NEW.chat_id; + END; + """ + ) + + conn.commit() + except mysql.connector.Error as err: + print(f"Error: {err}") + finally: + cursor.close() + conn.close() + print("Tables and trigger created successfully") + + +def insert_chat_history(start_time: datetime.datetime, is_stream: int) -> None: + """ + Insert a new row into the ChatDB.Chat_history table. + :param start_time: Timestamp of when the chat started + :param is_stream: Indicator of whether the conversation is + ongoing, starting, or ending + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO ChatDB.Chat_history (start_time, is_stream) + VALUES (%s, %s) + """, + (start_time, is_stream), + ) + conn.commit() + except mysql.connector.Error as err: + print(f"Error: {err}") + finally: + cursor.close() + conn.close() + + +def get_latest_chat_id() -> int: + """ + Retrieve the latest chat_id from the ChatDB.Chat_history table. + :return: The latest chat_id or None if no chat_id exists. + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + cursor.execute( + """ + SELECT chat_id FROM ChatDB.Chat_history + ORDER BY chat_id DESC LIMIT 1 + """ + ) + chat_id = cursor.fetchone()[0] + return chat_id if chat_id else None + except mysql.connector.Error as err: + print(f"Error: {err}") + return 0 + finally: + cursor.close() + conn.close() + + +def insert_chat_data(chat_id: int, user_message: str, assistant_message: str) -> None: + """ + Insert a new row into the ChatDB.Chat_data table. + :param chat_id: The ID of the chat session + :param user_message: The user's message + :param assistant_message: The assistant's message + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO ChatDB.Chat_data (chat_id, user, assistant) + VALUES (%s, %s, %s) + """, + (chat_id, user_message, assistant_message), + ) + conn.commit() + except mysql.connector.Error as err: + print(f"Error: {err}") + finally: + cursor.close() + conn.close() + + +def generate_llm_response( + conversation_history: list[dict], api_service: str = "Groq" +) -> str: + """ + Generate a response from the LLM based on the conversation history. + :param conversation_history: List of dictionaries representing + the conversation so far + :param api_service: Choose between "Together" or "Groq" as the + API service + :return: Assistant's response as a string + """ + bot_response = "" + if api_service == "Together": + client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) + response = client.chat.completions.create( + model="meta-llama/Llama-3.2-3B-Instruct-Turbo", + messages=conversation_history, + max_tokens=512, + temperature=0.3, + top_p=0.7, + top_k=50, + repetition_penalty=1, + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + bot_response = response.choices[0].message.content + else: + client = Groq(api_key=os.environ.get("GROQ_API_KEY")) + response = client.chat.completions.create( + model="llama3-8b-8192", + messages=conversation_history, + max_tokens=1024, + temperature=0.3, + top_p=0.7, + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + bot_response = response.choices[0].message.content + + return bot_response + + +def chat_session() -> None: + """ + Start a chatbot session, allowing the user to interact with the LLM. + Saves conversation history in the database and ends the session on "/stop" command. + """ + print("Welcome to the chatbot! Type '/stop' to end the conversation.") + + conversation_history = [] + start_time = datetime.datetime.now(datetime.timezone.utc) + + chat_id_pk = None + api_service = "Groq" # or "Together" + + while True: + user_input = input("\nYou: ").strip() + conversation_history.append({"role": "user", "content": user_input}) + + if chat_id_pk is None: + if user_input.lower() == "/stop": + break + bot_response = generate_llm_response(conversation_history, api_service) + conversation_history.append({"role": "assistant", "content": bot_response}) + + is_stream = 1 # New conversation + insert_chat_history(start_time, is_stream) + chat_id_pk = get_latest_chat_id() + insert_chat_data(chat_id_pk, user_input, bot_response) + else: + if user_input.lower() == "/stop": + is_stream = 2 # End of conversation + current_time = datetime.datetime.now(datetime.timezone.utc) + insert_chat_history(current_time, is_stream) + break + + bot_response = generate_llm_response(conversation_history, api_service) + conversation_history.append({"role": "assistant", "content": bot_response}) + + is_stream = 0 # Continuation of conversation + current_time = datetime.datetime.now(datetime.timezone.utc) + insert_chat_history(current_time, is_stream) + insert_chat_data(chat_id_pk, user_input, bot_response) + + if len(conversation_history) > 1000: + conversation_history = conversation_history[-3:] + + +# starting a chat session +create_tables() +chat_session() diff --git a/neural_network/chatbot/requirements.txt b/neural_network/chatbot/requirements.txt new file mode 100644 index 000000000000..0f1204243a5d --- /dev/null +++ b/neural_network/chatbot/requirements.txt @@ -0,0 +1,57 @@ +aiohappyeyeballs==2.4.2 +aiohttp==3.10.8 +aiosignal==1.3.1 +annotated-types==0.7.0 +anyio==4.6.0 +asgiref==3.8.1 +attrs==24.2.0 +black==24.10.0 +certifi==2024.8.30 +cfgv==3.4.0 +charset-normalizer==3.3.2 +click==8.1.7 +distlib==0.3.9 +distro==1.9.0 +Django==5.1.1 +djangorestframework==3.15.2 +eval_type_backport==0.2.0 +filelock==3.16.1 +frozenlist==1.4.1 +groq==0.11.0 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.2 +identify==2.6.1 +idna==3.10 +markdown-it-py==3.0.0 +mdurl==0.1.2 +multidict==6.1.0 +mypy-extensions==1.0.0 +mysql-connector-python==9.0.0 +nodeenv==1.9.1 +numpy==2.1.1 +packaging==24.1 +pathspec==0.12.1 +pillow==10.4.0 +platformdirs==4.3.6 +pre_commit==4.0.1 +pyarrow==17.0.0 +pydantic==2.9.2 +pydantic_core==2.23.4 +Pygments==2.18.0 +python-dotenv==1.0.1 +PyYAML==6.0.2 +requests==2.32.3 +rich==13.8.1 +ruff==0.7.0 +shellingham==1.5.4 +sniffio==1.3.1 +sqlparse==0.5.1 +tabulate==0.9.0 +together==1.3.0 +tqdm==4.66.5 +typer==0.12.5 +typing_extensions==4.12.2 +urllib3==2.2.3 +virtualenv==20.27.0 +yarl==1.13.1