336 lines
9.2 KiB
Go
336 lines
9.2 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"geo/internal/ai_tool"
|
|
"geo/internal/biz"
|
|
"geo/internal/collect"
|
|
"geo/internal/config"
|
|
"geo/internal/data/impl"
|
|
"geo/internal/data/model"
|
|
"geo/internal/entitys"
|
|
"geo/pkg"
|
|
"geo/tmpl/dataTemp"
|
|
"geo/tmpl/errcode"
|
|
"github.com/gofiber/fiber/v2"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"xorm.io/builder"
|
|
)
|
|
|
|
// CollectService 收集服务
|
|
type CollectService struct {
|
|
cfg *config.Config
|
|
collectBiz *biz.CollectBiz
|
|
collect *impl.CollectImpl
|
|
collectTask *impl.CollectTaskImpl
|
|
productBiz *biz.ProductBiz
|
|
authBiz *biz.AuthBiz
|
|
}
|
|
|
|
// NewCollectService 创建收集服务
|
|
func NewCollectService(
|
|
cfg *config.Config,
|
|
collectBiz *biz.CollectBiz,
|
|
collect *impl.CollectImpl,
|
|
collectTask *impl.CollectTaskImpl,
|
|
authBiz *biz.AuthBiz,
|
|
productBiz *biz.ProductBiz,
|
|
) *CollectService {
|
|
return &CollectService{
|
|
cfg: cfg,
|
|
collectBiz: collectBiz,
|
|
collect: collect,
|
|
collectTask: collectTask,
|
|
authBiz: authBiz,
|
|
productBiz: productBiz,
|
|
}
|
|
}
|
|
|
|
// CollectWithTasks 收集记录及其任务列表
|
|
type CollectWithTasks struct {
|
|
model.Collect
|
|
Tasks []model.CollectTask `json:"tasks"`
|
|
}
|
|
|
|
// CollectList 获取收集列表及对应的任务详情
|
|
func (c *CollectService) CollectList(ctx *fiber.Ctx, req *entitys.CollectListRequest) error {
|
|
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 构建查询条件
|
|
cond := builder.NewCond()
|
|
if req.ProductId > 0 {
|
|
cond = cond.And(builder.Eq{"product_id": req.ProductId})
|
|
}
|
|
|
|
// 查询 collect 列表
|
|
var collects []model.Collect
|
|
pageBo, err := c.collect.GetListToStruct(ctx.UserContext(), &cond, &dataTemp.ReqPageBo{Page: req.Page, Limit: req.Limit}, &collects, "created_at DESC")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 提取所有的 collect_code
|
|
collectCodes := make([]string, 0, len(collects))
|
|
for k, colt := range collects {
|
|
var platName = make([]string, 0)
|
|
plat := strings.Split(colt.Platform, ",")
|
|
for _, p := range plat {
|
|
name := collect.CollectorMap[p].Name
|
|
platName = append(platName, name)
|
|
}
|
|
collects[k].Platform = strings.Join(platName, ",")
|
|
collectCodes = append(collectCodes, colt.CollectCode)
|
|
}
|
|
|
|
// 批量查询所有相关的 collect_task
|
|
var tasks []model.CollectTask
|
|
if len(collectCodes) > 0 {
|
|
taskCond := builder.NewCond().And(builder.In("collect_code", collectCodes))
|
|
_, err := c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC")
|
|
if err != nil {
|
|
log.Printf("查询 collect_task 失败: %v", err)
|
|
}
|
|
}
|
|
|
|
// 按 collect_code 分组 tasks
|
|
tasksMap := make(map[string][]model.CollectTask)
|
|
for _, task := range tasks {
|
|
|
|
tasksMap[task.CollectCode] = append(tasksMap[task.CollectCode], task)
|
|
}
|
|
|
|
// 组装返回数据
|
|
result := make([]*CollectWithTasks, 0, len(collects))
|
|
for _, collect := range collects {
|
|
cwt := &CollectWithTasks{
|
|
Collect: collect,
|
|
Tasks: tasksMap[collect.CollectCode],
|
|
}
|
|
result = append(result, cwt)
|
|
}
|
|
|
|
return ctx.JSON(fiber.Map{
|
|
"list": result,
|
|
"total": pageBo.Total,
|
|
"page": req.Page,
|
|
"pageSize": req.Limit,
|
|
})
|
|
}
|
|
|
|
func (c *CollectService) GetCollectPlatForms(ctx *fiber.Ctx) error {
|
|
var list = make([]map[string]string, 0, len(collect.CollectorMap))
|
|
for _, v := range collect.CollectorMap {
|
|
list = append(list, map[string]string{
|
|
"name": v.Name,
|
|
"index": v.Platform,
|
|
"icon": v.Icon,
|
|
})
|
|
}
|
|
return ctx.JSON(list)
|
|
}
|
|
|
|
func (c *CollectService) Collect(ctx *fiber.Ctx, req *entitys.ProductCollectRequest) error {
|
|
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano())
|
|
collectData := &model.Collect{
|
|
CollectCode: collectCode,
|
|
ProductID: req.ProductId,
|
|
Keywords: strings.Join(req.Keywords, ","),
|
|
Platform: strings.Join(req.PlatformIndex, ","),
|
|
Question: req.Question,
|
|
CreatedAt: time.Now(),
|
|
Status: 0, // 0:处理中
|
|
}
|
|
|
|
err = c.collect.Add(ctx.UserContext(), collectData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 启动异步任务
|
|
go c.doCollectAsync(collectCode, req.PlatformIndex, req.Question, req.Keywords)
|
|
|
|
return ctx.JSON(fiber.Map{"message": "收录生成中"})
|
|
}
|
|
|
|
// doCollectAsync 异步执行收集任务(简化稳定版)
|
|
func (c *CollectService) doCollectAsync(collectCode string, platforms []string, question string, keywords []string) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("任务panic [%s]: %v", collectCode, r)
|
|
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 3})
|
|
}
|
|
}()
|
|
|
|
totalPlatforms := len(platforms)
|
|
if totalPlatforms == 0 {
|
|
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 2, "progress": 100})
|
|
return
|
|
}
|
|
|
|
// 初始化进度
|
|
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 1, "progress": 10})
|
|
|
|
var completed int32
|
|
var wg sync.WaitGroup
|
|
|
|
for _, platIndex := range platforms {
|
|
wg.Add(1)
|
|
go func(platIndex string) {
|
|
defer wg.Done()
|
|
|
|
c.processOnePlatform(collectCode, platIndex, question, keywords)
|
|
|
|
// 原子操作增加计数
|
|
done := atomic.AddInt32(&completed, 1)
|
|
|
|
// 计算并更新进度
|
|
progress := 10 + int(float64(done)/float64(totalPlatforms)*90)
|
|
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"progress": progress})
|
|
}(platIndex)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// 全部完成
|
|
c.collect.UpdateByKey(context.Background(), "collect_code", collectCode, map[string]interface{}{"status": 2, "progress": 100})
|
|
}
|
|
|
|
// processOnePlatform 处理单个平台(带超时)
|
|
func (c *CollectService) processOnePlatform(collectCode, platIndex, question string, keywords []string) {
|
|
// 创建120秒超时的context
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
// 检查平台是否存在
|
|
platformName, exist := collect.CollectorMap[platIndex]
|
|
if !exist {
|
|
log.Printf("未知平台: %s", platIndex)
|
|
return
|
|
}
|
|
|
|
// 使用channel控制超时
|
|
type result struct {
|
|
data *collect.CollectResult
|
|
err error
|
|
}
|
|
resultChan := make(chan result, 1)
|
|
|
|
go func() {
|
|
requestID := fmt.Sprintf("%s_%s", collectCode, platIndex)
|
|
res, err := c.collectBiz.AskAIQuestion(platIndex, requestID, question, true, keywords)
|
|
resultChan <- result{data: res, err: err}
|
|
}()
|
|
|
|
// 等待结果或超时
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Printf("平台 %s 超时", platformName.Name)
|
|
return
|
|
case res := <-resultChan:
|
|
if res.err != nil {
|
|
log.Printf("平台 %s 失败: %v", platformName.Name, res.err)
|
|
return
|
|
}
|
|
|
|
// 保存结果
|
|
ise := 1
|
|
if res.data.IsExposure {
|
|
ise = 2
|
|
}
|
|
task := &model.CollectTask{
|
|
CollectCode: collectCode,
|
|
AiPlatformIndex: platIndex,
|
|
ContentHTML: res.data.Answer,
|
|
ShareURL: res.data.ShareLink,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
IsExposure: int32(ise),
|
|
Status: 1,
|
|
}
|
|
|
|
if err := c.collectTask.Add(context.Background(), task); err != nil {
|
|
log.Printf("保存失败: %v", err)
|
|
return
|
|
}
|
|
|
|
log.Printf("平台 %s 完成", platformName.Name)
|
|
}
|
|
}
|
|
|
|
func (c *CollectService) CollectAan(ctx *fiber.Ctx, req *entitys.CollectAnaRequest) error {
|
|
|
|
var collectInfo model.Collect
|
|
err := c.collect.GetByKey(ctx.UserContext(), c.collect.PrimaryKey(), req.CollectId, &collectInfo)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if collectInfo.ID == 0 {
|
|
return errcode.NotFound("为找到收录计划")
|
|
}
|
|
err = c.collect.UpdateByKey(ctx.UserContext(), c.collect.PrimaryKey(), req.CollectId, map[string]interface{}{"report_status": 2})
|
|
var tasks []model.CollectTask
|
|
taskCond := builder.NewCond().And(builder.In("collect_code", collectInfo.CollectCode))
|
|
_, err = c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC")
|
|
if err != nil {
|
|
log.Printf("查询 collect_task 失败: %v", err)
|
|
}
|
|
mes := c.collectBiz.CreateAndPrompt(ctx.UserContext(), &collectInfo, tasks)
|
|
|
|
content, err := ai_tool.NewHsyq().RequestHsyqBot(ctx.UserContext(), c.cfg.Hsyq.ApiKey, c.cfg.AiBot.CollectInfo, mes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fileBaseName := fmt.Sprintf("%s分析报告_%d", collectInfo.Question, time.Now().UnixNano())
|
|
fileName := fmt.Sprintf("%s.md", fileBaseName)
|
|
mdAbs := filepath.Join(c.cfg.Sys.MdDir, fileName)
|
|
// 创建并写入文件
|
|
file, err := os.Create(mdAbs)
|
|
defer os.Remove(mdAbs)
|
|
if err != nil {
|
|
return fmt.Errorf("创建文件失败: %w", err)
|
|
}
|
|
defer file.Close()
|
|
if _, err := file.WriteString(*content); err != nil {
|
|
return fmt.Errorf("写入文件失败: %w", err)
|
|
}
|
|
|
|
docxPath, err := pkg.Md2wordFix(mdAbs, c.cfg.Sys.MdDir, nil)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
docxName := fmt.Sprintf("%s.docx", fileBaseName)
|
|
docxAbs := filepath.Join(docxPath, docxName)
|
|
defer os.Remove(docxAbs)
|
|
fileByte, err := pkg.ReadDocxToBytes(docxAbs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
url, err := c.productBiz.SourceUpload(ctx.UserContext(), fileByte, docxName)
|
|
if err != nil {
|
|
return fmt.Errorf("上传文件失败: %w", err)
|
|
}
|
|
err = c.collect.UpdateByKey(ctx.UserContext(), c.collect.PrimaryKey(), req.CollectId, map[string]interface{}{"end_file": url, "report_status": 3})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|