87 lines
2.3 KiB
Go
87 lines
2.3 KiB
Go
package biz
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"geo/internal/collect"
|
|
"geo/internal/config"
|
|
"log"
|
|
)
|
|
|
|
// CollectBiz AI收集业务层
|
|
type CollectBiz struct {
|
|
manager *collect.CollectManager
|
|
config *config.Config
|
|
logger *log.Logger
|
|
}
|
|
|
|
// NewCollectBiz 创建AI收集业务实例
|
|
func NewCollectBiz(ctx context.Context, cfg *config.Config, logger *log.Logger) *CollectBiz {
|
|
manager := collect.NewCollectManager(ctx, cfg, logger)
|
|
return &CollectBiz{
|
|
manager: manager,
|
|
config: cfg,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// AskAIQuestion 向指定AI平台提问
|
|
// platform: 平台类型 (wenxin, deepseek, doubao, qianwen)
|
|
// userIndex: 用户索引
|
|
// platIndex: 平台索引
|
|
// requestID: 请求ID
|
|
// question: 问题内容
|
|
// headless: 是否无头模式
|
|
func (b *CollectBiz) AskAIQuestion(platform string, userIndex, platIndex, requestID, question string, headless bool) (string, error) {
|
|
params := &collect.CollectParams{
|
|
Headless: headless,
|
|
UserIndex: userIndex,
|
|
PlatIndex: platIndex,
|
|
RequestID: requestID,
|
|
Platform: platform,
|
|
}
|
|
|
|
answer, err := b.manager.AskQuestion(platform, params, question)
|
|
if err != nil {
|
|
return "", fmt.Errorf("向%s提问失败: %w", platform, err)
|
|
}
|
|
|
|
return answer, nil
|
|
}
|
|
|
|
// WaitAILogin 等待AI平台登录
|
|
func (b *CollectBiz) WaitAILogin(platform string, userIndex, platIndex, requestID string, headless bool) (bool, string) {
|
|
params := &collect.CollectParams{
|
|
Headless: headless,
|
|
UserIndex: userIndex,
|
|
PlatIndex: platIndex,
|
|
RequestID: requestID,
|
|
Platform: platform,
|
|
}
|
|
|
|
return b.manager.WaitLogin(platform, params)
|
|
}
|
|
|
|
// ListAIPlatforms 列出所有支持的AI平台
|
|
func (b *CollectBiz) ListAIPlatforms() []string {
|
|
return b.manager.ListPlatforms()
|
|
}
|
|
|
|
// AskMultipleAI 向多个AI平台提问并收集答案
|
|
func (b *CollectBiz) AskMultipleAI(platforms []string, userIndex, requestID, question string, headless bool) map[string]string {
|
|
results := make(map[string]string)
|
|
|
|
for _, platform := range platforms {
|
|
platIndex := platform // 默认使用platform作为platIndex
|
|
answer, err := b.AskAIQuestion(platform, userIndex, platIndex, requestID+"_"+platform, question, headless)
|
|
if err != nil {
|
|
b.logger.Printf("向%s提问失败: %v", platform, err)
|
|
results[platform] = fmt.Sprintf("错误: %v", err)
|
|
} else {
|
|
results[platform] = answer
|
|
}
|
|
}
|
|
|
|
return results
|
|
}
|