|
| 1 | +import datetime |
| 2 | +from typing import List, Dict, Any |
| 3 | + |
| 4 | + |
| 5 | +class Chatbot: |
| 6 | + """ |
| 7 | + A Chatbot class to manage chat conversations using an LLM service and a database to store chat data. |
| 8 | +
|
| 9 | + Methods: |
| 10 | + - start_chat: Starts a new conversation, logs the start time. |
| 11 | + - handle_user_message: Processes user input and stores user message & bot response in DB. |
| 12 | + - end_chat: Ends the conversation and logs the end time. |
| 13 | + - continue_chat: Retains only the last few messages if the conversation exceeds 1000 messages. |
| 14 | + """ |
| 15 | + |
| 16 | + def __init__(self, db: Any, llm_service: Any) -> None: |
| 17 | + """ |
| 18 | + Initialize the Chatbot with a database and an LLM service. |
| 19 | +
|
| 20 | + Parameters: |
| 21 | + - db: The database instance used for storing chat data. |
| 22 | + - llm_service: The language model service for generating responses. |
| 23 | + """ |
| 24 | + self.db = db |
| 25 | + self.llm_service = llm_service |
| 26 | + self.conversation_history: List[Dict[str, str]] = [] |
| 27 | + self.chat_id_pk: int = None |
| 28 | + |
| 29 | + def start_chat(self) -> None: |
| 30 | + """ |
| 31 | + Start a new chat session and insert chat history to the database. |
| 32 | + """ |
| 33 | + start_time = datetime.datetime.now() |
| 34 | + is_stream = 1 # Start new conversation |
| 35 | + self.db.insert_chat_history(start_time, is_stream) |
| 36 | + self.chat_id_pk = self.db.get_latest_chat_id() |
| 37 | + |
| 38 | + def handle_user_message(self, user_input: str) -> str: |
| 39 | + """ |
| 40 | + Handle user input and generate a bot response. |
| 41 | + If the user sends '/stop', the conversation is terminated. |
| 42 | +
|
| 43 | + Parameters: |
| 44 | + - user_input: The input provided by the user. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + - bot_response: The response generated by the bot. |
| 48 | +
|
| 49 | + Raises: |
| 50 | + - ValueError: If user input is not a string or if no chat_id is available. |
| 51 | +
|
| 52 | + Doctest: |
| 53 | + >>> class MockDatabase: |
| 54 | + ... def __init__(self): |
| 55 | + ... self.data = [] |
| 56 | + ... def insert_chat_data(self, *args, **kwargs): |
| 57 | + ... pass |
| 58 | + ... def insert_chat_history(self, *args, **kwargs): |
| 59 | + ... pass |
| 60 | + ... def get_latest_chat_id(self): |
| 61 | + ... return 1 |
| 62 | + ... |
| 63 | + >>> class MockLLM: |
| 64 | + ... def generate_response(self, conversation_history): |
| 65 | + ... if conversation_history[-1]["content"] == "/stop": |
| 66 | + ... return "conversation-terminated" |
| 67 | + ... return "Mock response" |
| 68 | + >>> db_mock = MockDatabase() |
| 69 | + >>> llm_mock = MockLLM() |
| 70 | + >>> bot = Chatbot(db_mock, llm_mock) |
| 71 | + >>> bot.start_chat() |
| 72 | + >>> bot.handle_user_message("/stop") |
| 73 | + 'conversation-terminated' |
| 74 | + >>> bot.handle_user_message("Hello!") |
| 75 | + 'Mock response' |
| 76 | + """ |
| 77 | + if not isinstance(user_input, str): |
| 78 | + raise ValueError("User input must be a string.") |
| 79 | + |
| 80 | + if self.chat_id_pk is None: |
| 81 | + raise ValueError("Chat has not been started. Call start_chat() first.") |
| 82 | + |
| 83 | + self.conversation_history.append({"role": "user", "content": user_input}) |
| 84 | + |
| 85 | + if user_input == "/stop": |
| 86 | + self.end_chat() |
| 87 | + return "conversation-terminated" |
| 88 | + else: |
| 89 | + bot_response = self.llm_service.generate_response(self.conversation_history) |
| 90 | + print(f"Bot : ",bot_response) |
| 91 | + self.conversation_history.append( |
| 92 | + {"role": "assistant", "content": bot_response} |
| 93 | + ) |
| 94 | + self._store_message_in_db(user_input, bot_response) |
| 95 | + |
| 96 | + return bot_response |
| 97 | + |
| 98 | + def _store_message_in_db(self, user_input: str, bot_response: str) -> None: |
| 99 | + """ |
| 100 | + Store user input and bot response in the database. |
| 101 | +
|
| 102 | + Parameters: |
| 103 | + - user_input: The message from the user. |
| 104 | + - bot_response: The response generated by the bot. |
| 105 | +
|
| 106 | + Raises: |
| 107 | + - ValueError: If insertion into the database fails. |
| 108 | + """ |
| 109 | + try: |
| 110 | + self.db.insert_chat_data(self.chat_id_pk, user_input, bot_response) |
| 111 | + except Exception as e: |
| 112 | + raise ValueError(f"Failed to insert chat data: {e}") |
| 113 | + |
| 114 | + def end_chat(self) -> None: |
| 115 | + """ |
| 116 | + End the chat session and update the chat history in the database. |
| 117 | + """ |
| 118 | + current_time = datetime.datetime.now() |
| 119 | + is_stream = 2 # End of conversation |
| 120 | + try: |
| 121 | + user_input = "/stop" |
| 122 | + bot_response = "conversation-terminated" |
| 123 | + print(f"Bot : ",bot_response) |
| 124 | + self.db.insert_chat_data(self.chat_id_pk, user_input, bot_response) |
| 125 | + self.db.insert_chat_history(current_time, is_stream) |
| 126 | + except Exception as e: |
| 127 | + raise ValueError(f"Failed to update chat history: {e}") |
| 128 | + |
| 129 | + def continue_chat(self) -> None: |
| 130 | + """ |
| 131 | + Retain only the last few entries if the conversation exceeds 1000 messages. |
| 132 | + """ |
| 133 | + if len(self.conversation_history) > 1000: |
| 134 | + self.conversation_history = self.conversation_history[-3:] |
0 commit comments