diff --git a/cmd/server/wire_gen.go b/cmd/server/wire_gen.go index bdcbccb..70c8ad6 100644 --- a/cmd/server/wire_gen.go +++ b/cmd/server/wire_gen.go @@ -7,6 +7,7 @@ package main import ( + "context" "geo/internal/biz" "geo/internal/config" "geo/internal/data/impl" @@ -45,9 +46,11 @@ func InitializeApp(configConfig *config.Config, allLogger log.AllLogger) (*serve publishService := service.NewPublishService(configConfig, publishBiz, authBiz, db, aiBiz) collectImpl := impl.NewCollectImpl(db) collectTaskImpl := impl.NewCollectTaskImpl(db) - productService := service.NewProductService(configConfig, productImpl, authBiz, productBiz, aiBiz, collectImpl, collectTaskImpl) + collectBiz := biz.NewCollectBiz(context.Background(), configConfig, allLogger) + productService := service.NewProductService(configConfig, productImpl, authBiz, productBiz, aiBiz) + collectService := service.NewCollectService(configConfig, collectBiz, collectImpl, collectTaskImpl, authBiz) productSourceService := service.NewProductSourceService(configConfig, productImpl, authBiz, aiBiz, productBiz, productSourceImpl, publishBiz, articleTypeImpl) - appModule := router.NewAppModule(configConfig, appService, loginService, publishService, productService, productSourceService) + appModule := router.NewAppModule(configConfig, appService, loginService, publishService, productService, productSourceService, collectService) routerServer := router.NewRouterServer(appModule) app := server.NewHTTPServer(routerServer) servers := server.NewServers(configConfig, app) diff --git a/example_test.go b/example_test.go index 4a95e1d..32990f3 100644 --- a/example_test.go +++ b/example_test.go @@ -4,15 +4,14 @@ import ( "context" "geo/internal/collect" "geo/internal/config" - "log" - "os" + "github.com/gofiber/fiber/v2/log" "testing" ) var ( - cfg, _ = config.LoadConfig() - logger = log.New(os.Stdout, "", log.LstdFlags) - manager = collect.NewCollectManager(context.Background(), cfg, logger) + cfg, _ = config.LoadConfig() + + manager = collect.NewCollectManager(context.Background(), cfg, log.DefaultLogger()) ) // TestCollectManager_Basic 测试收集管理器的基本功能 @@ -29,8 +28,6 @@ func TestCollectManager_Basic(t *testing.T) { for _, platform := range platforms { params := &collect.CollectParams{ Headless: true, - UserIndex: "test_user", - PlatIndex: platform, RequestID: "test_req", Platform: platform, } @@ -57,8 +54,6 @@ func TestWenxinCollector_WaitLogin(t *testing.T) { params := &collect.CollectParams{ Headless: false, // 显示浏览器窗口以便扫码登录 - UserIndex: "test_user", - PlatIndex: "wenxin", RequestID: "test_wenxin_login_001", Platform: "wenxin", } @@ -87,8 +82,6 @@ func TestWenxinCollector_AskQuestion(t *testing.T) { // 设置收集参数 params := &collect.CollectParams{ Headless: false, // 显示浏览器以便调试 - UserIndex: "test_user", - PlatIndex: "wenxin", RequestID: "test_wenxin_001", Platform: "wenxin", } @@ -98,16 +91,17 @@ func TestWenxinCollector_AskQuestion(t *testing.T) { t.Logf("向文心一言提问: %s", question) // 调用管理器提问并获取答案 - answer, err := manager.AskQuestion("wenxin", params, question) + result, err := manager.AskQuestion("wenxin", params, question) if err != nil { t.Errorf("提问失败: %v", err) return } - t.Logf("获取到答案:\n%s", answer) + t.Logf("获取到答案:\n%s", result.Answer) + t.Logf("分享链接: %s", result.ShareLink) // 验证答案非空 - if len(answer) == 0 { + if len(result.Answer) == 0 { t.Error("答案为空") } } diff --git a/internal/ai_tool/collect.go b/internal/ai_tool/collect.go index 41953ee..acc0bac 100644 --- a/internal/ai_tool/collect.go +++ b/internal/ai_tool/collect.go @@ -19,45 +19,6 @@ type PlatForm struct { Icon string `json:"icon"` } -var PlatFormList = []PlatForm{ - { - Index: 1, - PlatFormName: "deepseek", - Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/deepseek.png", - }, - { - Index: 2, - PlatFormName: "豆包", - Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/doubao.png", - }, - { - Index: 3, - PlatFormName: "元宝", - Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/yuanbao.png", - }, - { - Index: 41, - PlatFormName: "千问", - Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/qianwen.png", - }, - { - Index: 5, - PlatFormName: "文心一言", - Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/wenxin.png", - }, - //{ - // Index: 6, - // PlatFormName: "纳米", - //}, - //{ - // Index: 7, - // PlatFormName: "kimi", - //}, { - // Index: 8, - // PlatFormName: "智普", - //}, -} - type CreateReq struct { // 品牌词,多个用英文逗号隔开 ApiKey string `json:"api_key"` diff --git a/internal/biz/ai_collect.go b/internal/biz/ai_collect.go index 6bcff65..94d51d2 100644 --- a/internal/biz/ai_collect.go +++ b/internal/biz/ai_collect.go @@ -5,18 +5,19 @@ import ( "fmt" "geo/internal/collect" "geo/internal/config" - "log" + + "github.com/gofiber/fiber/v2/log" ) // CollectBiz AI收集业务层 type CollectBiz struct { manager *collect.CollectManager config *config.Config - logger *log.Logger + logger log.AllLogger } // NewCollectBiz 创建AI收集业务实例 -func NewCollectBiz(ctx context.Context, cfg *config.Config, logger *log.Logger) *CollectBiz { +func NewCollectBiz(ctx context.Context, cfg *config.Config, logger log.AllLogger) *CollectBiz { manager := collect.NewCollectManager(ctx, cfg, logger) return &CollectBiz{ manager: manager, @@ -27,34 +28,28 @@ func NewCollectBiz(ctx context.Context, cfg *config.Config, logger *log.Logger) // AskAIQuestion 向指定AI平台提问 // platform: 平台类型 (wenxin, deepseek, doubao, qianwen) -// userIndex: 用户索引 -// platIndex: 平台索引 // requestID: 请求ID // question: 问题内容 // headless: 是否无头模式 -func (b *CollectBiz) AskAIQuestion(platform string, userIndex, platIndex, requestID, question string, headless bool) (string, error) { +func (b *CollectBiz) AskAIQuestion(platform string, requestID, question string, headless bool) (*collect.CollectResult, error) { params := &collect.CollectParams{ Headless: headless, - UserIndex: userIndex, - PlatIndex: platIndex, RequestID: requestID, Platform: platform, } - answer, err := b.manager.AskQuestion(platform, params, question) + result, err := b.manager.AskQuestion(platform, params, question) if err != nil { - return "", fmt.Errorf("向%s提问失败: %w", platform, err) + return nil, fmt.Errorf("向%s提问失败: %w", platform, err) } - return answer, nil + return result, nil } // WaitAILogin 等待AI平台登录 -func (b *CollectBiz) WaitAILogin(platform string, userIndex, platIndex, requestID string, headless bool) (bool, string) { +func (b *CollectBiz) WaitAILogin(platform string, requestID string, headless bool) (bool, string) { params := &collect.CollectParams{ Headless: headless, - UserIndex: userIndex, - PlatIndex: platIndex, RequestID: requestID, Platform: platform, } @@ -68,17 +63,22 @@ func (b *CollectBiz) ListAIPlatforms() []string { } // AskMultipleAI 向多个AI平台提问并收集答案 -func (b *CollectBiz) AskMultipleAI(platforms []string, userIndex, requestID, question string, headless bool) map[string]string { - results := make(map[string]string) +func (b *CollectBiz) AskMultipleAI(platforms []string, requestID, question string, headless bool) map[string]*collect.CollectResult { + results := make(map[string]*collect.CollectResult) for _, platform := range platforms { - platIndex := platform // 默认使用platform作为platIndex - answer, err := b.AskAIQuestion(platform, userIndex, platIndex, requestID+"_"+platform, question, headless) + // 为每个平台生成唯一的 requestID + platformRequestID := requestID + "_" + platform + result, err := b.AskAIQuestion(platform, platformRequestID, question, headless) if err != nil { - b.logger.Printf("向%s提问失败: %v", platform, err) - results[platform] = fmt.Sprintf("错误: %v", err) + b.logger.Errorf("向%s提问失败: %v", platform, err) + // 创建一个包含错误信息的结果 + results[platform] = &collect.CollectResult{ + Answer: fmt.Sprintf("错误: %v", err), + ShareLink: "", + } } else { - results[platform] = answer + results[platform] = result } } diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go index 168fd63..e53c1f1 100644 --- a/internal/biz/provider_set.go +++ b/internal/biz/provider_set.go @@ -9,4 +9,5 @@ var ProviderSetBiz = wire.NewSet( NewAuthBiz, NewAiBiz, NewProductBiz, + NewCollectBiz, ) diff --git a/internal/collect/base.go b/internal/collect/base.go index 8782080..91c160f 100644 --- a/internal/collect/base.go +++ b/internal/collect/base.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "geo/internal/config" - "log" "os" "path/filepath" "time" @@ -13,22 +12,20 @@ import ( "github.com/go-rod/rod" "github.com/go-rod/rod/lib/launcher" "github.com/go-rod/rod/lib/proto" + "github.com/gofiber/fiber/v2/log" ) // BaseCollector 基础收集器结构 type BaseCollector struct { ctx context.Context Headless bool - UserIndex string - PlatIndex string RequestID string Platform string Browser *rod.Browser Page *rod.Page - Logger *log.Logger - LogFile *os.File + Logger log.AllLogger LoginURL string ChatURL string @@ -41,53 +38,34 @@ type BaseCollector struct { } // NewBaseCollector 构造函数 -func NewBaseCollector(ctx context.Context, params *CollectParams, config *config.Config, logger *log.Logger) *BaseCollector { - var baseLogger *log.Logger - var logFile *os.File +func NewBaseCollector(ctx context.Context, params *CollectParams, config *config.Config, logger log.AllLogger) *BaseCollector { + var baseLogger log.AllLogger if logger != nil { baseLogger = logger - logFile = nil } else { - logsDir := config.Sys.LogsDir - if logsDir == "" { - logsDir = "./logs" - } - os.MkdirAll(logsDir, 0755) - logFile, _ = os.Create(filepath.Join(logsDir, fmt.Sprintf("collect_%s_%s.log", params.RequestID, params.Platform))) - baseLogger = log.New(logFile, "", log.LstdFlags) + baseLogger = log.DefaultLogger() } base := &BaseCollector{ ctx: ctx, Headless: params.Headless, - UserIndex: params.UserIndex, - PlatIndex: params.PlatIndex, RequestID: params.RequestID, Platform: params.Platform, Logger: baseLogger, - LogFile: logFile, config: config, MaxRetries: 3, RetryDelay: 200, } - base.CookiesFile = filepath.Join(base.cookiesDir(), params.PlatIndex+".json") + // Cookie文件按平台区分,而不是按用户索引 + base.CookiesFile = filepath.Join(base.cookiesDir(), params.Platform+".json") return base } -// cookiesDir 获取cookie目录 -func (b *BaseCollector) cookiesDir() string { - dir := filepath.Join(b.config.Sys.CookiesDir, b.UserIndex) - os.MkdirAll(dir, 0755) - return dir -} - // SetupDriver 初始化浏览器驱动 func (b *BaseCollector) SetupDriver() error { - b.LogInfo("初始化浏览器...") - - userDataDir := filepath.Join(b.config.Sys.ChromeDataDir, b.UserIndex, b.RequestID+fmt.Sprintf("___%d", time.Now().UnixNano())) + userDataDir := filepath.Join(b.config.Sys.ChromeDataDir, b.Platform, b.RequestID+fmt.Sprintf("___%d", time.Now().UnixNano())) os.MkdirAll(userDataDir, 0755) l := launcher.New(). @@ -108,15 +86,9 @@ func (b *BaseCollector) SetupDriver() error { l.Delete("headless") } - l.UserDataDir(userDataDir) - l.Set("window-size", "1920,1080") - - // 设置中文语言环境 l.Set("lang", "zh-CN") l.Set("accept-lang", "zh-CN,zh;q=0.9,en;q=0.8") l.Set("force-device-scale-factor", "1") - - // 设置时区为中国 l.Set("timezone", "Asia/Shanghai") url, err := l.Launch() @@ -125,14 +97,12 @@ func (b *BaseCollector) SetupDriver() error { } b.Browser = rod.New().Context(b.ctx).ControlURL(url).MustConnect() - - // 创建新页面 b.Page = b.Browser.MustPage() return nil } -// Close 关闭浏览器和日志文件 +// Close 关闭浏览器 func (b *BaseCollector) Close() { if b.Page != nil { b.Page.Close() @@ -140,9 +110,6 @@ func (b *BaseCollector) Close() { if b.Browser != nil { b.Browser.Close() } - if b.LogFile != nil { - b.LogFile.Close() - } } // SaveCookies 保存cookies @@ -213,12 +180,12 @@ func (b *BaseCollector) WaitForElementClickable(selector string, timeout int) (* // JSClick JavaScript点击元素 func (b *BaseCollector) JSClick(element *rod.Element) error { if element == nil { - b.Logger.Printf("element is nil") + b.Logger.Warn("element is nil") return fmt.Errorf("element is nil") } err := element.Click(proto.InputMouseButtonLeft, 1) if err != nil { - b.Logger.Printf("click fail: " + err.Error()) + b.Logger.Errorf("click fail: %v", err) } return err } @@ -252,25 +219,25 @@ func (b *BaseCollector) SleepMs(milliseconds int) { // LogInfo 记录信息日志 func (b *BaseCollector) LogInfo(message string) { - b.Logger.Printf("📌 %s", message) + b.Logger.Infof("📌 %s", message) } // LogInfof 格式化记录信息日志 func (b *BaseCollector) LogInfof(format string, args ...interface{}) { - b.Logger.Printf("📌 "+format, args...) + b.Logger.Infof("📌 "+format, args...) } // LogError 记录错误日志 func (b *BaseCollector) LogError(message string) { - b.Logger.Printf("❌ %s", message) + b.Logger.Errorf("❌ %s", message) } // LogStep 记录步骤日志 func (b *BaseCollector) LogStep(stepName string, success bool, message string) { if success { - b.Logger.Printf("✅ %s: 成功 %s", stepName, message) + b.Logger.Infof("✅ %s: 成功 %s", stepName, message) } else { - b.Logger.Printf("❌ %s: 失败 %s", stepName, message) + b.Logger.Errorf("❌ %s: 失败 %s", stepName, message) } } @@ -300,8 +267,8 @@ func (b *BaseCollector) WaitLogin() (bool, string) { } // AskQuestion 提问并获取答案(需要子类实现) -func (b *BaseCollector) AskQuestion(question string) (string, error) { - return "", fmt.Errorf("需要实现") +func (b *BaseCollector) AskQuestion(question string) (*CollectResult, error) { + return nil, fmt.Errorf("需要实现") } // InitPage 初始化页面 @@ -333,3 +300,10 @@ func (b *BaseCollector) SafeElement(selector string) (*rod.Element, error) { } return b.Page.Element(selector) } + +// cookiesDir 获取cookie目录 - 按平台区分 +func (b *BaseCollector) cookiesDir() string { + dir := filepath.Join(b.config.Sys.CookiesDir, b.Platform) + os.MkdirAll(dir, 0755) + return dir +} diff --git a/internal/collect/deepseek.go b/internal/collect/deepseek.go index c839534..a0d5e92 100644 --- a/internal/collect/deepseek.go +++ b/internal/collect/deepseek.go @@ -4,12 +4,12 @@ import ( "context" "fmt" "geo/internal/config" - "log" "strings" "time" "github.com/go-rod/rod" "github.com/go-rod/rod/lib/proto" + "github.com/gofiber/fiber/v2/log" ) // DeepseekCollector DeepSeek收集器 @@ -18,7 +18,7 @@ type DeepseekCollector struct { } // NewDeepseekCollector 创建DeepSeek收集器 -func NewDeepseekCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger *log.Logger) CollectorInterface { +func NewDeepseekCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger log.AllLogger) CollectorInterface { collector := &DeepseekCollector{ BaseCollector: NewBaseCollector(ctx, params, cfg, logger), } @@ -54,81 +54,65 @@ func (c *DeepseekCollector) CheckLoginStatus() bool { // WaitLogin 等待登录 func (c *DeepseekCollector) WaitLogin() (bool, string) { - c.LogInfo("开始等待DeepSeek登录...") - if err := c.SetupDriver(); err != nil { return false, fmt.Sprintf("浏览器启动失败: %v", err) } defer c.Close() - // 访问聊天页面 c.Page.MustNavigate(c.ChatURL) c.Sleep(3) - // 检查是否已登录 if c.CheckLoginStatus() { c.SaveCookies() - c.LogInfo("已有登录状态") return true, "already_logged_in" } - c.LogInfo("未检测到登录状态,请登录账号...") - - // 等待用户手动登录,最多300秒 for i := 0; i < 300; i++ { if c.CheckLoginStatus() { + c.Sleep(2) c.SaveCookies() - c.LogInfo("登录成功") return true, "login_success" } - time.Sleep(1 * time.Second) } - return false, "登录超时,请检查网络或账号状态" + return false, "登录超时" } // AskQuestion 提问并获取答案 -func (c *DeepseekCollector) AskQuestion(question string) (string, error) { - c.LogInfo(fmt.Sprintf("开始向DeepSeek提问: %s", question)) - - // 初始化浏览器 +func (c *DeepseekCollector) AskQuestion(question string) (*CollectResult, error) { if err := c.SetupDriver(); err != nil { - return "", fmt.Errorf("浏览器启动失败: %v", err) + return nil, fmt.Errorf("浏览器启动失败: %v", err) } defer c.Close() - // 初始化页面 if err := c.InitPage(); err != nil { - return "", fmt.Errorf("页面初始化失败,请先调用WaitLogin登录: %v", err) + return nil, fmt.Errorf("页面初始化失败: %v", err) } c.Sleep(3) - // 输入问题 if err := c.inputQuestion(question); err != nil { - return "", fmt.Errorf("输入问题失败: %v", err) + return nil, fmt.Errorf("输入问题失败: %v", err) } - // 点击发送 if err := c.clickSendButton(); err != nil { - return "", fmt.Errorf("点击发送按钮失败: %v", err) + return nil, fmt.Errorf("点击发送按钮失败: %v", err) } - // 等待并获取答案 answer, err := c.waitForAnswer() if err != nil { - return "", fmt.Errorf("获取答案失败: %v", err) + return nil, fmt.Errorf("获取答案失败: %v", err) } - c.LogInfo(fmt.Sprintf("成功获取DeepSeek答案,长度: %d 字符", len(answer))) - return answer, nil + return &CollectResult{ + Answer: answer, + ShareLink: "", + }, nil } // inputQuestion 输入问题 func (c *DeepseekCollector) inputQuestion(question string) error { - c.LogInfo("输入问题到DeepSeek...") - // DeepSeek的输入框选择器 inputSelectors := []string{ "textarea[placeholder*='输入']", @@ -145,7 +129,6 @@ func (c *DeepseekCollector) inputQuestion(question string) error { for _, selector := range inputSelectors { inputBox, err = c.WaitForElementVisible(selector, 10) if err == nil && inputBox != nil { - c.LogInfo(fmt.Sprintf("找到输入框: %s", selector)) break } } @@ -162,7 +145,7 @@ func (c *DeepseekCollector) inputQuestion(question string) error { // 清空输入框 if err := c.ClearInput(inputBox); err != nil { - c.LogInfo(fmt.Sprintf("清空输入框失败: %v", err)) + // Ignore clear error } c.SleepMs(300) @@ -171,7 +154,6 @@ func (c *DeepseekCollector) inputQuestion(question string) error { inputBox.Input(question) } - c.LogInfo(fmt.Sprintf("问题已输入")) c.SleepMs(1000) return nil @@ -179,8 +161,6 @@ func (c *DeepseekCollector) inputQuestion(question string) error { // clickSendButton 点击发送按钮 func (c *DeepseekCollector) clickSendButton() error { - c.LogInfo("点击发送按钮...") - // 发送按钮选择器 sendSelectors := []string{ "button[class*='send']", @@ -198,7 +178,6 @@ func (c *DeepseekCollector) clickSendButton() error { for _, selector := range sendSelectors { sendBtn, err = c.WaitForElementClickable(selector, 5) if err == nil && sendBtn != nil { - c.LogInfo(fmt.Sprintf("找到发送按钮: %s", selector)) break } } @@ -218,7 +197,6 @@ func (c *DeepseekCollector) clickSendButton() error { return fmt.Errorf("点击发送按钮失败: %v", err) } - c.LogInfo("已点击发送按钮") c.SleepMs(2000) return nil @@ -226,8 +204,6 @@ func (c *DeepseekCollector) clickSendButton() error { // waitForAnswer 等待并获取答案 func (c *DeepseekCollector) waitForAnswer() (string, error) { - c.LogInfo("等待DeepSeek回答...") - timeout := 120 // 最大等待时间(秒) startTime := time.Now() lastAnswerLength := 0 @@ -262,7 +238,6 @@ func (c *DeepseekCollector) waitForAnswer() (string, error) { currentLength := len(text) if currentLength == lastAnswerLength && currentLength > 10 { // 答案不再增长,认为已完成 - c.LogInfo("获取到完整答案") return strings.TrimSpace(text), nil } lastAnswerLength = currentLength diff --git a/internal/collect/doubao.go b/internal/collect/doubao.go index a2b81dd..98c437c 100644 --- a/internal/collect/doubao.go +++ b/internal/collect/doubao.go @@ -4,12 +4,12 @@ import ( "context" "fmt" "geo/internal/config" - "log" "strings" "time" "github.com/go-rod/rod" "github.com/go-rod/rod/lib/proto" + "github.com/gofiber/fiber/v2/log" ) // DoubaoCollector 豆包收集器 @@ -18,7 +18,7 @@ type DoubaoCollector struct { } // NewDoubaoCollector 创建豆包收集器 -func NewDoubaoCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger *log.Logger) CollectorInterface { +func NewDoubaoCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger log.AllLogger) CollectorInterface { collector := &DoubaoCollector{ BaseCollector: NewBaseCollector(ctx, params, cfg, logger), } @@ -54,82 +54,65 @@ func (c *DoubaoCollector) CheckLoginStatus() bool { // WaitLogin 等待登录 func (c *DoubaoCollector) WaitLogin() (bool, string) { - c.LogInfo("开始等待豆包登录...") - if err := c.SetupDriver(); err != nil { return false, fmt.Sprintf("浏览器启动失败: %v", err) } defer c.Close() - // 访问豆包首页 c.Page.MustNavigate(c.LoginURL) c.Sleep(3) - // 检查是否已登录 if c.CheckLoginStatus() { c.SaveCookies() - c.LogInfo("已有登录状态") return true, "already_logged_in" } - c.LogInfo("请登录豆包账号...") - - // 等待用户手动登录,最多300秒 for i := 0; i < 300; i++ { if c.CheckLoginStatus() { + c.Sleep(2) c.SaveCookies() - c.LogInfo("登录成功") return true, "login_success" } - time.Sleep(1 * time.Second) } - return false, "登录超时,请检查网络或账号状态" + return false, "登录超时" } // AskQuestion 提问并获取答案 -func (c *DoubaoCollector) AskQuestion(question string) (string, error) { - c.LogInfo(fmt.Sprintf("开始向豆包提问: %s", question)) - - // 初始化浏览器 +func (c *DoubaoCollector) AskQuestion(question string) (*CollectResult, error) { if err := c.SetupDriver(); err != nil { - return "", fmt.Errorf("浏览器启动失败: %v", err) + return nil, fmt.Errorf("浏览器启动失败: %v", err) } defer c.Close() - // 初始化页面 if err := c.InitPage(); err != nil { - return "", fmt.Errorf("页面初始化失败,请先调用WaitLogin登录: %v", err) + return nil, fmt.Errorf("页面初始化失败: %v", err) } c.Sleep(3) - // 输入问题 if err := c.inputQuestion(question); err != nil { - return "", fmt.Errorf("输入问题失败: %v", err) + return nil, fmt.Errorf("输入问题失败: %v", err) } - // 点击发送 if err := c.clickSendButton(); err != nil { - return "", fmt.Errorf("点击发送按钮失败: %v", err) + return nil, fmt.Errorf("点击发送按钮失败: %v", err) } - // 等待并获取答案 answer, err := c.waitForAnswer() if err != nil { - return "", fmt.Errorf("获取答案失败: %v", err) + return nil, fmt.Errorf("获取答案失败: %v", err) } - c.LogInfo(fmt.Sprintf("成功获取豆包答案,长度: %d 字符", len(answer))) - return answer, nil + return &CollectResult{ + Answer: answer, + ShareLink: "", + }, nil } // inputQuestion 输入问题 func (c *DoubaoCollector) inputQuestion(question string) error { - c.LogInfo("输入问题到豆包...") - - // 豆包的输入框选择器 inputSelectors := []string{ "textarea[placeholder*='输入']", "textarea[placeholder*='问']", @@ -146,7 +129,6 @@ func (c *DoubaoCollector) inputQuestion(question string) error { for _, selector := range inputSelectors { inputBox, err = c.WaitForElementVisible(selector, 10) if err == nil && inputBox != nil { - c.LogInfo(fmt.Sprintf("找到输入框: %s", selector)) break } } @@ -155,24 +137,20 @@ func (c *DoubaoCollector) inputQuestion(question string) error { return fmt.Errorf("未找到输入框") } - // 点击获取焦点 if err := inputBox.Click(proto.InputMouseButtonLeft, 1); err != nil { return fmt.Errorf("点击输入框失败: %v", err) } c.SleepMs(500) - // 清空输入框 if err := c.ClearInput(inputBox); err != nil { - c.LogInfo(fmt.Sprintf("清空输入框失败: %v", err)) + // Ignore clear error } c.SleepMs(300) - // 输入问题 if err := c.SetInputValue(inputBox, question); err != nil { inputBox.Input(question) } - c.LogInfo(fmt.Sprintf("问题已输入")) c.SleepMs(1000) return nil @@ -180,9 +158,6 @@ func (c *DoubaoCollector) inputQuestion(question string) error { // clickSendButton 点击发送按钮 func (c *DoubaoCollector) clickSendButton() error { - c.LogInfo("点击发送按钮...") - - // 发送按钮选择器 sendSelectors := []string{ "button[class*='send']", "button[class*='submit']", @@ -199,13 +174,11 @@ func (c *DoubaoCollector) clickSendButton() error { for _, selector := range sendSelectors { sendBtn, err = c.WaitForElementClickable(selector, 5) if err == nil && sendBtn != nil { - c.LogInfo(fmt.Sprintf("找到发送按钮: %s", selector)) break } } if sendBtn == nil { - // 尝试查找发送图标 sendBtn, err = c.Page.Element("button svg") if err != nil { return fmt.Errorf("未找到发送按钮") @@ -214,12 +187,10 @@ func (c *DoubaoCollector) clickSendButton() error { c.SleepMs(500) - // 点击发送按钮 if err := c.JSClick(sendBtn); err != nil { return fmt.Errorf("点击发送按钮失败: %v", err) } - c.LogInfo("已点击发送按钮") c.SleepMs(2000) return nil @@ -227,14 +198,11 @@ func (c *DoubaoCollector) clickSendButton() error { // waitForAnswer 等待并获取答案 func (c *DoubaoCollector) waitForAnswer() (string, error) { - c.LogInfo("等待豆包回答...") - - timeout := 120 // 最大等待时间(秒) + timeout := 120 startTime := time.Now() lastAnswerLength := 0 for time.Since(startTime).Seconds() < float64(timeout) { - // 查找答案区域 answerSelectors := []string{ ".message-content", ".response-text", @@ -247,24 +215,19 @@ func (c *DoubaoCollector) waitForAnswer() (string, error) { for _, selector := range answerSelectors { answerElements, err := c.Page.Elements(selector) if err == nil && len(answerElements) > 0 { - // 获取最后一个答案元素 lastAnswer := answerElements[len(answerElements)-1] visible, _ := lastAnswer.Visible() if visible { text, err := lastAnswer.Text() if err == nil && len(strings.TrimSpace(text)) > 0 { - // 检查是否正在生成 isGenerating := strings.Contains(text, "正在") || strings.Contains(text, "思考中") || strings.Contains(text, "typing") if !isGenerating { - // 检查答案是否还在增长 currentLength := len(text) if currentLength == lastAnswerLength && currentLength > 10 { - // 答案不再增长,认为已完成 - c.LogInfo("获取到完整答案") return strings.TrimSpace(text), nil } lastAnswerLength = currentLength diff --git a/internal/collect/interface.go b/internal/collect/interface.go index 7318985..3221d14 100644 --- a/internal/collect/interface.go +++ b/internal/collect/interface.go @@ -3,7 +3,8 @@ package collect import ( "context" "geo/internal/config" - "log" + + "github.com/gofiber/fiber/v2/log" ) // CollectorInterface AI平台收集器接口 @@ -11,7 +12,13 @@ type CollectorInterface interface { // WaitLogin 等待登录 WaitLogin() (bool, string) // AskQuestion 提问并获取答案 - AskQuestion(question string) (string, error) + AskQuestion(question string) (*CollectResult, error) +} + +// CollectResult 收集结果 +type CollectResult struct { + Answer string `json:"answer"` // AI回答内容 + ShareLink string `json:"share_link"` // 分享链接 } // NewCollector 创建收集器的工厂函数类型 @@ -19,22 +26,22 @@ type NewCollector func( ctx context.Context, param *CollectParams, cfg *config.Config, - logger *log.Logger) CollectorInterface + logger log.AllLogger) CollectorInterface // CollectorValue 收集器配置信息 type CollectorValue struct { - Name string // 平台名称 - InitMethod NewCollector // 初始化方法 - Platform string // 平台标识: wenxin, deepseek, doubao, qianwen + Name string // 平台名称 + InitMethod NewCollector // 初始化方法 + Platform string // 平台标识: wenxin, deepseek, doubao, qianwen + Icon string } // CollectParams 收集任务参数 type CollectParams struct { - Headless bool // 是否无头模式 - UserIndex string // 用户索引 - PlatIndex string // 平台索引 - RequestID string // 请求ID - Platform string // 平台类型 + Headless bool // 是否无头模式 + RequestID string // 请求ID + Platform string // 平台类型: wenxin, deepseek, doubao, qianwen + } // CollectorMap 收集器注册表 @@ -43,20 +50,24 @@ var CollectorMap = map[string]*CollectorValue{ Name: "文心一言", InitMethod: NewWenxinCollector, Platform: "wenxin", + Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/wenxin.png", }, "deepseek": { Name: "DeepSeek", InitMethod: NewDeepseekCollector, Platform: "deepseek", + Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/deepseek.png", }, "doubao": { Name: "豆包", InitMethod: NewDoubaoCollector, Platform: "doubao", + Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/doubao.png", }, "qianwen": { Name: "通义千问", InitMethod: NewQianwenCollector, Platform: "qianwen", + Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/qianwen.png", }, } diff --git a/internal/collect/manager.go b/internal/collect/manager.go index 9b58e52..7a4daa9 100644 --- a/internal/collect/manager.go +++ b/internal/collect/manager.go @@ -4,18 +4,19 @@ import ( "context" "fmt" "geo/internal/config" - "log" + + "github.com/gofiber/fiber/v2/log" ) // CollectManager 收集管理器 type CollectManager struct { ctx context.Context config *config.Config - logger *log.Logger + logger log.AllLogger } // NewCollectManager 创建收集管理器 -func NewCollectManager(ctx context.Context, cfg *config.Config, logger *log.Logger) *CollectManager { +func NewCollectManager(ctx context.Context, cfg *config.Config, logger log.AllLogger) *CollectManager { return &CollectManager{ ctx: ctx, config: cfg, @@ -39,10 +40,10 @@ func (m *CollectManager) GetCollector(platform string, params *CollectParams) (C } // AskQuestion 向指定AI平台提问 -func (m *CollectManager) AskQuestion(platform string, params *CollectParams, question string) (string, error) { +func (m *CollectManager) AskQuestion(platform string, params *CollectParams, question string) (*CollectResult, error) { collector, err := m.GetCollector(platform, params) if err != nil { - return "", err + return nil, err } return collector.AskQuestion(question) diff --git a/internal/collect/qianwen.go b/internal/collect/qianwen.go index fd6db9a..d760e1f 100644 --- a/internal/collect/qianwen.go +++ b/internal/collect/qianwen.go @@ -4,12 +4,12 @@ import ( "context" "fmt" "geo/internal/config" - "log" "strings" "time" "github.com/go-rod/rod" "github.com/go-rod/rod/lib/proto" + "github.com/gofiber/fiber/v2/log" ) // QianwenCollector 通义千问收集器 @@ -18,7 +18,7 @@ type QianwenCollector struct { } // NewQianwenCollector 创建通义千问收集器 -func NewQianwenCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger *log.Logger) CollectorInterface { +func NewQianwenCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger log.AllLogger) CollectorInterface { collector := &QianwenCollector{ BaseCollector: NewBaseCollector(ctx, params, cfg, logger), } @@ -54,81 +54,65 @@ func (c *QianwenCollector) CheckLoginStatus() bool { // WaitLogin 等待登录 func (c *QianwenCollector) WaitLogin() (bool, string) { - c.LogInfo("开始等待通义千问登录...") - if err := c.SetupDriver(); err != nil { return false, fmt.Sprintf("浏览器启动失败: %v", err) } defer c.Close() - // 访问通义千问页面 c.Page.MustNavigate(c.ChatURL) c.Sleep(3) - // 检查是否已登录 if c.CheckLoginStatus() { c.SaveCookies() - c.LogInfo("已有登录状态") return true, "already_logged_in" } - c.LogInfo("请登录阿里云账号...") - - // 等待用户手动登录,最多300秒 for i := 0; i < 300; i++ { if c.CheckLoginStatus() { + c.Sleep(2) c.SaveCookies() - c.LogInfo("登录成功") return true, "login_success" } - time.Sleep(1 * time.Second) } - return false, "登录超时,请检查网络或账号状态" + return false, "登录超时" } // AskQuestion 提问并获取答案 -func (c *QianwenCollector) AskQuestion(question string) (string, error) { - c.LogInfo(fmt.Sprintf("开始向通义千问提问: %s", question)) - - // 初始化浏览器 +func (c *QianwenCollector) AskQuestion(question string) (*CollectResult, error) { if err := c.SetupDriver(); err != nil { - return "", fmt.Errorf("浏览器启动失败: %v", err) + return nil, fmt.Errorf("浏览器启动失败: %v", err) } defer c.Close() - // 初始化页面 if err := c.InitPage(); err != nil { - return "", fmt.Errorf("页面初始化失败,请先调用WaitLogin登录: %v", err) + return nil, fmt.Errorf("页面初始化失败: %v", err) } c.Sleep(3) - // 输入问题 if err := c.inputQuestion(question); err != nil { - return "", fmt.Errorf("输入问题失败: %v", err) + return nil, fmt.Errorf("输入问题失败: %v", err) } - // 点击发送 if err := c.clickSendButton(); err != nil { - return "", fmt.Errorf("点击发送按钮失败: %v", err) + return nil, fmt.Errorf("点击发送按钮失败: %v", err) } - // 等待并获取答案 answer, err := c.waitForAnswer() if err != nil { - return "", fmt.Errorf("获取答案失败: %v", err) + return nil, fmt.Errorf("获取答案失败: %v", err) } - c.LogInfo(fmt.Sprintf("成功获取通义千问答案,长度: %d 字符", len(answer))) - return answer, nil + return &CollectResult{ + Answer: answer, + ShareLink: "", + }, nil } // inputQuestion 输入问题 func (c *QianwenCollector) inputQuestion(question string) error { - c.LogInfo("输入问题到通义千问...") - // 通义千问的输入框选择器 inputSelectors := []string{ "textarea[placeholder*='输入']", @@ -147,7 +131,6 @@ func (c *QianwenCollector) inputQuestion(question string) error { for _, selector := range inputSelectors { inputBox, err = c.WaitForElementVisible(selector, 10) if err == nil && inputBox != nil { - c.LogInfo(fmt.Sprintf("找到输入框: %s", selector)) break } } @@ -164,7 +147,7 @@ func (c *QianwenCollector) inputQuestion(question string) error { // 清空输入框 if err := c.ClearInput(inputBox); err != nil { - c.LogInfo(fmt.Sprintf("清空输入框失败: %v", err)) + // Ignore clear error } c.SleepMs(300) @@ -173,7 +156,6 @@ func (c *QianwenCollector) inputQuestion(question string) error { inputBox.Input(question) } - c.LogInfo(fmt.Sprintf("问题已输入")) c.SleepMs(1000) return nil @@ -181,8 +163,6 @@ func (c *QianwenCollector) inputQuestion(question string) error { // clickSendButton 点击发送按钮 func (c *QianwenCollector) clickSendButton() error { - c.LogInfo("点击发送按钮...") - // 发送按钮选择器 sendSelectors := []string{ "button[class*='send']", @@ -201,7 +181,6 @@ func (c *QianwenCollector) clickSendButton() error { for _, selector := range sendSelectors { sendBtn, err = c.WaitForElementClickable(selector, 5) if err == nil && sendBtn != nil { - c.LogInfo(fmt.Sprintf("找到发送按钮: %s", selector)) break } } @@ -221,7 +200,6 @@ func (c *QianwenCollector) clickSendButton() error { return fmt.Errorf("点击发送按钮失败: %v", err) } - c.LogInfo("已点击发送按钮") c.SleepMs(2000) return nil @@ -229,8 +207,6 @@ func (c *QianwenCollector) clickSendButton() error { // waitForAnswer 等待并获取答案 func (c *QianwenCollector) waitForAnswer() (string, error) { - c.LogInfo("等待通义千问回答...") - timeout := 120 // 最大等待时间(秒) startTime := time.Now() lastAnswerLength := 0 @@ -268,7 +244,6 @@ func (c *QianwenCollector) waitForAnswer() (string, error) { currentLength := len(text) if currentLength == lastAnswerLength && currentLength > 10 { // 答案不再增长,认为已完成 - c.LogInfo("获取到完整答案") return strings.TrimSpace(text), nil } lastAnswerLength = currentLength diff --git a/internal/collect/wenxin.go b/internal/collect/wenxin.go index 65e4e84..0902f91 100644 --- a/internal/collect/wenxin.go +++ b/internal/collect/wenxin.go @@ -4,14 +4,15 @@ import ( "context" "fmt" "geo/internal/config" - "log" + + "github.com/gofiber/fiber/v2/log" + "regexp" "strings" "time" "github.com/atotto/clipboard" "github.com/go-rod/rod" "github.com/go-rod/rod/lib/proto" - "regexp" ) // Source 文章引用来源结构体 @@ -28,7 +29,7 @@ type WenxinCollector struct { } // NewWenxinCollector 创建文心一言收集器 -func NewWenxinCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger *log.Logger) CollectorInterface { +func NewWenxinCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger log.AllLogger) CollectorInterface { collector := &WenxinCollector{ BaseCollector: NewBaseCollector(ctx, params, cfg, logger), } @@ -46,37 +47,11 @@ func (c *WenxinCollector) SetupDriver() error { return err } - // 通过 JavaScript 设置 navigator.language 为中文 - jsCode := ` - (function() { - Object.defineProperty(navigator, 'language', { - get: function() { return 'zh-CN'; }, - configurable: true - }); - Object.defineProperty(navigator, 'languages', { - get: function() { return ['zh-CN', 'zh', 'en']; }, - configurable: true - }); - })(); - ` - - if _, err := c.Page.Eval(jsCode); err != nil { - c.LogInfo(fmt.Sprintf("设置语言失败: %v", err)) - } else { - c.LogInfo("已设置浏览器语言为中文 (zh-CN)") - } - return nil } // CheckLoginStatus 检查登录状态 func (c *WenxinCollector) CheckLoginStatus() bool { - currentURL := c.GetCurrentURL() - - // 如果在登录页面,说明未登录 - if strings.Contains(currentURL, "passport.baidu.com") { - return false - } // 检查页面上是否存在内容为"登录"或"Login"的button,如果存在说明未登录 loginButtons, err := c.Page.Elements("button") @@ -97,117 +72,69 @@ func (c *WenxinCollector) CheckLoginStatus() bool { // WaitLogin 等待登录 func (c *WenxinCollector) WaitLogin() (bool, string) { - c.LogInfo("开始等待文心一言登录...") - if err := c.SetupDriver(); err != nil { return false, fmt.Sprintf("浏览器启动失败: %v", err) } defer c.Close() - // 访问聊天页面 c.Page.MustNavigate(c.ChatURL) c.Sleep(3) - // 检查是否已登录 if c.CheckLoginStatus() { c.SaveCookies() - c.LogInfo("已有登录状态") return true, "already_logged_in" } - c.LogInfo("检测到未登录,请在当前页面完成登录(扫码或输入账号密码)...") - - // 不跳转页面,在当前页面循环检查登录按钮是否存在 // 最多等待300秒 - for i := 0; i < 5000; i++ { - // 检查页面上是否还存在"登录"或"Login"按钮 - loginButtonExists := false - buttons, err := c.Page.Elements("button") - if err == nil { - for _, btn := range buttons { - text, _ := btn.Text() - trimmedText := strings.TrimSpace(text) - if trimmedText == "登录" || trimmedText == "Login" { - loginButtonExists = true - break - } - } - } - - // 如果登录按钮不存在,说明已登录 - if !loginButtonExists { - c.Sleep(2) // 等待页面稳定 + for i := 0; i < 300; i++ { + if c.CheckLoginStatus() { + c.Sleep(2) c.SaveCookies() - c.LogInfo("登录成功:登录按钮已消失") return true, "login_success" } - - // 每秒检查一次 time.Sleep(1 * time.Second) - - // 每30秒输出一次提示 - if i > 0 && i%30 == 0 { - c.LogInfo(fmt.Sprintf("等待登录中... 已等待 %d 秒", i)) - } } - return false, "登录超时,请检查网络或账号状态" + return false, "登录超时" } // AskQuestion 提问并获取答案 -func (c *WenxinCollector) AskQuestion(question string) (string, error) { - c.LogInfo(fmt.Sprintf("开始提问: %s", question)) - - // 初始化浏览器 +func (c *WenxinCollector) AskQuestion(question string) (*CollectResult, error) { if err := c.SetupDriver(); err != nil { - return "", fmt.Errorf("浏览器启动失败: %v", err) + return nil, fmt.Errorf("浏览器启动失败: %v", err) } defer c.Close() - //初始化页面(加载cookies和检查登录) if err := c.InitPage(); err != nil { - return "", fmt.Errorf("页面初始化失败,请先调用WaitLogin登录: %v", err) + return nil, fmt.Errorf("页面初始化失败: %v", err) } - // 等待页面完全加载 c.Sleep(3) - // 查找输入框并输入问题 if err := c.inputQuestion(question); err != nil { - return "", fmt.Errorf("输入问题失败: %v", err) + return nil, fmt.Errorf("输入问题失败: %v", err) } - // 点击发送按钮 if err := c.clickSendButton(); err != nil { - return "", fmt.Errorf("点击发送按钮失败: %v", err) + return nil, fmt.Errorf("点击发送按钮失败: %v", err) } - // 等待并获取答案 answer, err := c.waitForAnswer() if err != nil { - return "", fmt.Errorf("获取答案失败: %v", err) + return nil, fmt.Errorf("获取答案失败: %v", err) } - c.LogInfo(fmt.Sprintf("成功获取答案,长度: %d 字符", len(answer))) - // 获取分享链接 - _, shareErr := c.getShareLink() - if shareErr != nil { - c.LogInfo(fmt.Sprintf("分享链接获取状态: %v", shareErr)) + shareLink := "" + link, _ := c.getShareLink() + if link != "" { + shareLink = link } - // 获取引用来源 - sources, sourcesErr := c.GetSources() - if sourcesErr != nil { - c.LogInfo(fmt.Sprintf("引用来源获取失败: %v", sourcesErr)) - } else if len(sources) > 0 { - c.LogInfo(fmt.Sprintf("成功获取 %d 个引用来源", len(sources))) - for i, source := range sources { - c.LogInfo(fmt.Sprintf(" [%d] 标题: %s, 来源: %s, URL: %s", i+1, source.Title, source.PlatformName, source.Url)) - } - } - - return answer, nil + return &CollectResult{ + Answer: answer, + ShareLink: shareLink, + }, nil } // inputQuestion 输入问题 @@ -411,7 +338,7 @@ func (c *WenxinCollector) waitForAnswer() (string, error) { htmlContent, err := answerElem.HTML() if err == nil && len(strings.TrimSpace(htmlContent)) > 30 { // 清理HTML标签,只保留纯文本 - answerText = CleanHTMLTags(htmlContent) + answerText = CleanDivTags(htmlContent) c.LogInfo(fmt.Sprintf("找到答案容器,清理后文本长度: %d", len(answerText))) } else { // 如果HTML获取失败,尝试获取文本 @@ -557,61 +484,105 @@ func (c *WenxinCollector) getShareLink() (string, error) { } c.LogInfo("✓ 点击成功") - c.SleepMs(2000) // 等待弹窗出现 + c.SleepMs(3000) // 等待弹窗出现 c.Screenshot("after_share_icon_click") - // 步骤3: 在弹窗中查找shareContainer的div + // 步骤3: 在弹窗中查找shareContainer的div(带重试机制) c.LogInfo("步骤3: 查找包含'shareContainer'的div元素...") var shareContainerDiv *rod.Element + maxRetries := 5 + retryDelay := 1000 // 每次重试间隔1秒 - // 重新获取所有div元素 - allDivs, err = c.Page.Elements("div") - if err != nil { - return "", fmt.Errorf("获取页面div元素失败: %v", err) - } + for attempt := 1; attempt <= maxRetries; attempt++ { + c.LogInfo(fmt.Sprintf("第 %d/%d 次尝试查找shareContainer...", attempt, maxRetries)) - c.LogInfo(fmt.Sprintf("在 %d 个div元素中查找包含'shareContainer'的class", len(allDivs))) + // 重新获取所有div元素 + allDivs, err = c.Page.Elements("div") + if err != nil { + c.LogInfo(fmt.Sprintf("获取页面div元素失败: %v", err)) + if attempt < maxRetries { + c.SleepMs(retryDelay) + continue + } + return "", fmt.Errorf("获取页面div元素失败: %v", err) + } - for _, elem := range allDivs { - classAttr, _ := elem.Attribute("class") - if classAttr != nil && strings.Contains(strings.ToLower(*classAttr), "sharecontainer") { - tagName, _ := elem.Property("tagName") - c.LogInfo(fmt.Sprintf("✓ 找到shareContainer容器: tag=%s, class=%s", tagName.Str(), *classAttr)) - shareContainerDiv = elem - break + c.LogInfo(fmt.Sprintf("在 %d 个div元素中查找包含'shareContainer'的class", len(allDivs))) + + for _, elem := range allDivs { + classAttr, _ := elem.Attribute("class") + if classAttr != nil && strings.Contains(strings.ToLower(*classAttr), "sharecontainer") { + tagName, _ := elem.Property("tagName") + c.LogInfo(fmt.Sprintf("✓ 找到shareContainer容器: tag=%s, class=%s", tagName.Str(), *classAttr)) + shareContainerDiv = elem + break + } + } + + if shareContainerDiv != nil { + break // 找到了,退出重试循环 + } + + // 没找到,等待后重试 + if attempt < maxRetries { + c.LogInfo(fmt.Sprintf("未找到shareContainer,%d毫秒后重试...", retryDelay)) + c.SleepMs(retryDelay) } } if shareContainerDiv == nil { - return "", fmt.Errorf("未找到包含'shareContainer' class的div元素") + c.Screenshot("share_container_not_found") + return "", fmt.Errorf("经过 %d 次重试仍未找到包含'shareContainer' class的div元素", maxRetries) } - // 步骤4: 在shareContainer内查找genLink的button + // 步骤4: 在shareContainer内查找genLink的button(带重试机制) c.LogInfo("步骤4: 在shareContainer容器内查找包含'genLink'的button...") var genLinkBtn *rod.Element + maxRetries = 3 + retryDelay = 800 - buttons, err := shareContainerDiv.Elements("button") - if err != nil { - return "", fmt.Errorf("获取button元素失败: %v", err) - } + for attempt := 1; attempt <= maxRetries; attempt++ { + c.LogInfo(fmt.Sprintf("第 %d/%d 次尝试查找genLink按钮...", attempt, maxRetries)) - c.LogInfo(fmt.Sprintf("在 %d 个button元素中查找包含'genLink'的class", len(buttons))) + buttons, err := shareContainerDiv.Elements("button") + if err != nil { + c.LogInfo(fmt.Sprintf("获取button元素失败: %v", err)) + if attempt < maxRetries { + c.SleepMs(retryDelay) + continue + } + return "", fmt.Errorf("获取button元素失败: %v", err) + } - for _, elem := range buttons { - classAttr, _ := elem.Attribute("class") - if classAttr != nil && strings.Contains(strings.ToLower(*classAttr), "genlink") { - tagName, _ := elem.Property("tagName") - text, _ := elem.Text() - c.LogInfo(fmt.Sprintf("✓ 找到genLink按钮: tag=%s, class=%s, text=%s", tagName.Str(), *classAttr, strings.TrimSpace(text))) - genLinkBtn = elem - break + c.LogInfo(fmt.Sprintf("在 %d 个button元素中查找包含'genLink'的class", len(buttons))) + + for _, elem := range buttons { + classAttr, _ := elem.Attribute("class") + if classAttr != nil && strings.Contains(strings.ToLower(*classAttr), "genlink") { + tagName, _ := elem.Property("tagName") + text, _ := elem.Text() + c.LogInfo(fmt.Sprintf("✓ 找到genLink按钮: tag=%s, class=%s, text=%s", tagName.Str(), *classAttr, strings.TrimSpace(text))) + genLinkBtn = elem + break + } + } + + if genLinkBtn != nil { + break // 找到了,退出重试循环 + } + + // 没找到,等待后重试 + if attempt < maxRetries { + c.LogInfo(fmt.Sprintf("未找到genLink按钮,%d毫秒后重试...", retryDelay)) + c.SleepMs(retryDelay) } } if genLinkBtn == nil { - return "", fmt.Errorf("在shareContainer容器内未找到包含'genLink' class的button") + c.Screenshot("genlink_button_not_found") + return "", fmt.Errorf("经过 %d 次重试仍未在shareContainer容器内找到包含'genLink' class的button", maxRetries) } // 滚动到按钮位置 diff --git a/internal/entitys/request.go b/internal/entitys/request.go index 97b0006..5ffbf88 100644 --- a/internal/entitys/request.go +++ b/internal/entitys/request.go @@ -196,10 +196,30 @@ type ( } ProductCollectRequest struct { - AccessToken string `json:"access_token" validate:"required" zh:"access_token"` - Keywords []string `json:"keywords" validate:"required" zh:"关键词"` - Platform []int `json:"platform" validate:"required" zh:"平台"` - Question string `json:"question" validate:"required" zh:"问题"` - ProductId int32 `json:"product_id" validate:"required" zh:"项目Id"` + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + Keywords []string `json:"keywords" validate:"required" zh:"关键词"` + PlatformIndex []string `json:"platform_index" validate:"required" zh:"平台"` + Question string `json:"question" validate:"required" zh:"问题"` + ProductId int32 `json:"product_id" validate:"required" zh:"项目Id"` + } + + CollectListRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + ProductId int32 `json:"product_id" validate:"required" zh:"项目Id"` + Page int `json:"page" zh:"页码"` + Limit int `json:"limit" zh:"每页数量"` + } + + // PageRequest 分页请求 + PageRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + Page int `json:"page" zh:"页码"` + Limit int `json:"limit" zh:"每页数量"` + } + + // ReqPageBo 分页参数 + ReqPageBo struct { + Page int `json:"page"` + Limit int `json:"limit"` } ) diff --git a/internal/server/router/app.go b/internal/server/router/app.go index 7f44479..2b997ac 100644 --- a/internal/server/router/app.go +++ b/internal/server/router/app.go @@ -15,6 +15,7 @@ type AppModule struct { publishService *service.PublishService productService *service.ProductService productSourceService *service.ProductSourceService + collectService *service.CollectService } func NewAppModule( @@ -24,6 +25,7 @@ func NewAppModule( publishService *service.PublishService, productService *service.ProductService, productSourceService *service.ProductSourceService, + collectService *service.CollectService, ) *AppModule { return &AppModule{ cfg: cfg, @@ -32,6 +34,7 @@ func NewAppModule( publishService: publishService, productService: productService, productSourceService: productSourceService, + collectService: collectService, } } @@ -59,7 +62,6 @@ func (m *AppModule) Register(router fiber.Router) { router.Post("/product/detail", vali(m.productService.Detail, &entitys.ProductDetailRequest{})) router.Post("/product/update", vali(m.productService.Update, &entitys.ProductUpdateRequest{})) router.Post("/product/del", vali(m.productService.Del, &entitys.ProductDelRequest{})) - router.Post("/product/collect", vali(m.productService.Collect, &entitys.ProductCollectRequest{})) router.Post("/img/upload", m.productService.ImgUpload) router.Post("/plat/list", vali(m.appService.PlatList, &entitys.PlatListRequest{})) @@ -70,4 +72,8 @@ func (m *AppModule) Register(router fiber.Router) { router.Post("/product/source/update", vali(m.productSourceService.Update, &entitys.ProductSourceUpdateRequest{})) router.Post("/product/source/del", vali(m.productSourceService.Del, &entitys.ProductSourceDelRequest{})) router.Post("/product/source/publish", vali(m.productSourceService.Publish, &entitys.ProductPublishRequest{})) + + router.Post("/collect/create", vali(m.collectService.Collect, &entitys.ProductCollectRequest{})) + router.Get("/collect/platforms", m.collectService.GetCollectPlatForms) + router.Post("/collect/list", vali(m.collectService.CollectList, &entitys.CollectListRequest{})) } diff --git a/internal/service/collect.go b/internal/service/collect.go new file mode 100644 index 0000000..88266e8 --- /dev/null +++ b/internal/service/collect.go @@ -0,0 +1,204 @@ +package service + +import ( + "context" + "fmt" + "geo/internal/biz" + "geo/internal/collect" + "geo/internal/config" + "geo/internal/data/impl" + "geo/internal/data/model" + "geo/internal/entitys" + "geo/tmpl/dataTemp" + "github.com/gofiber/fiber/v2" + "log" + "strings" + "sync" + "time" + + "xorm.io/builder" +) + +// CollectService 收集服务 +type CollectService struct { + cfg *config.Config + collectBiz *biz.CollectBiz + collect *impl.CollectImpl + collectTask *impl.CollectTaskImpl + authBiz *biz.AuthBiz +} + +// NewCollectService 创建收集服务 +func NewCollectService( + cfg *config.Config, + collectBiz *biz.CollectBiz, + collect *impl.CollectImpl, + collectTask *impl.CollectTaskImpl, + authBiz *biz.AuthBiz, +) *CollectService { + return &CollectService{ + cfg: cfg, + collectBiz: collectBiz, + collect: collect, + collectTask: collectTask, + authBiz: authBiz, + } +} + +// CollectList 获取收集列表及对应的任务详情 +func (c *CollectService) CollectList(ctx *fiber.Ctx, req *entitys.CollectListRequest) error { + _, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken) + if err != nil { + return err + } + + // 构建查询条件 + cond := builder.NewCond() + if req.ProductId > 0 { + cond = cond.And(builder.Eq{"product_id": req.ProductId}) + } + + // 查询 collect 列表 + var collects []model.Collect + pageBo, err := c.collect.GetListToStruct(ctx.UserContext(), &cond, &dataTemp.ReqPageBo{Page: req.Page, Limit: req.Limit}, &collects, "created_at DESC") + if err != nil { + return err + } + + // 提取所有的 collect_code + collectCodes := make([]string, 0, len(collects)) + for _, collect := range collects { + collectCodes = append(collectCodes, collect.CollectCode) + } + + // 批量查询所有相关的 collect_task + var tasks []model.CollectTask + if len(collectCodes) > 0 { + taskCond := builder.NewCond().And(builder.In("collect_code", collectCodes)) + _, err := c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC") + if err != nil { + log.Printf("查询 collect_task 失败: %v", err) + } + } + + // 按 collect_code 分组 tasks + tasksMap := make(map[string][]model.CollectTask) + for _, task := range tasks { + tasksMap[task.CollectCode] = append(tasksMap[task.CollectCode], task) + } + + // 组装返回数据 + result := make([]*CollectWithTasks, 0, len(collects)) + for _, collect := range collects { + cwt := &CollectWithTasks{ + Collect: collect, + Tasks: tasksMap[collect.CollectCode], + } + result = append(result, cwt) + } + + return ctx.JSON(fiber.Map{ + "list": result, + "total": pageBo.Total, + "page": req.Page, + "pageSize": req.Limit, + }) +} + +// CollectWithTasks 收集记录及其任务列表 +type CollectWithTasks struct { + model.Collect + Tasks []model.CollectTask `json:"tasks"` +} + +func (c *CollectService) GetCollectPlatForms(ctx *fiber.Ctx) error { + var list = make([]map[string]string, 0, len(collect.CollectorMap)) + for _, v := range collect.CollectorMap { + list = append(list, map[string]string{ + "name": v.Name, + "index": v.Platform, + "icon": v.Icon, + }) + } + return ctx.JSON(list) +} + +// Collect 创建收集任务 +func (c *CollectService) Collect(ctx *fiber.Ctx, req *entitys.ProductCollectRequest) error { + _, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken) + if err != nil { + return err + } + + collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano()) + collectData := &model.Collect{ + CollectCode: collectCode, + ProductID: req.ProductId, + Keywords: strings.Join(req.Keywords, ","), + Platform: strings.Join(req.PlatformIndex, ","), + Question: req.Question, + CreatedAt: time.Now(), + } + + err = c.collect.Add(ctx.UserContext(), collectData) + if err != nil { + return err + } + + go c.doCollectAsync(collectCode, req.PlatformIndex, req.Question) + + return ctx.JSON(fiber.Map{"message": "收录生成中"}) +} + +// doCollectAsync 异步执行收集任务 +func (c *CollectService) doCollectAsync(collectCode string, platforms []string, question string) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*240) + defer cancel() + + var wg sync.WaitGroup + var mu sync.Mutex + tasks := make([]*model.CollectTask, 0, len(platforms)) + + for _, platIndex := range platforms { + wg.Add(1) + + go func(platIndex string) { + defer wg.Done() + + platformName, exist := collect.CollectorMap[platIndex] + if !exist { + log.Printf("未知的平台索引: %d", platIndex) + return + } + + requestID := fmt.Sprintf("%s_%s", collectCode, platIndex) + result, err := c.collectBiz.AskAIQuestion(platIndex, requestID, question, true) + if err != nil { + log.Printf("平台 %s 收集失败: %v", platformName, err) + return + } + + task := &model.CollectTask{ + CollectCode: collectCode, + AiPlatformIndex: platIndex, + ContentHTML: result.Answer, + ShareURL: result.ShareLink, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + Status: 1, + } + + mu.Lock() + tasks = append(tasks, task) + mu.Unlock() + }(platIndex) + } + + wg.Wait() + + if len(tasks) > 0 { + if err := c.collectTask.Add(ctx, tasks); err != nil { + log.Printf("保存收集任务失败: %v", err) + } + } +} diff --git a/internal/service/product.go b/internal/service/product.go index 9e8ab19..678ca28 100644 --- a/internal/service/product.go +++ b/internal/service/product.go @@ -1,8 +1,6 @@ package service import ( - "context" - "fmt" "geo/internal/ai_tool" "geo/internal/biz" "geo/internal/config" @@ -13,14 +11,8 @@ import ( "geo/tmpl/dataTemp" "geo/tmpl/errcode" "io" - "log" "os" "path/filepath" - "runtime/debug" - "strconv" - "strings" - "sync" - "time" "github.com/go-viper/mapstructure/v2" "github.com/gofiber/fiber/v2" @@ -33,8 +25,6 @@ type ProductService struct { authBiz *biz.AuthBiz productBiz *biz.ProductBiz aiBiz *biz.AiBiz - collect *impl.CollectImpl - collectTask *impl.CollectTaskImpl } func NewProductService( @@ -43,8 +33,6 @@ func NewProductService( authBiz *biz.AuthBiz, productBiz *biz.ProductBiz, aiBiz *biz.AiBiz, - collect *impl.CollectImpl, - collectTask *impl.CollectTaskImpl, ) *ProductService { return &ProductService{ cfg: cfg, @@ -52,8 +40,6 @@ func NewProductService( authBiz: authBiz, productBiz: productBiz, aiBiz: aiBiz, - collect: collect, - collectTask: collectTask, } } @@ -243,357 +229,3 @@ func (p *ProductService) CreateProductInfoByDocx(c *fiber.Ctx) error { return pkg.HandleResponse(c, productInfo) } - -type CollectInfo struct { - AIPlatformIndex string `json:"ai_platform_index"` - ContentHtml string `json:"content_html"` - ShareUrl string `json:"share_url"` - Source []Source `json:"source"` -} - -type Source struct { - Title string `json:"name"` - Url string `json:"url"` - PlatformName string `json:"platform"` - PlatformIcon string `json:"Platform_icon"` -} - -func (p *ProductService) Collect(c *fiber.Ctx, req *entitys.ProductCollectRequest) error { - log.Printf("[DEBUG] ========== 请求开始 ==========") - log.Printf("[DEBUG] 请求时间: %s", time.Now().Format("2006-01-02 15:04:05.000")) - log.Printf("[Collect] 开始处理收集请求, ProductID: %d, Platforms: %v, Keywords: %v", - req.ProductId, req.Platform, req.Keywords) - - _, err := p.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) - if err != nil { - log.Printf("[Collect] 验证token失败, ProductID: %d, Error: %v", req.ProductId, err) - return err - } - - _, err = p.productBiz.GetProduct(c.UserContext(), req.ProductId) - if err != nil { - log.Printf("[Collect] 获取产品信息失败, ProductID: %d, Error: %v", req.ProductId, err) - return err - } - - platformStr := make([]string, len(req.Platform)) - for i, s := range req.Platform { - platformStr[i] = strconv.Itoa(s) - } - - collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano()) - collectData := &model.Collect{ - CollectCode: collectCode, - ProductID: req.ProductId, - Keywords: strings.Join(req.Keywords, ","), - Platform: strings.Join(platformStr, ","), - Question: req.Question, - CreatedAt: time.Now(), - } - - log.Printf("[Collect] 创建收集记录, CollectCode: %s, ProductID: %d", collectCode, req.ProductId) - - err = p.collect.Add(c.UserContext(), collectData) - if err != nil { - log.Printf("[Collect] 保存收集记录失败, CollectCode: %s, Error: %v", collectCode, err) - return err - } - - log.Printf("[Collect] ✅ 启动异步收集任务, CollectCode: %s, Platforms: %v", collectCode, req.Platform) - - go func() { - // 记录 goroutine 启动时间 - startTime := time.Now() - log.Printf("[Goroutine] 异步任务启动, CollectCode: %s, 启动时间: %s", collectCode, startTime.Format("15:04:05.000")) - - // 使用独立 context,避免请求结束后任务被取消 - ctx, cancel := context.WithTimeout(context.Background(), time.Second*240) - - // 监控 context 取消 - go func() { - <-ctx.Done() - log.Printf("[Goroutine] ❌ Context被取消! CollectCode: %s, 原因: %v, 耗时: %v", - collectCode, ctx.Err(), time.Since(startTime)) - }() - - defer func() { - if r := recover(); r != nil { - log.Printf("[Goroutine] ❌ PANIC: %v\nStack: %s", r, debug.Stack()) - } - log.Printf("[Goroutine] 异步任务结束, CollectCode: %s, 总耗时: %v", collectCode, time.Since(startTime)) - cancel() - log.Printf("[Goroutine] 已调用 cancel(), CollectCode: %s", collectCode) - }() - - log.Printf("[Goroutine] 准备调用 doCollect, CollectCode: %s", collectCode) - p.doCollect(ctx, collectData, req.Platform) - log.Printf("[Goroutine] doCollect 已返回, CollectCode: %s", collectCode) - }() - - log.Printf("[DEBUG] ========== 请求返回 ==========") - return pkg.HandleResponse(c, "收录生成中") -} - -func (p *ProductService) doCollect(ctx context.Context, collectData *model.Collect, platforms []int) { - collectCode := collectData.CollectCode - startTime := time.Now() - - log.Printf("[doCollect] ========== 开始执行 ==========") - log.Printf("[doCollect] CollectCode: %s, Platforms: %v", collectCode, platforms) - log.Printf("[doCollect] Context状态: %v, 超时时间: %v", ctx.Err(), time.Second*240) - - // 监控 context - go func() { - <-ctx.Done() - log.Printf("[doCollect] ⚠️ 检测到Context取消! CollectCode: %s, 原因: %v, 已执行时间: %v", - collectCode, ctx.Err(), time.Since(startTime)) - }() - - collectClient := ai_tool.NewCollect(p.cfg.Collect.ApiKey) - log.Printf("[doCollect] 已创建 collectClient") - - var wg sync.WaitGroup - resCh := make(chan *model.CollectTask, len(platforms)) - log.Printf("[doCollect] 创建 channel, 容量: %d", len(platforms)) - - // 启动监控 goroutine - monitorStart := time.Now() - - // 启动所有平台的任务 - log.Printf("[doCollect] 启动 %d 个平台任务", len(platforms)) - for i, plat := range platforms { - log.Printf("[doCollect] 启动任务 #%d, Platform: %d", i+1, plat) - wg.Add(1) - go p.processPlatform(ctx, &wg, collectClient, collectData, plat, resCh, i+1) - } - - go func() { - log.Printf("[Monitor] 监控goroutine启动, CollectCode: %s", collectCode) - wg.Wait() - log.Printf("[Monitor] ✅ 所有任务完成, 准备关闭channel, 等待时间: %v", time.Since(monitorStart)) - close(resCh) - log.Printf("[Monitor] Channel已关闭") - }() - - // 收集结果 - 添加超时保护 - log.Printf("[doCollect] 开始等待结果...") - var datas []*model.CollectTask - taskCount := 0 - - // 设置一个最大等待时间 - waitTimeout := time.After(250 * time.Second) - - for { - select { - case task, ok := <-resCh: - if !ok { - log.Printf("[doCollect] Channel已关闭, 收集到 %d 条结果", len(datas)) - goto SAVE - } - datas = append(datas, task) - taskCount++ - log.Printf("[doCollect] ✅ 收到结果 #%d, Platform: %d, RequestID: %s, ScriptTime: %d", - taskCount, task.Platform, task.RequestID, task.ScriptTime) - - case <-waitTimeout: - log.Printf("[doCollect] ⚠️ 等待超时 250秒, 强制退出, 已收集: %d/%d", taskCount, len(platforms)) - goto SAVE - - case <-ctx.Done(): - log.Printf("[doCollect] ❌ Context取消, 强制退出, 已收集: %d/%d, 原因: %v", - taskCount, len(platforms), ctx.Err()) - goto SAVE - } - } - -SAVE: - log.Printf("[doCollect] 收集完成, 共 %d 条结果", len(datas)) - - // 保存结果 - if len(datas) > 0 { - log.Printf("[doCollect] 开始保存到数据库, 数量: %d", len(datas)) - saveStart := time.Now() - if err := p.collectTask.Add(ctx, datas); err != nil { - log.Printf("[doCollect] ❌ 保存失败: %v", err) - } else { - log.Printf("[doCollect] ✅ 保存成功, 耗时: %v", time.Since(saveStart)) - } - } else { - log.Printf("[doCollect] ⚠️ 没有结果需要保存") - } - - elapsed := time.Since(startTime) - log.Printf("[doCollect] ========== 结束执行, 总耗时: %v ==========", elapsed) -} - -func (p *ProductService) processPlatform(ctx context.Context, wg *sync.WaitGroup, - collectClient *ai_tool.Collect, collectData *model.Collect, plat int, - resCh chan<- *model.CollectTask, taskNum int) { - - collectCode := collectData.CollectCode - startTime := time.Now() - - log.Printf("[Platform #%d] ========== 开始 ==========", taskNum) - log.Printf("[Platform #%d] CollectCode: %s, Platform: %d", taskNum, collectCode, plat) - - // 确保 wg.Done() 一定会被调用 - defer func() { - log.Printf("[Platform #%d] 准备调用 wg.Done(), 已执行时间: %v", taskNum, time.Since(startTime)) - wg.Done() - log.Printf("[Platform #%d] 已调用 wg.Done()", taskNum) - }() - - defer func() { - if r := recover(); r != nil { - log.Printf("[Platform #%d] ❌ PANIC: %v\nStack: %s", taskNum, r, debug.Stack()) - } - log.Printf("[Platform #%d] ========== 结束, 耗时: %v ==========", taskNum, time.Since(startTime)) - }() - - // 检查 context 是否已取消 - select { - case <-ctx.Done(): - log.Printf("[Platform #%d] ❌ Context已取消, 退出执行, 原因: %v", taskNum, ctx.Err()) - return - default: - log.Printf("[Platform #%d] Context正常", taskNum) - } - - // 创建任务 - request := ai_tool.CreateReq{ - Keywords: collectData.Keywords, - Question: collectData.Question, - Platform: plat, - ThirdID: fmt.Sprintf("%s_%d", collectData.CollectCode, plat), - } - - log.Printf("[Platform #%d] 调用 Create API, Request: %+v", taskNum, request) - - createStart := time.Now() - res, err := collectClient.Create(ctx, &request) - createElapsed := time.Since(createStart) - - if err != nil { - log.Printf("[Platform #%d] ❌ Create失败, 耗时: %v, Error: %v", taskNum, createElapsed, err) - return - } - if res.Code != 1 { - log.Printf("[Platform #%d] ❌ Create返回错误码, 耗时: %v, Code: %d, Message: %s", - taskNum, createElapsed, res.Code, res.Msg) - return - } - - log.Printf("[Platform #%d] ✅ Create成功, 耗时: %v, RequestID: %s", - taskNum, createElapsed, res.Data.RequestId) - - // 轮询任务状态 - log.Printf("[Platform #%d] 开始轮询, RequestID: %s", taskNum, res.Data.RequestId) - - pollStart := time.Now() - task := p.pollTaskStatus(ctx, collectClient, res.Data.RequestId, collectData, plat, taskNum) - pollElapsed := time.Since(pollStart) - - if task != nil { - log.Printf("[Platform #%d] ✅ 轮询成功, 耗时: %v, ScriptTime: %d", - taskNum, pollElapsed, task.ScriptTime) - - // 发送结果到 channel - log.Printf("[Platform #%d] 准备发送结果到channel", taskNum) - select { - case resCh <- task: - log.Printf("[Platform #%d] ✅ 结果已发送到channel", taskNum) - case <-ctx.Done(): - log.Printf("[Platform #%d] ⚠️ Context取消, 放弃发送结果", taskNum) - return - } - } else { - log.Printf("[Platform #%d] ❌ 轮询失败, 耗时: %v, 未获取到结果", taskNum, pollElapsed) - } -} - -func (p *ProductService) pollTaskStatus(ctx context.Context, collectClient *ai_tool.Collect, - requestID string, collectData *model.Collect, plat int, taskNum int) *model.CollectTask { - - collectCode := collectData.CollectCode - startTime := time.Now() - - log.Printf("[Poll #%d] ========== 开始轮询 ==========", taskNum) - log.Printf("[Poll #%d] CollectCode: %s, Platform: %d, RequestID: %s", - taskNum, collectCode, plat, requestID) - - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - - errCount := 0 - const maxErrors = 5 - pollCount := 0 - - for { - select { - case <-ctx.Done(): - log.Printf("[Poll #%d] ❌ Context取消, 停止轮询, 已轮询%d次, 耗时: %v, 原因: %v", - taskNum, pollCount, time.Since(startTime), ctx.Err()) - return nil - - case <-ticker.C: - pollCount++ - log.Printf("[Poll #%d] 第 %d 次轮询, 已耗时: %v", taskNum, pollCount, time.Since(startTime)) - - checkStart := time.Now() - checkRes, err := collectClient.CheckTask(ctx, requestID) - checkElapsed := time.Since(checkStart) - - if err != nil { - errCount++ - log.Printf("[Poll #%d] ❌ 轮询失败(第%d次错误), 耗时: %v, Error: %v, 累计错误: %d/%d", - taskNum, pollCount, checkElapsed, err, errCount, maxErrors) - if errCount >= maxErrors { - log.Printf("[Poll #%d] 达到最大错误次数, 停止轮询", taskNum) - return nil - } - continue - } - - log.Printf("[Poll #%d] ✅ 轮询成功, 耗时: %v, Code: %d, Status: %d, ScriptTime: %d, ShouluDate: %s", - taskNum, checkElapsed, checkRes.Code, checkRes.Data.Status, - checkRes.Data.ScriptTime, checkRes.Data.ShouluDate) - - if checkRes.Code != 1 { - log.Printf("[Poll #%d] ❌ 返回错误码: %d", taskNum, checkRes.Code) - return nil - } - // 判断任务是否完成 - // 根据你的业务逻辑调整判断条件 - isCompleted := false - completeReason := "" - - if checkRes.Data.Status != 0 { // 假设 2 表示完成 - isCompleted = true - completeReason = fmt.Sprintf("chekcStatus=%d", checkRes.Data.Status) - } - - if isCompleted { - log.Printf("[Poll #%d] 🎉 任务完成! 原因: %s, 总轮询次数: %d, 总耗时: %v", - taskNum, completeReason, pollCount, time.Since(startTime)) - - return &model.CollectTask{ - RequestID: checkRes.Data.RequestId, - CollectCode: collectData.CollectCode, - ScriptTime: int32(checkRes.Data.ScriptTime), - Platform: int32(checkRes.Data.Platform), - CollectData: checkRes.Data.ShouluDate, - ShareURL: checkRes.Data.ShareUrl, - ImgURL: checkRes.Data.ImgUrl, - PointKeyword: checkRes.Data.HitWord, - Question: checkRes.Data.Question, - Res: pkg.JsonStringIgonErr(checkRes), - CreatedAt: time.Now(), - Status: int32(checkRes.Data.Status), - } - } - - log.Printf("[Poll #%d] 任务未完成, 继续轮询, Status=%d, ScriptTime=%d, ShouluDate=%s", - taskNum, checkRes.Data.Status, checkRes.Data.ScriptTime, checkRes.Data.ShouluDate) - } - } -} diff --git a/internal/service/provider_set.go b/internal/service/provider_set.go index af0e0c4..ebda5ff 100644 --- a/internal/service/provider_set.go +++ b/internal/service/provider_set.go @@ -10,4 +10,5 @@ var ProviderSetAppService = wire.NewSet( NewLoginService, NewProductService, NewProductSourceService, + NewCollectService, )