package service import ( "context" "fmt" "geo/internal/biz" "geo/internal/collect" "geo/internal/config" "geo/internal/data/impl" "geo/internal/data/model" "geo/internal/entitys" "geo/tmpl/dataTemp" "github.com/gofiber/fiber/v2" "log" "strings" "sync" "time" "xorm.io/builder" ) // CollectService 收集服务 type CollectService struct { cfg *config.Config collectBiz *biz.CollectBiz collect *impl.CollectImpl collectTask *impl.CollectTaskImpl authBiz *biz.AuthBiz } // NewCollectService 创建收集服务 func NewCollectService( cfg *config.Config, collectBiz *biz.CollectBiz, collect *impl.CollectImpl, collectTask *impl.CollectTaskImpl, authBiz *biz.AuthBiz, ) *CollectService { return &CollectService{ cfg: cfg, collectBiz: collectBiz, collect: collect, collectTask: collectTask, authBiz: authBiz, } } // CollectWithTasks 收集记录及其任务列表 type CollectWithTasks struct { model.Collect Tasks []model.CollectTask `json:"tasks"` } // CollectList 获取收集列表及对应的任务详情 func (c *CollectService) CollectList(ctx *fiber.Ctx, req *entitys.CollectListRequest) error { _, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken) if err != nil { return err } // 构建查询条件 cond := builder.NewCond() if req.ProductId > 0 { cond = cond.And(builder.Eq{"product_id": req.ProductId}) } // 查询 collect 列表 var collects []model.Collect pageBo, err := c.collect.GetListToStruct(ctx.UserContext(), &cond, &dataTemp.ReqPageBo{Page: req.Page, Limit: req.Limit}, &collects, "created_at DESC") if err != nil { return err } // 提取所有的 collect_code collectCodes := make([]string, 0, len(collects)) for k, colt := range collects { var platName = make([]string, 0) plat := strings.Split(colt.Platform, ",") for _, p := range plat { name := collect.CollectorMap[p].Name platName = append(platName, name) } collects[k].Platform = strings.Join(platName, ",") collectCodes = append(collectCodes, colt.CollectCode) } // 批量查询所有相关的 collect_task var tasks []model.CollectTask if len(collectCodes) > 0 { taskCond := builder.NewCond().And(builder.In("collect_code", collectCodes)) _, err := c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC") if err != nil { log.Printf("查询 collect_task 失败: %v", err) } } // 按 collect_code 分组 tasks tasksMap := make(map[string][]model.CollectTask) for _, task := range tasks { tasksMap[task.CollectCode] = append(tasksMap[task.CollectCode], task) } // 组装返回数据 result := make([]*CollectWithTasks, 0, len(collects)) for _, collect := range collects { cwt := &CollectWithTasks{ Collect: collect, Tasks: tasksMap[collect.CollectCode], } result = append(result, cwt) } return ctx.JSON(fiber.Map{ "list": result, "total": pageBo.Total, "page": req.Page, "pageSize": req.Limit, }) } func (c *CollectService) GetCollectPlatForms(ctx *fiber.Ctx) error { var list = make([]map[string]string, 0, len(collect.CollectorMap)) for _, v := range collect.CollectorMap { list = append(list, map[string]string{ "name": v.Name, "index": v.Platform, "icon": v.Icon, }) } return ctx.JSON(list) } // Collect 创建收集任务 func (c *CollectService) Collect(ctx *fiber.Ctx, req *entitys.ProductCollectRequest) error { _, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken) if err != nil { return err } collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano()) collectData := &model.Collect{ CollectCode: collectCode, ProductID: req.ProductId, Keywords: strings.Join(req.Keywords, ","), Platform: strings.Join(req.PlatformIndex, ","), Question: req.Question, CreatedAt: time.Now(), } err = c.collect.Add(ctx.UserContext(), collectData) if err != nil { return err } go c.doCollectAsync(collectCode, req.PlatformIndex, req.Question, req.Keywords) return ctx.JSON(fiber.Map{"message": "收录生成中"}) } // doCollectAsync 异步执行收集任务 func (c *CollectService) doCollectAsync(collectCode string, platforms []string, question string, keywords []string) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*240) defer cancel() defer func() { c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 2}) }() var wg sync.WaitGroup var mu sync.Mutex tasks := make([]*model.CollectTask, 0, len(platforms)) for _, platIndex := range platforms { wg.Add(1) go func(platIndex string) { defer wg.Done() platformName, exist := collect.CollectorMap[platIndex] if !exist { log.Printf("未知的平台索引: %s", platIndex) return } requestID := fmt.Sprintf("%s_%s", collectCode, platIndex) result, err := c.collectBiz.AskAIQuestion(platIndex, requestID, question, keywords, true) if err != nil { log.Printf("平台 %s 收集失败: %v", platformName.Name, err) return } ise := 1 if result.IsExposure { ise = 2 } task := &model.CollectTask{ CollectCode: collectCode, AiPlatformIndex: platIndex, ContentHTML: result.Answer, ShareURL: result.ShareLink, CreatedAt: time.Now(), UpdatedAt: time.Now(), IsExposure: int32(ise), Status: 1, } mu.Lock() tasks = append(tasks, task) mu.Unlock() }(platIndex) } wg.Wait() if len(tasks) > 0 { if err := c.collectTask.Add(ctx, tasks); err != nil { log.Printf("保存收集任务失败: %v", err) } } }