geoGo/internal/biz/ai_collect.go

112 lines
2.8 KiB
Go

package biz
import (
"context"
"fmt"
"geo/internal/collect"
"geo/internal/config"
"geo/internal/data/model"
"geo/pkg"
volmodle "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/volcengine"
"strings"
"github.com/gofiber/fiber/v2/log"
)
// CollectBiz AI收集业务层
type CollectBiz struct {
manager *collect.CollectManager
config *config.Config
logger log.AllLogger
}
// NewCollectBiz 创建AI收集业务实例
func NewCollectBiz(ctx context.Context, cfg *config.Config, logger log.AllLogger) *CollectBiz {
manager := collect.NewCollectManager(ctx, cfg, logger)
return &CollectBiz{
manager: manager,
config: cfg,
logger: logger,
}
}
// AskAIQuestion 向指定AI平台提问
// platform: 平台类型 (wenxin, deepseek, doubao, qianwen)
// requestID: 请求ID
// question: 问题内容
// headless: 是否无头模式
func (b *CollectBiz) AskAIQuestion(platform string, requestID, question string, headless bool, keywords []string) (*collect.CollectResult, error) {
params := &collect.CollectParams{
Headless: headless,
RequestID: requestID,
Platform: platform,
KeyWords: keywords,
}
result, err := b.manager.AskQuestion(platform, params, question)
if err != nil {
return nil, fmt.Errorf("向%s提问失败: %w", platform, err)
}
return result, nil
}
// WaitAILogin 等待AI平台登录
func (b *CollectBiz) WaitAILogin(platform string, requestID string, headless bool) (bool, string) {
params := &collect.CollectParams{
Headless: headless,
RequestID: requestID,
Platform: platform,
}
return b.manager.WaitLogin(platform, params)
}
// ListAIPlatforms 列出所有支持的AI平台
func (b *CollectBiz) ListAIPlatforms() []string {
return b.manager.ListPlatforms()
}
type AnaProject struct {
Ques string `json:"ques"`
Keywords []string `json:"keywords"`
Tasks []*Task `json:"tasks"`
}
type Task struct {
ContentHtml string `json:"content_html"`
PlatName string `json:"plat_name"`
IsExposure bool `json:"isExposure"`
}
func (b *CollectBiz) CreateAndPrompt(ctx context.Context, collectInfo *model.Collect, tasks []model.CollectTask) []*volmodle.ChatCompletionMessage {
var col = &AnaProject{
Ques: collectInfo.Question,
Keywords: strings.Split(collectInfo.Keywords, ","),
}
var resultMap = make([]*Task, 0, len(tasks))
for _, v := range tasks {
var exposure bool
if v.IsExposure == 2 {
exposure = true
}
resultMap = append(resultMap, &Task{
ContentHtml: v.ContentHTML,
PlatName: collect.CollectorMap[v.AiPlatformIndex].Name,
IsExposure: exposure,
})
}
col.Tasks = resultMap
colStr := pkg.JsonStringIgonErr(col)
mes := []*volmodle.ChatCompletionMessage{
{
Role: volmodle.ChatMessageRoleUser,
Content: &volmodle.ChatCompletionMessageContent{
StringValue: volcengine.String(colStr),
},
},
}
return mes
}