219 lines
5.6 KiB
Go
219 lines
5.6 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"geo/internal/biz"
|
|
"geo/internal/collect"
|
|
"geo/internal/config"
|
|
"geo/internal/data/impl"
|
|
"geo/internal/data/model"
|
|
"geo/internal/entitys"
|
|
"geo/tmpl/dataTemp"
|
|
"github.com/gofiber/fiber/v2"
|
|
"log"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"xorm.io/builder"
|
|
)
|
|
|
|
// CollectService 收集服务
|
|
type CollectService struct {
|
|
cfg *config.Config
|
|
collectBiz *biz.CollectBiz
|
|
collect *impl.CollectImpl
|
|
collectTask *impl.CollectTaskImpl
|
|
authBiz *biz.AuthBiz
|
|
}
|
|
|
|
// NewCollectService 创建收集服务
|
|
func NewCollectService(
|
|
cfg *config.Config,
|
|
collectBiz *biz.CollectBiz,
|
|
collect *impl.CollectImpl,
|
|
collectTask *impl.CollectTaskImpl,
|
|
authBiz *biz.AuthBiz,
|
|
) *CollectService {
|
|
return &CollectService{
|
|
cfg: cfg,
|
|
collectBiz: collectBiz,
|
|
collect: collect,
|
|
collectTask: collectTask,
|
|
authBiz: authBiz,
|
|
}
|
|
}
|
|
|
|
// CollectWithTasks 收集记录及其任务列表
|
|
type CollectWithTasks struct {
|
|
model.Collect
|
|
Tasks []model.CollectTask `json:"tasks"`
|
|
}
|
|
|
|
// CollectList 获取收集列表及对应的任务详情
|
|
func (c *CollectService) CollectList(ctx *fiber.Ctx, req *entitys.CollectListRequest) error {
|
|
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 构建查询条件
|
|
cond := builder.NewCond()
|
|
if req.ProductId > 0 {
|
|
cond = cond.And(builder.Eq{"product_id": req.ProductId})
|
|
}
|
|
|
|
// 查询 collect 列表
|
|
var collects []model.Collect
|
|
pageBo, err := c.collect.GetListToStruct(ctx.UserContext(), &cond, &dataTemp.ReqPageBo{Page: req.Page, Limit: req.Limit}, &collects, "created_at DESC")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 提取所有的 collect_code
|
|
collectCodes := make([]string, 0, len(collects))
|
|
for k, colt := range collects {
|
|
var platName = make([]string, 0)
|
|
plat := strings.Split(colt.Platform, ",")
|
|
for _, p := range plat {
|
|
name := collect.CollectorMap[p].Name
|
|
platName = append(platName, name)
|
|
}
|
|
collects[k].Platform = strings.Join(platName, ",")
|
|
collectCodes = append(collectCodes, colt.CollectCode)
|
|
}
|
|
|
|
// 批量查询所有相关的 collect_task
|
|
var tasks []model.CollectTask
|
|
if len(collectCodes) > 0 {
|
|
taskCond := builder.NewCond().And(builder.In("collect_code", collectCodes))
|
|
_, err := c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC")
|
|
if err != nil {
|
|
log.Printf("查询 collect_task 失败: %v", err)
|
|
}
|
|
}
|
|
|
|
// 按 collect_code 分组 tasks
|
|
tasksMap := make(map[string][]model.CollectTask)
|
|
for _, task := range tasks {
|
|
|
|
tasksMap[task.CollectCode] = append(tasksMap[task.CollectCode], task)
|
|
}
|
|
|
|
// 组装返回数据
|
|
result := make([]*CollectWithTasks, 0, len(collects))
|
|
for _, collect := range collects {
|
|
cwt := &CollectWithTasks{
|
|
Collect: collect,
|
|
Tasks: tasksMap[collect.CollectCode],
|
|
}
|
|
result = append(result, cwt)
|
|
}
|
|
|
|
return ctx.JSON(fiber.Map{
|
|
"list": result,
|
|
"total": pageBo.Total,
|
|
"page": req.Page,
|
|
"pageSize": req.Limit,
|
|
})
|
|
}
|
|
|
|
func (c *CollectService) GetCollectPlatForms(ctx *fiber.Ctx) error {
|
|
var list = make([]map[string]string, 0, len(collect.CollectorMap))
|
|
for _, v := range collect.CollectorMap {
|
|
list = append(list, map[string]string{
|
|
"name": v.Name,
|
|
"index": v.Platform,
|
|
"icon": v.Icon,
|
|
})
|
|
}
|
|
return ctx.JSON(list)
|
|
}
|
|
|
|
// Collect 创建收集任务
|
|
func (c *CollectService) Collect(ctx *fiber.Ctx, req *entitys.ProductCollectRequest) error {
|
|
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano())
|
|
collectData := &model.Collect{
|
|
CollectCode: collectCode,
|
|
ProductID: req.ProductId,
|
|
Keywords: strings.Join(req.Keywords, ","),
|
|
Platform: strings.Join(req.PlatformIndex, ","),
|
|
Question: req.Question,
|
|
CreatedAt: time.Now(),
|
|
}
|
|
|
|
err = c.collect.Add(ctx.UserContext(), collectData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go c.doCollectAsync(collectCode, req.PlatformIndex, req.Question, req.Keywords)
|
|
|
|
return ctx.JSON(fiber.Map{"message": "收录生成中"})
|
|
}
|
|
|
|
// doCollectAsync 异步执行收集任务
|
|
func (c *CollectService) doCollectAsync(collectCode string, platforms []string, question string, keywords []string) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
|
|
defer cancel()
|
|
defer func() {
|
|
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 2})
|
|
}()
|
|
var wg sync.WaitGroup
|
|
var mu sync.Mutex
|
|
tasks := make([]*model.CollectTask, 0, len(platforms))
|
|
|
|
for _, platIndex := range platforms {
|
|
wg.Add(1)
|
|
|
|
go func(platIndex string) {
|
|
defer wg.Done()
|
|
|
|
platformName, exist := collect.CollectorMap[platIndex]
|
|
if !exist {
|
|
log.Printf("未知的平台索引: %s", platIndex)
|
|
return
|
|
}
|
|
|
|
requestID := fmt.Sprintf("%s_%s", collectCode, platIndex)
|
|
result, err := c.collectBiz.AskAIQuestion(platIndex, requestID, question, keywords, true)
|
|
if err != nil {
|
|
log.Printf("平台 %s 收集失败: %v", platformName.Name, err)
|
|
return
|
|
}
|
|
ise := 1
|
|
if result.IsExposure {
|
|
ise = 2
|
|
}
|
|
task := &model.CollectTask{
|
|
CollectCode: collectCode,
|
|
AiPlatformIndex: platIndex,
|
|
ContentHTML: result.Answer,
|
|
ShareURL: result.ShareLink,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
IsExposure: int32(ise),
|
|
Status: 1,
|
|
}
|
|
|
|
mu.Lock()
|
|
tasks = append(tasks, task)
|
|
mu.Unlock()
|
|
}(platIndex)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if len(tasks) > 0 {
|
|
if err := c.collectTask.Add(ctx, tasks); err != nil {
|
|
log.Printf("保存收集任务失败: %v", err)
|
|
}
|
|
}
|
|
}
|