250 lines
8.3 KiB
Go
250 lines
8.3 KiB
Go
package router
|
||
|
||
import (
|
||
"time"
|
||
|
||
"github.com/gin-contrib/cors"
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/dig"
|
||
|
||
"knowlege-lsxd/internal/config"
|
||
"knowlege-lsxd/internal/handler"
|
||
"knowlege-lsxd/internal/middleware"
|
||
"knowlege-lsxd/internal/types/interfaces"
|
||
)
|
||
|
||
// RouterParams 路由参数
|
||
type RouterParams struct {
|
||
dig.In
|
||
|
||
Config *config.Config
|
||
KBHandler *handler.KnowledgeBaseHandler
|
||
KnowledgeHandler *handler.KnowledgeHandler
|
||
TenantHandler *handler.TenantHandler
|
||
TenantService interfaces.TenantService
|
||
ChunkHandler *handler.ChunkHandler
|
||
SessionHandler *handler.SessionHandler
|
||
MessageHandler *handler.MessageHandler
|
||
TestDataHandler *handler.TestDataHandler
|
||
ModelHandler *handler.ModelHandler
|
||
EvaluationHandler *handler.EvaluationHandler
|
||
InitializationHandler *handler.InitializationHandler
|
||
}
|
||
|
||
// NewRouter 创建新的路由
|
||
func NewRouter(params RouterParams) *gin.Engine {
|
||
r := gin.New()
|
||
|
||
// CORS 中间件应放在最前面
|
||
r.Use(cors.New(cors.Config{
|
||
AllowOrigins: []string{"*"},
|
||
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-API-Key", "X-Request-ID"},
|
||
ExposeHeaders: []string{"Content-Length", "Access-Control-Allow-Origin"},
|
||
AllowCredentials: true,
|
||
MaxAge: 12 * time.Hour,
|
||
}))
|
||
|
||
// 其他中间件
|
||
r.Use(middleware.RequestID())
|
||
r.Use(middleware.Logger())
|
||
r.Use(middleware.Recovery())
|
||
r.Use(middleware.ErrorHandler())
|
||
r.Use(middleware.Auth(params.TenantService, params.Config))
|
||
|
||
// 添加OpenTelemetry追踪中间件
|
||
//r.Use(middleware.TracingMiddleware())
|
||
|
||
// 健康检查
|
||
r.GET("/health", func(c *gin.Context) {
|
||
c.JSON(200, gin.H{"status": "ok"})
|
||
})
|
||
|
||
// 测试数据接口(不需要认证)
|
||
r.GET("/api/v1/test-data", params.TestDataHandler.GetTestData)
|
||
|
||
// 初始化接口(不需要认证)
|
||
r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus)
|
||
r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig)
|
||
r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize)
|
||
|
||
// Ollama相关接口(不需要认证)
|
||
r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus)
|
||
r.GET("/api/v1/initialization/ollama/models", params.InitializationHandler.ListOllamaModels)
|
||
r.POST("/api/v1/initialization/ollama/models/check", params.InitializationHandler.CheckOllamaModels)
|
||
r.POST("/api/v1/initialization/ollama/models/download", params.InitializationHandler.DownloadOllamaModel)
|
||
r.GET("/api/v1/initialization/ollama/download/progress/:taskId", params.InitializationHandler.GetDownloadProgress)
|
||
r.GET("/api/v1/initialization/ollama/download/tasks", params.InitializationHandler.ListDownloadTasks)
|
||
|
||
// 远程API相关接口(不需要认证)
|
||
r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
|
||
r.POST("/api/v1/initialization/embedding/test", params.InitializationHandler.TestEmbeddingModel)
|
||
r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
|
||
r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)
|
||
|
||
// 需要认证的API路由
|
||
v1 := r.Group("/api/v1")
|
||
{
|
||
RegisterTenantRoutes(v1, params.TenantHandler)
|
||
RegisterKnowledgeBaseRoutes(v1, params.KBHandler)
|
||
RegisterKnowledgeRoutes(v1, params.KnowledgeHandler)
|
||
RegisterChunkRoutes(v1, params.ChunkHandler)
|
||
RegisterSessionRoutes(v1, params.SessionHandler)
|
||
RegisterChatRoutes(v1, params.SessionHandler)
|
||
RegisterMessageRoutes(v1, params.MessageHandler)
|
||
RegisterModelRoutes(v1, params.ModelHandler)
|
||
RegisterEvaluationRoutes(v1, params.EvaluationHandler)
|
||
}
|
||
|
||
return r
|
||
}
|
||
|
||
// RegisterChunkRoutes 注册分块相关的路由
|
||
func RegisterChunkRoutes(r *gin.RouterGroup, handler *handler.ChunkHandler) {
|
||
// 分块路由组
|
||
chunks := r.Group("/chunks")
|
||
{
|
||
// 获取分块列表
|
||
chunks.GET("/:knowledge_id", handler.ListKnowledgeChunks)
|
||
// 删除分块
|
||
chunks.DELETE("/:knowledge_id/:id", handler.DeleteChunk)
|
||
// 删除知识下的所有分块
|
||
chunks.DELETE("/:knowledge_id", handler.DeleteChunksByKnowledgeID)
|
||
// 更新分块信息
|
||
chunks.PUT("/:knowledge_id/:id", handler.UpdateChunk)
|
||
}
|
||
}
|
||
|
||
// RegisterKnowledgeRoutes 注册知识相关的路由
|
||
func RegisterKnowledgeRoutes(r *gin.RouterGroup, handler *handler.KnowledgeHandler) {
|
||
// 知识库下的知识路由组
|
||
kb := r.Group("/knowledge-bases/:id/knowledge")
|
||
{
|
||
// 从文件创建知识
|
||
kb.POST("/file", handler.CreateKnowledgeFromFile)
|
||
// 从URL创建知识
|
||
kb.POST("/url", handler.CreateKnowledgeFromURL)
|
||
// 获取知识库下的知识列表
|
||
kb.GET("", handler.ListKnowledge)
|
||
}
|
||
|
||
// 知识路由组
|
||
k := r.Group("/knowledge")
|
||
{
|
||
// 批量获取知识
|
||
k.GET("/batch", handler.GetKnowledgeBatch)
|
||
// 获取知识详情
|
||
k.GET("/:id", handler.GetKnowledge)
|
||
// 删除知识
|
||
k.DELETE("/:id", handler.DeleteKnowledge)
|
||
// 更新知识
|
||
k.PUT("/:id", handler.UpdateKnowledge)
|
||
// 获取知识文件
|
||
k.GET("/:id/download", handler.DownloadKnowledgeFile)
|
||
// 更新图像分块信息
|
||
k.PUT("/image/:id/:chunk_id", handler.UpdateImageInfo)
|
||
}
|
||
}
|
||
|
||
// RegisterKnowledgeBaseRoutes 注册知识库相关的路由
|
||
func RegisterKnowledgeBaseRoutes(r *gin.RouterGroup, handler *handler.KnowledgeBaseHandler) {
|
||
// 知识库路由组
|
||
kb := r.Group("/knowledge-bases")
|
||
{
|
||
// 创建知识库
|
||
kb.POST("", handler.CreateKnowledgeBase)
|
||
// 获取知识库列表
|
||
kb.GET("", handler.ListKnowledgeBases)
|
||
// 获取知识库详情
|
||
kb.GET("/:id", handler.GetKnowledgeBase)
|
||
// 更新知识库
|
||
kb.PUT("/:id", handler.UpdateKnowledgeBase)
|
||
// 删除知识库
|
||
kb.DELETE("/:id", handler.DeleteKnowledgeBase)
|
||
// 混合搜索
|
||
kb.GET("/:id/hybrid-search", handler.HybridSearch)
|
||
// 拷贝知识库
|
||
kb.POST("/copy", handler.CopyKnowledgeBase)
|
||
}
|
||
}
|
||
|
||
// RegisterMessageRoutes 注册消息相关的路由
|
||
func RegisterMessageRoutes(r *gin.RouterGroup, handler *handler.MessageHandler) {
|
||
// 消息路由组
|
||
messages := r.Group("/messages")
|
||
{
|
||
// 加载更早的消息,用于向上滚动加载
|
||
messages.GET("/:session_id/load", handler.LoadMessages)
|
||
// 删除消息
|
||
messages.DELETE("/:session_id/:id", handler.DeleteMessage)
|
||
}
|
||
}
|
||
|
||
// RegisterSessionRoutes 注册路由
|
||
func RegisterSessionRoutes(r *gin.RouterGroup, handler *handler.SessionHandler) {
|
||
sessions := r.Group("/sessions")
|
||
{
|
||
sessions.POST("", handler.CreateSession)
|
||
sessions.GET("/:id", handler.GetSession)
|
||
sessions.GET("", handler.GetSessionsByTenant)
|
||
sessions.PUT("/:id", handler.UpdateSession)
|
||
sessions.DELETE("/:id", handler.DeleteSession)
|
||
sessions.POST("/:session_id/generate_title", handler.GenerateTitle)
|
||
// 继续接收活跃流
|
||
sessions.GET("/continue-stream/:session_id", handler.ContinueStream)
|
||
}
|
||
}
|
||
|
||
// RegisterChatRoutes 注册路由
|
||
func RegisterChatRoutes(r *gin.RouterGroup, handler *handler.SessionHandler) {
|
||
knowledgeChat := r.Group("/knowledge-chat")
|
||
{
|
||
knowledgeChat.POST("/:session_id", handler.KnowledgeQA)
|
||
}
|
||
|
||
// 新增知识检索接口,不需要session_id
|
||
knowledgeSearch := r.Group("/knowledge-search")
|
||
{
|
||
knowledgeSearch.POST("", handler.SearchKnowledge)
|
||
}
|
||
}
|
||
|
||
// RegisterTenantRoutes 注册租户相关的路由
|
||
func RegisterTenantRoutes(r *gin.RouterGroup, handler *handler.TenantHandler) {
|
||
// 租户路由组
|
||
tenantRoutes := r.Group("/tenants")
|
||
{
|
||
tenantRoutes.POST("", handler.CreateTenant)
|
||
tenantRoutes.GET("/:id", handler.GetTenant)
|
||
tenantRoutes.PUT("/:id", handler.UpdateTenant)
|
||
tenantRoutes.DELETE("/:id", handler.DeleteTenant)
|
||
tenantRoutes.GET("", handler.ListTenants)
|
||
}
|
||
}
|
||
|
||
// RegisterModelRoutes 注册模型相关的路由
|
||
func RegisterModelRoutes(r *gin.RouterGroup, handler *handler.ModelHandler) {
|
||
// 模型路由组
|
||
models := r.Group("/models")
|
||
{
|
||
// 创建模型
|
||
models.POST("", handler.CreateModel)
|
||
// 获取模型列表
|
||
models.GET("", handler.ListModels)
|
||
// 获取单个模型
|
||
models.GET("/:id", handler.GetModel)
|
||
// 更新模型
|
||
models.PUT("/:id", handler.UpdateModel)
|
||
// 删除模型
|
||
models.DELETE("/:id", handler.DeleteModel)
|
||
}
|
||
}
|
||
|
||
func RegisterEvaluationRoutes(r *gin.RouterGroup, handler *handler.EvaluationHandler) {
|
||
evaluationRoutes := r.Group("/evaluation")
|
||
{
|
||
evaluationRoutes.POST("/", handler.Evaluation)
|
||
evaluationRoutes.GET("/", handler.GetEvaluationResult)
|
||
}
|
||
}
|