ai-courseware/langchain-project/tests/test_simple_chain.py

111 lines
3.1 KiB
Python

#!/usr/bin/env python3
"""
简单对话链测试
"""
import asyncio
import sys
import os
# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.chains.simple_chat_chain import get_simple_chat_chain, get_conversation_manager
from src.models.chat_models import ChatMessage, MessageRole
from datetime import datetime
async def test_simple_chat():
"""测试简单对话"""
print("🧪 测试简单对话链")
print("=" * 50)
try:
# 获取对话链
chain = get_simple_chat_chain()
# 测试用例
test_cases = [
"你好,我是新用户",
"你能做什么?",
"今天天气怎么样?",
"帮我查询订单 ORD12345",
"谢谢你的帮助"
]
for i, user_input in enumerate(test_cases, 1):
print(f"\n📝 测试 {i}: {user_input}")
print("🤖 AI: ", end="", flush=True)
# 测试流式输出
response = ""
async for chunk in chain.chat_streaming(user_input, session_id="test_session"):
print(chunk, end="", flush=True)
response += chunk
print(f"\n✅ 回复长度: {len(response)} 字符")
print("\n🎉 简单对话链测试完成!")
except Exception as e:
print(f"❌ 测试失败: {e}")
raise
async def test_conversation_manager():
"""测试对话管理器"""
print("\n🧪 测试对话管理器")
print("=" * 50)
try:
# 获取对话管理器
manager = get_conversation_manager()
session_id = "test_conversation"
# 测试对话
test_messages = [
"你好",
"我想了解你的功能",
"能帮我查询订单吗?"
]
for i, message in enumerate(test_messages, 1):
print(f"\n📝 对话 {i}: {message}")
print("🤖 AI: ", end="", flush=True)
# 使用对话管理器
async for chunk in manager.chat_streaming(session_id, message):
print(chunk, end="", flush=True)
# 显示对话历史
print(f"\n\n📚 对话历史:")
history = manager.get_conversation(session_id)
for msg in history:
role_name = "用户" if msg.role == MessageRole.USER else "AI"
print(f" {role_name}: {msg.content[:50]}...")
print(f"\n✅ 对话历史包含 {len(history)} 条消息")
print("🎉 对话管理器测试完成!")
except Exception as e:
print(f"❌ 测试失败: {e}")
raise
async def main():
"""主测试函数"""
print("🚀 开始测试简单对话链")
try:
await test_simple_chat()
await test_conversation_manager()
print("\n🎊 所有测试通过!")
except Exception as e:
print(f"\n💥 测试过程中出现错误: {e}")
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())