geoGo/internal/service/collect.go

336 lines
9.2 KiB
Go

package service
import (
"context"
"fmt"
"geo/internal/ai_tool"
"geo/internal/biz"
"geo/internal/collect"
"geo/internal/config"
"geo/internal/data/impl"
"geo/internal/data/model"
"geo/internal/entitys"
"geo/pkg"
"geo/tmpl/dataTemp"
"geo/tmpl/errcode"
"github.com/gofiber/fiber/v2"
"log"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"xorm.io/builder"
)
// CollectService 收集服务
type CollectService struct {
cfg *config.Config
collectBiz *biz.CollectBiz
collect *impl.CollectImpl
collectTask *impl.CollectTaskImpl
productBiz *biz.ProductBiz
authBiz *biz.AuthBiz
}
// NewCollectService 创建收集服务
func NewCollectService(
cfg *config.Config,
collectBiz *biz.CollectBiz,
collect *impl.CollectImpl,
collectTask *impl.CollectTaskImpl,
authBiz *biz.AuthBiz,
productBiz *biz.ProductBiz,
) *CollectService {
return &CollectService{
cfg: cfg,
collectBiz: collectBiz,
collect: collect,
collectTask: collectTask,
authBiz: authBiz,
productBiz: productBiz,
}
}
// CollectWithTasks 收集记录及其任务列表
type CollectWithTasks struct {
model.Collect
Tasks []model.CollectTask `json:"tasks"`
}
// CollectList 获取收集列表及对应的任务详情
func (c *CollectService) CollectList(ctx *fiber.Ctx, req *entitys.CollectListRequest) error {
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
if err != nil {
return err
}
// 构建查询条件
cond := builder.NewCond()
if req.ProductId > 0 {
cond = cond.And(builder.Eq{"product_id": req.ProductId})
}
// 查询 collect 列表
var collects []model.Collect
pageBo, err := c.collect.GetListToStruct(ctx.UserContext(), &cond, &dataTemp.ReqPageBo{Page: req.Page, Limit: req.Limit}, &collects, "created_at DESC")
if err != nil {
return err
}
// 提取所有的 collect_code
collectCodes := make([]string, 0, len(collects))
for k, colt := range collects {
var platName = make([]string, 0)
plat := strings.Split(colt.Platform, ",")
for _, p := range plat {
name := collect.CollectorMap[p].Name
platName = append(platName, name)
}
collects[k].Platform = strings.Join(platName, ",")
collectCodes = append(collectCodes, colt.CollectCode)
}
// 批量查询所有相关的 collect_task
var tasks []model.CollectTask
if len(collectCodes) > 0 {
taskCond := builder.NewCond().And(builder.In("collect_code", collectCodes))
_, err := c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC")
if err != nil {
log.Printf("查询 collect_task 失败: %v", err)
}
}
// 按 collect_code 分组 tasks
tasksMap := make(map[string][]model.CollectTask)
for _, task := range tasks {
tasksMap[task.CollectCode] = append(tasksMap[task.CollectCode], task)
}
// 组装返回数据
result := make([]*CollectWithTasks, 0, len(collects))
for _, collect := range collects {
cwt := &CollectWithTasks{
Collect: collect,
Tasks: tasksMap[collect.CollectCode],
}
result = append(result, cwt)
}
return ctx.JSON(fiber.Map{
"list": result,
"total": pageBo.Total,
"page": req.Page,
"pageSize": req.Limit,
})
}
func (c *CollectService) GetCollectPlatForms(ctx *fiber.Ctx) error {
var list = make([]map[string]string, 0, len(collect.CollectorMap))
for _, v := range collect.CollectorMap {
list = append(list, map[string]string{
"name": v.Name,
"index": v.Platform,
"icon": v.Icon,
})
}
return ctx.JSON(list)
}
func (c *CollectService) Collect(ctx *fiber.Ctx, req *entitys.ProductCollectRequest) error {
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
if err != nil {
return err
}
collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano())
collectData := &model.Collect{
CollectCode: collectCode,
ProductID: req.ProductId,
Keywords: strings.Join(req.Keywords, ","),
Platform: strings.Join(req.PlatformIndex, ","),
Question: req.Question,
CreatedAt: time.Now(),
Status: 0, // 0:处理中
}
err = c.collect.Add(ctx.UserContext(), collectData)
if err != nil {
return err
}
// 启动异步任务
go c.doCollectAsync(collectCode, req.PlatformIndex, req.Question, req.Keywords)
return ctx.JSON(fiber.Map{"message": "收录生成中"})
}
// doCollectAsync 异步执行收集任务(简化稳定版)
func (c *CollectService) doCollectAsync(collectCode string, platforms []string, question string, keywords []string) {
defer func() {
if r := recover(); r != nil {
log.Printf("任务panic [%s]: %v", collectCode, r)
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 3})
}
}()
totalPlatforms := len(platforms)
if totalPlatforms == 0 {
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 2, "progress": 100})
return
}
// 初始化进度
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 1, "progress": 10})
var completed int32
var wg sync.WaitGroup
for _, platIndex := range platforms {
wg.Add(1)
go func(platIndex string) {
defer wg.Done()
c.processOnePlatform(collectCode, platIndex, question, keywords)
// 原子操作增加计数
done := atomic.AddInt32(&completed, 1)
// 计算并更新进度
progress := 10 + int(float64(done)/float64(totalPlatforms)*90)
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"progress": progress})
}(platIndex)
}
wg.Wait()
// 全部完成
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 2, "progress": 100})
}
// processOnePlatform 处理单个平台(带超时)
func (c *CollectService) processOnePlatform(collectCode, platIndex, question string, keywords []string) {
// 创建120秒超时的context
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
// 检查平台是否存在
platformName, exist := collect.CollectorMap[platIndex]
if !exist {
log.Printf("未知平台: %s", platIndex)
return
}
// 使用channel控制超时
type result struct {
data *collect.CollectResult
err error
}
resultChan := make(chan result, 1)
go func() {
requestID := fmt.Sprintf("%s_%s", collectCode, platIndex)
res, err := c.collectBiz.AskAIQuestion(platIndex, requestID, question, true, keywords)
resultChan <- result{data: res, err: err}
}()
// 等待结果或超时
select {
case <-ctx.Done():
log.Printf("平台 %s 超时", platformName.Name)
return
case res := <-resultChan:
if res.err != nil {
log.Printf("平台 %s 失败: %v", platformName.Name, res.err)
return
}
// 保存结果
ise := 1
if res.data.IsExposure {
ise = 2
}
task := &model.CollectTask{
CollectCode: collectCode,
AiPlatformIndex: platIndex,
ContentHTML: res.data.Answer,
ShareURL: res.data.ShareLink,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
IsExposure: int32(ise),
Status: 1,
}
if err := c.collectTask.Add(context.Background(), task); err != nil {
log.Printf("保存失败: %v", err)
return
}
log.Printf("平台 %s 完成", platformName.Name)
}
}
func (c *CollectService) CollectAan(ctx *fiber.Ctx, req *entitys.CollectAnaRequest) error {
var collectInfo model.Collect
err := c.collect.GetByKey(ctx.UserContext(), c.collect.PrimaryKey(), req.CollectId, &collectInfo)
if err != nil {
return err
}
if collectInfo.ID == 0 {
return errcode.NotFound("为找到收录计划")
}
err = c.collect.UpdateByKey(ctx.UserContext(), c.collect.PrimaryKey(), req.CollectId, map[string]interface{}{"report_status": 2})
var tasks []model.CollectTask
taskCond := builder.NewCond().And(builder.In("collect_code", collectInfo.CollectCode))
_, err = c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC")
if err != nil {
log.Printf("查询 collect_task 失败: %v", err)
}
mes := c.collectBiz.CreateAndPrompt(ctx.UserContext(), &collectInfo, tasks)
content, err := ai_tool.NewHsyq().RequestHsyqBot(ctx.UserContext(), c.cfg.Hsyq.ApiKey, c.cfg.AiBot.CollectInfo, mes)
if err != nil {
return err
}
fileBaseName := fmt.Sprintf("%s分析报告_%d", collectInfo.Question, time.Now().UnixNano())
fileName := fmt.Sprintf("%s.md", fileBaseName)
mdAbs := filepath.Join(c.cfg.Sys.MdDir, fileName)
// 创建并写入文件
file, err := os.Create(mdAbs)
defer os.Remove(mdAbs)
if err != nil {
return fmt.Errorf("创建文件失败: %w", err)
}
defer file.Close()
if _, err := file.WriteString(*content); err != nil {
return fmt.Errorf("写入文件失败: %w", err)
}
docxPath, err := pkg.Md2wordFix(mdAbs, c.cfg.Sys.MdDir, nil)
if err != nil {
return err
}
docxName := fmt.Sprintf("%s.docx", fileBaseName)
docxAbs := filepath.Join(docxPath, docxName)
defer os.Remove(docxAbs)
fileByte, err := pkg.ReadDocxToBytes(docxAbs)
if err != nil {
return err
}
url, err := c.productBiz.SourceUpload(ctx.UserContext(), fileByte, docxName)
if err != nil {
return fmt.Errorf("上传文件失败: %w", err)
}
err = c.collect.UpdateByKey(ctx.UserContext(), c.collect.PrimaryKey(), req.CollectId, map[string]interface{}{"end_file": url, "report_status": 3})
if err != nil {
return err
}
return nil
}