geoGo/internal/biz/ai_collect.go

87 lines
2.3 KiB
Go

package biz
import (
"context"
"fmt"
"geo/internal/collect"
"geo/internal/config"
"github.com/gofiber/fiber/v2/log"
)
// CollectBiz AI收集业务层
type CollectBiz struct {
manager *collect.CollectManager
config *config.Config
logger log.AllLogger
}
// NewCollectBiz 创建AI收集业务实例
func NewCollectBiz(ctx context.Context, cfg *config.Config, logger log.AllLogger) *CollectBiz {
manager := collect.NewCollectManager(ctx, cfg, logger)
return &CollectBiz{
manager: manager,
config: cfg,
logger: logger,
}
}
// AskAIQuestion 向指定AI平台提问
// platform: 平台类型 (wenxin, deepseek, doubao, qianwen)
// requestID: 请求ID
// question: 问题内容
// headless: 是否无头模式
func (b *CollectBiz) AskAIQuestion(platform string, requestID, question string, headless bool) (*collect.CollectResult, error) {
params := &collect.CollectParams{
Headless: headless,
RequestID: requestID,
Platform: platform,
}
result, err := b.manager.AskQuestion(platform, params, question)
if err != nil {
return nil, fmt.Errorf("向%s提问失败: %w", platform, err)
}
return result, nil
}
// WaitAILogin 等待AI平台登录
func (b *CollectBiz) WaitAILogin(platform string, requestID string, headless bool) (bool, string) {
params := &collect.CollectParams{
Headless: headless,
RequestID: requestID,
Platform: platform,
}
return b.manager.WaitLogin(platform, params)
}
// ListAIPlatforms 列出所有支持的AI平台
func (b *CollectBiz) ListAIPlatforms() []string {
return b.manager.ListPlatforms()
}
// AskMultipleAI 向多个AI平台提问并收集答案
func (b *CollectBiz) AskMultipleAI(platforms []string, requestID, question string, headless bool) map[string]*collect.CollectResult {
results := make(map[string]*collect.CollectResult)
for _, platform := range platforms {
// 为每个平台生成唯一的 requestID
platformRequestID := requestID + "_" + platform
result, err := b.AskAIQuestion(platform, platformRequestID, question, headless)
if err != nil {
b.logger.Errorf("向%s提问失败: %v", platform, err)
// 创建一个包含错误信息的结果
results[platform] = &collect.CollectResult{
Answer: fmt.Sprintf("错误: %v", err),
ShareLink: "",
}
} else {
results[platform] = result
}
}
return results
}