298 lines
11 KiB
Python
298 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
命令行对话程序
|
||
提供交互式的命令行聊天界面
|
||
"""
|
||
|
||
import asyncio
|
||
import sys
|
||
import uuid
|
||
from typing import Optional, List
|
||
from datetime import datetime
|
||
|
||
# 添加项目根目录到 Python 路径
|
||
sys.path.append('.')
|
||
|
||
from src.chains.simple_chat_chain import get_conversation_manager, get_simple_chat_chain
|
||
from src.workflows.chat_workflow import get_chat_workflow, create_chat_session
|
||
from src.models.chat_models import ChatMessage, MessageRole
|
||
from src.utils.config import get_config
|
||
from src.utils.logger import get_workflow_logger
|
||
|
||
config = get_config()
|
||
logger = get_workflow_logger()
|
||
|
||
|
||
class CLIChatBot:
|
||
"""命令行聊天机器人"""
|
||
|
||
def __init__(self, use_workflow: bool = True):
|
||
"""
|
||
初始化聊天机器人
|
||
|
||
Args:
|
||
use_workflow: 是否使用 LangGraph 工作流,False 则使用简单链
|
||
"""
|
||
self.use_workflow = use_workflow
|
||
self.session_id = str(uuid.uuid4())
|
||
self.conversation_manager = get_conversation_manager()
|
||
|
||
if use_workflow:
|
||
self.workflow = get_chat_workflow()
|
||
logger.info("🤖 使用 LangGraph 工作流模式")
|
||
else:
|
||
self.simple_chain = get_simple_chat_chain()
|
||
logger.info("🔗 使用简单链模式")
|
||
|
||
print(f"🤖 聊天机器人已启动 (会话ID: {self.session_id[:8]}...)")
|
||
print(f"📋 模式: {'LangGraph 工作流' if use_workflow else '简单对话链'}")
|
||
print("💡 输入 'help' 查看帮助,输入 'quit' 或 'exit' 退出")
|
||
print("-" * 50)
|
||
|
||
async def start_chat(self):
|
||
"""开始聊天循环"""
|
||
try:
|
||
while True:
|
||
# 获取用户输入
|
||
user_input = input("\n👤 您: ").strip()
|
||
|
||
if not user_input:
|
||
continue
|
||
|
||
# 处理特殊命令
|
||
if user_input.lower() in ['quit', 'exit', '退出']:
|
||
print("👋 再见!")
|
||
break
|
||
elif user_input.lower() == 'help':
|
||
self.show_help()
|
||
continue
|
||
elif user_input.lower() == 'clear':
|
||
self.clear_history()
|
||
continue
|
||
elif user_input.lower() == 'history':
|
||
self.show_history()
|
||
continue
|
||
elif user_input.lower() == 'switch':
|
||
await self.switch_mode()
|
||
continue
|
||
elif user_input.lower() == 'status':
|
||
await self.show_status()
|
||
continue
|
||
|
||
# 处理聊天消息
|
||
await self.process_message(user_input)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\n👋 程序被中断,再见!")
|
||
except Exception as e:
|
||
print(f"\n❌ 发生错误: {e}")
|
||
logger.error(f"CLI 聊天错误: {e}")
|
||
|
||
async def process_message(self, user_input: str):
|
||
"""处理用户消息"""
|
||
try:
|
||
print("\n🤖 AI: ", end="", flush=True)
|
||
|
||
if self.use_workflow:
|
||
# 使用工作流处理
|
||
await self.process_with_workflow(user_input)
|
||
else:
|
||
# 使用简单链处理
|
||
await self.process_with_simple_chain(user_input)
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ 处理消息时出错: {e}")
|
||
logger.error(f"消息处理错误: {e}")
|
||
|
||
async def process_with_workflow(self, user_input: str):
|
||
"""使用工作流处理消息"""
|
||
try:
|
||
# 获取对话历史
|
||
history = self.conversation_manager.get_conversation(self.session_id)
|
||
|
||
# 流式处理
|
||
full_response = ""
|
||
current_type = None
|
||
|
||
async for result in self.workflow.run_streaming(user_input, self.session_id, history):
|
||
result_type = result.get("type", "")
|
||
content = result.get("content", "")
|
||
|
||
# 根据不同类型显示不同的输出
|
||
if result_type == "workflow_start":
|
||
print(f"🔄 {content}", flush=True)
|
||
elif result_type == "intent_analysis":
|
||
print(f"\n🎯 意图分析: {content}", flush=True)
|
||
elif result_type == "route_decision":
|
||
print(f"\n🚦 {content}", flush=True)
|
||
elif result_type == "order_info":
|
||
print(f"\n📦 订单信息已获取", flush=True)
|
||
elif result_type == "thinking_start":
|
||
print(f"\n🤔 {content}", flush=True)
|
||
elif result_type == "thinking":
|
||
print(content, end="", flush=True)
|
||
elif result_type == "diagnosis_result":
|
||
print(f"\n🔍 诊断完成", flush=True)
|
||
elif result_type == "chat_start":
|
||
print(f"💭 {content}", flush=True)
|
||
elif result_type == "chat_chunk":
|
||
print(content, end="", flush=True)
|
||
full_response += content
|
||
elif result_type == "final_response":
|
||
if not full_response: # 如果没有通过 chat_chunk 获取到回复
|
||
print(content, flush=True)
|
||
full_response = content
|
||
elif result_type == "chat_complete":
|
||
if not full_response:
|
||
full_response = content
|
||
# 在对话完成后添加换行符,确保后续日志信息不会紧跟在AI回复后面
|
||
print("", flush=True)
|
||
elif result_type == "error":
|
||
print(f"\n❌ {content}", flush=True)
|
||
return
|
||
|
||
# 更新对话历史
|
||
if full_response:
|
||
user_message = ChatMessage(
|
||
role=MessageRole.USER,
|
||
content=user_input,
|
||
timestamp=datetime.now()
|
||
)
|
||
ai_message = ChatMessage(
|
||
role=MessageRole.ASSISTANT,
|
||
content=full_response,
|
||
timestamp=datetime.now()
|
||
)
|
||
|
||
self.conversation_manager.add_message(self.session_id, user_message)
|
||
self.conversation_manager.add_message(self.session_id, ai_message)
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ 工作流处理错误: {e}")
|
||
logger.error(f"工作流处理错误: {e}")
|
||
|
||
async def process_with_simple_chain(self, user_input: str):
|
||
"""使用简单链处理消息"""
|
||
try:
|
||
# 流式输出
|
||
async for chunk in self.conversation_manager.chat_streaming(self.session_id, user_input):
|
||
print(chunk, end="", flush=True)
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ 简单链处理错误: {e}")
|
||
logger.error(f"简单链处理错误: {e}")
|
||
|
||
def show_help(self):
|
||
"""显示帮助信息"""
|
||
help_text = """
|
||
🆘 命令帮助:
|
||
help - 显示此帮助信息
|
||
clear - 清除对话历史
|
||
history - 显示对话历史
|
||
switch - 切换对话模式 (工作流 ↔ 简单链)
|
||
status - 显示当前状态
|
||
quit/exit - 退出程序
|
||
|
||
💡 使用提示:
|
||
- 直接输入消息开始对话
|
||
- 支持订单查询 (如: "查询订单 ORD12345")
|
||
- 支持自然对话和闲聊
|
||
- 工作流模式支持智能意图识别
|
||
"""
|
||
print(help_text)
|
||
|
||
def clear_history(self):
|
||
"""清除对话历史"""
|
||
self.conversation_manager.clear_conversation(self.session_id)
|
||
print("🧹 对话历史已清除")
|
||
|
||
def show_history(self):
|
||
"""显示对话历史"""
|
||
history = self.conversation_manager.get_conversation(self.session_id)
|
||
|
||
if not history:
|
||
print("📝 暂无对话历史")
|
||
return
|
||
|
||
print(f"\n📚 对话历史 (共 {len(history)} 条消息):")
|
||
print("-" * 50)
|
||
|
||
for i, msg in enumerate(history[-10:], 1): # 只显示最近10条
|
||
role_icon = "👤" if msg.role == MessageRole.USER else "🤖"
|
||
role_name = "您" if msg.role == MessageRole.USER else "AI"
|
||
timestamp = msg.timestamp.strftime("%H:%M:%S")
|
||
|
||
# 限制消息长度
|
||
content = msg.content[:100] + "..." if len(msg.content) > 100 else msg.content
|
||
|
||
print(f"{i}. [{timestamp}] {role_icon} {role_name}: {content}")
|
||
|
||
if len(history) > 10:
|
||
print(f"... (还有 {len(history) - 10} 条更早的消息)")
|
||
|
||
async def switch_mode(self):
|
||
"""切换对话模式"""
|
||
self.use_workflow = not self.use_workflow
|
||
|
||
if self.use_workflow:
|
||
self.workflow = get_chat_workflow()
|
||
print("🔄 已切换到 LangGraph 工作流模式")
|
||
else:
|
||
self.simple_chain = get_simple_chat_chain()
|
||
print("🔄 已切换到简单对话链模式")
|
||
|
||
async def show_status(self):
|
||
"""显示当前状态"""
|
||
history = self.conversation_manager.get_conversation(self.session_id)
|
||
|
||
status_info = f"""
|
||
📊 当前状态:
|
||
会话ID: {self.session_id}
|
||
对话模式: {'LangGraph 工作流' if self.use_workflow else '简单对话链'}
|
||
消息数量: {len(history)}
|
||
启动时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||
"""
|
||
print(status_info)
|
||
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
print("🚀 LangChain 聊天机器人 CLI")
|
||
print("=" * 50)
|
||
|
||
# 检查命令行参数
|
||
use_workflow = True
|
||
if len(sys.argv) > 1:
|
||
if sys.argv[1].lower() in ['simple', 'chain']:
|
||
use_workflow = False
|
||
elif sys.argv[1].lower() in ['workflow', 'graph']:
|
||
use_workflow = True
|
||
elif sys.argv[1].lower() in ['help', '-h', '--help']:
|
||
print("""
|
||
使用方法:
|
||
python cli_chat.py [模式]
|
||
|
||
模式选项:
|
||
workflow/graph - 使用 LangGraph 工作流 (默认)
|
||
simple/chain - 使用简单对话链
|
||
help - 显示此帮助
|
||
|
||
示例:
|
||
python cli_chat.py # 使用工作流模式
|
||
python cli_chat.py simple # 使用简单链模式
|
||
python cli_chat.py workflow # 使用工作流模式
|
||
""")
|
||
return
|
||
|
||
# 创建并启动聊天机器人
|
||
try:
|
||
chatbot = CLIChatBot(use_workflow=use_workflow)
|
||
await chatbot.start_chat()
|
||
except Exception as e:
|
||
print(f"❌ 启动失败: {e}")
|
||
logger.error(f"CLI 启动失败: {e}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 运行主程序
|
||
asyncio.run(main()) |