geoGo/internal/service/product.go

586 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"fmt"
"geo/internal/ai_tool"
"geo/internal/biz"
"geo/internal/config"
"geo/internal/data/impl"
"geo/internal/data/model"
"geo/internal/entitys"
"geo/pkg"
"geo/tmpl/dataTemp"
"geo/tmpl/errcode"
"io"
"log"
"os"
"path/filepath"
"runtime/debug"
"strconv"
"strings"
"sync"
"time"
"github.com/go-viper/mapstructure/v2"
"github.com/gofiber/fiber/v2"
"xorm.io/builder"
)
type ProductService struct {
cfg *config.Config
productImpl *impl.ProductImpl
authBiz *biz.AuthBiz
productBiz *biz.ProductBiz
aiBiz *biz.AiBiz
collect *impl.CollectImpl
collectTask *impl.CollectTaskImpl
}
func NewProductService(
cfg *config.Config,
ProductImpl *impl.ProductImpl,
authBiz *biz.AuthBiz,
productBiz *biz.ProductBiz,
aiBiz *biz.AiBiz,
collect *impl.CollectImpl,
collectTask *impl.CollectTaskImpl,
) *ProductService {
return &ProductService{
cfg: cfg,
productImpl: ProductImpl,
authBiz: authBiz,
productBiz: productBiz,
aiBiz: aiBiz,
collect: collect,
collectTask: collectTask,
}
}
func (p *ProductService) Add(c *fiber.Ctx, req *entitys.CreateProductRequest) error {
_, _, err := p.authBiz.UserAndTokenValid(c.UserContext(), req.UserIndex, req.AccessToken)
if err != nil {
return err
}
var data model.Product
err = mapstructure.Decode(req, &data)
if err != nil {
return err
}
err = p.productImpl.Add(c.UserContext(), &data)
if err != nil {
return err
}
return nil
}
func (p *ProductService) Detail(c *fiber.Ctx, req *entitys.ProductDetailRequest) error {
_, err := p.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
var detail model.Product
cond := builder.NewCond().
And(builder.Eq{"id": req.Id})
err = p.productImpl.GetOneBySearchStruct(c.UserContext(), &cond, &detail)
if err != nil {
return err
}
return pkg.HandleResponse(c, detail)
}
func (p *ProductService) List(c *fiber.Ctx, req *entitys.ProductListRequest) error {
_, _, err := p.authBiz.UserAndTokenValid(c.UserContext(), req.UserIndex, req.AccessToken)
if err != nil {
return err
}
page := req.Page
if page < 1 {
page = 1
}
pageSize := req.PageSize
if pageSize < 1 {
pageSize = 20
}
if pageSize > 100 {
pageSize = 100
}
var list []*model.Product
cond := builder.NewCond().
And(builder.Eq{"user_index": req.UserIndex}).
And(builder.Eq{"status": 1})
total, err := p.productImpl.GetListToStruct(c.UserContext(), &cond, &dataTemp.ReqPageBo{Page: page, Limit: pageSize}, &list, "")
if err != nil {
return err
}
return pkg.SuccessWithPageMsg(c, list, total.Total, page, pageSize)
}
func (p *ProductService) Update(c *fiber.Ctx, req *entitys.ProductUpdateRequest) error {
_, err := p.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
var data model.Product
err = mapstructure.Decode(req, &data)
if err != nil {
return err
}
err = p.productImpl.UpdateByKey(c.UserContext(), p.productImpl.PrimaryKey(), req.Id, &data)
return nil
}
func (p *ProductService) Del(c *fiber.Ctx, req *entitys.ProductDelRequest) error {
_, err := p.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
err = p.productImpl.DeleteByKey(c.UserContext(), p.productImpl.PrimaryKey(), req.Id)
return nil
}
func (p *ProductService) ImgUpload(c *fiber.Ctx) error {
access := c.FormValue("access_token", "")
if access == "" {
return errcode.ParamErr("access_token未找到")
}
// 验证token
_, err := p.authBiz.ValidateAccessToken(c.UserContext(), access)
if err != nil {
return err
}
fileHeader, err := c.FormFile("file")
if err != nil {
return errcode.ParamErr("未找到上传文件")
}
file, err := fileHeader.Open()
if err != nil {
return errcode.ParamErrf("无法打开文件:%S", err.Error())
}
defer file.Close()
fileBytes, err := io.ReadAll(file)
if err != nil {
return errcode.ParamErrf("读取文件失败:%s", err.Error())
}
// 获取文件扩展名
ext := pkg.GetFileExtension(fileHeader.Filename)
// 根据图片类型生成文件名
fileName := pkg.GenerateImageFileName(ext)
url, err := p.productBiz.SourceUpload(c.UserContext(), fileBytes, fileName)
if err != nil {
return err
}
return pkg.HandleResponse(c, fiber.Map{"url": url})
}
func (p *ProductService) CreateProductInfoByDocx(c *fiber.Ctx) error {
access := c.FormValue("access_token", "")
if access == "" {
return errcode.ParamErr("access_token未找到")
}
userIndex := c.FormValue("user_index", "")
if userIndex == "" {
return errcode.ParamErr("user_index未找到")
}
// 验证token
_, err := p.authBiz.ValidateAccessToken(c.UserContext(), access)
if err != nil {
return err
}
fileHeader, err := c.FormFile("file")
if err != nil {
return errcode.ParamErr("未找到上传文件")
}
file, err := fileHeader.Open()
if err != nil {
return errcode.ParamErrf("无法打开文件:%s", err.Error())
}
defer file.Close()
// 创建临时文件
tempFile, err := os.CreateTemp("", "upload_*.docx")
if err != nil {
return errcode.ParamErrf("创建临时文件失败:%s", err.Error())
}
defer tempFile.Close()
// 获取临时文件的绝对路径
absPath, err := filepath.Abs(tempFile.Name())
if err != nil {
return errcode.ParamErrf("获取文件绝对路径失败:%s", err.Error())
}
// 将上传的文件内容复制到临时文件
_, err = io.Copy(tempFile, file)
if err != nil {
return errcode.ParamErrf("保存文件失败:%s", err.Error())
}
// 确保文件内容已写入磁盘
err = tempFile.Sync()
if err != nil {
return errcode.ParamErrf("同步文件失败:%s", err.Error())
}
markContent, err := pkg.ExtractWordContent(absPath, "markdown")
if err != nil {
return err
}
var productInfo entitys.ProductInfo
mes := p.aiBiz.CreateProjectInfoPrompt(c.UserContext(), markContent)
err = ai_tool.NewHsyq().RequestHsyqBotToJson(c.UserContext(), p.cfg.Hsyq.ApiKey, p.cfg.AiBot.ProductInfo, mes, &productInfo)
if err != nil {
return err
}
return pkg.HandleResponse(c, productInfo)
}
func (p *ProductService) Collect(c *fiber.Ctx, req *entitys.ProductCollectRequest) error {
log.Printf("[DEBUG] ========== 请求开始 ==========")
log.Printf("[DEBUG] 请求时间: %s", time.Now().Format("2006-01-02 15:04:05.000"))
log.Printf("[Collect] 开始处理收集请求, ProductID: %d, Platforms: %v, Keywords: %v",
req.ProductId, req.Platform, req.Keywords)
_, err := p.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
log.Printf("[Collect] 验证token失败, ProductID: %d, Error: %v", req.ProductId, err)
return err
}
_, err = p.productBiz.GetProduct(c.UserContext(), req.ProductId)
if err != nil {
log.Printf("[Collect] 获取产品信息失败, ProductID: %d, Error: %v", req.ProductId, err)
return err
}
platformStr := make([]string, len(req.Platform))
for i, s := range req.Platform {
platformStr[i] = strconv.Itoa(s)
}
collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano())
collectData := &model.Collect{
CollectCode: collectCode,
ProductID: req.ProductId,
Keywords: strings.Join(req.Keywords, ","),
Platform: strings.Join(platformStr, ","),
Question: req.Question,
CreatedAt: time.Now(),
}
log.Printf("[Collect] 创建收集记录, CollectCode: %s, ProductID: %d", collectCode, req.ProductId)
err = p.collect.Add(c.UserContext(), collectData)
if err != nil {
log.Printf("[Collect] 保存收集记录失败, CollectCode: %s, Error: %v", collectCode, err)
return err
}
log.Printf("[Collect] ✅ 启动异步收集任务, CollectCode: %s, Platforms: %v", collectCode, req.Platform)
go func() {
// 记录 goroutine 启动时间
startTime := time.Now()
log.Printf("[Goroutine] 异步任务启动, CollectCode: %s, 启动时间: %s", collectCode, startTime.Format("15:04:05.000"))
// 使用独立 context避免请求结束后任务被取消
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
// 监控 context 取消
go func() {
<-ctx.Done()
log.Printf("[Goroutine] ❌ Context被取消! CollectCode: %s, 原因: %v, 耗时: %v",
collectCode, ctx.Err(), time.Since(startTime))
}()
defer func() {
if r := recover(); r != nil {
log.Printf("[Goroutine] ❌ PANIC: %v\nStack: %s", r, debug.Stack())
}
log.Printf("[Goroutine] 异步任务结束, CollectCode: %s, 总耗时: %v", collectCode, time.Since(startTime))
cancel()
log.Printf("[Goroutine] 已调用 cancel(), CollectCode: %s", collectCode)
}()
log.Printf("[Goroutine] 准备调用 doCollect, CollectCode: %s", collectCode)
p.doCollect(ctx, collectData, req.Platform)
log.Printf("[Goroutine] doCollect 已返回, CollectCode: %s", collectCode)
}()
log.Printf("[DEBUG] ========== 请求返回 ==========")
return pkg.HandleResponse(c, "收录生成中")
}
func (p *ProductService) doCollect(ctx context.Context, collectData *model.Collect, platforms []int) {
collectCode := collectData.CollectCode
startTime := time.Now()
log.Printf("[doCollect] ========== 开始执行 ==========")
log.Printf("[doCollect] CollectCode: %s, Platforms: %v", collectCode, platforms)
log.Printf("[doCollect] Context状态: %v, 超时时间: %v", ctx.Err(), time.Second*240)
// 监控 context
go func() {
<-ctx.Done()
log.Printf("[doCollect] ⚠️ 检测到Context取消! CollectCode: %s, 原因: %v, 已执行时间: %v",
collectCode, ctx.Err(), time.Since(startTime))
}()
collectClient := ai_tool.NewCollect(p.cfg.Collect.ApiKey)
log.Printf("[doCollect] 已创建 collectClient")
var wg sync.WaitGroup
resCh := make(chan *model.CollectTask, len(platforms))
log.Printf("[doCollect] 创建 channel, 容量: %d", len(platforms))
// 启动监控 goroutine
monitorStart := time.Now()
// 启动所有平台的任务
log.Printf("[doCollect] 启动 %d 个平台任务", len(platforms))
for i, plat := range platforms {
log.Printf("[doCollect] 启动任务 #%d, Platform: %d", i+1, plat)
wg.Add(1)
go p.processPlatform(ctx, &wg, collectClient, collectData, plat, resCh, i+1)
}
go func() {
log.Printf("[Monitor] 监控goroutine启动, CollectCode: %s", collectCode)
wg.Wait()
log.Printf("[Monitor] ✅ 所有任务完成, 准备关闭channel, 等待时间: %v", time.Since(monitorStart))
close(resCh)
log.Printf("[Monitor] Channel已关闭")
}()
// 收集结果 - 添加超时保护
log.Printf("[doCollect] 开始等待结果...")
var datas []*model.CollectTask
taskCount := 0
// 设置一个最大等待时间
waitTimeout := time.After(250 * time.Second)
for {
select {
case task, ok := <-resCh:
if !ok {
log.Printf("[doCollect] Channel已关闭, 收集到 %d 条结果", len(datas))
goto SAVE
}
datas = append(datas, task)
taskCount++
log.Printf("[doCollect] ✅ 收到结果 #%d, Platform: %d, RequestID: %s, ScriptTime: %d",
taskCount, task.Platform, task.RequestID, task.ScriptTime)
case <-waitTimeout:
log.Printf("[doCollect] ⚠️ 等待超时 250秒, 强制退出, 已收集: %d/%d", taskCount, len(platforms))
goto SAVE
case <-ctx.Done():
log.Printf("[doCollect] ❌ Context取消, 强制退出, 已收集: %d/%d, 原因: %v",
taskCount, len(platforms), ctx.Err())
goto SAVE
}
}
SAVE:
log.Printf("[doCollect] 收集完成, 共 %d 条结果", len(datas))
// 保存结果
if len(datas) > 0 {
log.Printf("[doCollect] 开始保存到数据库, 数量: %d", len(datas))
saveStart := time.Now()
if err := p.collectTask.Add(ctx, datas); err != nil {
log.Printf("[doCollect] ❌ 保存失败: %v", err)
} else {
log.Printf("[doCollect] ✅ 保存成功, 耗时: %v", time.Since(saveStart))
}
} else {
log.Printf("[doCollect] ⚠️ 没有结果需要保存")
}
elapsed := time.Since(startTime)
log.Printf("[doCollect] ========== 结束执行, 总耗时: %v ==========", elapsed)
}
func (p *ProductService) processPlatform(ctx context.Context, wg *sync.WaitGroup,
collectClient *ai_tool.Collect, collectData *model.Collect, plat int,
resCh chan<- *model.CollectTask, taskNum int) {
collectCode := collectData.CollectCode
startTime := time.Now()
log.Printf("[Platform #%d] ========== 开始 ==========", taskNum)
log.Printf("[Platform #%d] CollectCode: %s, Platform: %d", taskNum, collectCode, plat)
// 确保 wg.Done() 一定会被调用
defer func() {
log.Printf("[Platform #%d] 准备调用 wg.Done(), 已执行时间: %v", taskNum, time.Since(startTime))
wg.Done()
log.Printf("[Platform #%d] 已调用 wg.Done()", taskNum)
}()
defer func() {
if r := recover(); r != nil {
log.Printf("[Platform #%d] ❌ PANIC: %v\nStack: %s", taskNum, r, debug.Stack())
}
log.Printf("[Platform #%d] ========== 结束, 耗时: %v ==========", taskNum, time.Since(startTime))
}()
// 检查 context 是否已取消
select {
case <-ctx.Done():
log.Printf("[Platform #%d] ❌ Context已取消, 退出执行, 原因: %v", taskNum, ctx.Err())
return
default:
log.Printf("[Platform #%d] Context正常", taskNum)
}
// 创建任务
request := ai_tool.CreateReq{
Keywords: collectData.Keywords,
Question: collectData.Question,
Platform: plat,
ThirdID: fmt.Sprintf("%s_%d", collectData.CollectCode, plat),
}
log.Printf("[Platform #%d] 调用 Create API, Request: %+v", taskNum, request)
createStart := time.Now()
res, err := collectClient.Create(ctx, &request)
createElapsed := time.Since(createStart)
if err != nil {
log.Printf("[Platform #%d] ❌ Create失败, 耗时: %v, Error: %v", taskNum, createElapsed, err)
return
}
if res.Code != 1 {
log.Printf("[Platform #%d] ❌ Create返回错误码, 耗时: %v, Code: %d, Message: %s",
taskNum, createElapsed, res.Code, res.Msg)
return
}
log.Printf("[Platform #%d] ✅ Create成功, 耗时: %v, RequestID: %s",
taskNum, createElapsed, res.Data.RequestId)
// 轮询任务状态
log.Printf("[Platform #%d] 开始轮询, RequestID: %s", taskNum, res.Data.RequestId)
pollStart := time.Now()
task := p.pollTaskStatus(ctx, collectClient, res.Data.RequestId, collectData, plat, taskNum)
pollElapsed := time.Since(pollStart)
if task != nil {
log.Printf("[Platform #%d] ✅ 轮询成功, 耗时: %v, ScriptTime: %d",
taskNum, pollElapsed, task.ScriptTime)
// 发送结果到 channel
log.Printf("[Platform #%d] 准备发送结果到channel", taskNum)
select {
case resCh <- task:
log.Printf("[Platform #%d] ✅ 结果已发送到channel", taskNum)
case <-ctx.Done():
log.Printf("[Platform #%d] ⚠️ Context取消, 放弃发送结果", taskNum)
return
}
} else {
log.Printf("[Platform #%d] ❌ 轮询失败, 耗时: %v, 未获取到结果", taskNum, pollElapsed)
}
}
func (p *ProductService) pollTaskStatus(ctx context.Context, collectClient *ai_tool.Collect,
requestID string, collectData *model.Collect, plat int, taskNum int) *model.CollectTask {
collectCode := collectData.CollectCode
startTime := time.Now()
log.Printf("[Poll #%d] ========== 开始轮询 ==========", taskNum)
log.Printf("[Poll #%d] CollectCode: %s, Platform: %d, RequestID: %s",
taskNum, collectCode, plat, requestID)
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
errCount := 0
const maxErrors = 5
pollCount := 0
for {
select {
case <-ctx.Done():
log.Printf("[Poll #%d] ❌ Context取消, 停止轮询, 已轮询%d次, 耗时: %v, 原因: %v",
taskNum, pollCount, time.Since(startTime), ctx.Err())
return nil
case <-ticker.C:
pollCount++
log.Printf("[Poll #%d] 第 %d 次轮询, 已耗时: %v", taskNum, pollCount, time.Since(startTime))
checkStart := time.Now()
checkRes, err := collectClient.CheckTask(ctx, requestID)
checkElapsed := time.Since(checkStart)
if err != nil {
errCount++
log.Printf("[Poll #%d] ❌ 轮询失败(第%d次错误), 耗时: %v, Error: %v, 累计错误: %d/%d",
taskNum, pollCount, checkElapsed, err, errCount, maxErrors)
if errCount >= maxErrors {
log.Printf("[Poll #%d] 达到最大错误次数, 停止轮询", taskNum)
return nil
}
continue
}
log.Printf("[Poll #%d] ✅ 轮询成功, 耗时: %v, Code: %d, Status: %d, ScriptTime: %d, ShouluDate: %s",
taskNum, checkElapsed, checkRes.Code, checkRes.Data.Status,
checkRes.Data.ScriptTime, checkRes.Data.ShouluDate)
if checkRes.Code != 1 {
log.Printf("[Poll #%d] ❌ 返回错误码: %d", taskNum, checkRes.Code)
return nil
}
// 判断任务是否完成
// 根据你的业务逻辑调整判断条件
isCompleted := false
completeReason := ""
if checkRes.Data.Status != 0 { // 假设 2 表示完成
isCompleted = true
completeReason = fmt.Sprintf("chekcStatus=%d", checkRes.Data.Status)
}
if isCompleted {
log.Printf("[Poll #%d] 🎉 任务完成! 原因: %s, 总轮询次数: %d, 总耗时: %v",
taskNum, completeReason, pollCount, time.Since(startTime))
return &model.CollectTask{
RequestID: checkRes.Data.RequestId,
CollectCode: collectData.CollectCode,
ScriptTime: int32(checkRes.Data.ScriptTime),
Platform: int32(checkRes.Data.Platform),
CollectData: checkRes.Data.ShouluDate,
ShareURL: checkRes.Data.ShareUrl,
ImgURL: checkRes.Data.ImgUrl,
PointKeyword: checkRes.Data.HitWord,
Question: checkRes.Data.Question,
Res: pkg.JsonStringIgonErr(checkRes),
CreatedAt: time.Now(),
Status: int32(checkRes.Data.Status),
}
}
log.Printf("[Poll #%d] 任务未完成, 继续轮询, Status=%d, ScriptTime=%d, ShouluDate=%s",
taskNum, checkRes.Data.Status, checkRes.Data.ScriptTime, checkRes.Data.ShouluDate)
}
}
}