111 lines
3.1 KiB
Python
111 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
简单对话链测试
|
|
"""
|
|
|
|
import asyncio
|
|
import sys
|
|
import os
|
|
|
|
# 添加项目根目录到 Python 路径
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from src.chains.simple_chat_chain import get_simple_chat_chain, get_conversation_manager
|
|
from src.models.chat_models import ChatMessage, MessageRole
|
|
from datetime import datetime
|
|
|
|
|
|
async def test_simple_chat():
|
|
"""测试简单对话"""
|
|
print("🧪 测试简单对话链")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
# 获取对话链
|
|
chain = get_simple_chat_chain()
|
|
|
|
# 测试用例
|
|
test_cases = [
|
|
"你好,我是新用户",
|
|
"你能做什么?",
|
|
"今天天气怎么样?",
|
|
"帮我查询订单 ORD12345",
|
|
"谢谢你的帮助"
|
|
]
|
|
|
|
for i, user_input in enumerate(test_cases, 1):
|
|
print(f"\n📝 测试 {i}: {user_input}")
|
|
print("🤖 AI: ", end="", flush=True)
|
|
|
|
# 测试流式输出
|
|
response = ""
|
|
async for chunk in chain.chat_streaming(user_input, session_id="test_session"):
|
|
print(chunk, end="", flush=True)
|
|
response += chunk
|
|
|
|
print(f"\n✅ 回复长度: {len(response)} 字符")
|
|
|
|
print("\n🎉 简单对话链测试完成!")
|
|
|
|
except Exception as e:
|
|
print(f"❌ 测试失败: {e}")
|
|
raise
|
|
|
|
|
|
async def test_conversation_manager():
|
|
"""测试对话管理器"""
|
|
print("\n🧪 测试对话管理器")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
# 获取对话管理器
|
|
manager = get_conversation_manager()
|
|
session_id = "test_conversation"
|
|
|
|
# 测试对话
|
|
test_messages = [
|
|
"你好",
|
|
"我想了解你的功能",
|
|
"能帮我查询订单吗?"
|
|
]
|
|
|
|
for i, message in enumerate(test_messages, 1):
|
|
print(f"\n📝 对话 {i}: {message}")
|
|
print("🤖 AI: ", end="", flush=True)
|
|
|
|
# 使用对话管理器
|
|
async for chunk in manager.chat_streaming(session_id, message):
|
|
print(chunk, end="", flush=True)
|
|
|
|
# 显示对话历史
|
|
print(f"\n\n📚 对话历史:")
|
|
history = manager.get_conversation(session_id)
|
|
for msg in history:
|
|
role_name = "用户" if msg.role == MessageRole.USER else "AI"
|
|
print(f" {role_name}: {msg.content[:50]}...")
|
|
|
|
print(f"\n✅ 对话历史包含 {len(history)} 条消息")
|
|
print("🎉 对话管理器测试完成!")
|
|
|
|
except Exception as e:
|
|
print(f"❌ 测试失败: {e}")
|
|
raise
|
|
|
|
|
|
async def main():
|
|
"""主测试函数"""
|
|
print("🚀 开始测试简单对话链")
|
|
|
|
try:
|
|
await test_simple_chat()
|
|
await test_conversation_manager()
|
|
|
|
print("\n🎊 所有测试通过!")
|
|
|
|
except Exception as e:
|
|
print(f"\n💥 测试过程中出现错误: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |