1839 lines
53 KiB
Go
1839 lines
53 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"strconv"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/google/uuid"
|
||
"github.com/ollama/ollama/api"
|
||
"knowlege-lsxd/internal/config"
|
||
"knowlege-lsxd/internal/errors"
|
||
"knowlege-lsxd/internal/logger"
|
||
"knowlege-lsxd/internal/models/embedding"
|
||
"knowlege-lsxd/internal/models/utils/ollama"
|
||
"knowlege-lsxd/internal/types"
|
||
"knowlege-lsxd/internal/types/interfaces"
|
||
"knowlege-lsxd/services/docreader/src/client"
|
||
"knowlege-lsxd/services/docreader/src/proto"
|
||
)
|
||
|
||
// DownloadTask 下载任务信息
|
||
type DownloadTask struct {
|
||
ID string `json:"id"`
|
||
ModelName string `json:"modelName"`
|
||
Status string `json:"status"` // pending, downloading, completed, failed
|
||
Progress float64 `json:"progress"`
|
||
Message string `json:"message"`
|
||
StartTime time.Time `json:"startTime"`
|
||
EndTime *time.Time `json:"endTime,omitempty"`
|
||
}
|
||
|
||
// 全局下载任务管理器
|
||
var (
|
||
downloadTasks = make(map[string]*DownloadTask)
|
||
tasksMutex sync.RWMutex
|
||
)
|
||
|
||
// InitializationHandler 初始化处理器
|
||
type InitializationHandler struct {
|
||
config *config.Config
|
||
tenantService interfaces.TenantService
|
||
modelService interfaces.ModelService
|
||
kbService interfaces.KnowledgeBaseService
|
||
kbRepository interfaces.KnowledgeBaseRepository
|
||
knowledgeService interfaces.KnowledgeService
|
||
ollamaService *ollama.OllamaService
|
||
docReaderClient *client.Client
|
||
}
|
||
|
||
// NewInitializationHandler 创建初始化处理器
|
||
func NewInitializationHandler(
|
||
config *config.Config,
|
||
tenantService interfaces.TenantService,
|
||
modelService interfaces.ModelService,
|
||
kbService interfaces.KnowledgeBaseService,
|
||
kbRepository interfaces.KnowledgeBaseRepository,
|
||
knowledgeService interfaces.KnowledgeService,
|
||
ollamaService *ollama.OllamaService,
|
||
docReaderClient *client.Client,
|
||
) *InitializationHandler {
|
||
return &InitializationHandler{
|
||
config: config,
|
||
tenantService: tenantService,
|
||
modelService: modelService,
|
||
kbService: kbService,
|
||
kbRepository: kbRepository,
|
||
knowledgeService: knowledgeService,
|
||
ollamaService: ollamaService,
|
||
docReaderClient: docReaderClient,
|
||
}
|
||
}
|
||
|
||
// InitializationRequest 初始化请求结构
|
||
type InitializationRequest struct {
|
||
// 前端传入的存储类型:cos 或 minio
|
||
StorageType string `json:"storageType"`
|
||
LLM struct {
|
||
Source string `json:"source" binding:"required"`
|
||
ModelName string `json:"modelName" binding:"required"`
|
||
BaseURL string `json:"baseUrl"`
|
||
APIKey string `json:"apiKey"`
|
||
} `json:"llm" binding:"required"`
|
||
|
||
Embedding struct {
|
||
Source string `json:"source" binding:"required"`
|
||
ModelName string `json:"modelName" binding:"required"`
|
||
BaseURL string `json:"baseUrl"`
|
||
APIKey string `json:"apiKey"`
|
||
Dimension int `json:"dimension"` // 添加embedding维度字段
|
||
} `json:"embedding" binding:"required"`
|
||
|
||
Rerank struct {
|
||
Enabled bool `json:"enabled"`
|
||
ModelName string `json:"modelName"`
|
||
BaseURL string `json:"baseUrl"`
|
||
APIKey string `json:"apiKey"`
|
||
} `json:"rerank"`
|
||
|
||
Multimodal struct {
|
||
Enabled bool `json:"enabled"`
|
||
VLM *struct {
|
||
ModelName string `json:"modelName"`
|
||
BaseURL string `json:"baseUrl"`
|
||
APIKey string `json:"apiKey"`
|
||
InterfaceType string `json:"interfaceType"` // "ollama" or "openai"
|
||
} `json:"vlm,omitempty"`
|
||
COS *struct {
|
||
SecretID string `json:"secretId"`
|
||
SecretKey string `json:"secretKey"`
|
||
Region string `json:"region"`
|
||
BucketName string `json:"bucketName"`
|
||
AppID string `json:"appId"`
|
||
PathPrefix string `json:"pathPrefix"`
|
||
} `json:"cos,omitempty"`
|
||
Minio *struct {
|
||
BucketName string `json:"bucketName"`
|
||
PathPrefix string `json:"pathPrefix"`
|
||
} `json:"minio,omitempty"`
|
||
} `json:"multimodal"`
|
||
|
||
DocumentSplitting struct {
|
||
ChunkSize int `json:"chunkSize" binding:"required,min=100,max=10000"`
|
||
ChunkOverlap int `json:"chunkOverlap" binding:"required,min=0"`
|
||
Separators []string `json:"separators" binding:"required,min=1"`
|
||
} `json:"documentSplitting" binding:"required"`
|
||
}
|
||
|
||
// CheckStatus 检查系统初始化状态
|
||
func (h *InitializationHandler) CheckStatus(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
logger.Info(ctx, "Checking system initialization status")
|
||
|
||
// 检查是否存在租户
|
||
tenant, err := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"initialized": false,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 如果没有租户,说明系统未初始化
|
||
if tenant == nil {
|
||
logger.Info(ctx, "No tenants found, system not initialized")
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"initialized": false,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
ctx = context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
|
||
|
||
// 检查是否存在模型
|
||
models, err := h.modelService.ListModels(ctx)
|
||
if err != nil || len(models) == 0 {
|
||
logger.Info(ctx, "No models found, system not initialized")
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"initialized": false,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
logger.Info(ctx, "System is already initialized")
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"initialized": true,
|
||
},
|
||
})
|
||
}
|
||
|
||
// Initialize 执行系统初始化
|
||
func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Starting system initialization")
|
||
|
||
var req InitializationRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse initialization request", err)
|
||
c.Error(errors.NewBadRequestError(err.Error()))
|
||
return
|
||
}
|
||
|
||
// 验证多模态配置
|
||
if req.Multimodal.Enabled {
|
||
storageType := strings.ToLower(req.StorageType)
|
||
if req.Multimodal.VLM == nil {
|
||
logger.Error(ctx, "Multimodal enabled but missing VLM configuration")
|
||
c.Error(errors.NewBadRequestError("启用多模态时需要配置VLM信息"))
|
||
return
|
||
}
|
||
if req.Multimodal.VLM.InterfaceType == "ollama" {
|
||
req.Multimodal.VLM.BaseURL = os.Getenv("OLLAMA_BASE_URL") + "/v1"
|
||
}
|
||
if req.Multimodal.VLM.ModelName == "" || req.Multimodal.VLM.BaseURL == "" {
|
||
logger.Error(ctx, "VLM configuration incomplete")
|
||
c.Error(errors.NewBadRequestError("VLM配置不完整"))
|
||
return
|
||
}
|
||
switch storageType {
|
||
case "cos":
|
||
if req.Multimodal.COS == nil || req.Multimodal.COS.SecretID == "" || req.Multimodal.COS.SecretKey == "" ||
|
||
req.Multimodal.COS.Region == "" || req.Multimodal.COS.BucketName == "" ||
|
||
req.Multimodal.COS.AppID == "" {
|
||
logger.Error(ctx, "COS configuration incomplete")
|
||
c.Error(errors.NewBadRequestError("COS配置不完整"))
|
||
return
|
||
}
|
||
case "minio":
|
||
if req.Multimodal.Minio == nil || req.Multimodal.Minio.BucketName == "" ||
|
||
os.Getenv("MINIO_ACCESS_KEY_ID") == "" || os.Getenv("MINIO_SECRET_ACCESS_KEY") == "" {
|
||
logger.Error(ctx, "MinIO configuration incomplete")
|
||
c.Error(errors.NewBadRequestError("MinIO配置不完整"))
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// 验证Rerank配置(如果启用)
|
||
if req.Rerank.Enabled {
|
||
if req.Rerank.ModelName == "" || req.Rerank.BaseURL == "" {
|
||
logger.Error(ctx, "Rerank configuration incomplete")
|
||
c.Error(errors.NewBadRequestError("启用Rerank时需要配置模型名称和Base URL"))
|
||
return
|
||
}
|
||
}
|
||
|
||
// 验证文档分割配置
|
||
if req.DocumentSplitting.ChunkOverlap >= req.DocumentSplitting.ChunkSize {
|
||
logger.Error(ctx, "Chunk overlap must be less than chunk size")
|
||
c.Error(errors.NewBadRequestError("分块重叠大小必须小于分块大小"))
|
||
return
|
||
}
|
||
if len(req.DocumentSplitting.Separators) == 0 {
|
||
logger.Error(ctx, "Document separators cannot be empty")
|
||
c.Error(errors.NewBadRequestError("文档分隔符不能为空"))
|
||
return
|
||
}
|
||
var err error
|
||
// 1. 处理租户 - 检查是否存在,不存在则创建
|
||
tenant, _ := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID)
|
||
if tenant == nil {
|
||
logger.Info(ctx, "Tenant not found, creating tenant")
|
||
// 创建默认租户
|
||
tenant = &types.Tenant{
|
||
ID: types.InitDefaultTenantID,
|
||
Name: "Default Tenant",
|
||
Description: "System Default Tenant",
|
||
RetrieverEngines: types.RetrieverEngines{
|
||
Engines: []types.RetrieverEngineParams{
|
||
{
|
||
RetrieverType: types.KeywordsRetrieverType,
|
||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||
},
|
||
{
|
||
RetrieverType: types.VectorRetrieverType,
|
||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||
},
|
||
},
|
||
},
|
||
}
|
||
logger.Info(ctx, "Creating default tenant")
|
||
tenant, err = h.tenantService.CreateTenant(ctx, tenant)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("创建租户失败: " + err.Error()))
|
||
return
|
||
}
|
||
} else {
|
||
logger.Info(ctx, "Tenant exists, updating if needed")
|
||
// 更新租户信息(如果需要)
|
||
updated := false
|
||
if tenant.Name != "Default Tenant" {
|
||
tenant.Name = "Default Tenant"
|
||
updated = true
|
||
}
|
||
if tenant.Description != "System Default Tenant" {
|
||
tenant.Description = "System Default Tenant"
|
||
updated = true
|
||
}
|
||
|
||
if updated {
|
||
_, err = h.tenantService.UpdateTenant(ctx, tenant)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("更新租户失败: " + err.Error()))
|
||
return
|
||
}
|
||
logger.Info(ctx, "Tenant updated successfully")
|
||
}
|
||
}
|
||
|
||
// 创建带有租户ID的新上下文
|
||
newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
|
||
|
||
// 2. 处理模型 - 检查现有模型并更新或创建
|
||
existingModels, err := h.modelService.ListModels(newCtx)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
// 如果获取失败,继续执行创建流程
|
||
existingModels = []*types.Model{}
|
||
}
|
||
|
||
// 构建模型映射,按类型分组
|
||
modelMap := make(map[types.ModelType]*types.Model)
|
||
for _, model := range existingModels {
|
||
modelMap[model.Type] = model
|
||
}
|
||
|
||
// 要处理的模型配置
|
||
modelsToProcess := []struct {
|
||
modelType types.ModelType
|
||
name string
|
||
source types.ModelSource
|
||
description string
|
||
baseURL string
|
||
apiKey string
|
||
dimension int
|
||
}{
|
||
{
|
||
modelType: types.ModelTypeKnowledgeQA,
|
||
name: req.LLM.ModelName,
|
||
source: types.ModelSource(req.LLM.Source),
|
||
description: "LLM Model for Knowledge QA",
|
||
baseURL: req.LLM.BaseURL,
|
||
apiKey: req.LLM.APIKey,
|
||
},
|
||
{
|
||
modelType: types.ModelTypeEmbedding,
|
||
name: req.Embedding.ModelName,
|
||
source: types.ModelSource(req.Embedding.Source),
|
||
description: "Embedding Model",
|
||
baseURL: req.Embedding.BaseURL,
|
||
apiKey: req.Embedding.APIKey,
|
||
dimension: req.Embedding.Dimension,
|
||
},
|
||
}
|
||
|
||
// 如果启用Rerank,添加Rerank模型
|
||
if req.Rerank.Enabled {
|
||
modelsToProcess = append(modelsToProcess, struct {
|
||
modelType types.ModelType
|
||
name string
|
||
source types.ModelSource
|
||
description string
|
||
baseURL string
|
||
apiKey string
|
||
dimension int
|
||
}{
|
||
modelType: types.ModelTypeRerank,
|
||
name: req.Rerank.ModelName,
|
||
source: types.ModelSourceRemote,
|
||
description: "Rerank Model",
|
||
baseURL: req.Rerank.BaseURL,
|
||
apiKey: req.Rerank.APIKey,
|
||
})
|
||
}
|
||
|
||
// 如果启用多模态,添加VLM模型
|
||
if req.Multimodal.Enabled && req.Multimodal.VLM != nil {
|
||
modelsToProcess = append(modelsToProcess, struct {
|
||
modelType types.ModelType
|
||
name string
|
||
source types.ModelSource
|
||
description string
|
||
baseURL string
|
||
apiKey string
|
||
dimension int
|
||
}{
|
||
modelType: types.ModelTypeVLLM,
|
||
name: req.Multimodal.VLM.ModelName,
|
||
source: types.ModelSourceRemote,
|
||
description: "Vision Language Model",
|
||
baseURL: req.Multimodal.VLM.BaseURL,
|
||
apiKey: req.Multimodal.VLM.APIKey,
|
||
})
|
||
}
|
||
|
||
// 处理每个模型
|
||
var processedModels []*types.Model
|
||
for _, modelConfig := range modelsToProcess {
|
||
existingModel, exists := modelMap[modelConfig.modelType]
|
||
|
||
if exists {
|
||
// 更新现有模型
|
||
logger.Infof(ctx, "Updating existing model: %s (%s)",
|
||
modelConfig.name, modelConfig.modelType,
|
||
)
|
||
existingModel.Name = modelConfig.name
|
||
existingModel.Source = modelConfig.source
|
||
existingModel.Description = modelConfig.description
|
||
existingModel.Parameters = types.ModelParameters{
|
||
BaseURL: modelConfig.baseURL,
|
||
APIKey: modelConfig.apiKey,
|
||
EmbeddingParameters: types.EmbeddingParameters{
|
||
Dimension: modelConfig.dimension,
|
||
},
|
||
}
|
||
existingModel.IsDefault = true
|
||
existingModel.Status = types.ModelStatusActive
|
||
|
||
err := h.modelService.UpdateModel(newCtx, existingModel)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||
"model_name": modelConfig.name,
|
||
"model_type": modelConfig.modelType,
|
||
})
|
||
c.Error(errors.NewInternalServerError("更新模型失败: " + err.Error()))
|
||
return
|
||
}
|
||
processedModels = append(processedModels, existingModel)
|
||
} else {
|
||
// 创建新模型
|
||
logger.Infof(ctx,
|
||
"Creating new model: %s (%s)",
|
||
modelConfig.name, modelConfig.modelType,
|
||
)
|
||
newModel := &types.Model{
|
||
TenantID: types.InitDefaultTenantID,
|
||
Name: modelConfig.name,
|
||
Type: modelConfig.modelType,
|
||
Source: modelConfig.source,
|
||
Description: modelConfig.description,
|
||
Parameters: types.ModelParameters{
|
||
BaseURL: modelConfig.baseURL,
|
||
APIKey: modelConfig.apiKey,
|
||
EmbeddingParameters: types.EmbeddingParameters{
|
||
Dimension: modelConfig.dimension,
|
||
},
|
||
},
|
||
IsDefault: true,
|
||
Status: types.ModelStatusActive,
|
||
}
|
||
|
||
err := h.modelService.CreateModel(newCtx, newModel)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||
"model_name": modelConfig.name,
|
||
"model_type": modelConfig.modelType,
|
||
})
|
||
c.Error(errors.NewInternalServerError("创建模型失败: " + err.Error()))
|
||
return
|
||
}
|
||
processedModels = append(processedModels, newModel)
|
||
}
|
||
}
|
||
|
||
// 删除不需要的VLM模型(如果多模态被禁用)
|
||
if !req.Multimodal.Enabled {
|
||
if existingVLM, exists := modelMap[types.ModelTypeVLLM]; exists {
|
||
logger.Info(ctx, "Deleting VLM model as multimodal is disabled")
|
||
err := h.modelService.DeleteModel(newCtx, existingVLM.ID)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||
"model_id": existingVLM.ID,
|
||
})
|
||
// 记录错误但不阻止流程
|
||
logger.Warn(ctx, "Failed to delete VLM model, but continuing")
|
||
}
|
||
}
|
||
}
|
||
|
||
// 删除不需要的Rerank模型(如果Rerank被禁用)
|
||
if !req.Rerank.Enabled {
|
||
if existingRerank, exists := modelMap[types.ModelTypeRerank]; exists {
|
||
logger.Info(ctx, "Deleting Rerank model as rerank is disabled")
|
||
err := h.modelService.DeleteModel(newCtx, existingRerank.ID)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||
"model_id": existingRerank.ID,
|
||
})
|
||
// 记录错误但不阻止流程
|
||
logger.Warn(ctx, "Failed to delete Rerank model, but continuing")
|
||
}
|
||
}
|
||
}
|
||
|
||
// 3. 处理知识库 - 检查是否存在,不存在则创建,存在则更新
|
||
kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
|
||
|
||
// 找到embedding模型ID和LLM模型ID
|
||
var embeddingModelID, llmModelID, rerankModelID, vlmModelID string
|
||
for _, model := range processedModels {
|
||
if model.Type == types.ModelTypeEmbedding {
|
||
embeddingModelID = model.ID
|
||
}
|
||
if model.Type == types.ModelTypeKnowledgeQA {
|
||
llmModelID = model.ID
|
||
}
|
||
if model.Type == types.ModelTypeRerank && req.Rerank.Enabled {
|
||
rerankModelID = model.ID
|
||
}
|
||
if model.Type == types.ModelTypeVLLM {
|
||
vlmModelID = model.ID
|
||
}
|
||
}
|
||
|
||
if kb == nil {
|
||
// 创建新知识库
|
||
logger.Info(ctx, "Creating default knowledge base")
|
||
kb = &types.KnowledgeBase{
|
||
ID: types.InitDefaultKnowledgeBaseID,
|
||
Name: "Default Knowledge Base",
|
||
Description: "System Default Knowledge Base",
|
||
TenantID: types.InitDefaultTenantID,
|
||
ChunkingConfig: types.ChunkingConfig{
|
||
ChunkSize: req.DocumentSplitting.ChunkSize,
|
||
ChunkOverlap: req.DocumentSplitting.ChunkOverlap,
|
||
Separators: req.DocumentSplitting.Separators,
|
||
EnableMultimodal: req.Multimodal.Enabled,
|
||
},
|
||
EmbeddingModelID: embeddingModelID,
|
||
SummaryModelID: llmModelID,
|
||
RerankModelID: rerankModelID,
|
||
VLMModelID: vlmModelID,
|
||
VLMConfig: types.VLMConfig{
|
||
ModelName: req.Multimodal.VLM.ModelName,
|
||
BaseURL: req.Multimodal.VLM.BaseURL,
|
||
APIKey: req.Multimodal.VLM.APIKey,
|
||
InterfaceType: req.Multimodal.VLM.InterfaceType,
|
||
},
|
||
}
|
||
switch req.StorageType {
|
||
case "cos":
|
||
if req.Multimodal.COS != nil {
|
||
kb.StorageConfig = types.StorageConfig{
|
||
Provider: req.StorageType,
|
||
BucketName: req.Multimodal.COS.BucketName,
|
||
AppID: req.Multimodal.COS.AppID,
|
||
PathPrefix: req.Multimodal.COS.PathPrefix,
|
||
SecretID: req.Multimodal.COS.SecretID,
|
||
SecretKey: req.Multimodal.COS.SecretKey,
|
||
Region: req.Multimodal.COS.Region,
|
||
}
|
||
}
|
||
case "minio":
|
||
if req.Multimodal.Minio != nil {
|
||
kb.StorageConfig = types.StorageConfig{
|
||
Provider: req.StorageType,
|
||
BucketName: req.Multimodal.Minio.BucketName,
|
||
PathPrefix: req.Multimodal.Minio.PathPrefix,
|
||
SecretID: os.Getenv("MINIO_ACCESS_KEY_ID"),
|
||
SecretKey: os.Getenv("MINIO_SECRET_ACCESS_KEY"),
|
||
}
|
||
}
|
||
}
|
||
|
||
_, err = h.kbService.CreateKnowledgeBase(newCtx, kb)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("创建知识库失败: " + err.Error()))
|
||
return
|
||
}
|
||
} else {
|
||
// 更新现有知识库
|
||
logger.Info(ctx, "Updating existing knowledge base")
|
||
|
||
// 检查是否有文件,如果有文件则不允许修改Embedding模型
|
||
knowledgeList, err := h.knowledgeService.ListKnowledgeByKnowledgeBaseID(
|
||
newCtx, types.InitDefaultKnowledgeBaseID,
|
||
)
|
||
hasFiles := err == nil && len(knowledgeList) > 0
|
||
|
||
// 先更新模型ID(直接在对象上)
|
||
kb.SummaryModelID = llmModelID
|
||
if req.Rerank.Enabled {
|
||
kb.RerankModelID = rerankModelID
|
||
} else {
|
||
kb.RerankModelID = "" // 清空Rerank模型ID
|
||
}
|
||
if req.Multimodal.Enabled {
|
||
kb.VLMModelID = vlmModelID
|
||
// 更新VLM配置
|
||
kb.VLMConfig = types.VLMConfig{
|
||
ModelName: req.Multimodal.VLM.ModelName,
|
||
BaseURL: req.Multimodal.VLM.BaseURL,
|
||
APIKey: req.Multimodal.VLM.APIKey,
|
||
InterfaceType: req.Multimodal.VLM.InterfaceType,
|
||
}
|
||
switch req.StorageType {
|
||
case "cos":
|
||
if req.Multimodal.COS != nil {
|
||
kb.StorageConfig = types.StorageConfig{
|
||
Provider: req.StorageType,
|
||
SecretID: req.Multimodal.COS.SecretID,
|
||
SecretKey: req.Multimodal.COS.SecretKey,
|
||
Region: req.Multimodal.COS.Region,
|
||
BucketName: req.Multimodal.COS.BucketName,
|
||
AppID: req.Multimodal.COS.AppID,
|
||
PathPrefix: req.Multimodal.COS.PathPrefix,
|
||
}
|
||
}
|
||
case "minio":
|
||
if req.Multimodal.Minio != nil {
|
||
kb.StorageConfig = types.StorageConfig{
|
||
Provider: req.StorageType,
|
||
BucketName: req.Multimodal.Minio.BucketName,
|
||
PathPrefix: req.Multimodal.Minio.PathPrefix,
|
||
SecretID: os.Getenv("MINIO_ACCESS_KEY_ID"),
|
||
SecretKey: os.Getenv("MINIO_SECRET_ACCESS_KEY"),
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
kb.VLMModelID = "" // 清空VLM模型ID
|
||
// 清空VLM配置
|
||
kb.VLMConfig = types.VLMConfig{}
|
||
kb.StorageConfig = types.StorageConfig{}
|
||
}
|
||
if !hasFiles {
|
||
kb.EmbeddingModelID = embeddingModelID
|
||
}
|
||
kb.ChunkingConfig = types.ChunkingConfig{
|
||
ChunkSize: req.DocumentSplitting.ChunkSize,
|
||
ChunkOverlap: req.DocumentSplitting.ChunkOverlap,
|
||
Separators: req.DocumentSplitting.Separators,
|
||
EnableMultimodal: req.Multimodal.Enabled,
|
||
}
|
||
|
||
// 更新基本信息和配置
|
||
err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("更新知识库配置失败: " + err.Error()))
|
||
return
|
||
}
|
||
|
||
// 如果需要更新模型ID,使用repository直接更新
|
||
if !hasFiles || kb.SummaryModelID != llmModelID {
|
||
// 刷新知识库对象以获取最新信息
|
||
kb, err = h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("获取更新后的知识库失败: " + err.Error()))
|
||
return
|
||
}
|
||
|
||
// 更新模型ID
|
||
kb.SummaryModelID = llmModelID
|
||
if req.Rerank.Enabled {
|
||
kb.RerankModelID = rerankModelID
|
||
} else {
|
||
kb.RerankModelID = "" // 清空Rerank模型ID
|
||
}
|
||
|
||
// 使用repository直接更新模型ID
|
||
err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("更新知识库模型ID失败: " + err.Error()))
|
||
return
|
||
}
|
||
|
||
logger.Info(ctx, "Model IDs updated successfully")
|
||
}
|
||
}
|
||
|
||
logger.Info(ctx, "System initialization completed successfully")
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "系统初始化成功",
|
||
"data": gin.H{
|
||
"tenant": tenant,
|
||
"models": processedModels,
|
||
"knowledge_base": kb,
|
||
},
|
||
})
|
||
}
|
||
|
||
// CheckOllamaStatus 检查Ollama服务状态
|
||
func (h *InitializationHandler) CheckOllamaStatus(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Checking Ollama service status")
|
||
|
||
// Determine Ollama base URL for display
|
||
baseURL := os.Getenv("OLLAMA_BASE_URL")
|
||
if baseURL == "" {
|
||
baseURL = "http://host.docker.internal:11434"
|
||
}
|
||
|
||
// 检查Ollama服务是否可用
|
||
err := h.ollamaService.StartService(ctx)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"available": false,
|
||
"error": err.Error(),
|
||
"baseUrl": baseURL,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
version, err := h.ollamaService.GetVersion(ctx)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
version = "unknown"
|
||
}
|
||
|
||
logger.Info(ctx, "Ollama service is available")
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"available": h.ollamaService.IsAvailable(),
|
||
"version": version,
|
||
"baseUrl": baseURL,
|
||
},
|
||
})
|
||
}
|
||
|
||
// CheckOllamaModels 检查Ollama模型状态
|
||
func (h *InitializationHandler) CheckOllamaModels(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Checking Ollama models status")
|
||
|
||
var req struct {
|
||
Models []string `json:"models" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse models check request", err)
|
||
c.Error(errors.NewBadRequestError(err.Error()))
|
||
return
|
||
}
|
||
|
||
// 检查Ollama服务是否可用
|
||
if !h.ollamaService.IsAvailable() {
|
||
err := h.ollamaService.StartService(ctx)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error()))
|
||
return
|
||
}
|
||
}
|
||
|
||
modelStatus := make(map[string]bool)
|
||
|
||
// 检查每个模型是否存在
|
||
for _, modelName := range req.Models {
|
||
checkModelName := modelName
|
||
if !strings.Contains(modelName, ":") {
|
||
checkModelName = modelName + ":latest"
|
||
}
|
||
available, err := h.ollamaService.IsModelAvailable(ctx, checkModelName)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||
"model_name": modelName,
|
||
})
|
||
modelStatus[modelName] = false
|
||
} else {
|
||
modelStatus[modelName] = available
|
||
}
|
||
|
||
logger.Infof(ctx, "Model %s availability: %v", modelName, modelStatus[modelName])
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"models": modelStatus,
|
||
},
|
||
})
|
||
}
|
||
|
||
// DownloadOllamaModel 异步下载Ollama模型
|
||
func (h *InitializationHandler) DownloadOllamaModel(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Starting async Ollama model download")
|
||
|
||
var req struct {
|
||
ModelName string `json:"modelName" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse model download request", err)
|
||
c.Error(errors.NewBadRequestError(err.Error()))
|
||
return
|
||
}
|
||
|
||
// 检查Ollama服务是否可用
|
||
if !h.ollamaService.IsAvailable() {
|
||
err := h.ollamaService.StartService(ctx)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error()))
|
||
return
|
||
}
|
||
}
|
||
|
||
// 检查模型是否已存在
|
||
available, err := h.ollamaService.IsModelAvailable(ctx, req.ModelName)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||
"model_name": req.ModelName,
|
||
})
|
||
c.Error(errors.NewInternalServerError("检查模型状态失败: " + err.Error()))
|
||
return
|
||
}
|
||
|
||
if available {
|
||
logger.Infof(ctx, "Model %s already exists", req.ModelName)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "模型已存在",
|
||
"data": gin.H{
|
||
"modelName": req.ModelName,
|
||
"status": "completed",
|
||
"progress": 100.0,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 检查是否已有相同模型的下载任务
|
||
tasksMutex.RLock()
|
||
for _, task := range downloadTasks {
|
||
if task.ModelName == req.ModelName && (task.Status == "pending" || task.Status == "downloading") {
|
||
tasksMutex.RUnlock()
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "模型下载任务已存在",
|
||
"data": gin.H{
|
||
"taskId": task.ID,
|
||
"modelName": task.ModelName,
|
||
"status": task.Status,
|
||
"progress": task.Progress,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
}
|
||
tasksMutex.RUnlock()
|
||
|
||
// 创建下载任务
|
||
taskID := uuid.New().String()
|
||
task := &DownloadTask{
|
||
ID: taskID,
|
||
ModelName: req.ModelName,
|
||
Status: "pending",
|
||
Progress: 0.0,
|
||
Message: "准备下载",
|
||
StartTime: time.Now(),
|
||
}
|
||
|
||
tasksMutex.Lock()
|
||
downloadTasks[taskID] = task
|
||
tasksMutex.Unlock()
|
||
|
||
// 启动异步下载
|
||
newCtx, cancel := context.WithTimeout(context.Background(), 12*time.Hour)
|
||
go func() {
|
||
defer cancel()
|
||
h.downloadModelAsync(newCtx, taskID, req.ModelName)
|
||
}()
|
||
|
||
logger.Infof(ctx, "Created download task for model: %s, task ID: %s", req.ModelName, taskID)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "模型下载任务已创建",
|
||
"data": gin.H{
|
||
"taskId": taskID,
|
||
"modelName": req.ModelName,
|
||
"status": "pending",
|
||
"progress": 0.0,
|
||
},
|
||
})
|
||
}
|
||
|
||
// GetDownloadProgress 获取下载进度
|
||
func (h *InitializationHandler) GetDownloadProgress(c *gin.Context) {
|
||
taskID := c.Param("taskId")
|
||
|
||
if taskID == "" {
|
||
c.Error(errors.NewBadRequestError("任务ID不能为空"))
|
||
return
|
||
}
|
||
|
||
tasksMutex.RLock()
|
||
task, exists := downloadTasks[taskID]
|
||
tasksMutex.RUnlock()
|
||
|
||
if !exists {
|
||
c.Error(errors.NewNotFoundError("下载任务不存在"))
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": task,
|
||
})
|
||
}
|
||
|
||
// ListDownloadTasks 列出所有下载任务
|
||
func (h *InitializationHandler) ListDownloadTasks(c *gin.Context) {
|
||
tasksMutex.RLock()
|
||
tasks := make([]*DownloadTask, 0, len(downloadTasks))
|
||
for _, task := range downloadTasks {
|
||
tasks = append(tasks, task)
|
||
}
|
||
tasksMutex.RUnlock()
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": tasks,
|
||
})
|
||
}
|
||
|
||
// ListOllamaModels 列出已安装的 Ollama 模型
|
||
func (h *InitializationHandler) ListOllamaModels(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Listing installed Ollama models")
|
||
|
||
// 确保服务可用
|
||
if !h.ollamaService.IsAvailable() {
|
||
if err := h.ollamaService.StartService(ctx); err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error()))
|
||
return
|
||
}
|
||
}
|
||
|
||
models, err := h.ollamaService.ListModels(ctx)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("获取模型列表失败: " + err.Error()))
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"models": models,
|
||
},
|
||
})
|
||
}
|
||
|
||
// downloadModelAsync 异步下载模型
|
||
func (h *InitializationHandler) downloadModelAsync(ctx context.Context,
|
||
taskID, modelName string,
|
||
) {
|
||
logger.Infof(ctx, "Starting async download for model: %s, task: %s", modelName, taskID)
|
||
|
||
// 更新任务状态为下载中
|
||
h.updateTaskStatus(taskID, "downloading", 0.0, "开始下载模型")
|
||
|
||
// 执行下载,带进度回调
|
||
err := h.pullModelWithProgress(ctx, modelName, func(progress float64, message string) {
|
||
h.updateTaskStatus(taskID, "downloading", progress, message)
|
||
})
|
||
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||
"model_name": modelName,
|
||
"task_id": taskID,
|
||
})
|
||
h.updateTaskStatus(taskID, "failed", 0.0, fmt.Sprintf("下载失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 下载成功
|
||
logger.Infof(ctx, "Model %s downloaded successfully, task: %s", modelName, taskID)
|
||
h.updateTaskStatus(taskID, "completed", 100.0, "下载完成")
|
||
}
|
||
|
||
// pullModelWithProgress 下载模型并提供进度回调
|
||
func (h *InitializationHandler) pullModelWithProgress(ctx context.Context,
|
||
modelName string,
|
||
progressCallback func(float64, string),
|
||
) error {
|
||
// 检查服务是否可用
|
||
if err := h.ollamaService.StartService(ctx); err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
return err
|
||
}
|
||
|
||
// 检查模型是否已存在
|
||
available, err := h.ollamaService.IsModelAvailable(ctx, modelName)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||
"model_name": modelName,
|
||
})
|
||
return err
|
||
}
|
||
if available {
|
||
progressCallback(100.0, "模型已存在")
|
||
return nil
|
||
}
|
||
|
||
logger.GetLogger(ctx).Infof("Pulling model %s...", modelName)
|
||
|
||
// 创建下载请求
|
||
pullReq := &api.PullRequest{
|
||
Name: modelName,
|
||
}
|
||
|
||
// 使用Ollama客户端的Pull方法,带进度回调
|
||
err = h.ollamaService.GetClient().Pull(ctx, pullReq, func(progress api.ProgressResponse) error {
|
||
var progressPercent float64 = 0.0
|
||
var message string = "下载中"
|
||
|
||
if progress.Total > 0 && progress.Completed > 0 {
|
||
progressPercent = float64(progress.Completed) / float64(progress.Total) * 100
|
||
message = fmt.Sprintf("下载中: %.1f%% (%s)", progressPercent, progress.Status)
|
||
} else if progress.Status != "" {
|
||
message = progress.Status
|
||
}
|
||
|
||
// 调用进度回调
|
||
progressCallback(progressPercent, message)
|
||
|
||
logger.Infof(ctx,
|
||
"Download progress for %s: %.2f%% - %s",
|
||
modelName, progressPercent, message,
|
||
)
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("failed to pull model: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// updateTaskStatus 更新任务状态
|
||
func (h *InitializationHandler) updateTaskStatus(
|
||
taskID, status string, progress float64, message string,
|
||
) {
|
||
tasksMutex.Lock()
|
||
defer tasksMutex.Unlock()
|
||
|
||
if task, exists := downloadTasks[taskID]; exists {
|
||
task.Status = status
|
||
task.Progress = progress
|
||
task.Message = message
|
||
|
||
if status == "completed" || status == "failed" {
|
||
now := time.Now()
|
||
task.EndTime = &now
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetCurrentConfig 获取当前系统配置信息
|
||
func (h *InitializationHandler) GetCurrentConfig(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Getting current system configuration")
|
||
|
||
// 设置租户上下文
|
||
newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
|
||
|
||
// 获取模型信息
|
||
models, err := h.modelService.ListModels(newCtx)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("获取模型列表失败: " + err.Error()))
|
||
return
|
||
}
|
||
|
||
// 获取知识库信息
|
||
kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, nil)
|
||
c.Error(errors.NewInternalServerError("获取知识库信息失败: " + err.Error()))
|
||
return
|
||
}
|
||
|
||
// 检查知识库是否有文件
|
||
knowledgeList, err := h.knowledgeService.ListPagedKnowledgeByKnowledgeBaseID(newCtx,
|
||
types.InitDefaultKnowledgeBaseID, &types.Pagination{
|
||
Page: 1,
|
||
PageSize: 1,
|
||
})
|
||
hasFiles := false
|
||
if err == nil && knowledgeList != nil && knowledgeList.Total > 0 {
|
||
hasFiles = true
|
||
}
|
||
|
||
// 构建配置响应
|
||
config := buildConfigResponse(models, kb, hasFiles)
|
||
|
||
logger.Info(ctx, "Current system configuration retrieved successfully")
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": config,
|
||
})
|
||
}
|
||
|
||
// buildConfigResponse 构建配置响应数据
|
||
func buildConfigResponse(models []*types.Model,
|
||
kb *types.KnowledgeBase, hasFiles bool,
|
||
) map[string]interface{} {
|
||
config := map[string]interface{}{
|
||
"hasFiles": hasFiles,
|
||
}
|
||
|
||
// 按类型分组模型
|
||
for _, model := range models {
|
||
switch model.Type {
|
||
case types.ModelTypeKnowledgeQA:
|
||
config["llm"] = map[string]interface{}{
|
||
"source": string(model.Source),
|
||
"modelName": model.Name,
|
||
"baseUrl": model.Parameters.BaseURL,
|
||
"apiKey": model.Parameters.APIKey,
|
||
}
|
||
case types.ModelTypeEmbedding:
|
||
config["embedding"] = map[string]interface{}{
|
||
"source": string(model.Source),
|
||
"modelName": model.Name,
|
||
"baseUrl": model.Parameters.BaseURL,
|
||
"apiKey": model.Parameters.APIKey,
|
||
"dimension": model.Parameters.EmbeddingParameters.Dimension,
|
||
}
|
||
case types.ModelTypeRerank:
|
||
config["rerank"] = map[string]interface{}{
|
||
"enabled": true,
|
||
"modelName": model.Name,
|
||
"baseUrl": model.Parameters.BaseURL,
|
||
"apiKey": model.Parameters.APIKey,
|
||
}
|
||
case types.ModelTypeVLLM:
|
||
if config["multimodal"] == nil {
|
||
config["multimodal"] = map[string]interface{}{
|
||
"enabled": true,
|
||
}
|
||
}
|
||
multimodal := config["multimodal"].(map[string]interface{})
|
||
multimodal["vlm"] = map[string]interface{}{
|
||
"modelName": model.Name,
|
||
"baseUrl": model.Parameters.BaseURL,
|
||
"apiKey": model.Parameters.APIKey,
|
||
"interfaceType": kb.VLMConfig.InterfaceType,
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果没有VLM模型,设置multimodal为disabled
|
||
if config["multimodal"] == nil {
|
||
config["multimodal"] = map[string]interface{}{
|
||
"enabled": false,
|
||
}
|
||
}
|
||
|
||
// 如果没有Rerank模型,设置rerank为disabled
|
||
if config["rerank"] == nil {
|
||
config["rerank"] = map[string]interface{}{
|
||
"enabled": false,
|
||
"modelName": "",
|
||
"baseUrl": "",
|
||
"apiKey": "",
|
||
}
|
||
}
|
||
|
||
// 添加知识库的文档分割配置
|
||
if kb != nil {
|
||
config["documentSplitting"] = map[string]interface{}{
|
||
"chunkSize": kb.ChunkingConfig.ChunkSize,
|
||
"chunkOverlap": kb.ChunkingConfig.ChunkOverlap,
|
||
"separators": kb.ChunkingConfig.Separators,
|
||
}
|
||
|
||
// 添加多模态的COS配置信息
|
||
if kb.StorageConfig.SecretID != "" {
|
||
if config["multimodal"] == nil {
|
||
config["multimodal"] = map[string]interface{}{
|
||
"enabled": true,
|
||
}
|
||
}
|
||
multimodal := config["multimodal"].(map[string]interface{})
|
||
multimodal["storageType"] = kb.StorageConfig.Provider
|
||
switch kb.StorageConfig.Provider {
|
||
case "cos":
|
||
multimodal["cos"] = map[string]interface{}{
|
||
"secretId": kb.StorageConfig.SecretID,
|
||
"secretKey": kb.StorageConfig.SecretKey,
|
||
"region": kb.StorageConfig.Region,
|
||
"bucketName": kb.StorageConfig.BucketName,
|
||
"appId": kb.StorageConfig.AppID,
|
||
"pathPrefix": kb.StorageConfig.PathPrefix,
|
||
}
|
||
case "minio":
|
||
multimodal["minio"] = map[string]interface{}{
|
||
"bucketName": kb.StorageConfig.BucketName,
|
||
"pathPrefix": kb.StorageConfig.PathPrefix,
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return config
|
||
}
|
||
|
||
// RemoteModelCheckRequest 远程模型检查请求结构
|
||
type RemoteModelCheckRequest struct {
|
||
ModelName string `json:"modelName" binding:"required"`
|
||
BaseURL string `json:"baseUrl" binding:"required"`
|
||
APIKey string `json:"apiKey"`
|
||
}
|
||
|
||
// CheckRemoteModel 检查远程API模型连接
|
||
func (h *InitializationHandler) CheckRemoteModel(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Checking remote model connection")
|
||
|
||
var req RemoteModelCheckRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse remote model check request", err)
|
||
c.Error(errors.NewBadRequestError(err.Error()))
|
||
return
|
||
}
|
||
|
||
// 验证请求参数
|
||
if req.ModelName == "" || req.BaseURL == "" {
|
||
logger.Error(ctx, "Model name and base URL are required")
|
||
c.Error(errors.NewBadRequestError("模型名称和Base URL不能为空"))
|
||
return
|
||
}
|
||
|
||
// 创建模型配置进行测试
|
||
modelConfig := &types.Model{
|
||
Name: req.ModelName,
|
||
Source: "remote",
|
||
Parameters: types.ModelParameters{
|
||
BaseURL: req.BaseURL,
|
||
APIKey: req.APIKey,
|
||
},
|
||
Type: "llm", // 默认类型,实际检查时不区分具体类型
|
||
}
|
||
|
||
// 检查远程模型连接
|
||
available, message := h.checkRemoteModelConnection(ctx, modelConfig)
|
||
|
||
logger.Info(ctx,
|
||
fmt.Sprintf(
|
||
"Remote model check completed: modelName=%s, baseUrl=%s, available=%v, message=%s",
|
||
req.ModelName, req.BaseURL, available, message,
|
||
),
|
||
)
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"available": available,
|
||
"message": message,
|
||
},
|
||
})
|
||
}
|
||
|
||
// TestEmbeddingModel 测试 Embedding 接口(本地或远程)是否可用
|
||
func (h *InitializationHandler) TestEmbeddingModel(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Testing embedding model connectivity and functionality")
|
||
|
||
var req struct {
|
||
Source string `json:"source" binding:"required"`
|
||
ModelName string `json:"modelName" binding:"required"`
|
||
BaseURL string `json:"baseUrl"`
|
||
APIKey string `json:"apiKey"`
|
||
Dimension int `json:"dimension"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse embedding test request", err)
|
||
c.Error(errors.NewBadRequestError(err.Error()))
|
||
return
|
||
}
|
||
|
||
// 构造 embedder 配置
|
||
cfg := embedding.Config{
|
||
Source: types.ModelSource(strings.ToLower(req.Source)),
|
||
BaseURL: req.BaseURL,
|
||
ModelName: req.ModelName,
|
||
APIKey: req.APIKey,
|
||
TruncatePromptTokens: 256,
|
||
Dimensions: req.Dimension,
|
||
ModelID: "",
|
||
}
|
||
|
||
emb, err := embedding.NewEmbedder(cfg)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName})
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{`available`: false, `message`: fmt.Sprintf("创建Embedder失败: %v", err), `dimension`: 0},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 执行一次最小化 embedding 调用
|
||
sample := "hello"
|
||
vec, err := emb.Embed(ctx, sample)
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": req.ModelName})
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{`available`: false, `message`: fmt.Sprintf("调用Embedding失败: %v", err), `dimension`: 0},
|
||
})
|
||
return
|
||
}
|
||
|
||
logger.Infof(ctx, "Embedding test succeeded, dim=%d", len(vec))
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{`available`: true, `message`: fmt.Sprintf("测试成功,向量维度=%d", len(vec)), `dimension`: len(vec)},
|
||
})
|
||
}
|
||
|
||
// checkRemoteModelConnection 检查远程模型连接的内部方法
|
||
func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
|
||
model *types.Model,
|
||
) (bool, string) {
|
||
// 这里需要根据实际情况实现远程API的连接检查
|
||
// 可以发送一个简单的请求来验证连接和认证
|
||
|
||
client := &http.Client{
|
||
Timeout: 10 * time.Second,
|
||
}
|
||
|
||
// 根据不同的API类型构造测试请求
|
||
testEndpoint := ""
|
||
if model.Parameters.BaseURL != "" {
|
||
testEndpoint = model.Parameters.BaseURL + "/models"
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "GET", testEndpoint, nil)
|
||
if err != nil {
|
||
return false, fmt.Sprintf("创建请求失败: %v", err)
|
||
}
|
||
|
||
// 添加认证头
|
||
if model.Parameters.APIKey != "" {
|
||
req.Header.Set("Authorization", "Bearer "+model.Parameters.APIKey)
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return false, fmt.Sprintf("连接失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 检查响应状态
|
||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||
// 连接成功,现在检查模型是否存在
|
||
return true, "连接正常,请自行确保模型存在"
|
||
} else if resp.StatusCode == 401 {
|
||
return false, "认证失败,请检查API Key"
|
||
} else if resp.StatusCode == 403 {
|
||
return false, "权限不足,请检查API Key权限"
|
||
} else if resp.StatusCode == 404 {
|
||
return false, "API端点不存在,请检查Base URL"
|
||
} else {
|
||
return false, fmt.Sprintf("API返回错误状态: %d", resp.StatusCode)
|
||
}
|
||
}
|
||
|
||
// checkModelExistence 检查指定模型是否在模型列表中存在
|
||
func (h *InitializationHandler) checkModelExistence(ctx context.Context,
|
||
resp *http.Response, modelName string) (bool, string) {
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return true, "连接正常,但无法验证模型列表"
|
||
}
|
||
|
||
var modelsResp struct {
|
||
Data []struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
} `json:"data"`
|
||
Object string `json:"object"`
|
||
}
|
||
|
||
// 尝试解析模型列表响应
|
||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
||
// 如果无法解析,可能是非标准API,只要连接成功就认为可用
|
||
return true, "连接正常"
|
||
}
|
||
|
||
// 检查模型是否在列表中
|
||
for _, model := range modelsResp.Data {
|
||
if model.ID == modelName {
|
||
return true, "连接正常,模型存在"
|
||
}
|
||
}
|
||
|
||
// 模型不在列表中,返回可用的模型建议
|
||
if len(modelsResp.Data) > 0 {
|
||
availableModels := make([]string, 0, min(3, len(modelsResp.Data)))
|
||
for i, model := range modelsResp.Data {
|
||
if i >= 3 {
|
||
break
|
||
}
|
||
availableModels = append(availableModels, model.ID)
|
||
}
|
||
return false, fmt.Sprintf("模型 '%s' 不存在,可用模型: %s", modelName, strings.Join(availableModels, ", "))
|
||
}
|
||
|
||
return false, fmt.Sprintf("模型 '%s' 不存在", modelName)
|
||
}
|
||
|
||
// min returns the minimum of two integers
|
||
func min(a, b int) int {
|
||
if a < b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
// checkRerankModelConnection 检查Rerank模型连接和功能的内部方法
|
||
func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context,
|
||
modelName, baseURL, apiKey string) (bool, string) {
|
||
client := &http.Client{
|
||
Timeout: 15 * time.Second,
|
||
}
|
||
|
||
// 构造重排API端点
|
||
rerankEndpoint := baseURL + "/rerank"
|
||
|
||
// Mock测试数据
|
||
testQuery := "什么是人工智能?"
|
||
testPassages := []string{
|
||
"机器学习是人工智能的一个子领域,专注于算法和统计模型,使计算机系统能够通过经验自动改进。",
|
||
"深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。",
|
||
}
|
||
|
||
// 构造重排请求
|
||
rerankRequest := map[string]interface{}{
|
||
"model": modelName,
|
||
"query": testQuery,
|
||
"documents": testPassages,
|
||
"truncate_prompt_tokens": 512,
|
||
}
|
||
|
||
jsonData, err := json.Marshal(rerankRequest)
|
||
if err != nil {
|
||
return false, fmt.Sprintf("构造请求失败: %v", err)
|
||
}
|
||
|
||
logger.Infof(ctx, "Rerank request: %s, modelName=%s, baseURL=%s, apiKey=%s",
|
||
string(jsonData), modelName, baseURL, apiKey)
|
||
|
||
req, err := http.NewRequestWithContext(
|
||
ctx, "POST", rerankEndpoint, strings.NewReader(string(jsonData)),
|
||
)
|
||
if err != nil {
|
||
return false, fmt.Sprintf("创建请求失败: %v", err)
|
||
}
|
||
|
||
// 添加认证头
|
||
if apiKey != "" {
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return false, fmt.Sprintf("连接失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return false, fmt.Sprintf("读取响应失败: %v", err)
|
||
}
|
||
|
||
// 检查响应状态
|
||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||
// 尝试解析重排响应
|
||
var rerankResp struct {
|
||
Results []struct {
|
||
Index int `json:"index"`
|
||
Document string `json:"document"`
|
||
RelevanceScore float64 `json:"relevance_score"`
|
||
} `json:"results"`
|
||
}
|
||
|
||
if err := json.Unmarshal(body, &rerankResp); err != nil {
|
||
// 如果无法解析标准重排响应,检查是否有其他格式
|
||
return true, "连接正常,但响应格式非标准"
|
||
}
|
||
|
||
// 检查是否返回了重排结果
|
||
if len(rerankResp.Results) > 0 {
|
||
return true, fmt.Sprintf("重排功能正常,返回%d个结果", len(rerankResp.Results))
|
||
} else {
|
||
return false, "重排接口连接成功,但未返回重排结果"
|
||
}
|
||
} else if resp.StatusCode == 401 {
|
||
return false, "认证失败,请检查API Key"
|
||
} else if resp.StatusCode == 403 {
|
||
return false, "权限不足,请检查API Key权限"
|
||
} else if resp.StatusCode == 404 {
|
||
return false, "重排API端点不存在,请检查Base URL"
|
||
} else if resp.StatusCode == 422 {
|
||
return false, fmt.Sprintf("请求参数错误: %s", string(body))
|
||
} else {
|
||
return false, fmt.Sprintf("API返回错误状态: %d, 响应: %s", resp.StatusCode, string(body))
|
||
}
|
||
}
|
||
|
||
// CheckRerankModel 检查Rerank模型连接和功能
|
||
func (h *InitializationHandler) CheckRerankModel(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Checking rerank model connection and functionality")
|
||
|
||
var req struct {
|
||
ModelName string `json:"modelName" binding:"required"`
|
||
BaseURL string `json:"baseUrl" binding:"required"`
|
||
APIKey string `json:"apiKey"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse rerank model check request", err)
|
||
c.Error(errors.NewBadRequestError(err.Error()))
|
||
return
|
||
}
|
||
|
||
// 验证请求参数
|
||
if req.ModelName == "" || req.BaseURL == "" {
|
||
logger.Error(ctx, "Model name and base URL are required")
|
||
c.Error(errors.NewBadRequestError("模型名称和Base URL不能为空"))
|
||
return
|
||
}
|
||
|
||
// 检查Rerank模型连接和功能
|
||
available, message := h.checkRerankModelConnection(
|
||
ctx, req.ModelName, req.BaseURL, req.APIKey,
|
||
)
|
||
|
||
logger.Info(ctx,
|
||
fmt.Sprintf("Rerank model check completed: modelName=%s, baseUrl=%s, available=%v, message=%s",
|
||
req.ModelName, req.BaseURL, available, message,
|
||
),
|
||
)
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"available": available,
|
||
"message": message,
|
||
},
|
||
})
|
||
}
|
||
|
||
// 使用结构体解析表单数据
|
||
type testMultimodalForm struct {
|
||
VLMModel string `form:"vlm_model"`
|
||
VLMBaseURL string `form:"vlm_base_url"`
|
||
VLMAPIKey string `form:"vlm_api_key"`
|
||
VLMInterfaceType string `form:"vlm_interface_type"`
|
||
|
||
StorageType string `form:"storage_type"`
|
||
|
||
// COS 配置
|
||
COSSecretID string `form:"cos_secret_id"`
|
||
COSSecretKey string `form:"cos_secret_key"`
|
||
COSRegion string `form:"cos_region"`
|
||
COSBucketName string `form:"cos_bucket_name"`
|
||
COSAppID string `form:"cos_app_id"`
|
||
COSPathPrefix string `form:"cos_path_prefix"`
|
||
|
||
// MinIO 配置(当存储为 minio 时)
|
||
MinioBucketName string `form:"minio_bucket_name"`
|
||
MinioPathPrefix string `form:"minio_path_prefix"`
|
||
|
||
// 文档切分配置(字符串后续自行解析,以避免类型绑定失败)
|
||
ChunkSize string `form:"chunk_size"`
|
||
ChunkOverlap string `form:"chunk_overlap"`
|
||
SeparatorsRaw string `form:"separators"`
|
||
}
|
||
|
||
// TestMultimodalFunction 测试多模态功能
|
||
func (h *InitializationHandler) TestMultimodalFunction(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Testing multimodal functionality")
|
||
|
||
var req testMultimodalForm
|
||
if err := c.ShouldBind(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse form data", err)
|
||
c.Error(errors.NewBadRequestError("表单参数解析失败"))
|
||
return
|
||
}
|
||
// ollama 场景自动拼接 base url
|
||
if req.VLMInterfaceType == "ollama" {
|
||
req.VLMBaseURL = os.Getenv("OLLAMA_BASE_URL") + "/v1"
|
||
}
|
||
|
||
req.StorageType = strings.ToLower(req.StorageType)
|
||
|
||
if req.VLMModel == "" || req.VLMBaseURL == "" {
|
||
logger.Error(ctx, "VLM model name and base URL are required")
|
||
c.Error(errors.NewBadRequestError("VLM模型名称和Base URL不能为空"))
|
||
return
|
||
}
|
||
switch req.StorageType {
|
||
case "cos":
|
||
logger.Infof(ctx, "COS config: Region=%s, Bucket=%s, App=%s, Prefix=%s",
|
||
req.COSRegion, req.COSBucketName, req.COSAppID, req.COSPathPrefix)
|
||
// 必填:SecretID/SecretKey/Region/BucketName/AppID;PathPrefix 可选
|
||
if req.COSSecretID == "" || req.COSSecretKey == "" ||
|
||
req.COSRegion == "" || req.COSBucketName == "" ||
|
||
req.COSAppID == "" {
|
||
logger.Error(ctx, "COS configuration is required")
|
||
c.Error(errors.NewBadRequestError("COS配置信息不能为空"))
|
||
return
|
||
}
|
||
case "minio":
|
||
logger.Infof(ctx, "MinIO config: Bucket=%s, PathPrefix=%s", req.MinioBucketName, req.MinioPathPrefix)
|
||
if req.MinioBucketName == "" {
|
||
logger.Error(ctx, "MinIO configuration is required")
|
||
c.Error(errors.NewBadRequestError("MinIO配置信息不能为空"))
|
||
return
|
||
}
|
||
default:
|
||
logger.Error(ctx, "Invalid storage type")
|
||
c.Error(errors.NewBadRequestError("无效的存储类型"))
|
||
return
|
||
}
|
||
|
||
logger.Infof(ctx, "VLM config: Model=%s, URL=%s, HasKey=%v, Type=%s",
|
||
req.VLMModel, req.VLMBaseURL, req.VLMAPIKey != "", req.VLMInterfaceType)
|
||
|
||
// 获取上传的图片文件
|
||
file, header, err := c.Request.FormFile("image")
|
||
if err != nil {
|
||
logger.Error(ctx, "Failed to get uploaded image", err)
|
||
c.Error(errors.NewBadRequestError("获取上传图片失败"))
|
||
return
|
||
}
|
||
defer file.Close()
|
||
|
||
// 验证文件类型
|
||
if !strings.HasPrefix(header.Header.Get("Content-Type"), "image/") {
|
||
logger.Error(ctx, "Invalid file type, only images are allowed")
|
||
c.Error(errors.NewBadRequestError("只允许上传图片文件"))
|
||
return
|
||
}
|
||
|
||
// 验证文件大小 (10MB)
|
||
if header.Size > 10*1024*1024 {
|
||
logger.Error(ctx, "File size too large")
|
||
c.Error(errors.NewBadRequestError("图片文件大小不能超过10MB"))
|
||
return
|
||
}
|
||
logger.Infof(ctx, "Processing image: %s, size: %d bytes", header.Filename, header.Size)
|
||
|
||
// 解析文档分割配置
|
||
chunkSize, err := strconv.Atoi(req.ChunkSize)
|
||
if err != nil || chunkSize < 100 || chunkSize > 10000 {
|
||
chunkSize = 1000
|
||
}
|
||
|
||
chunkOverlap, err := strconv.Atoi(req.ChunkOverlap)
|
||
if err != nil || chunkOverlap < 0 || chunkOverlap >= chunkSize {
|
||
chunkOverlap = 200
|
||
}
|
||
|
||
var separators []string
|
||
if req.SeparatorsRaw != "" {
|
||
if err := json.Unmarshal([]byte(req.SeparatorsRaw), &separators); err != nil {
|
||
separators = []string{"\n\n", "\n", "。", "!", "?", ";", ";"}
|
||
}
|
||
} else {
|
||
separators = []string{"\n\n", "\n", "。", "!", "?", ";", ";"}
|
||
}
|
||
|
||
// 读取图片文件内容
|
||
imageContent, err := io.ReadAll(file)
|
||
if err != nil {
|
||
logger.Error(ctx, "Failed to read image file", err)
|
||
c.Error(errors.NewBadRequestError("读取图片文件失败"))
|
||
return
|
||
}
|
||
|
||
// 调用多模态测试
|
||
startTime := time.Now()
|
||
result, err := h.testMultimodalWithDocReader(
|
||
ctx,
|
||
imageContent, header.Filename,
|
||
chunkSize, chunkOverlap, separators, &req,
|
||
)
|
||
processingTime := time.Since(startTime).Milliseconds()
|
||
|
||
if err != nil {
|
||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||
"vlm_model": req.VLMModel,
|
||
"vlm_base_url": req.VLMBaseURL,
|
||
"filename": header.Filename,
|
||
})
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
"processing_time": processingTime,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
logger.Info(ctx, fmt.Sprintf("Multimodal test completed successfully in %dms", processingTime))
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"success": true,
|
||
"caption": result["caption"],
|
||
"ocr": result["ocr"],
|
||
"processing_time": processingTime,
|
||
},
|
||
})
|
||
}
|
||
|
||
// testMultimodalWithDocReader 调用docreader服务进行多模态处理
|
||
func (h *InitializationHandler) testMultimodalWithDocReader(
|
||
ctx context.Context,
|
||
imageContent []byte, filename string,
|
||
chunkSize, chunkOverlap int, separators []string,
|
||
req *testMultimodalForm,
|
||
) (map[string]string, error) {
|
||
// 获取文件扩展名
|
||
fileExt := ""
|
||
if idx := strings.LastIndex(filename, "."); idx != -1 {
|
||
fileExt = strings.ToLower(filename[idx+1:])
|
||
}
|
||
|
||
// 检查docreader服务配置
|
||
if h.docReaderClient == nil {
|
||
return nil, fmt.Errorf("DocReader service not configured")
|
||
}
|
||
|
||
// 构造请求
|
||
request := &proto.ReadFromFileRequest{
|
||
FileContent: imageContent,
|
||
FileName: filename,
|
||
FileType: fileExt,
|
||
ReadConfig: &proto.ReadConfig{
|
||
ChunkSize: int32(chunkSize),
|
||
ChunkOverlap: int32(chunkOverlap),
|
||
Separators: separators,
|
||
EnableMultimodal: true, // 启用多模态处理
|
||
VlmConfig: &proto.VLMConfig{
|
||
ModelName: req.VLMModel,
|
||
BaseUrl: req.VLMBaseURL,
|
||
ApiKey: req.VLMAPIKey,
|
||
InterfaceType: req.VLMInterfaceType,
|
||
},
|
||
},
|
||
RequestId: ctx.Value(types.RequestIDContextKey).(string),
|
||
}
|
||
|
||
// 设置对象存储配置(通用)
|
||
switch strings.ToLower(req.StorageType) {
|
||
case "cos":
|
||
request.ReadConfig.StorageConfig = &proto.StorageConfig{
|
||
Provider: proto.StorageProvider_COS,
|
||
Region: req.COSRegion,
|
||
BucketName: req.COSBucketName,
|
||
AccessKeyId: req.COSSecretID,
|
||
SecretAccessKey: req.COSSecretKey,
|
||
AppId: req.COSAppID,
|
||
PathPrefix: req.COSPathPrefix,
|
||
}
|
||
case "minio":
|
||
request.ReadConfig.StorageConfig = &proto.StorageConfig{
|
||
Provider: proto.StorageProvider_MINIO,
|
||
BucketName: req.MinioBucketName,
|
||
PathPrefix: req.MinioPathPrefix,
|
||
AccessKeyId: os.Getenv("MINIO_ACCESS_KEY_ID"),
|
||
SecretAccessKey: os.Getenv("MINIO_SECRET_ACCESS_KEY"),
|
||
}
|
||
}
|
||
|
||
// 调用docreader服务
|
||
response, err := h.docReaderClient.ReadFromFile(ctx, request)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("调用DocReader服务失败: %v", err)
|
||
}
|
||
|
||
if response.Error != "" {
|
||
return nil, fmt.Errorf("DocReader服务返回错误: %s", response.Error)
|
||
}
|
||
|
||
// 处理响应,提取Caption和OCR信息
|
||
result := make(map[string]string)
|
||
var allCaptions, allOCRTexts []string
|
||
|
||
for _, chunk := range response.Chunks {
|
||
if len(chunk.Images) > 0 {
|
||
for _, image := range chunk.Images {
|
||
if image.Caption != "" {
|
||
allCaptions = append(allCaptions, image.Caption)
|
||
}
|
||
if image.OcrText != "" {
|
||
allOCRTexts = append(allOCRTexts, image.OcrText)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 合并所有Caption和OCR结果
|
||
result["caption"] = strings.Join(allCaptions, "; ")
|
||
result["ocr"] = strings.Join(allOCRTexts, "; ")
|
||
|
||
return result, nil
|
||
}
|