250 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			250 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Go
		
	
	
	
package router
 | 
						||
 | 
						||
import (
 | 
						||
	"time"
 | 
						||
 | 
						||
	"github.com/gin-contrib/cors"
 | 
						||
	"github.com/gin-gonic/gin"
 | 
						||
	"go.uber.org/dig"
 | 
						||
 | 
						||
	"github.com/Tencent/WeKnora/internal/config"
 | 
						||
	"github.com/Tencent/WeKnora/internal/handler"
 | 
						||
	"github.com/Tencent/WeKnora/internal/middleware"
 | 
						||
	"github.com/Tencent/WeKnora/internal/types/interfaces"
 | 
						||
)
 | 
						||
 | 
						||
// RouterParams 路由参数
 | 
						||
type RouterParams struct {
 | 
						||
	dig.In
 | 
						||
 | 
						||
	Config                *config.Config
 | 
						||
	KBHandler             *handler.KnowledgeBaseHandler
 | 
						||
	KnowledgeHandler      *handler.KnowledgeHandler
 | 
						||
	TenantHandler         *handler.TenantHandler
 | 
						||
	TenantService         interfaces.TenantService
 | 
						||
	ChunkHandler          *handler.ChunkHandler
 | 
						||
	SessionHandler        *handler.SessionHandler
 | 
						||
	MessageHandler        *handler.MessageHandler
 | 
						||
	TestDataHandler       *handler.TestDataHandler
 | 
						||
	ModelHandler          *handler.ModelHandler
 | 
						||
	EvaluationHandler     *handler.EvaluationHandler
 | 
						||
	InitializationHandler *handler.InitializationHandler
 | 
						||
}
 | 
						||
 | 
						||
// NewRouter 创建新的路由
 | 
						||
func NewRouter(params RouterParams) *gin.Engine {
 | 
						||
	r := gin.New()
 | 
						||
 | 
						||
	// CORS 中间件应放在最前面
 | 
						||
	r.Use(cors.New(cors.Config{
 | 
						||
		AllowOrigins:     []string{"*"},
 | 
						||
		AllowMethods:     []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
 | 
						||
		AllowHeaders:     []string{"Origin", "Content-Type", "Accept", "Authorization", "X-API-Key", "X-Request-ID"},
 | 
						||
		ExposeHeaders:    []string{"Content-Length", "Access-Control-Allow-Origin"},
 | 
						||
		AllowCredentials: true,
 | 
						||
		MaxAge:           12 * time.Hour,
 | 
						||
	}))
 | 
						||
 | 
						||
	// 其他中间件
 | 
						||
	r.Use(middleware.RequestID())
 | 
						||
	r.Use(middleware.Logger())
 | 
						||
	r.Use(middleware.Recovery())
 | 
						||
	r.Use(middleware.ErrorHandler())
 | 
						||
	r.Use(middleware.Auth(params.TenantService, params.Config))
 | 
						||
 | 
						||
	// 添加OpenTelemetry追踪中间件
 | 
						||
	r.Use(middleware.TracingMiddleware())
 | 
						||
 | 
						||
	// 健康检查
 | 
						||
	r.GET("/health", func(c *gin.Context) {
 | 
						||
		c.JSON(200, gin.H{"status": "ok"})
 | 
						||
	})
 | 
						||
 | 
						||
	// 测试数据接口(不需要认证)
 | 
						||
	r.GET("/api/v1/test-data", params.TestDataHandler.GetTestData)
 | 
						||
 | 
						||
	// 初始化接口(不需要认证)
 | 
						||
	r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus)
 | 
						||
	r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig)
 | 
						||
	r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize)
 | 
						||
 | 
						||
	// Ollama相关接口(不需要认证)
 | 
						||
	r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus)
 | 
						||
	r.GET("/api/v1/initialization/ollama/models", params.InitializationHandler.ListOllamaModels)
 | 
						||
	r.POST("/api/v1/initialization/ollama/models/check", params.InitializationHandler.CheckOllamaModels)
 | 
						||
	r.POST("/api/v1/initialization/ollama/models/download", params.InitializationHandler.DownloadOllamaModel)
 | 
						||
	r.GET("/api/v1/initialization/ollama/download/progress/:taskId", params.InitializationHandler.GetDownloadProgress)
 | 
						||
	r.GET("/api/v1/initialization/ollama/download/tasks", params.InitializationHandler.ListDownloadTasks)
 | 
						||
 | 
						||
	// 远程API相关接口(不需要认证)
 | 
						||
	r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
 | 
						||
	r.POST("/api/v1/initialization/embedding/test", params.InitializationHandler.TestEmbeddingModel)
 | 
						||
	r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
 | 
						||
	r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)
 | 
						||
 | 
						||
	// 需要认证的API路由
 | 
						||
	v1 := r.Group("/api/v1")
 | 
						||
	{
 | 
						||
		RegisterTenantRoutes(v1, params.TenantHandler)
 | 
						||
		RegisterKnowledgeBaseRoutes(v1, params.KBHandler)
 | 
						||
		RegisterKnowledgeRoutes(v1, params.KnowledgeHandler)
 | 
						||
		RegisterChunkRoutes(v1, params.ChunkHandler)
 | 
						||
		RegisterSessionRoutes(v1, params.SessionHandler)
 | 
						||
		RegisterChatRoutes(v1, params.SessionHandler)
 | 
						||
		RegisterMessageRoutes(v1, params.MessageHandler)
 | 
						||
		RegisterModelRoutes(v1, params.ModelHandler)
 | 
						||
		RegisterEvaluationRoutes(v1, params.EvaluationHandler)
 | 
						||
	}
 | 
						||
 | 
						||
	return r
 | 
						||
}
 | 
						||
 | 
						||
// RegisterChunkRoutes 注册分块相关的路由
 | 
						||
func RegisterChunkRoutes(r *gin.RouterGroup, handler *handler.ChunkHandler) {
 | 
						||
	// 分块路由组
 | 
						||
	chunks := r.Group("/chunks")
 | 
						||
	{
 | 
						||
		// 获取分块列表
 | 
						||
		chunks.GET("/:knowledge_id", handler.ListKnowledgeChunks)
 | 
						||
		// 删除分块
 | 
						||
		chunks.DELETE("/:knowledge_id/:id", handler.DeleteChunk)
 | 
						||
		// 删除知识下的所有分块
 | 
						||
		chunks.DELETE("/:knowledge_id", handler.DeleteChunksByKnowledgeID)
 | 
						||
		// 更新分块信息
 | 
						||
		chunks.PUT("/:knowledge_id/:id", handler.UpdateChunk)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// RegisterKnowledgeRoutes 注册知识相关的路由
 | 
						||
func RegisterKnowledgeRoutes(r *gin.RouterGroup, handler *handler.KnowledgeHandler) {
 | 
						||
	// 知识库下的知识路由组
 | 
						||
	kb := r.Group("/knowledge-bases/:id/knowledge")
 | 
						||
	{
 | 
						||
		// 从文件创建知识
 | 
						||
		kb.POST("/file", handler.CreateKnowledgeFromFile)
 | 
						||
		// 从URL创建知识
 | 
						||
		kb.POST("/url", handler.CreateKnowledgeFromURL)
 | 
						||
		// 获取知识库下的知识列表
 | 
						||
		kb.GET("", handler.ListKnowledge)
 | 
						||
	}
 | 
						||
 | 
						||
	// 知识路由组
 | 
						||
	k := r.Group("/knowledge")
 | 
						||
	{
 | 
						||
		// 批量获取知识
 | 
						||
		k.GET("/batch", handler.GetKnowledgeBatch)
 | 
						||
		// 获取知识详情
 | 
						||
		k.GET("/:id", handler.GetKnowledge)
 | 
						||
		// 删除知识
 | 
						||
		k.DELETE("/:id", handler.DeleteKnowledge)
 | 
						||
		// 更新知识
 | 
						||
		k.PUT("/:id", handler.UpdateKnowledge)
 | 
						||
		// 获取知识文件
 | 
						||
		k.GET("/:id/download", handler.DownloadKnowledgeFile)
 | 
						||
		// 更新图像分块信息
 | 
						||
		k.PUT("/image/:id/:chunk_id", handler.UpdateImageInfo)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// RegisterKnowledgeBaseRoutes 注册知识库相关的路由
 | 
						||
func RegisterKnowledgeBaseRoutes(r *gin.RouterGroup, handler *handler.KnowledgeBaseHandler) {
 | 
						||
	// 知识库路由组
 | 
						||
	kb := r.Group("/knowledge-bases")
 | 
						||
	{
 | 
						||
		// 创建知识库
 | 
						||
		kb.POST("", handler.CreateKnowledgeBase)
 | 
						||
		// 获取知识库列表
 | 
						||
		kb.GET("", handler.ListKnowledgeBases)
 | 
						||
		// 获取知识库详情
 | 
						||
		kb.GET("/:id", handler.GetKnowledgeBase)
 | 
						||
		// 更新知识库
 | 
						||
		kb.PUT("/:id", handler.UpdateKnowledgeBase)
 | 
						||
		// 删除知识库
 | 
						||
		kb.DELETE("/:id", handler.DeleteKnowledgeBase)
 | 
						||
		// 混合搜索
 | 
						||
		kb.GET("/:id/hybrid-search", handler.HybridSearch)
 | 
						||
		// 拷贝知识库
 | 
						||
		kb.POST("/copy", handler.CopyKnowledgeBase)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// RegisterMessageRoutes 注册消息相关的路由
 | 
						||
func RegisterMessageRoutes(r *gin.RouterGroup, handler *handler.MessageHandler) {
 | 
						||
	// 消息路由组
 | 
						||
	messages := r.Group("/messages")
 | 
						||
	{
 | 
						||
		// 加载更早的消息,用于向上滚动加载
 | 
						||
		messages.GET("/:session_id/load", handler.LoadMessages)
 | 
						||
		// 删除消息
 | 
						||
		messages.DELETE("/:session_id/:id", handler.DeleteMessage)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// RegisterSessionRoutes 注册路由
 | 
						||
func RegisterSessionRoutes(r *gin.RouterGroup, handler *handler.SessionHandler) {
 | 
						||
	sessions := r.Group("/sessions")
 | 
						||
	{
 | 
						||
		sessions.POST("", handler.CreateSession)
 | 
						||
		sessions.GET("/:id", handler.GetSession)
 | 
						||
		sessions.GET("", handler.GetSessionsByTenant)
 | 
						||
		sessions.PUT("/:id", handler.UpdateSession)
 | 
						||
		sessions.DELETE("/:id", handler.DeleteSession)
 | 
						||
		sessions.POST("/:session_id/generate_title", handler.GenerateTitle)
 | 
						||
		// 继续接收活跃流
 | 
						||
		sessions.GET("/continue-stream/:session_id", handler.ContinueStream)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// RegisterChatRoutes 注册路由
 | 
						||
func RegisterChatRoutes(r *gin.RouterGroup, handler *handler.SessionHandler) {
 | 
						||
	knowledgeChat := r.Group("/knowledge-chat")
 | 
						||
	{
 | 
						||
		knowledgeChat.POST("/:session_id", handler.KnowledgeQA)
 | 
						||
	}
 | 
						||
 | 
						||
	// 新增知识检索接口,不需要session_id
 | 
						||
	knowledgeSearch := r.Group("/knowledge-search")
 | 
						||
	{
 | 
						||
		knowledgeSearch.POST("", handler.SearchKnowledge)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// RegisterTenantRoutes 注册租户相关的路由
 | 
						||
func RegisterTenantRoutes(r *gin.RouterGroup, handler *handler.TenantHandler) {
 | 
						||
	// 租户路由组
 | 
						||
	tenantRoutes := r.Group("/tenants")
 | 
						||
	{
 | 
						||
		tenantRoutes.POST("", handler.CreateTenant)
 | 
						||
		tenantRoutes.GET("/:id", handler.GetTenant)
 | 
						||
		tenantRoutes.PUT("/:id", handler.UpdateTenant)
 | 
						||
		tenantRoutes.DELETE("/:id", handler.DeleteTenant)
 | 
						||
		tenantRoutes.GET("", handler.ListTenants)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// RegisterModelRoutes 注册模型相关的路由
 | 
						||
func RegisterModelRoutes(r *gin.RouterGroup, handler *handler.ModelHandler) {
 | 
						||
	// 模型路由组
 | 
						||
	models := r.Group("/models")
 | 
						||
	{
 | 
						||
		// 创建模型
 | 
						||
		models.POST("", handler.CreateModel)
 | 
						||
		// 获取模型列表
 | 
						||
		models.GET("", handler.ListModels)
 | 
						||
		// 获取单个模型
 | 
						||
		models.GET("/:id", handler.GetModel)
 | 
						||
		// 更新模型
 | 
						||
		models.PUT("/:id", handler.UpdateModel)
 | 
						||
		// 删除模型
 | 
						||
		models.DELETE("/:id", handler.DeleteModel)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func RegisterEvaluationRoutes(r *gin.RouterGroup, handler *handler.EvaluationHandler) {
 | 
						||
	evaluationRoutes := r.Group("/evaluation")
 | 
						||
	{
 | 
						||
		evaluationRoutes.POST("/", handler.Evaluation)
 | 
						||
		evaluationRoutes.GET("/", handler.GetEvaluationResult)
 | 
						||
	}
 | 
						||
}
 |