ai-courseware/langchain-project/run_tests.py

196 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试运行器 - 运行所有测试
"""
import asyncio
import sys
import os
import subprocess
from typing import List, Tuple
# 添加项目根目录到 Python 路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
class TestRunner:
"""测试运行器"""
def __init__(self):
"""初始化测试运行器"""
self.test_files = [
("Ollama 服务测试", "tests/test_ollama.py"),
("简单对话链测试", "tests/test_simple_chain.py"),
("LangGraph 工作流测试", "tests/test_workflow.py"),
("FastAPI 接口测试", "tests/test_api.py")
]
self.results: List[Tuple[str, bool, str]] = []
def print_header(self, title: str):
"""打印测试标题"""
print("\n" + "=" * 60)
print(f"🧪 {title}")
print("=" * 60)
def print_separator(self):
"""打印分隔符"""
print("-" * 60)
async def run_test_file(self, name: str, file_path: str) -> Tuple[bool, str]:
"""运行单个测试文件"""
self.print_header(name)
if not os.path.exists(file_path):
error_msg = f"测试文件不存在: {file_path}"
print(f"{error_msg}")
return False, error_msg
try:
# 运行测试文件
process = await asyncio.create_subprocess_exec(
sys.executable, file_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
# 输出结果
if stdout:
print(stdout.decode('utf-8'))
if stderr:
print("⚠️ 错误输出:")
print(stderr.decode('utf-8'))
# 检查返回码
if process.returncode == 0:
print(f"{name} 测试通过")
return True, "测试通过"
else:
error_msg = f"测试失败,返回码: {process.returncode}"
print(f"{name} {error_msg}")
return False, error_msg
except Exception as e:
error_msg = f"运行测试时出现异常: {e}"
print(f"{name} {error_msg}")
return False, error_msg
def check_dependencies(self):
"""检查依赖"""
print("🔍 检查项目依赖...")
required_files = [
"requirements.txt",
".env",
"src/services/ollama_service.py",
"src/workflows/chat_workflow.py",
"src/chains/simple_chat_chain.py",
"main.py"
]
missing_files = []
for file_path in required_files:
if not os.path.exists(file_path):
missing_files.append(file_path)
if missing_files:
print("❌ 缺少以下必要文件:")
for file_path in missing_files:
print(f" - {file_path}")
return False
print("✅ 所有必要文件都存在")
return True
def check_environment(self):
"""检查环境变量"""
print("\n🔍 检查环境变量...")
required_env_vars = [
"OLLAMA_BASE_URL",
"CHAT_MODEL",
"ROUTER_MODEL"
]
missing_vars = []
for var in required_env_vars:
if not os.getenv(var):
missing_vars.append(var)
if missing_vars:
print("⚠️ 缺少以下环境变量:")
for var in missing_vars:
print(f" - {var}")
print("💡 请检查 .env 文件是否正确配置")
else:
print("✅ 所有必要环境变量都已设置")
return len(missing_vars) == 0
def print_summary(self):
"""打印测试总结"""
self.print_header("测试总结")
total_tests = len(self.results)
passed_tests = sum(1 for _, success, _ in self.results if success)
failed_tests = total_tests - passed_tests
print(f"📊 总测试数: {total_tests}")
print(f"✅ 通过: {passed_tests}")
print(f"❌ 失败: {failed_tests}")
print(f"📈 成功率: {(passed_tests/total_tests*100):.1f}%")
if failed_tests > 0:
print("\n❌ 失败的测试:")
for name, success, error in self.results:
if not success:
print(f" - {name}: {error}")
print("\n" + "=" * 60)
if failed_tests == 0:
print("🎊 所有测试都通过了!项目可以正常运行。")
else:
print("⚠️ 有测试失败,请检查相关组件。")
print("=" * 60)
async def run_all_tests(self):
"""运行所有测试"""
print("🚀 开始运行所有测试")
print("=" * 60)
# 检查依赖和环境
if not self.check_dependencies():
print("❌ 依赖检查失败,无法继续测试")
return
self.check_environment()
# 运行测试
for name, file_path in self.test_files:
success, error = await self.run_test_file(name, file_path)
self.results.append((name, success, error))
# 在测试之间添加延迟,避免资源冲突
await asyncio.sleep(1)
# 打印总结
self.print_summary()
async def main():
"""主函数"""
runner = TestRunner()
await runner.run_all_tests()
if __name__ == "__main__":
# 加载环境变量
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
print("⚠️ 未安装 python-dotenv跳过 .env 文件加载")
asyncio.run(main())