package service import ( "context" "fmt" "geo/internal/ai_tool" "geo/internal/biz" "geo/internal/collect" "geo/internal/config" "geo/internal/data/impl" "geo/internal/data/model" "geo/internal/entitys" "geo/pkg" "geo/tmpl/dataTemp" "geo/tmpl/errcode" "github.com/gofiber/fiber/v2" "log" "os" "path/filepath" "strings" "sync" "sync/atomic" "time" "xorm.io/builder" ) // CollectService 收集服务 type CollectService struct { cfg *config.Config collectBiz *biz.CollectBiz collect *impl.CollectImpl collectTask *impl.CollectTaskImpl productBiz *biz.ProductBiz authBiz *biz.AuthBiz } // NewCollectService 创建收集服务 func NewCollectService( cfg *config.Config, collectBiz *biz.CollectBiz, collect *impl.CollectImpl, collectTask *impl.CollectTaskImpl, authBiz *biz.AuthBiz, productBiz *biz.ProductBiz, ) *CollectService { return &CollectService{ cfg: cfg, collectBiz: collectBiz, collect: collect, collectTask: collectTask, authBiz: authBiz, productBiz: productBiz, } } // CollectWithTasks 收集记录及其任务列表 type CollectWithTasks struct { model.Collect Tasks []model.CollectTask `json:"tasks"` } // CollectList 获取收集列表及对应的任务详情 func (c *CollectService) CollectList(ctx *fiber.Ctx, req *entitys.CollectListRequest) error { _, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken) if err != nil { return err } // 构建查询条件 cond := builder.NewCond() if req.ProductId > 0 { cond = cond.And(builder.Eq{"product_id": req.ProductId}) } // 查询 collect 列表 var collects []model.Collect pageBo, err := c.collect.GetListToStruct(ctx.UserContext(), &cond, &dataTemp.ReqPageBo{Page: req.Page, Limit: req.Limit}, &collects, "created_at DESC") if err != nil { return err } // 提取所有的 collect_code collectCodes := make([]string, 0, len(collects)) for k, colt := range collects { var platName = make([]string, 0) plat := strings.Split(colt.Platform, ",") for _, p := range plat { name := collect.CollectorMap[p].Name platName = append(platName, name) } collects[k].Platform = strings.Join(platName, ",") collectCodes = append(collectCodes, colt.CollectCode) } // 批量查询所有相关的 collect_task var tasks []model.CollectTask if len(collectCodes) > 0 { taskCond := builder.NewCond().And(builder.In("collect_code", collectCodes)) _, err := c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC") if err != nil { log.Printf("查询 collect_task 失败: %v", err) } } // 按 collect_code 分组 tasks tasksMap := make(map[string][]model.CollectTask) for _, task := range tasks { tasksMap[task.CollectCode] = append(tasksMap[task.CollectCode], task) } // 组装返回数据 result := make([]*CollectWithTasks, 0, len(collects)) for _, collect := range collects { cwt := &CollectWithTasks{ Collect: collect, Tasks: tasksMap[collect.CollectCode], } result = append(result, cwt) } return ctx.JSON(fiber.Map{ "list": result, "total": pageBo.Total, "page": req.Page, "pageSize": req.Limit, }) } func (c *CollectService) GetCollectPlatForms(ctx *fiber.Ctx) error { var list = make([]map[string]string, 0, len(collect.CollectorMap)) for _, v := range collect.CollectorMap { list = append(list, map[string]string{ "name": v.Name, "index": v.Platform, "icon": v.Icon, }) } return ctx.JSON(list) } func (c *CollectService) Collect(ctx *fiber.Ctx, req *entitys.ProductCollectRequest) error { _, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken) if err != nil { return err } collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano()) collectData := &model.Collect{ CollectCode: collectCode, ProductID: req.ProductId, Keywords: strings.Join(req.Keywords, ","), Platform: strings.Join(req.PlatformIndex, ","), Question: req.Question, CreatedAt: time.Now(), Status: 0, // 0:处理中 } err = c.collect.Add(ctx.UserContext(), collectData) if err != nil { return err } // 启动异步任务 go c.doCollectAsync(collectCode, req.PlatformIndex, req.Question, req.Keywords) return ctx.JSON(fiber.Map{"message": "收录生成中"}) } // doCollectAsync 异步执行收集任务(简化稳定版) func (c *CollectService) doCollectAsync(collectCode string, platforms []string, question string, keywords []string) { defer func() { if r := recover(); r != nil { log.Printf("任务panic [%s]: %v", collectCode, r) c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 3}) } }() totalPlatforms := len(platforms) if totalPlatforms == 0 { c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 2, "progress": 100}) return } // 初始化进度 c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 1, "progress": 10}) var completed int32 var wg sync.WaitGroup for _, platIndex := range platforms { wg.Add(1) go func(platIndex string) { defer wg.Done() c.processOnePlatform(collectCode, platIndex, question, keywords) // 原子操作增加计数 done := atomic.AddInt32(&completed, 1) // 计算并更新进度 progress := 10 + int(float64(done)/float64(totalPlatforms)*90) c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"progress": progress}) }(platIndex) } wg.Wait() // 全部完成 c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 2, "progress": 100}) } // processOnePlatform 处理单个平台(带超时) func (c *CollectService) processOnePlatform(collectCode, platIndex, question string, keywords []string) { // 创建120秒超时的context ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() // 检查平台是否存在 platformName, exist := collect.CollectorMap[platIndex] if !exist { log.Printf("未知平台: %s", platIndex) return } // 使用channel控制超时 type result struct { data *collect.CollectResult err error } resultChan := make(chan result, 1) go func() { requestID := fmt.Sprintf("%s_%s", collectCode, platIndex) res, err := c.collectBiz.AskAIQuestion(platIndex, requestID, question, true, keywords) resultChan <- result{data: res, err: err} }() // 等待结果或超时 select { case <-ctx.Done(): log.Printf("平台 %s 超时", platformName.Name) return case res := <-resultChan: if res.err != nil { log.Printf("平台 %s 失败: %v", platformName.Name, res.err) return } // 保存结果 ise := 1 if res.data.IsExposure { ise = 2 } task := &model.CollectTask{ CollectCode: collectCode, AiPlatformIndex: platIndex, ContentHTML: res.data.Answer, ShareURL: res.data.ShareLink, CreatedAt: time.Now(), UpdatedAt: time.Now(), IsExposure: int32(ise), Status: 1, } if err := c.collectTask.Add(context.Background(), task); err != nil { log.Printf("保存失败: %v", err) return } log.Printf("平台 %s 完成", platformName.Name) } } func (c *CollectService) CollectAan(ctx *fiber.Ctx, req *entitys.CollectAnaRequest) error { var collectInfo model.Collect err := c.collect.GetByKey(ctx.UserContext(), c.collect.PrimaryKey(), req.CollectId, &collectInfo) if err != nil { return err } if collectInfo.ID == 0 { return errcode.NotFound("为找到收录计划") } err = c.collect.UpdateByKey(ctx.UserContext(), c.collect.PrimaryKey(), req.CollectId, map[string]interface{}{"report_status": 2}) var tasks []model.CollectTask taskCond := builder.NewCond().And(builder.In("collect_code", collectInfo.CollectCode)) _, err = c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC") if err != nil { log.Printf("查询 collect_task 失败: %v", err) } mes := c.collectBiz.CreateAndPrompt(ctx.UserContext(), &collectInfo, tasks) content, err := ai_tool.NewHsyq().RequestHsyqBot(ctx.UserContext(), c.cfg.Hsyq.ApiKey, c.cfg.AiBot.CollectInfo, mes) if err != nil { return err } fileBaseName := fmt.Sprintf("%s分析报告_%d", collectInfo.Question, time.Now().UnixNano()) fileName := fmt.Sprintf("%s.md", fileBaseName) mdAbs := filepath.Join(c.cfg.Sys.MdDir, fileName) // 创建并写入文件 file, err := os.Create(mdAbs) defer os.Remove(mdAbs) if err != nil { return fmt.Errorf("创建文件失败: %w", err) } defer file.Close() if _, err := file.WriteString(*content); err != nil { return fmt.Errorf("写入文件失败: %w", err) } docxPath, err := pkg.Md2wordFix(mdAbs, c.cfg.Sys.MdDir, nil) if err != nil { return err } docxName := fmt.Sprintf("%s.docx", fileBaseName) docxAbs := filepath.Join(docxPath, docxName) defer os.Remove(docxAbs) fileByte, err := pkg.ReadDocxToBytes(docxAbs) if err != nil { return err } url, err := c.productBiz.SourceUpload(ctx.UserContext(), fileByte, docxName) if err != nil { return fmt.Errorf("上传文件失败: %w", err) } err = c.collect.UpdateByKey(ctx.UserContext(), c.collect.PrimaryKey(), req.CollectId, map[string]interface{}{"end_file": url, "report_status": 3}) if err != nil { return err } return nil }