package handler import ( "context" "encoding/json" "fmt" "io" "net/http" "os" "strings" "sync" "time" "strconv" "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/errors" "github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/models/embedding" "github.com/Tencent/WeKnora/internal/models/utils/ollama" "github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types/interfaces" "github.com/Tencent/WeKnora/services/docreader/src/client" "github.com/Tencent/WeKnora/services/docreader/src/proto" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/ollama/ollama/api" ) // DownloadTask 下载任务信息 type DownloadTask struct { ID string `json:"id"` ModelName string `json:"modelName"` Status string `json:"status"` // pending, downloading, completed, failed Progress float64 `json:"progress"` Message string `json:"message"` StartTime time.Time `json:"startTime"` EndTime *time.Time `json:"endTime,omitempty"` } // 全局下载任务管理器 var ( downloadTasks = make(map[string]*DownloadTask) tasksMutex sync.RWMutex ) // InitializationHandler 初始化处理器 type InitializationHandler struct { config *config.Config tenantService interfaces.TenantService modelService interfaces.ModelService kbService interfaces.KnowledgeBaseService kbRepository interfaces.KnowledgeBaseRepository knowledgeService interfaces.KnowledgeService ollamaService *ollama.OllamaService docReaderClient *client.Client } // NewInitializationHandler 创建初始化处理器 func NewInitializationHandler( config *config.Config, tenantService interfaces.TenantService, modelService interfaces.ModelService, kbService interfaces.KnowledgeBaseService, kbRepository interfaces.KnowledgeBaseRepository, knowledgeService interfaces.KnowledgeService, ollamaService *ollama.OllamaService, docReaderClient *client.Client, ) *InitializationHandler { return &InitializationHandler{ config: config, tenantService: tenantService, modelService: modelService, kbService: kbService, kbRepository: kbRepository, knowledgeService: knowledgeService, ollamaService: ollamaService, docReaderClient: docReaderClient, } } // InitializationRequest 初始化请求结构 type InitializationRequest struct { // 前端传入的存储类型:cos 或 minio StorageType string `json:"storageType"` LLM struct { Source string `json:"source" binding:"required"` ModelName string `json:"modelName" binding:"required"` BaseURL string `json:"baseUrl"` APIKey string `json:"apiKey"` } `json:"llm" binding:"required"` Embedding struct { Source string `json:"source" binding:"required"` ModelName string `json:"modelName" binding:"required"` BaseURL string `json:"baseUrl"` APIKey string `json:"apiKey"` Dimension int `json:"dimension"` // 添加embedding维度字段 } `json:"embedding" binding:"required"` Rerank struct { Enabled bool `json:"enabled"` ModelName string `json:"modelName"` BaseURL string `json:"baseUrl"` APIKey string `json:"apiKey"` } `json:"rerank"` Multimodal struct { Enabled bool `json:"enabled"` VLM *struct { ModelName string `json:"modelName"` BaseURL string `json:"baseUrl"` APIKey string `json:"apiKey"` InterfaceType string `json:"interfaceType"` // "ollama" or "openai" } `json:"vlm,omitempty"` COS *struct { SecretID string `json:"secretId"` SecretKey string `json:"secretKey"` Region string `json:"region"` BucketName string `json:"bucketName"` AppID string `json:"appId"` PathPrefix string `json:"pathPrefix"` } `json:"cos,omitempty"` Minio *struct { BucketName string `json:"bucketName"` PathPrefix string `json:"pathPrefix"` } `json:"minio,omitempty"` } `json:"multimodal"` DocumentSplitting struct { ChunkSize int `json:"chunkSize" binding:"required,min=100,max=10000"` ChunkOverlap int `json:"chunkOverlap" binding:"required,min=0"` Separators []string `json:"separators" binding:"required,min=1"` } `json:"documentSplitting" binding:"required"` } // CheckStatus 检查系统初始化状态 func (h *InitializationHandler) CheckStatus(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Checking system initialization status") // 检查是否存在租户 tenant, err := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "initialized": false, }, }) return } // 如果没有租户,说明系统未初始化 if tenant == nil { logger.Info(ctx, "No tenants found, system not initialized") c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "initialized": false, }, }) return } ctx = context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID) // 检查是否存在模型 models, err := h.modelService.ListModels(ctx) if err != nil || len(models) == 0 { logger.Info(ctx, "No models found, system not initialized") c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "initialized": false, }, }) return } logger.Info(ctx, "System is already initialized") c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "initialized": true, }, }) } // Initialize 执行系统初始化 func (h *InitializationHandler) Initialize(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Starting system initialization") var req InitializationRequest if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse initialization request", err) c.Error(errors.NewBadRequestError(err.Error())) return } // 验证多模态配置 if req.Multimodal.Enabled { storageType := strings.ToLower(req.StorageType) if req.Multimodal.VLM == nil { logger.Error(ctx, "Multimodal enabled but missing VLM configuration") c.Error(errors.NewBadRequestError("启用多模态时需要配置VLM信息")) return } if req.Multimodal.VLM.InterfaceType == "ollama" { req.Multimodal.VLM.BaseURL = os.Getenv("OLLAMA_BASE_URL") + "/v1" } if req.Multimodal.VLM.ModelName == "" || req.Multimodal.VLM.BaseURL == "" { logger.Error(ctx, "VLM configuration incomplete") c.Error(errors.NewBadRequestError("VLM配置不完整")) return } switch storageType { case "cos": if req.Multimodal.COS == nil || req.Multimodal.COS.SecretID == "" || req.Multimodal.COS.SecretKey == "" || req.Multimodal.COS.Region == "" || req.Multimodal.COS.BucketName == "" || req.Multimodal.COS.AppID == "" { logger.Error(ctx, "COS configuration incomplete") c.Error(errors.NewBadRequestError("COS配置不完整")) return } case "minio": if req.Multimodal.Minio == nil || req.Multimodal.Minio.BucketName == "" || os.Getenv("MINIO_ACCESS_KEY_ID") == "" || os.Getenv("MINIO_SECRET_ACCESS_KEY") == "" { logger.Error(ctx, "MinIO configuration incomplete") c.Error(errors.NewBadRequestError("MinIO配置不完整")) return } } } // 验证Rerank配置(如果启用) if req.Rerank.Enabled { if req.Rerank.ModelName == "" || req.Rerank.BaseURL == "" { logger.Error(ctx, "Rerank configuration incomplete") c.Error(errors.NewBadRequestError("启用Rerank时需要配置模型名称和Base URL")) return } } // 验证文档分割配置 if req.DocumentSplitting.ChunkOverlap >= req.DocumentSplitting.ChunkSize { logger.Error(ctx, "Chunk overlap must be less than chunk size") c.Error(errors.NewBadRequestError("分块重叠大小必须小于分块大小")) return } if len(req.DocumentSplitting.Separators) == 0 { logger.Error(ctx, "Document separators cannot be empty") c.Error(errors.NewBadRequestError("文档分隔符不能为空")) return } var err error // 1. 处理租户 - 检查是否存在,不存在则创建 tenant, _ := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID) if tenant == nil { logger.Info(ctx, "Tenant not found, creating tenant") // 创建默认租户 tenant = &types.Tenant{ ID: types.InitDefaultTenantID, Name: "Default Tenant", Description: "System Default Tenant", RetrieverEngines: types.RetrieverEngines{ Engines: []types.RetrieverEngineParams{ { RetrieverType: types.KeywordsRetrieverType, RetrieverEngineType: types.PostgresRetrieverEngineType, }, { RetrieverType: types.VectorRetrieverType, RetrieverEngineType: types.PostgresRetrieverEngineType, }, }, }, } logger.Info(ctx, "Creating default tenant") tenant, err = h.tenantService.CreateTenant(ctx, tenant) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("创建租户失败: " + err.Error())) return } } else { logger.Info(ctx, "Tenant exists, updating if needed") // 更新租户信息(如果需要) updated := false if tenant.Name != "Default Tenant" { tenant.Name = "Default Tenant" updated = true } if tenant.Description != "System Default Tenant" { tenant.Description = "System Default Tenant" updated = true } if updated { _, err = h.tenantService.UpdateTenant(ctx, tenant) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("更新租户失败: " + err.Error())) return } logger.Info(ctx, "Tenant updated successfully") } } // 创建带有租户ID的新上下文 newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID) // 2. 处理模型 - 检查现有模型并更新或创建 existingModels, err := h.modelService.ListModels(newCtx) if err != nil { logger.ErrorWithFields(ctx, err, nil) // 如果获取失败,继续执行创建流程 existingModels = []*types.Model{} } // 构建模型映射,按类型分组 modelMap := make(map[types.ModelType]*types.Model) for _, model := range existingModels { modelMap[model.Type] = model } // 要处理的模型配置 modelsToProcess := []struct { modelType types.ModelType name string source types.ModelSource description string baseURL string apiKey string dimension int }{ { modelType: types.ModelTypeKnowledgeQA, name: req.LLM.ModelName, source: types.ModelSource(req.LLM.Source), description: "LLM Model for Knowledge QA", baseURL: req.LLM.BaseURL, apiKey: req.LLM.APIKey, }, { modelType: types.ModelTypeEmbedding, name: req.Embedding.ModelName, source: types.ModelSource(req.Embedding.Source), description: "Embedding Model", baseURL: req.Embedding.BaseURL, apiKey: req.Embedding.APIKey, dimension: req.Embedding.Dimension, }, } // 如果启用Rerank,添加Rerank模型 if req.Rerank.Enabled { modelsToProcess = append(modelsToProcess, struct { modelType types.ModelType name string source types.ModelSource description string baseURL string apiKey string dimension int }{ modelType: types.ModelTypeRerank, name: req.Rerank.ModelName, source: types.ModelSourceRemote, description: "Rerank Model", baseURL: req.Rerank.BaseURL, apiKey: req.Rerank.APIKey, }) } // 如果启用多模态,添加VLM模型 if req.Multimodal.Enabled && req.Multimodal.VLM != nil { modelsToProcess = append(modelsToProcess, struct { modelType types.ModelType name string source types.ModelSource description string baseURL string apiKey string dimension int }{ modelType: types.ModelTypeVLLM, name: req.Multimodal.VLM.ModelName, source: types.ModelSourceRemote, description: "Vision Language Model", baseURL: req.Multimodal.VLM.BaseURL, apiKey: req.Multimodal.VLM.APIKey, }) } // 处理每个模型 var processedModels []*types.Model for _, modelConfig := range modelsToProcess { existingModel, exists := modelMap[modelConfig.modelType] if exists { // 更新现有模型 logger.Infof(ctx, "Updating existing model: %s (%s)", modelConfig.name, modelConfig.modelType, ) existingModel.Name = modelConfig.name existingModel.Source = modelConfig.source existingModel.Description = modelConfig.description existingModel.Parameters = types.ModelParameters{ BaseURL: modelConfig.baseURL, APIKey: modelConfig.apiKey, EmbeddingParameters: types.EmbeddingParameters{ Dimension: modelConfig.dimension, }, } existingModel.IsDefault = true existingModel.Status = types.ModelStatusActive err := h.modelService.UpdateModel(newCtx, existingModel) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_name": modelConfig.name, "model_type": modelConfig.modelType, }) c.Error(errors.NewInternalServerError("更新模型失败: " + err.Error())) return } processedModels = append(processedModels, existingModel) } else { // 创建新模型 logger.Infof(ctx, "Creating new model: %s (%s)", modelConfig.name, modelConfig.modelType, ) newModel := &types.Model{ TenantID: types.InitDefaultTenantID, Name: modelConfig.name, Type: modelConfig.modelType, Source: modelConfig.source, Description: modelConfig.description, Parameters: types.ModelParameters{ BaseURL: modelConfig.baseURL, APIKey: modelConfig.apiKey, EmbeddingParameters: types.EmbeddingParameters{ Dimension: modelConfig.dimension, }, }, IsDefault: true, Status: types.ModelStatusActive, } err := h.modelService.CreateModel(newCtx, newModel) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_name": modelConfig.name, "model_type": modelConfig.modelType, }) c.Error(errors.NewInternalServerError("创建模型失败: " + err.Error())) return } processedModels = append(processedModels, newModel) } } // 删除不需要的VLM模型(如果多模态被禁用) if !req.Multimodal.Enabled { if existingVLM, exists := modelMap[types.ModelTypeVLLM]; exists { logger.Info(ctx, "Deleting VLM model as multimodal is disabled") err := h.modelService.DeleteModel(newCtx, existingVLM.ID) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": existingVLM.ID, }) // 记录错误但不阻止流程 logger.Warn(ctx, "Failed to delete VLM model, but continuing") } } } // 删除不需要的Rerank模型(如果Rerank被禁用) if !req.Rerank.Enabled { if existingRerank, exists := modelMap[types.ModelTypeRerank]; exists { logger.Info(ctx, "Deleting Rerank model as rerank is disabled") err := h.modelService.DeleteModel(newCtx, existingRerank.ID) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_id": existingRerank.ID, }) // 记录错误但不阻止流程 logger.Warn(ctx, "Failed to delete Rerank model, but continuing") } } } // 3. 处理知识库 - 检查是否存在,不存在则创建,存在则更新 kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID) // 找到embedding模型ID和LLM模型ID var embeddingModelID, llmModelID, rerankModelID, vlmModelID string for _, model := range processedModels { if model.Type == types.ModelTypeEmbedding { embeddingModelID = model.ID } if model.Type == types.ModelTypeKnowledgeQA { llmModelID = model.ID } if model.Type == types.ModelTypeRerank && req.Rerank.Enabled { rerankModelID = model.ID } if model.Type == types.ModelTypeVLLM { vlmModelID = model.ID } } if kb == nil { // 创建新知识库 logger.Info(ctx, "Creating default knowledge base") kb = &types.KnowledgeBase{ ID: types.InitDefaultKnowledgeBaseID, Name: "Default Knowledge Base", Description: "System Default Knowledge Base", TenantID: types.InitDefaultTenantID, ChunkingConfig: types.ChunkingConfig{ ChunkSize: req.DocumentSplitting.ChunkSize, ChunkOverlap: req.DocumentSplitting.ChunkOverlap, Separators: req.DocumentSplitting.Separators, EnableMultimodal: req.Multimodal.Enabled, }, EmbeddingModelID: embeddingModelID, SummaryModelID: llmModelID, RerankModelID: rerankModelID, VLMModelID: vlmModelID, VLMConfig: types.VLMConfig{ ModelName: req.Multimodal.VLM.ModelName, BaseURL: req.Multimodal.VLM.BaseURL, APIKey: req.Multimodal.VLM.APIKey, InterfaceType: req.Multimodal.VLM.InterfaceType, }, } switch req.StorageType { case "cos": if req.Multimodal.COS != nil { kb.StorageConfig = types.StorageConfig{ Provider: req.StorageType, BucketName: req.Multimodal.COS.BucketName, AppID: req.Multimodal.COS.AppID, PathPrefix: req.Multimodal.COS.PathPrefix, SecretID: req.Multimodal.COS.SecretID, SecretKey: req.Multimodal.COS.SecretKey, Region: req.Multimodal.COS.Region, } } case "minio": if req.Multimodal.Minio != nil { kb.StorageConfig = types.StorageConfig{ Provider: req.StorageType, BucketName: req.Multimodal.Minio.BucketName, PathPrefix: req.Multimodal.Minio.PathPrefix, SecretID: os.Getenv("MINIO_ACCESS_KEY_ID"), SecretKey: os.Getenv("MINIO_SECRET_ACCESS_KEY"), } } } _, err = h.kbService.CreateKnowledgeBase(newCtx, kb) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("创建知识库失败: " + err.Error())) return } } else { // 更新现有知识库 logger.Info(ctx, "Updating existing knowledge base") // 检查是否有文件,如果有文件则不允许修改Embedding模型 knowledgeList, err := h.knowledgeService.ListKnowledgeByKnowledgeBaseID( newCtx, types.InitDefaultKnowledgeBaseID, ) hasFiles := err == nil && len(knowledgeList) > 0 // 先更新模型ID(直接在对象上) kb.SummaryModelID = llmModelID if req.Rerank.Enabled { kb.RerankModelID = rerankModelID } else { kb.RerankModelID = "" // 清空Rerank模型ID } if req.Multimodal.Enabled { kb.VLMModelID = vlmModelID // 更新VLM配置 kb.VLMConfig = types.VLMConfig{ ModelName: req.Multimodal.VLM.ModelName, BaseURL: req.Multimodal.VLM.BaseURL, APIKey: req.Multimodal.VLM.APIKey, InterfaceType: req.Multimodal.VLM.InterfaceType, } switch req.StorageType { case "cos": if req.Multimodal.COS != nil { kb.StorageConfig = types.StorageConfig{ Provider: req.StorageType, SecretID: req.Multimodal.COS.SecretID, SecretKey: req.Multimodal.COS.SecretKey, Region: req.Multimodal.COS.Region, BucketName: req.Multimodal.COS.BucketName, AppID: req.Multimodal.COS.AppID, PathPrefix: req.Multimodal.COS.PathPrefix, } } case "minio": if req.Multimodal.Minio != nil { kb.StorageConfig = types.StorageConfig{ Provider: req.StorageType, BucketName: req.Multimodal.Minio.BucketName, PathPrefix: req.Multimodal.Minio.PathPrefix, SecretID: os.Getenv("MINIO_ACCESS_KEY_ID"), SecretKey: os.Getenv("MINIO_SECRET_ACCESS_KEY"), } } } } else { kb.VLMModelID = "" // 清空VLM模型ID // 清空VLM配置 kb.VLMConfig = types.VLMConfig{} kb.StorageConfig = types.StorageConfig{} } if !hasFiles { kb.EmbeddingModelID = embeddingModelID } kb.ChunkingConfig = types.ChunkingConfig{ ChunkSize: req.DocumentSplitting.ChunkSize, ChunkOverlap: req.DocumentSplitting.ChunkOverlap, Separators: req.DocumentSplitting.Separators, EnableMultimodal: req.Multimodal.Enabled, } // 更新基本信息和配置 err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("更新知识库配置失败: " + err.Error())) return } // 如果需要更新模型ID,使用repository直接更新 if !hasFiles || kb.SummaryModelID != llmModelID { // 刷新知识库对象以获取最新信息 kb, err = h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("获取更新后的知识库失败: " + err.Error())) return } // 更新模型ID kb.SummaryModelID = llmModelID if req.Rerank.Enabled { kb.RerankModelID = rerankModelID } else { kb.RerankModelID = "" // 清空Rerank模型ID } // 使用repository直接更新模型ID err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("更新知识库模型ID失败: " + err.Error())) return } logger.Info(ctx, "Model IDs updated successfully") } } logger.Info(ctx, "System initialization completed successfully") c.JSON(http.StatusOK, gin.H{ "success": true, "message": "系统初始化成功", "data": gin.H{ "tenant": tenant, "models": processedModels, "knowledge_base": kb, }, }) } // CheckOllamaStatus 检查Ollama服务状态 func (h *InitializationHandler) CheckOllamaStatus(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Checking Ollama service status") // Determine Ollama base URL for display baseURL := os.Getenv("OLLAMA_BASE_URL") if baseURL == "" { baseURL = "http://host.docker.internal:11434" } // 检查Ollama服务是否可用 err := h.ollamaService.StartService(ctx) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "available": false, "error": err.Error(), "baseUrl": baseURL, }, }) return } version, err := h.ollamaService.GetVersion(ctx) if err != nil { logger.ErrorWithFields(ctx, err, nil) version = "unknown" } logger.Info(ctx, "Ollama service is available") c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "available": h.ollamaService.IsAvailable(), "version": version, "baseUrl": baseURL, }, }) } // CheckOllamaModels 检查Ollama模型状态 func (h *InitializationHandler) CheckOllamaModels(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Checking Ollama models status") var req struct { Models []string `json:"models" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse models check request", err) c.Error(errors.NewBadRequestError(err.Error())) return } // 检查Ollama服务是否可用 if !h.ollamaService.IsAvailable() { err := h.ollamaService.StartService(ctx) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error())) return } } modelStatus := make(map[string]bool) // 检查每个模型是否存在 for _, modelName := range req.Models { checkModelName := modelName if !strings.Contains(modelName, ":") { checkModelName = modelName + ":latest" } available, err := h.ollamaService.IsModelAvailable(ctx, checkModelName) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_name": modelName, }) modelStatus[modelName] = false } else { modelStatus[modelName] = available } logger.Infof(ctx, "Model %s availability: %v", modelName, modelStatus[modelName]) } c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "models": modelStatus, }, }) } // DownloadOllamaModel 异步下载Ollama模型 func (h *InitializationHandler) DownloadOllamaModel(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Starting async Ollama model download") var req struct { ModelName string `json:"modelName" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse model download request", err) c.Error(errors.NewBadRequestError(err.Error())) return } // 检查Ollama服务是否可用 if !h.ollamaService.IsAvailable() { err := h.ollamaService.StartService(ctx) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error())) return } } // 检查模型是否已存在 available, err := h.ollamaService.IsModelAvailable(ctx, req.ModelName) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_name": req.ModelName, }) c.Error(errors.NewInternalServerError("检查模型状态失败: " + err.Error())) return } if available { logger.Infof(ctx, "Model %s already exists", req.ModelName) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "模型已存在", "data": gin.H{ "modelName": req.ModelName, "status": "completed", "progress": 100.0, }, }) return } // 检查是否已有相同模型的下载任务 tasksMutex.RLock() for _, task := range downloadTasks { if task.ModelName == req.ModelName && (task.Status == "pending" || task.Status == "downloading") { tasksMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "模型下载任务已存在", "data": gin.H{ "taskId": task.ID, "modelName": task.ModelName, "status": task.Status, "progress": task.Progress, }, }) return } } tasksMutex.RUnlock() // 创建下载任务 taskID := uuid.New().String() task := &DownloadTask{ ID: taskID, ModelName: req.ModelName, Status: "pending", Progress: 0.0, Message: "准备下载", StartTime: time.Now(), } tasksMutex.Lock() downloadTasks[taskID] = task tasksMutex.Unlock() // 启动异步下载 newCtx, cancel := context.WithTimeout(context.Background(), 12*time.Hour) go func() { defer cancel() h.downloadModelAsync(newCtx, taskID, req.ModelName) }() logger.Infof(ctx, "Created download task for model: %s, task ID: %s", req.ModelName, taskID) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "模型下载任务已创建", "data": gin.H{ "taskId": taskID, "modelName": req.ModelName, "status": "pending", "progress": 0.0, }, }) } // GetDownloadProgress 获取下载进度 func (h *InitializationHandler) GetDownloadProgress(c *gin.Context) { taskID := c.Param("taskId") if taskID == "" { c.Error(errors.NewBadRequestError("任务ID不能为空")) return } tasksMutex.RLock() task, exists := downloadTasks[taskID] tasksMutex.RUnlock() if !exists { c.Error(errors.NewNotFoundError("下载任务不存在")) return } c.JSON(http.StatusOK, gin.H{ "success": true, "data": task, }) } // ListDownloadTasks 列出所有下载任务 func (h *InitializationHandler) ListDownloadTasks(c *gin.Context) { tasksMutex.RLock() tasks := make([]*DownloadTask, 0, len(downloadTasks)) for _, task := range downloadTasks { tasks = append(tasks, task) } tasksMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "data": tasks, }) } // ListOllamaModels 列出已安装的 Ollama 模型 func (h *InitializationHandler) ListOllamaModels(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Listing installed Ollama models") // 确保服务可用 if !h.ollamaService.IsAvailable() { if err := h.ollamaService.StartService(ctx); err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error())) return } } models, err := h.ollamaService.ListModels(ctx) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("获取模型列表失败: " + err.Error())) return } c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "models": models, }, }) } // downloadModelAsync 异步下载模型 func (h *InitializationHandler) downloadModelAsync(ctx context.Context, taskID, modelName string, ) { logger.Infof(ctx, "Starting async download for model: %s, task: %s", modelName, taskID) // 更新任务状态为下载中 h.updateTaskStatus(taskID, "downloading", 0.0, "开始下载模型") // 执行下载,带进度回调 err := h.pullModelWithProgress(ctx, modelName, func(progress float64, message string) { h.updateTaskStatus(taskID, "downloading", progress, message) }) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_name": modelName, "task_id": taskID, }) h.updateTaskStatus(taskID, "failed", 0.0, fmt.Sprintf("下载失败: %v", err)) return } // 下载成功 logger.Infof(ctx, "Model %s downloaded successfully, task: %s", modelName, taskID) h.updateTaskStatus(taskID, "completed", 100.0, "下载完成") } // pullModelWithProgress 下载模型并提供进度回调 func (h *InitializationHandler) pullModelWithProgress(ctx context.Context, modelName string, progressCallback func(float64, string), ) error { // 检查服务是否可用 if err := h.ollamaService.StartService(ctx); err != nil { logger.ErrorWithFields(ctx, err, nil) return err } // 检查模型是否已存在 available, err := h.ollamaService.IsModelAvailable(ctx, modelName) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_name": modelName, }) return err } if available { progressCallback(100.0, "模型已存在") return nil } logger.GetLogger(ctx).Infof("Pulling model %s...", modelName) // 创建下载请求 pullReq := &api.PullRequest{ Name: modelName, } // 使用Ollama客户端的Pull方法,带进度回调 err = h.ollamaService.GetClient().Pull(ctx, pullReq, func(progress api.ProgressResponse) error { var progressPercent float64 = 0.0 var message string = "下载中" if progress.Total > 0 && progress.Completed > 0 { progressPercent = float64(progress.Completed) / float64(progress.Total) * 100 message = fmt.Sprintf("下载中: %.1f%% (%s)", progressPercent, progress.Status) } else if progress.Status != "" { message = progress.Status } // 调用进度回调 progressCallback(progressPercent, message) logger.Infof(ctx, "Download progress for %s: %.2f%% - %s", modelName, progressPercent, message, ) return nil }) if err != nil { return fmt.Errorf("failed to pull model: %w", err) } return nil } // updateTaskStatus 更新任务状态 func (h *InitializationHandler) updateTaskStatus( taskID, status string, progress float64, message string, ) { tasksMutex.Lock() defer tasksMutex.Unlock() if task, exists := downloadTasks[taskID]; exists { task.Status = status task.Progress = progress task.Message = message if status == "completed" || status == "failed" { now := time.Now() task.EndTime = &now } } } // GetCurrentConfig 获取当前系统配置信息 func (h *InitializationHandler) GetCurrentConfig(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Getting current system configuration") // 设置租户上下文 newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID) // 获取模型信息 models, err := h.modelService.ListModels(newCtx) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("获取模型列表失败: " + err.Error())) return } // 获取知识库信息 kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError("获取知识库信息失败: " + err.Error())) return } // 检查知识库是否有文件 knowledgeList, err := h.knowledgeService.ListPagedKnowledgeByKnowledgeBaseID(newCtx, types.InitDefaultKnowledgeBaseID, &types.Pagination{ Page: 1, PageSize: 1, }) hasFiles := false if err == nil && knowledgeList != nil && knowledgeList.Total > 0 { hasFiles = true } // 构建配置响应 config := buildConfigResponse(models, kb, hasFiles) logger.Info(ctx, "Current system configuration retrieved successfully") c.JSON(http.StatusOK, gin.H{ "success": true, "data": config, }) } // buildConfigResponse 构建配置响应数据 func buildConfigResponse(models []*types.Model, kb *types.KnowledgeBase, hasFiles bool, ) map[string]interface{} { config := map[string]interface{}{ "hasFiles": hasFiles, } // 按类型分组模型 for _, model := range models { switch model.Type { case types.ModelTypeKnowledgeQA: config["llm"] = map[string]interface{}{ "source": string(model.Source), "modelName": model.Name, "baseUrl": model.Parameters.BaseURL, "apiKey": model.Parameters.APIKey, } case types.ModelTypeEmbedding: config["embedding"] = map[string]interface{}{ "source": string(model.Source), "modelName": model.Name, "baseUrl": model.Parameters.BaseURL, "apiKey": model.Parameters.APIKey, "dimension": model.Parameters.EmbeddingParameters.Dimension, } case types.ModelTypeRerank: config["rerank"] = map[string]interface{}{ "enabled": true, "modelName": model.Name, "baseUrl": model.Parameters.BaseURL, "apiKey": model.Parameters.APIKey, } case types.ModelTypeVLLM: if config["multimodal"] == nil { config["multimodal"] = map[string]interface{}{ "enabled": true, } } multimodal := config["multimodal"].(map[string]interface{}) multimodal["vlm"] = map[string]interface{}{ "modelName": model.Name, "baseUrl": model.Parameters.BaseURL, "apiKey": model.Parameters.APIKey, "interfaceType": kb.VLMConfig.InterfaceType, } } } // 如果没有VLM模型,设置multimodal为disabled if config["multimodal"] == nil { config["multimodal"] = map[string]interface{}{ "enabled": false, } } // 如果没有Rerank模型,设置rerank为disabled if config["rerank"] == nil { config["rerank"] = map[string]interface{}{ "enabled": false, "modelName": "", "baseUrl": "", "apiKey": "", } } // 添加知识库的文档分割配置 if kb != nil { config["documentSplitting"] = map[string]interface{}{ "chunkSize": kb.ChunkingConfig.ChunkSize, "chunkOverlap": kb.ChunkingConfig.ChunkOverlap, "separators": kb.ChunkingConfig.Separators, } // 添加多模态的COS配置信息 if kb.StorageConfig.SecretID != "" { if config["multimodal"] == nil { config["multimodal"] = map[string]interface{}{ "enabled": true, } } multimodal := config["multimodal"].(map[string]interface{}) multimodal["storageType"] = kb.StorageConfig.Provider switch kb.StorageConfig.Provider { case "cos": multimodal["cos"] = map[string]interface{}{ "secretId": kb.StorageConfig.SecretID, "secretKey": kb.StorageConfig.SecretKey, "region": kb.StorageConfig.Region, "bucketName": kb.StorageConfig.BucketName, "appId": kb.StorageConfig.AppID, "pathPrefix": kb.StorageConfig.PathPrefix, } case "minio": multimodal["minio"] = map[string]interface{}{ "bucketName": kb.StorageConfig.BucketName, "pathPrefix": kb.StorageConfig.PathPrefix, } } } } return config } // RemoteModelCheckRequest 远程模型检查请求结构 type RemoteModelCheckRequest struct { ModelName string `json:"modelName" binding:"required"` BaseURL string `json:"baseUrl" binding:"required"` APIKey string `json:"apiKey"` } // CheckRemoteModel 检查远程API模型连接 func (h *InitializationHandler) CheckRemoteModel(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Checking remote model connection") var req RemoteModelCheckRequest if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse remote model check request", err) c.Error(errors.NewBadRequestError(err.Error())) return } // 验证请求参数 if req.ModelName == "" || req.BaseURL == "" { logger.Error(ctx, "Model name and base URL are required") c.Error(errors.NewBadRequestError("模型名称和Base URL不能为空")) return } // 创建模型配置进行测试 modelConfig := &types.Model{ Name: req.ModelName, Source: "remote", Parameters: types.ModelParameters{ BaseURL: req.BaseURL, APIKey: req.APIKey, }, Type: "llm", // 默认类型,实际检查时不区分具体类型 } // 检查远程模型连接 available, message := h.checkRemoteModelConnection(ctx, modelConfig) logger.Info(ctx, fmt.Sprintf( "Remote model check completed: modelName=%s, baseUrl=%s, available=%v, message=%s", req.ModelName, req.BaseURL, available, message, ), ) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "available": available, "message": message, }, }) } // TestEmbeddingModel 测试 Embedding 接口(本地或远程)是否可用 func (h *InitializationHandler) TestEmbeddingModel(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Testing embedding model connectivity and functionality") var req struct { Source string `json:"source" binding:"required"` ModelName string `json:"modelName" binding:"required"` BaseURL string `json:"baseUrl"` APIKey string `json:"apiKey"` Dimension int `json:"dimension"` } if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse embedding test request", err) c.Error(errors.NewBadRequestError(err.Error())) return } // 构造 embedder 配置 cfg := embedding.Config{ Source: types.ModelSource(strings.ToLower(req.Source)), BaseURL: req.BaseURL, ModelName: req.ModelName, APIKey: req.APIKey, TruncatePromptTokens: 256, Dimensions: req.Dimension, ModelID: "", } emb, err := embedding.NewEmbedder(cfg) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName}) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{`available`: false, `message`: fmt.Sprintf("创建Embedder失败: %v", err), `dimension`: 0}, }) return } // 执行一次最小化 embedding 调用 sample := "hello" vec, err := emb.Embed(ctx, sample) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName}) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{`available`: false, `message`: fmt.Sprintf("调用Embedding失败: %v", err), `dimension`: 0}, }) return } logger.Infof(ctx, "Embedding test succeeded, dim=%d", len(vec)) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{`available`: true, `message`: fmt.Sprintf("测试成功,向量维度=%d", len(vec)), `dimension`: len(vec)}, }) } // checkRemoteModelConnection 检查远程模型连接的内部方法 func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context, model *types.Model, ) (bool, string) { // 这里需要根据实际情况实现远程API的连接检查 // 可以发送一个简单的请求来验证连接和认证 client := &http.Client{ Timeout: 10 * time.Second, } // 根据不同的API类型构造测试请求 testEndpoint := "" if model.Parameters.BaseURL != "" { testEndpoint = model.Parameters.BaseURL + "/models" } req, err := http.NewRequestWithContext(ctx, "GET", testEndpoint, nil) if err != nil { return false, fmt.Sprintf("创建请求失败: %v", err) } // 添加认证头 if model.Parameters.APIKey != "" { req.Header.Set("Authorization", "Bearer "+model.Parameters.APIKey) } req.Header.Set("Content-Type", "application/json") resp, err := client.Do(req) if err != nil { return false, fmt.Sprintf("连接失败: %v", err) } defer resp.Body.Close() // 检查响应状态 if resp.StatusCode >= 200 && resp.StatusCode < 300 { // 连接成功,现在检查模型是否存在 return true, "连接正常,请自行确保模型存在" } else if resp.StatusCode == 401 { return false, "认证失败,请检查API Key" } else if resp.StatusCode == 403 { return false, "权限不足,请检查API Key权限" } else if resp.StatusCode == 404 { return false, "API端点不存在,请检查Base URL" } else { return false, fmt.Sprintf("API返回错误状态: %d", resp.StatusCode) } } // checkModelExistence 检查指定模型是否在模型列表中存在 func (h *InitializationHandler) checkModelExistence(ctx context.Context, resp *http.Response, modelName string) (bool, string) { body, err := io.ReadAll(resp.Body) if err != nil { return true, "连接正常,但无法验证模型列表" } var modelsResp struct { Data []struct { ID string `json:"id"` Object string `json:"object"` } `json:"data"` Object string `json:"object"` } // 尝试解析模型列表响应 if err := json.Unmarshal(body, &modelsResp); err != nil { // 如果无法解析,可能是非标准API,只要连接成功就认为可用 return true, "连接正常" } // 检查模型是否在列表中 for _, model := range modelsResp.Data { if model.ID == modelName { return true, "连接正常,模型存在" } } // 模型不在列表中,返回可用的模型建议 if len(modelsResp.Data) > 0 { availableModels := make([]string, 0, min(3, len(modelsResp.Data))) for i, model := range modelsResp.Data { if i >= 3 { break } availableModels = append(availableModels, model.ID) } return false, fmt.Sprintf("模型 '%s' 不存在,可用模型: %s", modelName, strings.Join(availableModels, ", ")) } return false, fmt.Sprintf("模型 '%s' 不存在", modelName) } // min returns the minimum of two integers func min(a, b int) int { if a < b { return a } return b } // checkRerankModelConnection 检查Rerank模型连接和功能的内部方法 func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context, modelName, baseURL, apiKey string) (bool, string) { client := &http.Client{ Timeout: 15 * time.Second, } // 构造重排API端点 rerankEndpoint := baseURL + "/rerank" // Mock测试数据 testQuery := "什么是人工智能?" testPassages := []string{ "机器学习是人工智能的一个子领域,专注于算法和统计模型,使计算机系统能够通过经验自动改进。", "深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。", } // 构造重排请求 rerankRequest := map[string]interface{}{ "model": modelName, "query": testQuery, "documents": testPassages, "truncate_prompt_tokens": 512, } jsonData, err := json.Marshal(rerankRequest) if err != nil { return false, fmt.Sprintf("构造请求失败: %v", err) } logger.Infof(ctx, "Rerank request: %s, modelName=%s, baseURL=%s, apiKey=%s", string(jsonData), modelName, baseURL, apiKey) req, err := http.NewRequestWithContext( ctx, "POST", rerankEndpoint, strings.NewReader(string(jsonData)), ) if err != nil { return false, fmt.Sprintf("创建请求失败: %v", err) } // 添加认证头 if apiKey != "" { req.Header.Set("Authorization", "Bearer "+apiKey) } req.Header.Set("Content-Type", "application/json") resp, err := client.Do(req) if err != nil { return false, fmt.Sprintf("连接失败: %v", err) } defer resp.Body.Close() // 读取响应 body, err := io.ReadAll(resp.Body) if err != nil { return false, fmt.Sprintf("读取响应失败: %v", err) } // 检查响应状态 if resp.StatusCode >= 200 && resp.StatusCode < 300 { // 尝试解析重排响应 var rerankResp struct { Results []struct { Index int `json:"index"` Document string `json:"document"` RelevanceScore float64 `json:"relevance_score"` } `json:"results"` } if err := json.Unmarshal(body, &rerankResp); err != nil { // 如果无法解析标准重排响应,检查是否有其他格式 return true, "连接正常,但响应格式非标准" } // 检查是否返回了重排结果 if len(rerankResp.Results) > 0 { return true, fmt.Sprintf("重排功能正常,返回%d个结果", len(rerankResp.Results)) } else { return false, "重排接口连接成功,但未返回重排结果" } } else if resp.StatusCode == 401 { return false, "认证失败,请检查API Key" } else if resp.StatusCode == 403 { return false, "权限不足,请检查API Key权限" } else if resp.StatusCode == 404 { return false, "重排API端点不存在,请检查Base URL" } else if resp.StatusCode == 422 { return false, fmt.Sprintf("请求参数错误: %s", string(body)) } else { return false, fmt.Sprintf("API返回错误状态: %d, 响应: %s", resp.StatusCode, string(body)) } } // CheckRerankModel 检查Rerank模型连接和功能 func (h *InitializationHandler) CheckRerankModel(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Checking rerank model connection and functionality") var req struct { ModelName string `json:"modelName" binding:"required"` BaseURL string `json:"baseUrl" binding:"required"` APIKey string `json:"apiKey"` } if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse rerank model check request", err) c.Error(errors.NewBadRequestError(err.Error())) return } // 验证请求参数 if req.ModelName == "" || req.BaseURL == "" { logger.Error(ctx, "Model name and base URL are required") c.Error(errors.NewBadRequestError("模型名称和Base URL不能为空")) return } // 检查Rerank模型连接和功能 available, message := h.checkRerankModelConnection( ctx, req.ModelName, req.BaseURL, req.APIKey, ) logger.Info(ctx, fmt.Sprintf("Rerank model check completed: modelName=%s, baseUrl=%s, available=%v, message=%s", req.ModelName, req.BaseURL, available, message, ), ) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "available": available, "message": message, }, }) } // 使用结构体解析表单数据 type testMultimodalForm struct { VLMModel string `form:"vlm_model"` VLMBaseURL string `form:"vlm_base_url"` VLMAPIKey string `form:"vlm_api_key"` VLMInterfaceType string `form:"vlm_interface_type"` StorageType string `form:"storage_type"` // COS 配置 COSSecretID string `form:"cos_secret_id"` COSSecretKey string `form:"cos_secret_key"` COSRegion string `form:"cos_region"` COSBucketName string `form:"cos_bucket_name"` COSAppID string `form:"cos_app_id"` COSPathPrefix string `form:"cos_path_prefix"` // MinIO 配置(当存储为 minio 时) MinioBucketName string `form:"minio_bucket_name"` MinioPathPrefix string `form:"minio_path_prefix"` // 文档切分配置(字符串后续自行解析,以避免类型绑定失败) ChunkSize string `form:"chunk_size"` ChunkOverlap string `form:"chunk_overlap"` SeparatorsRaw string `form:"separators"` } // TestMultimodalFunction 测试多模态功能 func (h *InitializationHandler) TestMultimodalFunction(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Testing multimodal functionality") var req testMultimodalForm if err := c.ShouldBind(&req); err != nil { logger.Error(ctx, "Failed to parse form data", err) c.Error(errors.NewBadRequestError("表单参数解析失败")) return } // ollama 场景自动拼接 base url if req.VLMInterfaceType == "ollama" { req.VLMBaseURL = os.Getenv("OLLAMA_BASE_URL") + "/v1" } req.StorageType = strings.ToLower(req.StorageType) if req.VLMModel == "" || req.VLMBaseURL == "" { logger.Error(ctx, "VLM model name and base URL are required") c.Error(errors.NewBadRequestError("VLM模型名称和Base URL不能为空")) return } switch req.StorageType { case "cos": logger.Infof(ctx, "COS config: Region=%s, Bucket=%s, App=%s, Prefix=%s", req.COSRegion, req.COSBucketName, req.COSAppID, req.COSPathPrefix) // 必填:SecretID/SecretKey/Region/BucketName/AppID;PathPrefix 可选 if req.COSSecretID == "" || req.COSSecretKey == "" || req.COSRegion == "" || req.COSBucketName == "" || req.COSAppID == "" { logger.Error(ctx, "COS configuration is required") c.Error(errors.NewBadRequestError("COS配置信息不能为空")) return } case "minio": logger.Infof(ctx, "MinIO config: Bucket=%s, PathPrefix=%s", req.MinioBucketName, req.MinioPathPrefix) if req.MinioBucketName == "" { logger.Error(ctx, "MinIO configuration is required") c.Error(errors.NewBadRequestError("MinIO配置信息不能为空")) return } default: logger.Error(ctx, "Invalid storage type") c.Error(errors.NewBadRequestError("无效的存储类型")) return } logger.Infof(ctx, "VLM config: Model=%s, URL=%s, HasKey=%v, Type=%s", req.VLMModel, req.VLMBaseURL, req.VLMAPIKey != "", req.VLMInterfaceType) // 获取上传的图片文件 file, header, err := c.Request.FormFile("image") if err != nil { logger.Error(ctx, "Failed to get uploaded image", err) c.Error(errors.NewBadRequestError("获取上传图片失败")) return } defer file.Close() // 验证文件类型 if !strings.HasPrefix(header.Header.Get("Content-Type"), "image/") { logger.Error(ctx, "Invalid file type, only images are allowed") c.Error(errors.NewBadRequestError("只允许上传图片文件")) return } // 验证文件大小 (10MB) if header.Size > 10*1024*1024 { logger.Error(ctx, "File size too large") c.Error(errors.NewBadRequestError("图片文件大小不能超过10MB")) return } logger.Infof(ctx, "Processing image: %s, size: %d bytes", header.Filename, header.Size) // 解析文档分割配置 chunkSize, err := strconv.Atoi(req.ChunkSize) if err != nil || chunkSize < 100 || chunkSize > 10000 { chunkSize = 1000 } chunkOverlap, err := strconv.Atoi(req.ChunkOverlap) if err != nil || chunkOverlap < 0 || chunkOverlap >= chunkSize { chunkOverlap = 200 } var separators []string if req.SeparatorsRaw != "" { if err := json.Unmarshal([]byte(req.SeparatorsRaw), &separators); err != nil { separators = []string{"\n\n", "\n", "。", "!", "?", ";", ";"} } } else { separators = []string{"\n\n", "\n", "。", "!", "?", ";", ";"} } // 读取图片文件内容 imageContent, err := io.ReadAll(file) if err != nil { logger.Error(ctx, "Failed to read image file", err) c.Error(errors.NewBadRequestError("读取图片文件失败")) return } // 调用多模态测试 startTime := time.Now() result, err := h.testMultimodalWithDocReader( ctx, imageContent, header.Filename, chunkSize, chunkOverlap, separators, &req, ) processingTime := time.Since(startTime).Milliseconds() if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "vlm_model": req.VLMModel, "vlm_base_url": req.VLMBaseURL, "filename": header.Filename, }) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "success": false, "message": err.Error(), "processing_time": processingTime, }, }) return } logger.Info(ctx, fmt.Sprintf("Multimodal test completed successfully in %dms", processingTime)) c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "success": true, "caption": result["caption"], "ocr": result["ocr"], "processing_time": processingTime, }, }) } // testMultimodalWithDocReader 调用docreader服务进行多模态处理 func (h *InitializationHandler) testMultimodalWithDocReader( ctx context.Context, imageContent []byte, filename string, chunkSize, chunkOverlap int, separators []string, req *testMultimodalForm, ) (map[string]string, error) { // 获取文件扩展名 fileExt := "" if idx := strings.LastIndex(filename, "."); idx != -1 { fileExt = strings.ToLower(filename[idx+1:]) } // 检查docreader服务配置 if h.docReaderClient == nil { return nil, fmt.Errorf("DocReader service not configured") } // 构造请求 request := &proto.ReadFromFileRequest{ FileContent: imageContent, FileName: filename, FileType: fileExt, ReadConfig: &proto.ReadConfig{ ChunkSize: int32(chunkSize), ChunkOverlap: int32(chunkOverlap), Separators: separators, EnableMultimodal: true, // 启用多模态处理 VlmConfig: &proto.VLMConfig{ ModelName: req.VLMModel, BaseUrl: req.VLMBaseURL, ApiKey: req.VLMAPIKey, InterfaceType: req.VLMInterfaceType, }, }, RequestId: ctx.Value(types.RequestIDContextKey).(string), } // 设置对象存储配置(通用) switch strings.ToLower(req.StorageType) { case "cos": request.ReadConfig.StorageConfig = &proto.StorageConfig{ Provider: proto.StorageProvider_COS, Region: req.COSRegion, BucketName: req.COSBucketName, AccessKeyId: req.COSSecretID, SecretAccessKey: req.COSSecretKey, AppId: req.COSAppID, PathPrefix: req.COSPathPrefix, } case "minio": request.ReadConfig.StorageConfig = &proto.StorageConfig{ Provider: proto.StorageProvider_MINIO, BucketName: req.MinioBucketName, PathPrefix: req.MinioPathPrefix, AccessKeyId: os.Getenv("MINIO_ACCESS_KEY_ID"), SecretAccessKey: os.Getenv("MINIO_SECRET_ACCESS_KEY"), } } // 调用docreader服务 response, err := h.docReaderClient.ReadFromFile(ctx, request) if err != nil { return nil, fmt.Errorf("调用DocReader服务失败: %v", err) } if response.Error != "" { return nil, fmt.Errorf("DocReader服务返回错误: %s", response.Error) } // 处理响应,提取Caption和OCR信息 result := make(map[string]string) var allCaptions, allOCRTexts []string for _, chunk := range response.Chunks { if len(chunk.Images) > 0 { for _, image := range chunk.Images { if image.Caption != "" { allCaptions = append(allCaptions, image.Caption) } if image.OcrText != "" { allOCRTexts = append(allOCRTexts, image.OcrText) } } } } // 合并所有Caption和OCR结果 result["caption"] = strings.Join(allCaptions, "; ") result["ocr"] = strings.Join(allOCRTexts, "; ") return result, nil }