403 lines
12 KiB
Python
403 lines
12 KiB
Python
"""
|
|
FastAPI 流式对话接口
|
|
提供 HTTP API 服务
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from typing import Dict, Any, List, Optional
|
|
from datetime import datetime
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
|
|
from src.models.chat_models import ChatMessage, MessageRole, APIRequest, APIResponse
|
|
from src.workflows.chat_workflow import (
|
|
get_chat_workflow,
|
|
create_chat_session,
|
|
process_message_streaming,
|
|
process_message_simple,
|
|
get_state_manager
|
|
)
|
|
from src.services.ollama_service import get_ollama_service
|
|
from src.services.langsmith_service import get_langsmith_service
|
|
from src.utils.config import get_config
|
|
from src.utils.logger import get_api_logger
|
|
|
|
# 获取配置和服务
|
|
config = get_config()
|
|
api_logger = get_api_logger()
|
|
langsmith_service = get_langsmith_service()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""应用生命周期管理"""
|
|
# 启动时初始化
|
|
api_logger.info("🚀 启动 FastAPI 应用")
|
|
|
|
# 检查 Ollama 服务健康状态
|
|
ollama_service = get_ollama_service()
|
|
try:
|
|
health_status = await ollama_service.health_check()
|
|
if health_status["status"] == "healthy":
|
|
api_logger.info("✅ Ollama 服务连接正常")
|
|
else:
|
|
api_logger.warning("⚠️ Ollama 服务连接异常")
|
|
except Exception as e:
|
|
api_logger.error(f"❌ Ollama 服务检查失败: {e}")
|
|
|
|
# 初始化工作流
|
|
try:
|
|
workflow = get_chat_workflow()
|
|
api_logger.info("✅ 工作流初始化完成")
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 工作流初始化失败: {e}")
|
|
|
|
yield
|
|
|
|
# 关闭时清理
|
|
api_logger.info("🛑 关闭 FastAPI 应用")
|
|
|
|
|
|
# 创建 FastAPI 应用
|
|
app = FastAPI(
|
|
title="LangChain 聊天机器人 API",
|
|
description="基于 LangChain 和 LangGraph 的智能对话机器人",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
# 添加 CORS 中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # 生产环境中应该限制具体域名
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# 请求模型
|
|
class ChatRequest(BaseModel):
|
|
"""聊天请求模型"""
|
|
message: str = Field(..., description="用户消息")
|
|
session_id: Optional[str] = Field(None, description="会话ID")
|
|
stream: bool = Field(True, description="是否使用流式输出")
|
|
chat_history: Optional[List[Dict[str, Any]]] = Field(None, description="对话历史")
|
|
|
|
|
|
class SessionRequest(BaseModel):
|
|
"""会话请求模型"""
|
|
session_id: Optional[str] = Field(None, description="指定会话ID")
|
|
|
|
|
|
# 响应模型
|
|
class ChatResponse(BaseModel):
|
|
"""聊天响应模型"""
|
|
session_id: str = Field(..., description="会话ID")
|
|
response: str = Field(..., description="回复内容")
|
|
timestamp: datetime = Field(..., description="时间戳")
|
|
metadata: Optional[Dict[str, Any]] = Field(None, description="元数据")
|
|
|
|
|
|
class SessionResponse(BaseModel):
|
|
"""会话响应模型"""
|
|
session_id: str = Field(..., description="会话ID")
|
|
created_at: datetime = Field(..., description="创建时间")
|
|
status: str = Field(..., description="状态")
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""健康检查响应模型"""
|
|
status: str = Field(..., description="服务状态")
|
|
timestamp: datetime = Field(..., description="检查时间")
|
|
services: Dict[str, Any] = Field(..., description="各服务状态")
|
|
|
|
|
|
@app.get("/", response_model=Dict[str, str])
|
|
async def root():
|
|
"""根路径"""
|
|
return {
|
|
"message": "LangChain 聊天机器人 API",
|
|
"version": "1.0.0",
|
|
"docs": "/docs"
|
|
}
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health_check():
|
|
"""健康检查"""
|
|
try:
|
|
# 检查 Ollama 服务
|
|
ollama_service = get_ollama_service()
|
|
ollama_status = await ollama_service.health_check()
|
|
|
|
# 检查工作流
|
|
workflow_status = {"status": "healthy"}
|
|
try:
|
|
workflow = get_chat_workflow()
|
|
workflow_status["status"] = "healthy"
|
|
except Exception as e:
|
|
workflow_status = {"status": "unhealthy", "error": str(e)}
|
|
|
|
# 检查 LangSmith
|
|
langsmith_status = {
|
|
"status": "healthy" if langsmith_service.is_enabled() else "disabled"
|
|
}
|
|
|
|
overall_status = "healthy"
|
|
if ollama_status["status"] != "healthy" or workflow_status["status"] != "healthy":
|
|
overall_status = "degraded"
|
|
|
|
return HealthResponse(
|
|
status=overall_status,
|
|
timestamp=datetime.now(),
|
|
services={
|
|
"ollama": ollama_status,
|
|
"workflow": workflow_status,
|
|
"langsmith": langsmith_status
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 健康检查失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"健康检查失败: {str(e)}")
|
|
|
|
|
|
@app.post("/sessions", response_model=SessionResponse)
|
|
async def create_session(request: SessionRequest):
|
|
"""创建新的聊天会话"""
|
|
try:
|
|
session_id = await create_chat_session(request.session_id)
|
|
|
|
api_logger.info(f"🆕 创建会话: {session_id}")
|
|
|
|
return SessionResponse(
|
|
session_id=session_id,
|
|
created_at=datetime.now(),
|
|
status="active"
|
|
)
|
|
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 创建会话失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"创建会话失败: {str(e)}")
|
|
|
|
|
|
@app.get("/sessions", response_model=List[str])
|
|
async def list_sessions():
|
|
"""获取活跃会话列表"""
|
|
try:
|
|
state_manager = get_state_manager()
|
|
sessions = state_manager.get_active_sessions()
|
|
|
|
api_logger.info(f"📋 获取会话列表: {len(sessions)} 个会话")
|
|
return sessions
|
|
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 获取会话列表失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取会话列表失败: {str(e)}")
|
|
|
|
|
|
@app.delete("/sessions/{session_id}")
|
|
async def delete_session(session_id: str):
|
|
"""删除会话"""
|
|
try:
|
|
state_manager = get_state_manager()
|
|
state_manager.clear_session(session_id)
|
|
|
|
api_logger.info(f"🗑️ 删除会话: {session_id}")
|
|
return {"message": f"会话 {session_id} 已删除"}
|
|
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 删除会话失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"删除会话失败: {str(e)}")
|
|
|
|
|
|
@app.post("/chat")
|
|
async def chat(request: ChatRequest):
|
|
"""聊天接口"""
|
|
try:
|
|
# 生成会话ID
|
|
session_id = request.session_id or str(uuid.uuid4())
|
|
|
|
api_logger.info(f"💬 收到聊天请求: {session_id}")
|
|
api_logger.debug(f"用户输入: {request.message}")
|
|
|
|
if request.stream:
|
|
# 流式响应
|
|
return StreamingResponse(
|
|
stream_chat_response(request.message, session_id, request.chat_history),
|
|
media_type="text/plain; charset=utf-8"
|
|
)
|
|
else:
|
|
# 非流式响应
|
|
result = await process_message_simple(
|
|
request.message,
|
|
session_id,
|
|
request.chat_history
|
|
)
|
|
|
|
return ChatResponse(
|
|
session_id=session_id,
|
|
response=result.get("response", ""),
|
|
timestamp=datetime.now(),
|
|
metadata=result
|
|
)
|
|
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 聊天处理失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"聊天处理失败: {str(e)}")
|
|
|
|
|
|
async def stream_chat_response(
|
|
message: str,
|
|
session_id: str,
|
|
chat_history: Optional[List[Dict[str, Any]]] = None
|
|
):
|
|
"""流式聊天响应生成器"""
|
|
try:
|
|
async for result in process_message_streaming(message, session_id, chat_history):
|
|
# 将结果转换为 JSON 字符串并发送
|
|
yield f"data: {json.dumps(result, ensure_ascii=False, default=str)}\n\n"
|
|
|
|
# 添加小延迟以确保流式效果
|
|
await asyncio.sleep(0.01)
|
|
|
|
# 发送结束标记
|
|
yield "data: [DONE]\n\n"
|
|
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 流式响应生成失败: {e}")
|
|
error_data = {
|
|
"type": "error",
|
|
"content": f"流式响应生成失败: {str(e)}",
|
|
"session_id": session_id
|
|
}
|
|
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
@app.post("/chat/simple", response_model=ChatResponse)
|
|
async def chat_simple(request: ChatRequest):
|
|
"""简单聊天接口(非流式)"""
|
|
try:
|
|
session_id = request.session_id or str(uuid.uuid4())
|
|
|
|
api_logger.info(f"💬 收到简单聊天请求: {session_id}")
|
|
|
|
result = await process_message_simple(
|
|
request.message,
|
|
session_id,
|
|
request.chat_history
|
|
)
|
|
|
|
return ChatResponse(
|
|
session_id=session_id,
|
|
response=result.get("response", ""),
|
|
timestamp=datetime.now(),
|
|
metadata=result
|
|
)
|
|
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 简单聊天处理失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"简单聊天处理失败: {str(e)}")
|
|
|
|
|
|
@app.get("/models")
|
|
async def list_models():
|
|
"""获取可用模型列表"""
|
|
try:
|
|
ollama_service = get_ollama_service()
|
|
|
|
models_info = {
|
|
"router_model": {
|
|
"name": config.router_model_name,
|
|
"base_url": config.router_model_base_url,
|
|
"type": "router"
|
|
},
|
|
"chat_model": {
|
|
"name": config.chat_model_name,
|
|
"base_url": config.chat_model_base_url,
|
|
"type": "chat"
|
|
}
|
|
}
|
|
|
|
# 检查模型健康状态
|
|
health_status = await ollama_service.health_check()
|
|
models_info["health_status"] = health_status
|
|
|
|
return models_info
|
|
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 获取模型列表失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
|
|
|
|
|
|
@app.get("/config")
|
|
async def get_app_config():
|
|
"""获取应用配置信息"""
|
|
try:
|
|
config_info = {
|
|
"app_name": config.app_name,
|
|
"app_version": config.app_version,
|
|
"debug": config.debug,
|
|
"api_host": config.api_host,
|
|
"api_port": config.api_port,
|
|
"langsmith_enabled": langsmith_service.is_enabled(),
|
|
"models": {
|
|
"router_model": config.router_model_name,
|
|
"chat_model": config.chat_model_name
|
|
}
|
|
}
|
|
|
|
return config_info
|
|
|
|
except Exception as e:
|
|
api_logger.error(f"❌ 获取配置失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取配置失败: {str(e)}")
|
|
|
|
|
|
# 错误处理
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request, exc):
|
|
"""HTTP 异常处理器"""
|
|
api_logger.error(f"HTTP 异常: {exc.status_code} - {exc.detail}")
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={"error": exc.detail, "timestamp": datetime.now().isoformat()}
|
|
)
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
async def general_exception_handler(request, exc):
|
|
"""通用异常处理器"""
|
|
api_logger.error(f"未处理的异常: {str(exc)}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"error": "内部服务器错误", "timestamp": datetime.now().isoformat()}
|
|
)
|
|
|
|
|
|
# 后台任务
|
|
async def cleanup_old_sessions():
|
|
"""清理旧会话的后台任务"""
|
|
state_manager = get_state_manager()
|
|
# 这里可以添加清理逻辑
|
|
api_logger.info("🧹 执行会话清理任务")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
# 启动服务器
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=config.api_host,
|
|
port=config.api_port,
|
|
reload=config.debug,
|
|
log_level="info" if not config.debug else "debug"
|
|
) |