l_ai_knowledge/internal/handler/initialization.go

1839 lines
53 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
"strconv"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/ollama/ollama/api"
"knowlege-lsxd/internal/config"
"knowlege-lsxd/internal/errors"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/embedding"
"knowlege-lsxd/internal/models/utils/ollama"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
"knowlege-lsxd/services/docreader/src/client"
"knowlege-lsxd/services/docreader/src/proto"
)
// DownloadTask 下载任务信息
type DownloadTask struct {
ID string `json:"id"`
ModelName string `json:"modelName"`
Status string `json:"status"` // pending, downloading, completed, failed
Progress float64 `json:"progress"`
Message string `json:"message"`
StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
}
// 全局下载任务管理器
var (
downloadTasks = make(map[string]*DownloadTask)
tasksMutex sync.RWMutex
)
// InitializationHandler 初始化处理器
type InitializationHandler struct {
config *config.Config
tenantService interfaces.TenantService
modelService interfaces.ModelService
kbService interfaces.KnowledgeBaseService
kbRepository interfaces.KnowledgeBaseRepository
knowledgeService interfaces.KnowledgeService
ollamaService *ollama.OllamaService
docReaderClient *client.Client
}
// NewInitializationHandler 创建初始化处理器
func NewInitializationHandler(
config *config.Config,
tenantService interfaces.TenantService,
modelService interfaces.ModelService,
kbService interfaces.KnowledgeBaseService,
kbRepository interfaces.KnowledgeBaseRepository,
knowledgeService interfaces.KnowledgeService,
ollamaService *ollama.OllamaService,
docReaderClient *client.Client,
) *InitializationHandler {
return &InitializationHandler{
config: config,
tenantService: tenantService,
modelService: modelService,
kbService: kbService,
kbRepository: kbRepository,
knowledgeService: knowledgeService,
ollamaService: ollamaService,
docReaderClient: docReaderClient,
}
}
// InitializationRequest 初始化请求结构
type InitializationRequest struct {
// 前端传入的存储类型cos 或 minio
StorageType string `json:"storageType"`
LLM struct {
Source string `json:"source" binding:"required"`
ModelName string `json:"modelName" binding:"required"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
} `json:"llm" binding:"required"`
Embedding struct {
Source string `json:"source" binding:"required"`
ModelName string `json:"modelName" binding:"required"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
Dimension int `json:"dimension"` // 添加embedding维度字段
} `json:"embedding" binding:"required"`
Rerank struct {
Enabled bool `json:"enabled"`
ModelName string `json:"modelName"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
} `json:"rerank"`
Multimodal struct {
Enabled bool `json:"enabled"`
VLM *struct {
ModelName string `json:"modelName"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
InterfaceType string `json:"interfaceType"` // "ollama" or "openai"
} `json:"vlm,omitempty"`
COS *struct {
SecretID string `json:"secretId"`
SecretKey string `json:"secretKey"`
Region string `json:"region"`
BucketName string `json:"bucketName"`
AppID string `json:"appId"`
PathPrefix string `json:"pathPrefix"`
} `json:"cos,omitempty"`
Minio *struct {
BucketName string `json:"bucketName"`
PathPrefix string `json:"pathPrefix"`
} `json:"minio,omitempty"`
} `json:"multimodal"`
DocumentSplitting struct {
ChunkSize int `json:"chunkSize" binding:"required,min=100,max=10000"`
ChunkOverlap int `json:"chunkOverlap" binding:"required,min=0"`
Separators []string `json:"separators" binding:"required,min=1"`
} `json:"documentSplitting" binding:"required"`
}
// CheckStatus 检查系统初始化状态
func (h *InitializationHandler) CheckStatus(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Checking system initialization status")
// 检查是否存在租户
tenant, err := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"initialized": false,
},
})
return
}
// 如果没有租户,说明系统未初始化
if tenant == nil {
logger.Info(ctx, "No tenants found, system not initialized")
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"initialized": false,
},
})
return
}
ctx = context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
// 检查是否存在模型
models, err := h.modelService.ListModels(ctx)
if err != nil || len(models) == 0 {
logger.Info(ctx, "No models found, system not initialized")
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"initialized": false,
},
})
return
}
logger.Info(ctx, "System is already initialized")
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"initialized": true,
},
})
}
// Initialize 执行系统初始化
func (h *InitializationHandler) Initialize(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Starting system initialization")
var req InitializationRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse initialization request", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// 验证多模态配置
if req.Multimodal.Enabled {
storageType := strings.ToLower(req.StorageType)
if req.Multimodal.VLM == nil {
logger.Error(ctx, "Multimodal enabled but missing VLM configuration")
c.Error(errors.NewBadRequestError("启用多模态时需要配置VLM信息"))
return
}
if req.Multimodal.VLM.InterfaceType == "ollama" {
req.Multimodal.VLM.BaseURL = os.Getenv("OLLAMA_BASE_URL") + "/v1"
}
if req.Multimodal.VLM.ModelName == "" || req.Multimodal.VLM.BaseURL == "" {
logger.Error(ctx, "VLM configuration incomplete")
c.Error(errors.NewBadRequestError("VLM配置不完整"))
return
}
switch storageType {
case "cos":
if req.Multimodal.COS == nil || req.Multimodal.COS.SecretID == "" || req.Multimodal.COS.SecretKey == "" ||
req.Multimodal.COS.Region == "" || req.Multimodal.COS.BucketName == "" ||
req.Multimodal.COS.AppID == "" {
logger.Error(ctx, "COS configuration incomplete")
c.Error(errors.NewBadRequestError("COS配置不完整"))
return
}
case "minio":
if req.Multimodal.Minio == nil || req.Multimodal.Minio.BucketName == "" ||
os.Getenv("MINIO_ACCESS_KEY_ID") == "" || os.Getenv("MINIO_SECRET_ACCESS_KEY") == "" {
logger.Error(ctx, "MinIO configuration incomplete")
c.Error(errors.NewBadRequestError("MinIO配置不完整"))
return
}
}
}
// 验证Rerank配置如果启用
if req.Rerank.Enabled {
if req.Rerank.ModelName == "" || req.Rerank.BaseURL == "" {
logger.Error(ctx, "Rerank configuration incomplete")
c.Error(errors.NewBadRequestError("启用Rerank时需要配置模型名称和Base URL"))
return
}
}
// 验证文档分割配置
if req.DocumentSplitting.ChunkOverlap >= req.DocumentSplitting.ChunkSize {
logger.Error(ctx, "Chunk overlap must be less than chunk size")
c.Error(errors.NewBadRequestError("分块重叠大小必须小于分块大小"))
return
}
if len(req.DocumentSplitting.Separators) == 0 {
logger.Error(ctx, "Document separators cannot be empty")
c.Error(errors.NewBadRequestError("文档分隔符不能为空"))
return
}
var err error
// 1. 处理租户 - 检查是否存在,不存在则创建
tenant, _ := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID)
if tenant == nil {
logger.Info(ctx, "Tenant not found, creating tenant")
// 创建默认租户
tenant = &types.Tenant{
ID: types.InitDefaultTenantID,
Name: "Default Tenant",
Description: "System Default Tenant",
RetrieverEngines: types.RetrieverEngines{
Engines: []types.RetrieverEngineParams{
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
{
RetrieverType: types.VectorRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
},
},
}
logger.Info(ctx, "Creating default tenant")
tenant, err = h.tenantService.CreateTenant(ctx, tenant)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("创建租户失败: " + err.Error()))
return
}
} else {
logger.Info(ctx, "Tenant exists, updating if needed")
// 更新租户信息(如果需要)
updated := false
if tenant.Name != "Default Tenant" {
tenant.Name = "Default Tenant"
updated = true
}
if tenant.Description != "System Default Tenant" {
tenant.Description = "System Default Tenant"
updated = true
}
if updated {
_, err = h.tenantService.UpdateTenant(ctx, tenant)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("更新租户失败: " + err.Error()))
return
}
logger.Info(ctx, "Tenant updated successfully")
}
}
// 创建带有租户ID的新上下文
newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
// 2. 处理模型 - 检查现有模型并更新或创建
existingModels, err := h.modelService.ListModels(newCtx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
// 如果获取失败,继续执行创建流程
existingModels = []*types.Model{}
}
// 构建模型映射,按类型分组
modelMap := make(map[types.ModelType]*types.Model)
for _, model := range existingModels {
modelMap[model.Type] = model
}
// 要处理的模型配置
modelsToProcess := []struct {
modelType types.ModelType
name string
source types.ModelSource
description string
baseURL string
apiKey string
dimension int
}{
{
modelType: types.ModelTypeKnowledgeQA,
name: req.LLM.ModelName,
source: types.ModelSource(req.LLM.Source),
description: "LLM Model for Knowledge QA",
baseURL: req.LLM.BaseURL,
apiKey: req.LLM.APIKey,
},
{
modelType: types.ModelTypeEmbedding,
name: req.Embedding.ModelName,
source: types.ModelSource(req.Embedding.Source),
description: "Embedding Model",
baseURL: req.Embedding.BaseURL,
apiKey: req.Embedding.APIKey,
dimension: req.Embedding.Dimension,
},
}
// 如果启用Rerank添加Rerank模型
if req.Rerank.Enabled {
modelsToProcess = append(modelsToProcess, struct {
modelType types.ModelType
name string
source types.ModelSource
description string
baseURL string
apiKey string
dimension int
}{
modelType: types.ModelTypeRerank,
name: req.Rerank.ModelName,
source: types.ModelSourceRemote,
description: "Rerank Model",
baseURL: req.Rerank.BaseURL,
apiKey: req.Rerank.APIKey,
})
}
// 如果启用多模态添加VLM模型
if req.Multimodal.Enabled && req.Multimodal.VLM != nil {
modelsToProcess = append(modelsToProcess, struct {
modelType types.ModelType
name string
source types.ModelSource
description string
baseURL string
apiKey string
dimension int
}{
modelType: types.ModelTypeVLLM,
name: req.Multimodal.VLM.ModelName,
source: types.ModelSourceRemote,
description: "Vision Language Model",
baseURL: req.Multimodal.VLM.BaseURL,
apiKey: req.Multimodal.VLM.APIKey,
})
}
// 处理每个模型
var processedModels []*types.Model
for _, modelConfig := range modelsToProcess {
existingModel, exists := modelMap[modelConfig.modelType]
if exists {
// 更新现有模型
logger.Infof(ctx, "Updating existing model: %s (%s)",
modelConfig.name, modelConfig.modelType,
)
existingModel.Name = modelConfig.name
existingModel.Source = modelConfig.source
existingModel.Description = modelConfig.description
existingModel.Parameters = types.ModelParameters{
BaseURL: modelConfig.baseURL,
APIKey: modelConfig.apiKey,
EmbeddingParameters: types.EmbeddingParameters{
Dimension: modelConfig.dimension,
},
}
existingModel.IsDefault = true
existingModel.Status = types.ModelStatusActive
err := h.modelService.UpdateModel(newCtx, existingModel)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": modelConfig.name,
"model_type": modelConfig.modelType,
})
c.Error(errors.NewInternalServerError("更新模型失败: " + err.Error()))
return
}
processedModels = append(processedModels, existingModel)
} else {
// 创建新模型
logger.Infof(ctx,
"Creating new model: %s (%s)",
modelConfig.name, modelConfig.modelType,
)
newModel := &types.Model{
TenantID: types.InitDefaultTenantID,
Name: modelConfig.name,
Type: modelConfig.modelType,
Source: modelConfig.source,
Description: modelConfig.description,
Parameters: types.ModelParameters{
BaseURL: modelConfig.baseURL,
APIKey: modelConfig.apiKey,
EmbeddingParameters: types.EmbeddingParameters{
Dimension: modelConfig.dimension,
},
},
IsDefault: true,
Status: types.ModelStatusActive,
}
err := h.modelService.CreateModel(newCtx, newModel)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": modelConfig.name,
"model_type": modelConfig.modelType,
})
c.Error(errors.NewInternalServerError("创建模型失败: " + err.Error()))
return
}
processedModels = append(processedModels, newModel)
}
}
// 删除不需要的VLM模型如果多模态被禁用
if !req.Multimodal.Enabled {
if existingVLM, exists := modelMap[types.ModelTypeVLLM]; exists {
logger.Info(ctx, "Deleting VLM model as multimodal is disabled")
err := h.modelService.DeleteModel(newCtx, existingVLM.ID)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": existingVLM.ID,
})
// 记录错误但不阻止流程
logger.Warn(ctx, "Failed to delete VLM model, but continuing")
}
}
}
// 删除不需要的Rerank模型如果Rerank被禁用
if !req.Rerank.Enabled {
if existingRerank, exists := modelMap[types.ModelTypeRerank]; exists {
logger.Info(ctx, "Deleting Rerank model as rerank is disabled")
err := h.modelService.DeleteModel(newCtx, existingRerank.ID)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": existingRerank.ID,
})
// 记录错误但不阻止流程
logger.Warn(ctx, "Failed to delete Rerank model, but continuing")
}
}
}
// 3. 处理知识库 - 检查是否存在,不存在则创建,存在则更新
kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
// 找到embedding模型ID和LLM模型ID
var embeddingModelID, llmModelID, rerankModelID, vlmModelID string
for _, model := range processedModels {
if model.Type == types.ModelTypeEmbedding {
embeddingModelID = model.ID
}
if model.Type == types.ModelTypeKnowledgeQA {
llmModelID = model.ID
}
if model.Type == types.ModelTypeRerank && req.Rerank.Enabled {
rerankModelID = model.ID
}
if model.Type == types.ModelTypeVLLM {
vlmModelID = model.ID
}
}
if kb == nil {
// 创建新知识库
logger.Info(ctx, "Creating default knowledge base")
kb = &types.KnowledgeBase{
ID: types.InitDefaultKnowledgeBaseID,
Name: "Default Knowledge Base",
Description: "System Default Knowledge Base",
TenantID: types.InitDefaultTenantID,
ChunkingConfig: types.ChunkingConfig{
ChunkSize: req.DocumentSplitting.ChunkSize,
ChunkOverlap: req.DocumentSplitting.ChunkOverlap,
Separators: req.DocumentSplitting.Separators,
EnableMultimodal: req.Multimodal.Enabled,
},
EmbeddingModelID: embeddingModelID,
SummaryModelID: llmModelID,
RerankModelID: rerankModelID,
VLMModelID: vlmModelID,
VLMConfig: types.VLMConfig{
ModelName: req.Multimodal.VLM.ModelName,
BaseURL: req.Multimodal.VLM.BaseURL,
APIKey: req.Multimodal.VLM.APIKey,
InterfaceType: req.Multimodal.VLM.InterfaceType,
},
}
switch req.StorageType {
case "cos":
if req.Multimodal.COS != nil {
kb.StorageConfig = types.StorageConfig{
Provider: req.StorageType,
BucketName: req.Multimodal.COS.BucketName,
AppID: req.Multimodal.COS.AppID,
PathPrefix: req.Multimodal.COS.PathPrefix,
SecretID: req.Multimodal.COS.SecretID,
SecretKey: req.Multimodal.COS.SecretKey,
Region: req.Multimodal.COS.Region,
}
}
case "minio":
if req.Multimodal.Minio != nil {
kb.StorageConfig = types.StorageConfig{
Provider: req.StorageType,
BucketName: req.Multimodal.Minio.BucketName,
PathPrefix: req.Multimodal.Minio.PathPrefix,
SecretID: os.Getenv("MINIO_ACCESS_KEY_ID"),
SecretKey: os.Getenv("MINIO_SECRET_ACCESS_KEY"),
}
}
}
_, err = h.kbService.CreateKnowledgeBase(newCtx, kb)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("创建知识库失败: " + err.Error()))
return
}
} else {
// 更新现有知识库
logger.Info(ctx, "Updating existing knowledge base")
// 检查是否有文件如果有文件则不允许修改Embedding模型
knowledgeList, err := h.knowledgeService.ListKnowledgeByKnowledgeBaseID(
newCtx, types.InitDefaultKnowledgeBaseID,
)
hasFiles := err == nil && len(knowledgeList) > 0
// 先更新模型ID直接在对象上
kb.SummaryModelID = llmModelID
if req.Rerank.Enabled {
kb.RerankModelID = rerankModelID
} else {
kb.RerankModelID = "" // 清空Rerank模型ID
}
if req.Multimodal.Enabled {
kb.VLMModelID = vlmModelID
// 更新VLM配置
kb.VLMConfig = types.VLMConfig{
ModelName: req.Multimodal.VLM.ModelName,
BaseURL: req.Multimodal.VLM.BaseURL,
APIKey: req.Multimodal.VLM.APIKey,
InterfaceType: req.Multimodal.VLM.InterfaceType,
}
switch req.StorageType {
case "cos":
if req.Multimodal.COS != nil {
kb.StorageConfig = types.StorageConfig{
Provider: req.StorageType,
SecretID: req.Multimodal.COS.SecretID,
SecretKey: req.Multimodal.COS.SecretKey,
Region: req.Multimodal.COS.Region,
BucketName: req.Multimodal.COS.BucketName,
AppID: req.Multimodal.COS.AppID,
PathPrefix: req.Multimodal.COS.PathPrefix,
}
}
case "minio":
if req.Multimodal.Minio != nil {
kb.StorageConfig = types.StorageConfig{
Provider: req.StorageType,
BucketName: req.Multimodal.Minio.BucketName,
PathPrefix: req.Multimodal.Minio.PathPrefix,
SecretID: os.Getenv("MINIO_ACCESS_KEY_ID"),
SecretKey: os.Getenv("MINIO_SECRET_ACCESS_KEY"),
}
}
}
} else {
kb.VLMModelID = "" // 清空VLM模型ID
// 清空VLM配置
kb.VLMConfig = types.VLMConfig{}
kb.StorageConfig = types.StorageConfig{}
}
if !hasFiles {
kb.EmbeddingModelID = embeddingModelID
}
kb.ChunkingConfig = types.ChunkingConfig{
ChunkSize: req.DocumentSplitting.ChunkSize,
ChunkOverlap: req.DocumentSplitting.ChunkOverlap,
Separators: req.DocumentSplitting.Separators,
EnableMultimodal: req.Multimodal.Enabled,
}
// 更新基本信息和配置
err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("更新知识库配置失败: " + err.Error()))
return
}
// 如果需要更新模型ID使用repository直接更新
if !hasFiles || kb.SummaryModelID != llmModelID {
// 刷新知识库对象以获取最新信息
kb, err = h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("获取更新后的知识库失败: " + err.Error()))
return
}
// 更新模型ID
kb.SummaryModelID = llmModelID
if req.Rerank.Enabled {
kb.RerankModelID = rerankModelID
} else {
kb.RerankModelID = "" // 清空Rerank模型ID
}
// 使用repository直接更新模型ID
err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("更新知识库模型ID失败: " + err.Error()))
return
}
logger.Info(ctx, "Model IDs updated successfully")
}
}
logger.Info(ctx, "System initialization completed successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "系统初始化成功",
"data": gin.H{
"tenant": tenant,
"models": processedModels,
"knowledge_base": kb,
},
})
}
// CheckOllamaStatus 检查Ollama服务状态
func (h *InitializationHandler) CheckOllamaStatus(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Checking Ollama service status")
// Determine Ollama base URL for display
baseURL := os.Getenv("OLLAMA_BASE_URL")
if baseURL == "" {
baseURL = "http://host.docker.internal:11434"
}
// 检查Ollama服务是否可用
err := h.ollamaService.StartService(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"available": false,
"error": err.Error(),
"baseUrl": baseURL,
},
})
return
}
version, err := h.ollamaService.GetVersion(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
version = "unknown"
}
logger.Info(ctx, "Ollama service is available")
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"available": h.ollamaService.IsAvailable(),
"version": version,
"baseUrl": baseURL,
},
})
}
// CheckOllamaModels 检查Ollama模型状态
func (h *InitializationHandler) CheckOllamaModels(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Checking Ollama models status")
var req struct {
Models []string `json:"models" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse models check request", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// 检查Ollama服务是否可用
if !h.ollamaService.IsAvailable() {
err := h.ollamaService.StartService(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error()))
return
}
}
modelStatus := make(map[string]bool)
// 检查每个模型是否存在
for _, modelName := range req.Models {
checkModelName := modelName
if !strings.Contains(modelName, ":") {
checkModelName = modelName + ":latest"
}
available, err := h.ollamaService.IsModelAvailable(ctx, checkModelName)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": modelName,
})
modelStatus[modelName] = false
} else {
modelStatus[modelName] = available
}
logger.Infof(ctx, "Model %s availability: %v", modelName, modelStatus[modelName])
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"models": modelStatus,
},
})
}
// DownloadOllamaModel 异步下载Ollama模型
func (h *InitializationHandler) DownloadOllamaModel(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Starting async Ollama model download")
var req struct {
ModelName string `json:"modelName" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse model download request", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// 检查Ollama服务是否可用
if !h.ollamaService.IsAvailable() {
err := h.ollamaService.StartService(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error()))
return
}
}
// 检查模型是否已存在
available, err := h.ollamaService.IsModelAvailable(ctx, req.ModelName)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": req.ModelName,
})
c.Error(errors.NewInternalServerError("检查模型状态失败: " + err.Error()))
return
}
if available {
logger.Infof(ctx, "Model %s already exists", req.ModelName)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "模型已存在",
"data": gin.H{
"modelName": req.ModelName,
"status": "completed",
"progress": 100.0,
},
})
return
}
// 检查是否已有相同模型的下载任务
tasksMutex.RLock()
for _, task := range downloadTasks {
if task.ModelName == req.ModelName && (task.Status == "pending" || task.Status == "downloading") {
tasksMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "模型下载任务已存在",
"data": gin.H{
"taskId": task.ID,
"modelName": task.ModelName,
"status": task.Status,
"progress": task.Progress,
},
})
return
}
}
tasksMutex.RUnlock()
// 创建下载任务
taskID := uuid.New().String()
task := &DownloadTask{
ID: taskID,
ModelName: req.ModelName,
Status: "pending",
Progress: 0.0,
Message: "准备下载",
StartTime: time.Now(),
}
tasksMutex.Lock()
downloadTasks[taskID] = task
tasksMutex.Unlock()
// 启动异步下载
newCtx, cancel := context.WithTimeout(context.Background(), 12*time.Hour)
go func() {
defer cancel()
h.downloadModelAsync(newCtx, taskID, req.ModelName)
}()
logger.Infof(ctx, "Created download task for model: %s, task ID: %s", req.ModelName, taskID)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "模型下载任务已创建",
"data": gin.H{
"taskId": taskID,
"modelName": req.ModelName,
"status": "pending",
"progress": 0.0,
},
})
}
// GetDownloadProgress 获取下载进度
func (h *InitializationHandler) GetDownloadProgress(c *gin.Context) {
taskID := c.Param("taskId")
if taskID == "" {
c.Error(errors.NewBadRequestError("任务ID不能为空"))
return
}
tasksMutex.RLock()
task, exists := downloadTasks[taskID]
tasksMutex.RUnlock()
if !exists {
c.Error(errors.NewNotFoundError("下载任务不存在"))
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": task,
})
}
// ListDownloadTasks 列出所有下载任务
func (h *InitializationHandler) ListDownloadTasks(c *gin.Context) {
tasksMutex.RLock()
tasks := make([]*DownloadTask, 0, len(downloadTasks))
for _, task := range downloadTasks {
tasks = append(tasks, task)
}
tasksMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": tasks,
})
}
// ListOllamaModels 列出已安装的 Ollama 模型
func (h *InitializationHandler) ListOllamaModels(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Listing installed Ollama models")
// 确保服务可用
if !h.ollamaService.IsAvailable() {
if err := h.ollamaService.StartService(ctx); err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error()))
return
}
}
models, err := h.ollamaService.ListModels(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("获取模型列表失败: " + err.Error()))
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"models": models,
},
})
}
// downloadModelAsync 异步下载模型
func (h *InitializationHandler) downloadModelAsync(ctx context.Context,
taskID, modelName string,
) {
logger.Infof(ctx, "Starting async download for model: %s, task: %s", modelName, taskID)
// 更新任务状态为下载中
h.updateTaskStatus(taskID, "downloading", 0.0, "开始下载模型")
// 执行下载,带进度回调
err := h.pullModelWithProgress(ctx, modelName, func(progress float64, message string) {
h.updateTaskStatus(taskID, "downloading", progress, message)
})
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": modelName,
"task_id": taskID,
})
h.updateTaskStatus(taskID, "failed", 0.0, fmt.Sprintf("下载失败: %v", err))
return
}
// 下载成功
logger.Infof(ctx, "Model %s downloaded successfully, task: %s", modelName, taskID)
h.updateTaskStatus(taskID, "completed", 100.0, "下载完成")
}
// pullModelWithProgress 下载模型并提供进度回调
func (h *InitializationHandler) pullModelWithProgress(ctx context.Context,
modelName string,
progressCallback func(float64, string),
) error {
// 检查服务是否可用
if err := h.ollamaService.StartService(ctx); err != nil {
logger.ErrorWithFields(ctx, err, nil)
return err
}
// 检查模型是否已存在
available, err := h.ollamaService.IsModelAvailable(ctx, modelName)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": modelName,
})
return err
}
if available {
progressCallback(100.0, "模型已存在")
return nil
}
logger.GetLogger(ctx).Infof("Pulling model %s...", modelName)
// 创建下载请求
pullReq := &api.PullRequest{
Name: modelName,
}
// 使用Ollama客户端的Pull方法带进度回调
err = h.ollamaService.GetClient().Pull(ctx, pullReq, func(progress api.ProgressResponse) error {
var progressPercent float64 = 0.0
var message string = "下载中"
if progress.Total > 0 && progress.Completed > 0 {
progressPercent = float64(progress.Completed) / float64(progress.Total) * 100
message = fmt.Sprintf("下载中: %.1f%% (%s)", progressPercent, progress.Status)
} else if progress.Status != "" {
message = progress.Status
}
// 调用进度回调
progressCallback(progressPercent, message)
logger.Infof(ctx,
"Download progress for %s: %.2f%% - %s",
modelName, progressPercent, message,
)
return nil
})
if err != nil {
return fmt.Errorf("failed to pull model: %w", err)
}
return nil
}
// updateTaskStatus 更新任务状态
func (h *InitializationHandler) updateTaskStatus(
taskID, status string, progress float64, message string,
) {
tasksMutex.Lock()
defer tasksMutex.Unlock()
if task, exists := downloadTasks[taskID]; exists {
task.Status = status
task.Progress = progress
task.Message = message
if status == "completed" || status == "failed" {
now := time.Now()
task.EndTime = &now
}
}
}
// GetCurrentConfig 获取当前系统配置信息
func (h *InitializationHandler) GetCurrentConfig(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Getting current system configuration")
// 设置租户上下文
newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
// 获取模型信息
models, err := h.modelService.ListModels(newCtx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("获取模型列表失败: " + err.Error()))
return
}
// 获取知识库信息
kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("获取知识库信息失败: " + err.Error()))
return
}
// 检查知识库是否有文件
knowledgeList, err := h.knowledgeService.ListPagedKnowledgeByKnowledgeBaseID(newCtx,
types.InitDefaultKnowledgeBaseID, &types.Pagination{
Page: 1,
PageSize: 1,
})
hasFiles := false
if err == nil && knowledgeList != nil && knowledgeList.Total > 0 {
hasFiles = true
}
// 构建配置响应
config := buildConfigResponse(models, kb, hasFiles)
logger.Info(ctx, "Current system configuration retrieved successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": config,
})
}
// buildConfigResponse 构建配置响应数据
func buildConfigResponse(models []*types.Model,
kb *types.KnowledgeBase, hasFiles bool,
) map[string]interface{} {
config := map[string]interface{}{
"hasFiles": hasFiles,
}
// 按类型分组模型
for _, model := range models {
switch model.Type {
case types.ModelTypeKnowledgeQA:
config["llm"] = map[string]interface{}{
"source": string(model.Source),
"modelName": model.Name,
"baseUrl": model.Parameters.BaseURL,
"apiKey": model.Parameters.APIKey,
}
case types.ModelTypeEmbedding:
config["embedding"] = map[string]interface{}{
"source": string(model.Source),
"modelName": model.Name,
"baseUrl": model.Parameters.BaseURL,
"apiKey": model.Parameters.APIKey,
"dimension": model.Parameters.EmbeddingParameters.Dimension,
}
case types.ModelTypeRerank:
config["rerank"] = map[string]interface{}{
"enabled": true,
"modelName": model.Name,
"baseUrl": model.Parameters.BaseURL,
"apiKey": model.Parameters.APIKey,
}
case types.ModelTypeVLLM:
if config["multimodal"] == nil {
config["multimodal"] = map[string]interface{}{
"enabled": true,
}
}
multimodal := config["multimodal"].(map[string]interface{})
multimodal["vlm"] = map[string]interface{}{
"modelName": model.Name,
"baseUrl": model.Parameters.BaseURL,
"apiKey": model.Parameters.APIKey,
"interfaceType": kb.VLMConfig.InterfaceType,
}
}
}
// 如果没有VLM模型设置multimodal为disabled
if config["multimodal"] == nil {
config["multimodal"] = map[string]interface{}{
"enabled": false,
}
}
// 如果没有Rerank模型设置rerank为disabled
if config["rerank"] == nil {
config["rerank"] = map[string]interface{}{
"enabled": false,
"modelName": "",
"baseUrl": "",
"apiKey": "",
}
}
// 添加知识库的文档分割配置
if kb != nil {
config["documentSplitting"] = map[string]interface{}{
"chunkSize": kb.ChunkingConfig.ChunkSize,
"chunkOverlap": kb.ChunkingConfig.ChunkOverlap,
"separators": kb.ChunkingConfig.Separators,
}
// 添加多模态的COS配置信息
if kb.StorageConfig.SecretID != "" {
if config["multimodal"] == nil {
config["multimodal"] = map[string]interface{}{
"enabled": true,
}
}
multimodal := config["multimodal"].(map[string]interface{})
multimodal["storageType"] = kb.StorageConfig.Provider
switch kb.StorageConfig.Provider {
case "cos":
multimodal["cos"] = map[string]interface{}{
"secretId": kb.StorageConfig.SecretID,
"secretKey": kb.StorageConfig.SecretKey,
"region": kb.StorageConfig.Region,
"bucketName": kb.StorageConfig.BucketName,
"appId": kb.StorageConfig.AppID,
"pathPrefix": kb.StorageConfig.PathPrefix,
}
case "minio":
multimodal["minio"] = map[string]interface{}{
"bucketName": kb.StorageConfig.BucketName,
"pathPrefix": kb.StorageConfig.PathPrefix,
}
}
}
}
return config
}
// RemoteModelCheckRequest 远程模型检查请求结构
type RemoteModelCheckRequest struct {
ModelName string `json:"modelName" binding:"required"`
BaseURL string `json:"baseUrl" binding:"required"`
APIKey string `json:"apiKey"`
}
// CheckRemoteModel 检查远程API模型连接
func (h *InitializationHandler) CheckRemoteModel(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Checking remote model connection")
var req RemoteModelCheckRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse remote model check request", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// 验证请求参数
if req.ModelName == "" || req.BaseURL == "" {
logger.Error(ctx, "Model name and base URL are required")
c.Error(errors.NewBadRequestError("模型名称和Base URL不能为空"))
return
}
// 创建模型配置进行测试
modelConfig := &types.Model{
Name: req.ModelName,
Source: "remote",
Parameters: types.ModelParameters{
BaseURL: req.BaseURL,
APIKey: req.APIKey,
},
Type: "llm", // 默认类型,实际检查时不区分具体类型
}
// 检查远程模型连接
available, message := h.checkRemoteModelConnection(ctx, modelConfig)
logger.Info(ctx,
fmt.Sprintf(
"Remote model check completed: modelName=%s, baseUrl=%s, available=%v, message=%s",
req.ModelName, req.BaseURL, available, message,
),
)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"available": available,
"message": message,
},
})
}
// TestEmbeddingModel 测试 Embedding 接口(本地或远程)是否可用
func (h *InitializationHandler) TestEmbeddingModel(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Testing embedding model connectivity and functionality")
var req struct {
Source string `json:"source" binding:"required"`
ModelName string `json:"modelName" binding:"required"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
Dimension int `json:"dimension"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse embedding test request", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// 构造 embedder 配置
cfg := embedding.Config{
Source: types.ModelSource(strings.ToLower(req.Source)),
BaseURL: req.BaseURL,
ModelName: req.ModelName,
APIKey: req.APIKey,
TruncatePromptTokens: 256,
Dimensions: req.Dimension,
ModelID: "",
}
emb, err := embedding.NewEmbedder(cfg)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName})
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{`available`: false, `message`: fmt.Sprintf("创建Embedder失败: %v", err), `dimension`: 0},
})
return
}
// 执行一次最小化 embedding 调用
sample := "hello"
vec, err := emb.Embed(ctx, sample)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName})
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{`available`: false, `message`: fmt.Sprintf("调用Embedding失败: %v", err), `dimension`: 0},
})
return
}
logger.Infof(ctx, "Embedding test succeeded, dim=%d", len(vec))
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{`available`: true, `message`: fmt.Sprintf("测试成功,向量维度=%d", len(vec)), `dimension`: len(vec)},
})
}
// checkRemoteModelConnection 检查远程模型连接的内部方法
func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
model *types.Model,
) (bool, string) {
// 这里需要根据实际情况实现远程API的连接检查
// 可以发送一个简单的请求来验证连接和认证
client := &http.Client{
Timeout: 10 * time.Second,
}
// 根据不同的API类型构造测试请求
testEndpoint := ""
if model.Parameters.BaseURL != "" {
testEndpoint = model.Parameters.BaseURL + "/models"
}
req, err := http.NewRequestWithContext(ctx, "GET", testEndpoint, nil)
if err != nil {
return false, fmt.Sprintf("创建请求失败: %v", err)
}
// 添加认证头
if model.Parameters.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+model.Parameters.APIKey)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return false, fmt.Sprintf("连接失败: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
// 连接成功,现在检查模型是否存在
return true, "连接正常,请自行确保模型存在"
} else if resp.StatusCode == 401 {
return false, "认证失败请检查API Key"
} else if resp.StatusCode == 403 {
return false, "权限不足请检查API Key权限"
} else if resp.StatusCode == 404 {
return false, "API端点不存在请检查Base URL"
} else {
return false, fmt.Sprintf("API返回错误状态: %d", resp.StatusCode)
}
}
// checkModelExistence 检查指定模型是否在模型列表中存在
func (h *InitializationHandler) checkModelExistence(ctx context.Context,
resp *http.Response, modelName string) (bool, string) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return true, "连接正常,但无法验证模型列表"
}
var modelsResp struct {
Data []struct {
ID string `json:"id"`
Object string `json:"object"`
} `json:"data"`
Object string `json:"object"`
}
// 尝试解析模型列表响应
if err := json.Unmarshal(body, &modelsResp); err != nil {
// 如果无法解析可能是非标准API只要连接成功就认为可用
return true, "连接正常"
}
// 检查模型是否在列表中
for _, model := range modelsResp.Data {
if model.ID == modelName {
return true, "连接正常,模型存在"
}
}
// 模型不在列表中,返回可用的模型建议
if len(modelsResp.Data) > 0 {
availableModels := make([]string, 0, min(3, len(modelsResp.Data)))
for i, model := range modelsResp.Data {
if i >= 3 {
break
}
availableModels = append(availableModels, model.ID)
}
return false, fmt.Sprintf("模型 '%s' 不存在,可用模型: %s", modelName, strings.Join(availableModels, ", "))
}
return false, fmt.Sprintf("模型 '%s' 不存在", modelName)
}
// min returns the minimum of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}
// checkRerankModelConnection 检查Rerank模型连接和功能的内部方法
func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context,
modelName, baseURL, apiKey string) (bool, string) {
client := &http.Client{
Timeout: 15 * time.Second,
}
// 构造重排API端点
rerankEndpoint := baseURL + "/rerank"
// Mock测试数据
testQuery := "什么是人工智能?"
testPassages := []string{
"机器学习是人工智能的一个子领域,专注于算法和统计模型,使计算机系统能够通过经验自动改进。",
"深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。",
}
// 构造重排请求
rerankRequest := map[string]interface{}{
"model": modelName,
"query": testQuery,
"documents": testPassages,
"truncate_prompt_tokens": 512,
}
jsonData, err := json.Marshal(rerankRequest)
if err != nil {
return false, fmt.Sprintf("构造请求失败: %v", err)
}
logger.Infof(ctx, "Rerank request: %s, modelName=%s, baseURL=%s, apiKey=%s",
string(jsonData), modelName, baseURL, apiKey)
req, err := http.NewRequestWithContext(
ctx, "POST", rerankEndpoint, strings.NewReader(string(jsonData)),
)
if err != nil {
return false, fmt.Sprintf("创建请求失败: %v", err)
}
// 添加认证头
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return false, fmt.Sprintf("连接失败: %v", err)
}
defer resp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
return false, fmt.Sprintf("读取响应失败: %v", err)
}
// 检查响应状态
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
// 尝试解析重排响应
var rerankResp struct {
Results []struct {
Index int `json:"index"`
Document string `json:"document"`
RelevanceScore float64 `json:"relevance_score"`
} `json:"results"`
}
if err := json.Unmarshal(body, &rerankResp); err != nil {
// 如果无法解析标准重排响应,检查是否有其他格式
return true, "连接正常,但响应格式非标准"
}
// 检查是否返回了重排结果
if len(rerankResp.Results) > 0 {
return true, fmt.Sprintf("重排功能正常,返回%d个结果", len(rerankResp.Results))
} else {
return false, "重排接口连接成功,但未返回重排结果"
}
} else if resp.StatusCode == 401 {
return false, "认证失败请检查API Key"
} else if resp.StatusCode == 403 {
return false, "权限不足请检查API Key权限"
} else if resp.StatusCode == 404 {
return false, "重排API端点不存在请检查Base URL"
} else if resp.StatusCode == 422 {
return false, fmt.Sprintf("请求参数错误: %s", string(body))
} else {
return false, fmt.Sprintf("API返回错误状态: %d, 响应: %s", resp.StatusCode, string(body))
}
}
// CheckRerankModel 检查Rerank模型连接和功能
func (h *InitializationHandler) CheckRerankModel(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Checking rerank model connection and functionality")
var req struct {
ModelName string `json:"modelName" binding:"required"`
BaseURL string `json:"baseUrl" binding:"required"`
APIKey string `json:"apiKey"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse rerank model check request", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// 验证请求参数
if req.ModelName == "" || req.BaseURL == "" {
logger.Error(ctx, "Model name and base URL are required")
c.Error(errors.NewBadRequestError("模型名称和Base URL不能为空"))
return
}
// 检查Rerank模型连接和功能
available, message := h.checkRerankModelConnection(
ctx, req.ModelName, req.BaseURL, req.APIKey,
)
logger.Info(ctx,
fmt.Sprintf("Rerank model check completed: modelName=%s, baseUrl=%s, available=%v, message=%s",
req.ModelName, req.BaseURL, available, message,
),
)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"available": available,
"message": message,
},
})
}
// 使用结构体解析表单数据
type testMultimodalForm struct {
VLMModel string `form:"vlm_model"`
VLMBaseURL string `form:"vlm_base_url"`
VLMAPIKey string `form:"vlm_api_key"`
VLMInterfaceType string `form:"vlm_interface_type"`
StorageType string `form:"storage_type"`
// COS 配置
COSSecretID string `form:"cos_secret_id"`
COSSecretKey string `form:"cos_secret_key"`
COSRegion string `form:"cos_region"`
COSBucketName string `form:"cos_bucket_name"`
COSAppID string `form:"cos_app_id"`
COSPathPrefix string `form:"cos_path_prefix"`
// MinIO 配置(当存储为 minio 时)
MinioBucketName string `form:"minio_bucket_name"`
MinioPathPrefix string `form:"minio_path_prefix"`
// 文档切分配置(字符串后续自行解析,以避免类型绑定失败)
ChunkSize string `form:"chunk_size"`
ChunkOverlap string `form:"chunk_overlap"`
SeparatorsRaw string `form:"separators"`
}
// TestMultimodalFunction 测试多模态功能
func (h *InitializationHandler) TestMultimodalFunction(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Testing multimodal functionality")
var req testMultimodalForm
if err := c.ShouldBind(&req); err != nil {
logger.Error(ctx, "Failed to parse form data", err)
c.Error(errors.NewBadRequestError("表单参数解析失败"))
return
}
// ollama 场景自动拼接 base url
if req.VLMInterfaceType == "ollama" {
req.VLMBaseURL = os.Getenv("OLLAMA_BASE_URL") + "/v1"
}
req.StorageType = strings.ToLower(req.StorageType)
if req.VLMModel == "" || req.VLMBaseURL == "" {
logger.Error(ctx, "VLM model name and base URL are required")
c.Error(errors.NewBadRequestError("VLM模型名称和Base URL不能为空"))
return
}
switch req.StorageType {
case "cos":
logger.Infof(ctx, "COS config: Region=%s, Bucket=%s, App=%s, Prefix=%s",
req.COSRegion, req.COSBucketName, req.COSAppID, req.COSPathPrefix)
// 必填SecretID/SecretKey/Region/BucketName/AppIDPathPrefix 可选
if req.COSSecretID == "" || req.COSSecretKey == "" ||
req.COSRegion == "" || req.COSBucketName == "" ||
req.COSAppID == "" {
logger.Error(ctx, "COS configuration is required")
c.Error(errors.NewBadRequestError("COS配置信息不能为空"))
return
}
case "minio":
logger.Infof(ctx, "MinIO config: Bucket=%s, PathPrefix=%s", req.MinioBucketName, req.MinioPathPrefix)
if req.MinioBucketName == "" {
logger.Error(ctx, "MinIO configuration is required")
c.Error(errors.NewBadRequestError("MinIO配置信息不能为空"))
return
}
default:
logger.Error(ctx, "Invalid storage type")
c.Error(errors.NewBadRequestError("无效的存储类型"))
return
}
logger.Infof(ctx, "VLM config: Model=%s, URL=%s, HasKey=%v, Type=%s",
req.VLMModel, req.VLMBaseURL, req.VLMAPIKey != "", req.VLMInterfaceType)
// 获取上传的图片文件
file, header, err := c.Request.FormFile("image")
if err != nil {
logger.Error(ctx, "Failed to get uploaded image", err)
c.Error(errors.NewBadRequestError("获取上传图片失败"))
return
}
defer file.Close()
// 验证文件类型
if !strings.HasPrefix(header.Header.Get("Content-Type"), "image/") {
logger.Error(ctx, "Invalid file type, only images are allowed")
c.Error(errors.NewBadRequestError("只允许上传图片文件"))
return
}
// 验证文件大小 (10MB)
if header.Size > 10*1024*1024 {
logger.Error(ctx, "File size too large")
c.Error(errors.NewBadRequestError("图片文件大小不能超过10MB"))
return
}
logger.Infof(ctx, "Processing image: %s, size: %d bytes", header.Filename, header.Size)
// 解析文档分割配置
chunkSize, err := strconv.Atoi(req.ChunkSize)
if err != nil || chunkSize < 100 || chunkSize > 10000 {
chunkSize = 1000
}
chunkOverlap, err := strconv.Atoi(req.ChunkOverlap)
if err != nil || chunkOverlap < 0 || chunkOverlap >= chunkSize {
chunkOverlap = 200
}
var separators []string
if req.SeparatorsRaw != "" {
if err := json.Unmarshal([]byte(req.SeparatorsRaw), &separators); err != nil {
separators = []string{"\n\n", "\n", "。", "", "", ";", ""}
}
} else {
separators = []string{"\n\n", "\n", "。", "", "", ";", ""}
}
// 读取图片文件内容
imageContent, err := io.ReadAll(file)
if err != nil {
logger.Error(ctx, "Failed to read image file", err)
c.Error(errors.NewBadRequestError("读取图片文件失败"))
return
}
// 调用多模态测试
startTime := time.Now()
result, err := h.testMultimodalWithDocReader(
ctx,
imageContent, header.Filename,
chunkSize, chunkOverlap, separators, &req,
)
processingTime := time.Since(startTime).Milliseconds()
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"vlm_model": req.VLMModel,
"vlm_base_url": req.VLMBaseURL,
"filename": header.Filename,
})
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"success": false,
"message": err.Error(),
"processing_time": processingTime,
},
})
return
}
logger.Info(ctx, fmt.Sprintf("Multimodal test completed successfully in %dms", processingTime))
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"success": true,
"caption": result["caption"],
"ocr": result["ocr"],
"processing_time": processingTime,
},
})
}
// testMultimodalWithDocReader 调用docreader服务进行多模态处理
func (h *InitializationHandler) testMultimodalWithDocReader(
ctx context.Context,
imageContent []byte, filename string,
chunkSize, chunkOverlap int, separators []string,
req *testMultimodalForm,
) (map[string]string, error) {
// 获取文件扩展名
fileExt := ""
if idx := strings.LastIndex(filename, "."); idx != -1 {
fileExt = strings.ToLower(filename[idx+1:])
}
// 检查docreader服务配置
if h.docReaderClient == nil {
return nil, fmt.Errorf("DocReader service not configured")
}
// 构造请求
request := &proto.ReadFromFileRequest{
FileContent: imageContent,
FileName: filename,
FileType: fileExt,
ReadConfig: &proto.ReadConfig{
ChunkSize: int32(chunkSize),
ChunkOverlap: int32(chunkOverlap),
Separators: separators,
EnableMultimodal: true, // 启用多模态处理
VlmConfig: &proto.VLMConfig{
ModelName: req.VLMModel,
BaseUrl: req.VLMBaseURL,
ApiKey: req.VLMAPIKey,
InterfaceType: req.VLMInterfaceType,
},
},
RequestId: ctx.Value(types.RequestIDContextKey).(string),
}
// 设置对象存储配置(通用)
switch strings.ToLower(req.StorageType) {
case "cos":
request.ReadConfig.StorageConfig = &proto.StorageConfig{
Provider: proto.StorageProvider_COS,
Region: req.COSRegion,
BucketName: req.COSBucketName,
AccessKeyId: req.COSSecretID,
SecretAccessKey: req.COSSecretKey,
AppId: req.COSAppID,
PathPrefix: req.COSPathPrefix,
}
case "minio":
request.ReadConfig.StorageConfig = &proto.StorageConfig{
Provider: proto.StorageProvider_MINIO,
BucketName: req.MinioBucketName,
PathPrefix: req.MinioPathPrefix,
AccessKeyId: os.Getenv("MINIO_ACCESS_KEY_ID"),
SecretAccessKey: os.Getenv("MINIO_SECRET_ACCESS_KEY"),
}
}
// 调用docreader服务
response, err := h.docReaderClient.ReadFromFile(ctx, request)
if err != nil {
return nil, fmt.Errorf("调用DocReader服务失败: %v", err)
}
if response.Error != "" {
return nil, fmt.Errorf("DocReader服务返回错误: %s", response.Error)
}
// 处理响应提取Caption和OCR信息
result := make(map[string]string)
var allCaptions, allOCRTexts []string
for _, chunk := range response.Chunks {
if len(chunk.Images) > 0 {
for _, image := range chunk.Images {
if image.Caption != "" {
allCaptions = append(allCaptions, image.Caption)
}
if image.OcrText != "" {
allOCRTexts = append(allOCRTexts, image.OcrText)
}
}
}
}
// 合并所有Caption和OCR结果
result["caption"] = strings.Join(allCaptions, "; ")
result["ocr"] = strings.Join(allOCRTexts, "; ")
return result, nil
}