ai-courseware/langchain-project/tests/test_api.py

288 lines
11 KiB
Python

#!/usr/bin/env python3
"""
FastAPI 接口测试
"""
import asyncio
import sys
import os
import json
import aiohttp
from typing import Dict, Any
# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class APITester:
"""API 测试器"""
def __init__(self, base_url: str = "http://localhost:8000"):
"""初始化测试器"""
self.base_url = base_url
self.session = None
async def __aenter__(self):
"""异步上下文管理器入口"""
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
if self.session:
await self.session.close()
async def test_health_check(self):
"""测试健康检查"""
print("🧪 测试健康检查")
print("-" * 30)
try:
async with self.session.get(f"{self.base_url}/health") as response:
if response.status == 200:
data = await response.json()
print(f"✅ 健康检查通过")
print(f"📊 状态: {data.get('status', 'unknown')}")
services = data.get('services', {})
for service, status in services.items():
service_status = status.get('status', 'unknown')
icon = "" if service_status == "healthy" else "⚠️"
print(f" {icon} {service}: {service_status}")
else:
print(f"❌ 健康检查失败: HTTP {response.status}")
except Exception as e:
print(f"❌ 健康检查异常: {e}")
async def test_create_session(self):
"""测试创建会话"""
print("\n🧪 测试创建会话")
print("-" * 30)
try:
payload = {}
async with self.session.post(
f"{self.base_url}/sessions",
json=payload
) as response:
if response.status == 200:
data = await response.json()
session_id = data.get('session_id')
print(f"✅ 会话创建成功")
print(f"🆔 会话ID: {session_id[:8]}...")
return session_id
else:
print(f"❌ 会话创建失败: HTTP {response.status}")
return None
except Exception as e:
print(f"❌ 会话创建异常: {e}")
return None
async def test_simple_chat(self, session_id: str = None):
"""测试简单聊天"""
print("\n🧪 测试简单聊天")
print("-" * 30)
try:
payload = {
"message": "你好,我是测试用户",
"session_id": session_id,
"stream": False
}
async with self.session.post(
f"{self.base_url}/chat/simple",
json=payload
) as response:
if response.status == 200:
data = await response.json()
print(f"✅ 简单聊天成功")
print(f"🤖 AI回复: {data.get('response', '')[:100]}...")
print(f"🆔 会话ID: {data.get('session_id', '')[:8]}...")
else:
print(f"❌ 简单聊天失败: HTTP {response.status}")
except Exception as e:
print(f"❌ 简单聊天异常: {e}")
async def test_streaming_chat(self, session_id: str = None):
"""测试流式聊天"""
print("\n🧪 测试流式聊天")
print("-" * 30)
try:
payload = {
"message": "帮我查询订单 ORD12345",
"session_id": session_id,
"stream": True
}
print("📝 用户: 帮我查询订单 ORD12345")
print("🤖 AI: ", end="", flush=True)
async with self.session.post(
f"{self.base_url}/chat",
json=payload
) as response:
if response.status == 200:
chunk_count = 0
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('data: '):
data_str = line[6:] # 移除 'data: ' 前缀
if data_str == '[DONE]':
break
try:
data = json.loads(data_str)
result_type = data.get('type', '')
content = data.get('content', '')
if result_type == "chat_chunk":
print(content, end="", flush=True)
elif result_type == "final_response":
if content:
print(f"\n💬 最终回复: {content[:50]}...")
elif result_type == "error":
print(f"\n❌ 错误: {content}")
chunk_count += 1
except json.JSONDecodeError:
continue
print(f"\n✅ 流式聊天完成,收到 {chunk_count} 个数据块")
else:
print(f"❌ 流式聊天失败: HTTP {response.status}")
except Exception as e:
print(f"❌ 流式聊天异常: {e}")
async def test_list_sessions(self):
"""测试获取会话列表"""
print("\n🧪 测试获取会话列表")
print("-" * 30)
try:
async with self.session.get(f"{self.base_url}/sessions") as response:
if response.status == 200:
sessions = await response.json()
print(f"✅ 获取会话列表成功")
print(f"📋 活跃会话数: {len(sessions)}")
for i, session_id in enumerate(sessions[:5], 1): # 只显示前5个
print(f" {i}. {session_id[:8]}...")
if len(sessions) > 5:
print(f" ... 还有 {len(sessions) - 5} 个会话")
else:
print(f"❌ 获取会话列表失败: HTTP {response.status}")
except Exception as e:
print(f"❌ 获取会话列表异常: {e}")
async def test_get_models(self):
"""测试获取模型信息"""
print("\n🧪 测试获取模型信息")
print("-" * 30)
try:
async with self.session.get(f"{self.base_url}/models") as response:
if response.status == 200:
data = await response.json()
print(f"✅ 获取模型信息成功")
router_model = data.get('router_model', {})
chat_model = data.get('chat_model', {})
print(f"🎯 路由模型: {router_model.get('name', 'unknown')}")
print(f"💬 聊天模型: {chat_model.get('name', 'unknown')}")
health_status = data.get('health_status', {})
print(f"🏥 健康状态: {health_status.get('status', 'unknown')}")
else:
print(f"❌ 获取模型信息失败: HTTP {response.status}")
except Exception as e:
print(f"❌ 获取模型信息异常: {e}")
async def test_get_config(self):
"""测试获取配置信息"""
print("\n🧪 测试获取配置信息")
print("-" * 30)
try:
async with self.session.get(f"{self.base_url}/config") as response:
if response.status == 200:
data = await response.json()
print(f"✅ 获取配置信息成功")
print(f"📱 应用名称: {data.get('app_name', 'unknown')}")
print(f"🔢 应用版本: {data.get('app_version', 'unknown')}")
print(f"🐛 调试模式: {data.get('debug', False)}")
print(f"🔬 LangSmith: {'启用' if data.get('langsmith_enabled') else '禁用'}")
else:
print(f"❌ 获取配置信息失败: HTTP {response.status}")
except Exception as e:
print(f"❌ 获取配置信息异常: {e}")
async def test_server_running():
"""测试服务器是否运行"""
print("🔍 检查服务器状态...")
try:
async with aiohttp.ClientSession() as session:
async with session.get("http://localhost:8000/") as response:
if response.status == 200:
data = await response.json()
print(f"✅ 服务器运行正常")
print(f"📱 {data.get('message', 'LangChain 聊天机器人 API')}")
return True
else:
print(f"⚠️ 服务器响应异常: HTTP {response.status}")
return False
except Exception as e:
print(f"❌ 无法连接到服务器: {e}")
print("💡 请确保服务器正在运行: python main.py")
return False
async def main():
"""主测试函数"""
print("🚀 开始测试 FastAPI 接口")
print("=" * 50)
# 检查服务器状态
if not await test_server_running():
print("\n💡 启动服务器命令: python main.py")
return
# 运行测试
async with APITester() as tester:
try:
# 基础测试
await tester.test_health_check()
await tester.test_get_config()
await tester.test_get_models()
# 会话测试
session_id = await tester.test_create_session()
await tester.test_list_sessions()
# 聊天测试
await tester.test_simple_chat(session_id)
await tester.test_streaming_chat(session_id)
print("\n🎊 所有 API 测试通过!")
except Exception as e:
print(f"\n💥 测试过程中出现错误: {e}")
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())