geoGo/internal/service/collect.go

219 lines
5.5 KiB
Go

package service
import (
"context"
"fmt"
"geo/internal/biz"
"geo/internal/collect"
"geo/internal/config"
"geo/internal/data/impl"
"geo/internal/data/model"
"geo/internal/entitys"
"geo/tmpl/dataTemp"
"github.com/gofiber/fiber/v2"
"log"
"strings"
"sync"
"time"
"xorm.io/builder"
)
// CollectService 收集服务
type CollectService struct {
cfg *config.Config
collectBiz *biz.CollectBiz
collect *impl.CollectImpl
collectTask *impl.CollectTaskImpl
authBiz *biz.AuthBiz
}
// NewCollectService 创建收集服务
func NewCollectService(
cfg *config.Config,
collectBiz *biz.CollectBiz,
collect *impl.CollectImpl,
collectTask *impl.CollectTaskImpl,
authBiz *biz.AuthBiz,
) *CollectService {
return &CollectService{
cfg: cfg,
collectBiz: collectBiz,
collect: collect,
collectTask: collectTask,
authBiz: authBiz,
}
}
// CollectWithTasks 收集记录及其任务列表
type CollectWithTasks struct {
model.Collect
Tasks []model.CollectTask `json:"tasks"`
}
// CollectList 获取收集列表及对应的任务详情
func (c *CollectService) CollectList(ctx *fiber.Ctx, req *entitys.CollectListRequest) error {
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
if err != nil {
return err
}
// 构建查询条件
cond := builder.NewCond()
if req.ProductId > 0 {
cond = cond.And(builder.Eq{"product_id": req.ProductId})
}
// 查询 collect 列表
var collects []model.Collect
pageBo, err := c.collect.GetListToStruct(ctx.UserContext(), &cond, &dataTemp.ReqPageBo{Page: req.Page, Limit: req.Limit}, &collects, "created_at DESC")
if err != nil {
return err
}
// 提取所有的 collect_code
collectCodes := make([]string, 0, len(collects))
for k, colt := range collects {
var platName = make([]string, 0)
plat := strings.Split(colt.Platform, ",")
for _, p := range plat {
name := collect.CollectorMap[p].Name
platName = append(platName, name)
}
collects[k].Platform = strings.Join(platName, ",")
collectCodes = append(collectCodes, colt.CollectCode)
}
// 批量查询所有相关的 collect_task
var tasks []model.CollectTask
if len(collectCodes) > 0 {
taskCond := builder.NewCond().And(builder.In("collect_code", collectCodes))
_, err := c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC")
if err != nil {
log.Printf("查询 collect_task 失败: %v", err)
}
}
// 按 collect_code 分组 tasks
tasksMap := make(map[string][]model.CollectTask)
for _, task := range tasks {
tasksMap[task.CollectCode] = append(tasksMap[task.CollectCode], task)
}
// 组装返回数据
result := make([]*CollectWithTasks, 0, len(collects))
for _, collect := range collects {
cwt := &CollectWithTasks{
Collect: collect,
Tasks: tasksMap[collect.CollectCode],
}
result = append(result, cwt)
}
return ctx.JSON(fiber.Map{
"list": result,
"total": pageBo.Total,
"page": req.Page,
"pageSize": req.Limit,
})
}
func (c *CollectService) GetCollectPlatForms(ctx *fiber.Ctx) error {
var list = make([]map[string]string, 0, len(collect.CollectorMap))
for _, v := range collect.CollectorMap {
list = append(list, map[string]string{
"name": v.Name,
"index": v.Platform,
"icon": v.Icon,
})
}
return ctx.JSON(list)
}
// Collect 创建收集任务
func (c *CollectService) Collect(ctx *fiber.Ctx, req *entitys.ProductCollectRequest) error {
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
if err != nil {
return err
}
collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano())
collectData := &model.Collect{
CollectCode: collectCode,
ProductID: req.ProductId,
Keywords: strings.Join(req.Keywords, ","),
Platform: strings.Join(req.PlatformIndex, ","),
Question: req.Question,
CreatedAt: time.Now(),
}
err = c.collect.Add(ctx.UserContext(), collectData)
if err != nil {
return err
}
go c.doCollectAsync(collectCode, req.PlatformIndex, req.Question)
return ctx.JSON(fiber.Map{"message": "收录生成中"})
}
// doCollectAsync 异步执行收集任务
func (c *CollectService) doCollectAsync(collectCode string, platforms []string, question string) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
defer cancel()
defer func() {
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 2})
}()
var wg sync.WaitGroup
var mu sync.Mutex
tasks := make([]*model.CollectTask, 0, len(platforms))
for _, platIndex := range platforms {
wg.Add(1)
go func(platIndex string) {
defer wg.Done()
platformName, exist := collect.CollectorMap[platIndex]
if !exist {
log.Printf("未知的平台索引: %d", platIndex)
return
}
requestID := fmt.Sprintf("%s_%s", collectCode, platIndex)
result, err := c.collectBiz.AskAIQuestion(platIndex, requestID, question, true)
if err != nil {
log.Printf("平台 %s 收集失败: %v", platformName, err)
return
}
ise := 1
if result.IsExposure {
ise = 2
}
task := &model.CollectTask{
CollectCode: collectCode,
AiPlatformIndex: platIndex,
ContentHTML: result.Answer,
ShareURL: result.ShareLink,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
IsExposure: int32(ise),
Status: 1,
}
mu.Lock()
tasks = append(tasks, task)
mu.Unlock()
}(platIndex)
}
wg.Wait()
if len(tasks) > 0 {
if err := c.collectTask.Add(ctx, tasks); err != nil {
log.Printf("保存收集任务失败: %v", err)
}
}
}