3232
This commit is contained in:
parent
5d75f68ab9
commit
1f79fa82ee
|
|
@ -7,6 +7,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"geo/internal/biz"
|
||||
"geo/internal/config"
|
||||
"geo/internal/data/impl"
|
||||
|
|
@ -45,9 +46,11 @@ func InitializeApp(configConfig *config.Config, allLogger log.AllLogger) (*serve
|
|||
publishService := service.NewPublishService(configConfig, publishBiz, authBiz, db, aiBiz)
|
||||
collectImpl := impl.NewCollectImpl(db)
|
||||
collectTaskImpl := impl.NewCollectTaskImpl(db)
|
||||
productService := service.NewProductService(configConfig, productImpl, authBiz, productBiz, aiBiz, collectImpl, collectTaskImpl)
|
||||
collectBiz := biz.NewCollectBiz(context.Background(), configConfig, allLogger)
|
||||
productService := service.NewProductService(configConfig, productImpl, authBiz, productBiz, aiBiz)
|
||||
collectService := service.NewCollectService(configConfig, collectBiz, collectImpl, collectTaskImpl, authBiz)
|
||||
productSourceService := service.NewProductSourceService(configConfig, productImpl, authBiz, aiBiz, productBiz, productSourceImpl, publishBiz, articleTypeImpl)
|
||||
appModule := router.NewAppModule(configConfig, appService, loginService, publishService, productService, productSourceService)
|
||||
appModule := router.NewAppModule(configConfig, appService, loginService, publishService, productService, productSourceService, collectService)
|
||||
routerServer := router.NewRouterServer(appModule)
|
||||
app := server.NewHTTPServer(routerServer)
|
||||
servers := server.NewServers(configConfig, app)
|
||||
|
|
|
|||
|
|
@ -4,15 +4,14 @@ import (
|
|||
"context"
|
||||
"geo/internal/collect"
|
||||
"geo/internal/config"
|
||||
"log"
|
||||
"os"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
cfg, _ = config.LoadConfig()
|
||||
logger = log.New(os.Stdout, "", log.LstdFlags)
|
||||
manager = collect.NewCollectManager(context.Background(), cfg, logger)
|
||||
cfg, _ = config.LoadConfig()
|
||||
|
||||
manager = collect.NewCollectManager(context.Background(), cfg, log.DefaultLogger())
|
||||
)
|
||||
|
||||
// TestCollectManager_Basic 测试收集管理器的基本功能
|
||||
|
|
@ -29,8 +28,6 @@ func TestCollectManager_Basic(t *testing.T) {
|
|||
for _, platform := range platforms {
|
||||
params := &collect.CollectParams{
|
||||
Headless: true,
|
||||
UserIndex: "test_user",
|
||||
PlatIndex: platform,
|
||||
RequestID: "test_req",
|
||||
Platform: platform,
|
||||
}
|
||||
|
|
@ -57,8 +54,6 @@ func TestWenxinCollector_WaitLogin(t *testing.T) {
|
|||
|
||||
params := &collect.CollectParams{
|
||||
Headless: false, // 显示浏览器窗口以便扫码登录
|
||||
UserIndex: "test_user",
|
||||
PlatIndex: "wenxin",
|
||||
RequestID: "test_wenxin_login_001",
|
||||
Platform: "wenxin",
|
||||
}
|
||||
|
|
@ -87,8 +82,6 @@ func TestWenxinCollector_AskQuestion(t *testing.T) {
|
|||
// 设置收集参数
|
||||
params := &collect.CollectParams{
|
||||
Headless: false, // 显示浏览器以便调试
|
||||
UserIndex: "test_user",
|
||||
PlatIndex: "wenxin",
|
||||
RequestID: "test_wenxin_001",
|
||||
Platform: "wenxin",
|
||||
}
|
||||
|
|
@ -98,16 +91,17 @@ func TestWenxinCollector_AskQuestion(t *testing.T) {
|
|||
t.Logf("向文心一言提问: %s", question)
|
||||
|
||||
// 调用管理器提问并获取答案
|
||||
answer, err := manager.AskQuestion("wenxin", params, question)
|
||||
result, err := manager.AskQuestion("wenxin", params, question)
|
||||
if err != nil {
|
||||
t.Errorf("提问失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("获取到答案:\n%s", answer)
|
||||
t.Logf("获取到答案:\n%s", result.Answer)
|
||||
t.Logf("分享链接: %s", result.ShareLink)
|
||||
|
||||
// 验证答案非空
|
||||
if len(answer) == 0 {
|
||||
if len(result.Answer) == 0 {
|
||||
t.Error("答案为空")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,45 +19,6 @@ type PlatForm struct {
|
|||
Icon string `json:"icon"`
|
||||
}
|
||||
|
||||
var PlatFormList = []PlatForm{
|
||||
{
|
||||
Index: 1,
|
||||
PlatFormName: "deepseek",
|
||||
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/deepseek.png",
|
||||
},
|
||||
{
|
||||
Index: 2,
|
||||
PlatFormName: "豆包",
|
||||
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/doubao.png",
|
||||
},
|
||||
{
|
||||
Index: 3,
|
||||
PlatFormName: "元宝",
|
||||
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/yuanbao.png",
|
||||
},
|
||||
{
|
||||
Index: 41,
|
||||
PlatFormName: "千问",
|
||||
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/qianwen.png",
|
||||
},
|
||||
{
|
||||
Index: 5,
|
||||
PlatFormName: "文心一言",
|
||||
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/wenxin.png",
|
||||
},
|
||||
//{
|
||||
// Index: 6,
|
||||
// PlatFormName: "纳米",
|
||||
//},
|
||||
//{
|
||||
// Index: 7,
|
||||
// PlatFormName: "kimi",
|
||||
//}, {
|
||||
// Index: 8,
|
||||
// PlatFormName: "智普",
|
||||
//},
|
||||
}
|
||||
|
||||
type CreateReq struct {
|
||||
// 品牌词,多个用英文逗号隔开
|
||||
ApiKey string `json:"api_key"`
|
||||
|
|
|
|||
|
|
@ -5,18 +5,19 @@ import (
|
|||
"fmt"
|
||||
"geo/internal/collect"
|
||||
"geo/internal/config"
|
||||
"log"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// CollectBiz AI收集业务层
|
||||
type CollectBiz struct {
|
||||
manager *collect.CollectManager
|
||||
config *config.Config
|
||||
logger *log.Logger
|
||||
logger log.AllLogger
|
||||
}
|
||||
|
||||
// NewCollectBiz 创建AI收集业务实例
|
||||
func NewCollectBiz(ctx context.Context, cfg *config.Config, logger *log.Logger) *CollectBiz {
|
||||
func NewCollectBiz(ctx context.Context, cfg *config.Config, logger log.AllLogger) *CollectBiz {
|
||||
manager := collect.NewCollectManager(ctx, cfg, logger)
|
||||
return &CollectBiz{
|
||||
manager: manager,
|
||||
|
|
@ -27,34 +28,28 @@ func NewCollectBiz(ctx context.Context, cfg *config.Config, logger *log.Logger)
|
|||
|
||||
// AskAIQuestion 向指定AI平台提问
|
||||
// platform: 平台类型 (wenxin, deepseek, doubao, qianwen)
|
||||
// userIndex: 用户索引
|
||||
// platIndex: 平台索引
|
||||
// requestID: 请求ID
|
||||
// question: 问题内容
|
||||
// headless: 是否无头模式
|
||||
func (b *CollectBiz) AskAIQuestion(platform string, userIndex, platIndex, requestID, question string, headless bool) (string, error) {
|
||||
func (b *CollectBiz) AskAIQuestion(platform string, requestID, question string, headless bool) (*collect.CollectResult, error) {
|
||||
params := &collect.CollectParams{
|
||||
Headless: headless,
|
||||
UserIndex: userIndex,
|
||||
PlatIndex: platIndex,
|
||||
RequestID: requestID,
|
||||
Platform: platform,
|
||||
}
|
||||
|
||||
answer, err := b.manager.AskQuestion(platform, params, question)
|
||||
result, err := b.manager.AskQuestion(platform, params, question)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("向%s提问失败: %w", platform, err)
|
||||
return nil, fmt.Errorf("向%s提问失败: %w", platform, err)
|
||||
}
|
||||
|
||||
return answer, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// WaitAILogin 等待AI平台登录
|
||||
func (b *CollectBiz) WaitAILogin(platform string, userIndex, platIndex, requestID string, headless bool) (bool, string) {
|
||||
func (b *CollectBiz) WaitAILogin(platform string, requestID string, headless bool) (bool, string) {
|
||||
params := &collect.CollectParams{
|
||||
Headless: headless,
|
||||
UserIndex: userIndex,
|
||||
PlatIndex: platIndex,
|
||||
RequestID: requestID,
|
||||
Platform: platform,
|
||||
}
|
||||
|
|
@ -68,17 +63,22 @@ func (b *CollectBiz) ListAIPlatforms() []string {
|
|||
}
|
||||
|
||||
// AskMultipleAI 向多个AI平台提问并收集答案
|
||||
func (b *CollectBiz) AskMultipleAI(platforms []string, userIndex, requestID, question string, headless bool) map[string]string {
|
||||
results := make(map[string]string)
|
||||
func (b *CollectBiz) AskMultipleAI(platforms []string, requestID, question string, headless bool) map[string]*collect.CollectResult {
|
||||
results := make(map[string]*collect.CollectResult)
|
||||
|
||||
for _, platform := range platforms {
|
||||
platIndex := platform // 默认使用platform作为platIndex
|
||||
answer, err := b.AskAIQuestion(platform, userIndex, platIndex, requestID+"_"+platform, question, headless)
|
||||
// 为每个平台生成唯一的 requestID
|
||||
platformRequestID := requestID + "_" + platform
|
||||
result, err := b.AskAIQuestion(platform, platformRequestID, question, headless)
|
||||
if err != nil {
|
||||
b.logger.Printf("向%s提问失败: %v", platform, err)
|
||||
results[platform] = fmt.Sprintf("错误: %v", err)
|
||||
b.logger.Errorf("向%s提问失败: %v", platform, err)
|
||||
// 创建一个包含错误信息的结果
|
||||
results[platform] = &collect.CollectResult{
|
||||
Answer: fmt.Sprintf("错误: %v", err),
|
||||
ShareLink: "",
|
||||
}
|
||||
} else {
|
||||
results[platform] = answer
|
||||
results[platform] = result
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,4 +9,5 @@ var ProviderSetBiz = wire.NewSet(
|
|||
NewAuthBiz,
|
||||
NewAiBiz,
|
||||
NewProductBiz,
|
||||
NewCollectBiz,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"geo/internal/config"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
|
@ -13,22 +12,20 @@ import (
|
|||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/launcher"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// BaseCollector 基础收集器结构
|
||||
type BaseCollector struct {
|
||||
ctx context.Context
|
||||
Headless bool
|
||||
UserIndex string
|
||||
PlatIndex string
|
||||
RequestID string
|
||||
Platform string
|
||||
|
||||
Browser *rod.Browser
|
||||
Page *rod.Page
|
||||
|
||||
Logger *log.Logger
|
||||
LogFile *os.File
|
||||
Logger log.AllLogger
|
||||
|
||||
LoginURL string
|
||||
ChatURL string
|
||||
|
|
@ -41,53 +38,34 @@ type BaseCollector struct {
|
|||
}
|
||||
|
||||
// NewBaseCollector 构造函数
|
||||
func NewBaseCollector(ctx context.Context, params *CollectParams, config *config.Config, logger *log.Logger) *BaseCollector {
|
||||
var baseLogger *log.Logger
|
||||
var logFile *os.File
|
||||
func NewBaseCollector(ctx context.Context, params *CollectParams, config *config.Config, logger log.AllLogger) *BaseCollector {
|
||||
var baseLogger log.AllLogger
|
||||
|
||||
if logger != nil {
|
||||
baseLogger = logger
|
||||
logFile = nil
|
||||
} else {
|
||||
logsDir := config.Sys.LogsDir
|
||||
if logsDir == "" {
|
||||
logsDir = "./logs"
|
||||
}
|
||||
os.MkdirAll(logsDir, 0755)
|
||||
logFile, _ = os.Create(filepath.Join(logsDir, fmt.Sprintf("collect_%s_%s.log", params.RequestID, params.Platform)))
|
||||
baseLogger = log.New(logFile, "", log.LstdFlags)
|
||||
baseLogger = log.DefaultLogger()
|
||||
}
|
||||
|
||||
base := &BaseCollector{
|
||||
ctx: ctx,
|
||||
Headless: params.Headless,
|
||||
UserIndex: params.UserIndex,
|
||||
PlatIndex: params.PlatIndex,
|
||||
RequestID: params.RequestID,
|
||||
Platform: params.Platform,
|
||||
Logger: baseLogger,
|
||||
LogFile: logFile,
|
||||
config: config,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 200,
|
||||
}
|
||||
|
||||
base.CookiesFile = filepath.Join(base.cookiesDir(), params.PlatIndex+".json")
|
||||
// Cookie文件按平台区分,而不是按用户索引
|
||||
base.CookiesFile = filepath.Join(base.cookiesDir(), params.Platform+".json")
|
||||
return base
|
||||
}
|
||||
|
||||
// cookiesDir 获取cookie目录
|
||||
func (b *BaseCollector) cookiesDir() string {
|
||||
dir := filepath.Join(b.config.Sys.CookiesDir, b.UserIndex)
|
||||
os.MkdirAll(dir, 0755)
|
||||
return dir
|
||||
}
|
||||
|
||||
// SetupDriver 初始化浏览器驱动
|
||||
func (b *BaseCollector) SetupDriver() error {
|
||||
b.LogInfo("初始化浏览器...")
|
||||
|
||||
userDataDir := filepath.Join(b.config.Sys.ChromeDataDir, b.UserIndex, b.RequestID+fmt.Sprintf("___%d", time.Now().UnixNano()))
|
||||
userDataDir := filepath.Join(b.config.Sys.ChromeDataDir, b.Platform, b.RequestID+fmt.Sprintf("___%d", time.Now().UnixNano()))
|
||||
os.MkdirAll(userDataDir, 0755)
|
||||
|
||||
l := launcher.New().
|
||||
|
|
@ -108,15 +86,9 @@ func (b *BaseCollector) SetupDriver() error {
|
|||
l.Delete("headless")
|
||||
}
|
||||
|
||||
l.UserDataDir(userDataDir)
|
||||
l.Set("window-size", "1920,1080")
|
||||
|
||||
// 设置中文语言环境
|
||||
l.Set("lang", "zh-CN")
|
||||
l.Set("accept-lang", "zh-CN,zh;q=0.9,en;q=0.8")
|
||||
l.Set("force-device-scale-factor", "1")
|
||||
|
||||
// 设置时区为中国
|
||||
l.Set("timezone", "Asia/Shanghai")
|
||||
|
||||
url, err := l.Launch()
|
||||
|
|
@ -125,14 +97,12 @@ func (b *BaseCollector) SetupDriver() error {
|
|||
}
|
||||
|
||||
b.Browser = rod.New().Context(b.ctx).ControlURL(url).MustConnect()
|
||||
|
||||
// 创建新页面
|
||||
b.Page = b.Browser.MustPage()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭浏览器和日志文件
|
||||
// Close 关闭浏览器
|
||||
func (b *BaseCollector) Close() {
|
||||
if b.Page != nil {
|
||||
b.Page.Close()
|
||||
|
|
@ -140,9 +110,6 @@ func (b *BaseCollector) Close() {
|
|||
if b.Browser != nil {
|
||||
b.Browser.Close()
|
||||
}
|
||||
if b.LogFile != nil {
|
||||
b.LogFile.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// SaveCookies 保存cookies
|
||||
|
|
@ -213,12 +180,12 @@ func (b *BaseCollector) WaitForElementClickable(selector string, timeout int) (*
|
|||
// JSClick JavaScript点击元素
|
||||
func (b *BaseCollector) JSClick(element *rod.Element) error {
|
||||
if element == nil {
|
||||
b.Logger.Printf("element is nil")
|
||||
b.Logger.Warn("element is nil")
|
||||
return fmt.Errorf("element is nil")
|
||||
}
|
||||
err := element.Click(proto.InputMouseButtonLeft, 1)
|
||||
if err != nil {
|
||||
b.Logger.Printf("click fail: " + err.Error())
|
||||
b.Logger.Errorf("click fail: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
@ -252,25 +219,25 @@ func (b *BaseCollector) SleepMs(milliseconds int) {
|
|||
|
||||
// LogInfo 记录信息日志
|
||||
func (b *BaseCollector) LogInfo(message string) {
|
||||
b.Logger.Printf("📌 %s", message)
|
||||
b.Logger.Infof("📌 %s", message)
|
||||
}
|
||||
|
||||
// LogInfof 格式化记录信息日志
|
||||
func (b *BaseCollector) LogInfof(format string, args ...interface{}) {
|
||||
b.Logger.Printf("📌 "+format, args...)
|
||||
b.Logger.Infof("📌 "+format, args...)
|
||||
}
|
||||
|
||||
// LogError 记录错误日志
|
||||
func (b *BaseCollector) LogError(message string) {
|
||||
b.Logger.Printf("❌ %s", message)
|
||||
b.Logger.Errorf("❌ %s", message)
|
||||
}
|
||||
|
||||
// LogStep 记录步骤日志
|
||||
func (b *BaseCollector) LogStep(stepName string, success bool, message string) {
|
||||
if success {
|
||||
b.Logger.Printf("✅ %s: 成功 %s", stepName, message)
|
||||
b.Logger.Infof("✅ %s: 成功 %s", stepName, message)
|
||||
} else {
|
||||
b.Logger.Printf("❌ %s: 失败 %s", stepName, message)
|
||||
b.Logger.Errorf("❌ %s: 失败 %s", stepName, message)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -300,8 +267,8 @@ func (b *BaseCollector) WaitLogin() (bool, string) {
|
|||
}
|
||||
|
||||
// AskQuestion 提问并获取答案(需要子类实现)
|
||||
func (b *BaseCollector) AskQuestion(question string) (string, error) {
|
||||
return "", fmt.Errorf("需要实现")
|
||||
func (b *BaseCollector) AskQuestion(question string) (*CollectResult, error) {
|
||||
return nil, fmt.Errorf("需要实现")
|
||||
}
|
||||
|
||||
// InitPage 初始化页面
|
||||
|
|
@ -333,3 +300,10 @@ func (b *BaseCollector) SafeElement(selector string) (*rod.Element, error) {
|
|||
}
|
||||
return b.Page.Element(selector)
|
||||
}
|
||||
|
||||
// cookiesDir 获取cookie目录 - 按平台区分
|
||||
func (b *BaseCollector) cookiesDir() string {
|
||||
dir := filepath.Join(b.config.Sys.CookiesDir, b.Platform)
|
||||
os.MkdirAll(dir, 0755)
|
||||
return dir
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"geo/internal/config"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// DeepseekCollector DeepSeek收集器
|
||||
|
|
@ -18,7 +18,7 @@ type DeepseekCollector struct {
|
|||
}
|
||||
|
||||
// NewDeepseekCollector 创建DeepSeek收集器
|
||||
func NewDeepseekCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger *log.Logger) CollectorInterface {
|
||||
func NewDeepseekCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger log.AllLogger) CollectorInterface {
|
||||
collector := &DeepseekCollector{
|
||||
BaseCollector: NewBaseCollector(ctx, params, cfg, logger),
|
||||
}
|
||||
|
|
@ -54,81 +54,65 @@ func (c *DeepseekCollector) CheckLoginStatus() bool {
|
|||
|
||||
// WaitLogin 等待登录
|
||||
func (c *DeepseekCollector) WaitLogin() (bool, string) {
|
||||
c.LogInfo("开始等待DeepSeek登录...")
|
||||
|
||||
if err := c.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// 访问聊天页面
|
||||
c.Page.MustNavigate(c.ChatURL)
|
||||
c.Sleep(3)
|
||||
|
||||
// 检查是否已登录
|
||||
if c.CheckLoginStatus() {
|
||||
c.SaveCookies()
|
||||
c.LogInfo("已有登录状态")
|
||||
return true, "already_logged_in"
|
||||
}
|
||||
|
||||
c.LogInfo("未检测到登录状态,请登录账号...")
|
||||
|
||||
// 等待用户手动登录,最多300秒
|
||||
for i := 0; i < 300; i++ {
|
||||
if c.CheckLoginStatus() {
|
||||
c.Sleep(2)
|
||||
c.SaveCookies()
|
||||
c.LogInfo("登录成功")
|
||||
return true, "login_success"
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
return false, "登录超时,请检查网络或账号状态"
|
||||
return false, "登录超时"
|
||||
}
|
||||
|
||||
// AskQuestion 提问并获取答案
|
||||
func (c *DeepseekCollector) AskQuestion(question string) (string, error) {
|
||||
c.LogInfo(fmt.Sprintf("开始向DeepSeek提问: %s", question))
|
||||
|
||||
// 初始化浏览器
|
||||
func (c *DeepseekCollector) AskQuestion(question string) (*CollectResult, error) {
|
||||
if err := c.SetupDriver(); err != nil {
|
||||
return "", fmt.Errorf("浏览器启动失败: %v", err)
|
||||
return nil, fmt.Errorf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// 初始化页面
|
||||
if err := c.InitPage(); err != nil {
|
||||
return "", fmt.Errorf("页面初始化失败,请先调用WaitLogin登录: %v", err)
|
||||
return nil, fmt.Errorf("页面初始化失败: %v", err)
|
||||
}
|
||||
|
||||
c.Sleep(3)
|
||||
|
||||
// 输入问题
|
||||
if err := c.inputQuestion(question); err != nil {
|
||||
return "", fmt.Errorf("输入问题失败: %v", err)
|
||||
return nil, fmt.Errorf("输入问题失败: %v", err)
|
||||
}
|
||||
|
||||
// 点击发送
|
||||
if err := c.clickSendButton(); err != nil {
|
||||
return "", fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
return nil, fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待并获取答案
|
||||
answer, err := c.waitForAnswer()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取答案失败: %v", err)
|
||||
return nil, fmt.Errorf("获取答案失败: %v", err)
|
||||
}
|
||||
|
||||
c.LogInfo(fmt.Sprintf("成功获取DeepSeek答案,长度: %d 字符", len(answer)))
|
||||
return answer, nil
|
||||
return &CollectResult{
|
||||
Answer: answer,
|
||||
ShareLink: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// inputQuestion 输入问题
|
||||
func (c *DeepseekCollector) inputQuestion(question string) error {
|
||||
c.LogInfo("输入问题到DeepSeek...")
|
||||
|
||||
// DeepSeek的输入框选择器
|
||||
inputSelectors := []string{
|
||||
"textarea[placeholder*='输入']",
|
||||
|
|
@ -145,7 +129,6 @@ func (c *DeepseekCollector) inputQuestion(question string) error {
|
|||
for _, selector := range inputSelectors {
|
||||
inputBox, err = c.WaitForElementVisible(selector, 10)
|
||||
if err == nil && inputBox != nil {
|
||||
c.LogInfo(fmt.Sprintf("找到输入框: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -162,7 +145,7 @@ func (c *DeepseekCollector) inputQuestion(question string) error {
|
|||
|
||||
// 清空输入框
|
||||
if err := c.ClearInput(inputBox); err != nil {
|
||||
c.LogInfo(fmt.Sprintf("清空输入框失败: %v", err))
|
||||
// Ignore clear error
|
||||
}
|
||||
c.SleepMs(300)
|
||||
|
||||
|
|
@ -171,7 +154,6 @@ func (c *DeepseekCollector) inputQuestion(question string) error {
|
|||
inputBox.Input(question)
|
||||
}
|
||||
|
||||
c.LogInfo(fmt.Sprintf("问题已输入"))
|
||||
c.SleepMs(1000)
|
||||
|
||||
return nil
|
||||
|
|
@ -179,8 +161,6 @@ func (c *DeepseekCollector) inputQuestion(question string) error {
|
|||
|
||||
// clickSendButton 点击发送按钮
|
||||
func (c *DeepseekCollector) clickSendButton() error {
|
||||
c.LogInfo("点击发送按钮...")
|
||||
|
||||
// 发送按钮选择器
|
||||
sendSelectors := []string{
|
||||
"button[class*='send']",
|
||||
|
|
@ -198,7 +178,6 @@ func (c *DeepseekCollector) clickSendButton() error {
|
|||
for _, selector := range sendSelectors {
|
||||
sendBtn, err = c.WaitForElementClickable(selector, 5)
|
||||
if err == nil && sendBtn != nil {
|
||||
c.LogInfo(fmt.Sprintf("找到发送按钮: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -218,7 +197,6 @@ func (c *DeepseekCollector) clickSendButton() error {
|
|||
return fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
}
|
||||
|
||||
c.LogInfo("已点击发送按钮")
|
||||
c.SleepMs(2000)
|
||||
|
||||
return nil
|
||||
|
|
@ -226,8 +204,6 @@ func (c *DeepseekCollector) clickSendButton() error {
|
|||
|
||||
// waitForAnswer 等待并获取答案
|
||||
func (c *DeepseekCollector) waitForAnswer() (string, error) {
|
||||
c.LogInfo("等待DeepSeek回答...")
|
||||
|
||||
timeout := 120 // 最大等待时间(秒)
|
||||
startTime := time.Now()
|
||||
lastAnswerLength := 0
|
||||
|
|
@ -262,7 +238,6 @@ func (c *DeepseekCollector) waitForAnswer() (string, error) {
|
|||
currentLength := len(text)
|
||||
if currentLength == lastAnswerLength && currentLength > 10 {
|
||||
// 答案不再增长,认为已完成
|
||||
c.LogInfo("获取到完整答案")
|
||||
return strings.TrimSpace(text), nil
|
||||
}
|
||||
lastAnswerLength = currentLength
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"geo/internal/config"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// DoubaoCollector 豆包收集器
|
||||
|
|
@ -18,7 +18,7 @@ type DoubaoCollector struct {
|
|||
}
|
||||
|
||||
// NewDoubaoCollector 创建豆包收集器
|
||||
func NewDoubaoCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger *log.Logger) CollectorInterface {
|
||||
func NewDoubaoCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger log.AllLogger) CollectorInterface {
|
||||
collector := &DoubaoCollector{
|
||||
BaseCollector: NewBaseCollector(ctx, params, cfg, logger),
|
||||
}
|
||||
|
|
@ -54,82 +54,65 @@ func (c *DoubaoCollector) CheckLoginStatus() bool {
|
|||
|
||||
// WaitLogin 等待登录
|
||||
func (c *DoubaoCollector) WaitLogin() (bool, string) {
|
||||
c.LogInfo("开始等待豆包登录...")
|
||||
|
||||
if err := c.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// 访问豆包首页
|
||||
c.Page.MustNavigate(c.LoginURL)
|
||||
c.Sleep(3)
|
||||
|
||||
// 检查是否已登录
|
||||
if c.CheckLoginStatus() {
|
||||
c.SaveCookies()
|
||||
c.LogInfo("已有登录状态")
|
||||
return true, "already_logged_in"
|
||||
}
|
||||
|
||||
c.LogInfo("请登录豆包账号...")
|
||||
|
||||
// 等待用户手动登录,最多300秒
|
||||
for i := 0; i < 300; i++ {
|
||||
if c.CheckLoginStatus() {
|
||||
c.Sleep(2)
|
||||
c.SaveCookies()
|
||||
c.LogInfo("登录成功")
|
||||
return true, "login_success"
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
return false, "登录超时,请检查网络或账号状态"
|
||||
return false, "登录超时"
|
||||
}
|
||||
|
||||
// AskQuestion 提问并获取答案
|
||||
func (c *DoubaoCollector) AskQuestion(question string) (string, error) {
|
||||
c.LogInfo(fmt.Sprintf("开始向豆包提问: %s", question))
|
||||
|
||||
// 初始化浏览器
|
||||
func (c *DoubaoCollector) AskQuestion(question string) (*CollectResult, error) {
|
||||
if err := c.SetupDriver(); err != nil {
|
||||
return "", fmt.Errorf("浏览器启动失败: %v", err)
|
||||
return nil, fmt.Errorf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// 初始化页面
|
||||
if err := c.InitPage(); err != nil {
|
||||
return "", fmt.Errorf("页面初始化失败,请先调用WaitLogin登录: %v", err)
|
||||
return nil, fmt.Errorf("页面初始化失败: %v", err)
|
||||
}
|
||||
|
||||
c.Sleep(3)
|
||||
|
||||
// 输入问题
|
||||
if err := c.inputQuestion(question); err != nil {
|
||||
return "", fmt.Errorf("输入问题失败: %v", err)
|
||||
return nil, fmt.Errorf("输入问题失败: %v", err)
|
||||
}
|
||||
|
||||
// 点击发送
|
||||
if err := c.clickSendButton(); err != nil {
|
||||
return "", fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
return nil, fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待并获取答案
|
||||
answer, err := c.waitForAnswer()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取答案失败: %v", err)
|
||||
return nil, fmt.Errorf("获取答案失败: %v", err)
|
||||
}
|
||||
|
||||
c.LogInfo(fmt.Sprintf("成功获取豆包答案,长度: %d 字符", len(answer)))
|
||||
return answer, nil
|
||||
return &CollectResult{
|
||||
Answer: answer,
|
||||
ShareLink: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// inputQuestion 输入问题
|
||||
func (c *DoubaoCollector) inputQuestion(question string) error {
|
||||
c.LogInfo("输入问题到豆包...")
|
||||
|
||||
// 豆包的输入框选择器
|
||||
inputSelectors := []string{
|
||||
"textarea[placeholder*='输入']",
|
||||
"textarea[placeholder*='问']",
|
||||
|
|
@ -146,7 +129,6 @@ func (c *DoubaoCollector) inputQuestion(question string) error {
|
|||
for _, selector := range inputSelectors {
|
||||
inputBox, err = c.WaitForElementVisible(selector, 10)
|
||||
if err == nil && inputBox != nil {
|
||||
c.LogInfo(fmt.Sprintf("找到输入框: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -155,24 +137,20 @@ func (c *DoubaoCollector) inputQuestion(question string) error {
|
|||
return fmt.Errorf("未找到输入框")
|
||||
}
|
||||
|
||||
// 点击获取焦点
|
||||
if err := inputBox.Click(proto.InputMouseButtonLeft, 1); err != nil {
|
||||
return fmt.Errorf("点击输入框失败: %v", err)
|
||||
}
|
||||
c.SleepMs(500)
|
||||
|
||||
// 清空输入框
|
||||
if err := c.ClearInput(inputBox); err != nil {
|
||||
c.LogInfo(fmt.Sprintf("清空输入框失败: %v", err))
|
||||
// Ignore clear error
|
||||
}
|
||||
c.SleepMs(300)
|
||||
|
||||
// 输入问题
|
||||
if err := c.SetInputValue(inputBox, question); err != nil {
|
||||
inputBox.Input(question)
|
||||
}
|
||||
|
||||
c.LogInfo(fmt.Sprintf("问题已输入"))
|
||||
c.SleepMs(1000)
|
||||
|
||||
return nil
|
||||
|
|
@ -180,9 +158,6 @@ func (c *DoubaoCollector) inputQuestion(question string) error {
|
|||
|
||||
// clickSendButton 点击发送按钮
|
||||
func (c *DoubaoCollector) clickSendButton() error {
|
||||
c.LogInfo("点击发送按钮...")
|
||||
|
||||
// 发送按钮选择器
|
||||
sendSelectors := []string{
|
||||
"button[class*='send']",
|
||||
"button[class*='submit']",
|
||||
|
|
@ -199,13 +174,11 @@ func (c *DoubaoCollector) clickSendButton() error {
|
|||
for _, selector := range sendSelectors {
|
||||
sendBtn, err = c.WaitForElementClickable(selector, 5)
|
||||
if err == nil && sendBtn != nil {
|
||||
c.LogInfo(fmt.Sprintf("找到发送按钮: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if sendBtn == nil {
|
||||
// 尝试查找发送图标
|
||||
sendBtn, err = c.Page.Element("button svg")
|
||||
if err != nil {
|
||||
return fmt.Errorf("未找到发送按钮")
|
||||
|
|
@ -214,12 +187,10 @@ func (c *DoubaoCollector) clickSendButton() error {
|
|||
|
||||
c.SleepMs(500)
|
||||
|
||||
// 点击发送按钮
|
||||
if err := c.JSClick(sendBtn); err != nil {
|
||||
return fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
}
|
||||
|
||||
c.LogInfo("已点击发送按钮")
|
||||
c.SleepMs(2000)
|
||||
|
||||
return nil
|
||||
|
|
@ -227,14 +198,11 @@ func (c *DoubaoCollector) clickSendButton() error {
|
|||
|
||||
// waitForAnswer 等待并获取答案
|
||||
func (c *DoubaoCollector) waitForAnswer() (string, error) {
|
||||
c.LogInfo("等待豆包回答...")
|
||||
|
||||
timeout := 120 // 最大等待时间(秒)
|
||||
timeout := 120
|
||||
startTime := time.Now()
|
||||
lastAnswerLength := 0
|
||||
|
||||
for time.Since(startTime).Seconds() < float64(timeout) {
|
||||
// 查找答案区域
|
||||
answerSelectors := []string{
|
||||
".message-content",
|
||||
".response-text",
|
||||
|
|
@ -247,24 +215,19 @@ func (c *DoubaoCollector) waitForAnswer() (string, error) {
|
|||
for _, selector := range answerSelectors {
|
||||
answerElements, err := c.Page.Elements(selector)
|
||||
if err == nil && len(answerElements) > 0 {
|
||||
// 获取最后一个答案元素
|
||||
lastAnswer := answerElements[len(answerElements)-1]
|
||||
|
||||
visible, _ := lastAnswer.Visible()
|
||||
if visible {
|
||||
text, err := lastAnswer.Text()
|
||||
if err == nil && len(strings.TrimSpace(text)) > 0 {
|
||||
// 检查是否正在生成
|
||||
isGenerating := strings.Contains(text, "正在") ||
|
||||
strings.Contains(text, "思考中") ||
|
||||
strings.Contains(text, "typing")
|
||||
|
||||
if !isGenerating {
|
||||
// 检查答案是否还在增长
|
||||
currentLength := len(text)
|
||||
if currentLength == lastAnswerLength && currentLength > 10 {
|
||||
// 答案不再增长,认为已完成
|
||||
c.LogInfo("获取到完整答案")
|
||||
return strings.TrimSpace(text), nil
|
||||
}
|
||||
lastAnswerLength = currentLength
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ package collect
|
|||
import (
|
||||
"context"
|
||||
"geo/internal/config"
|
||||
"log"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// CollectorInterface AI平台收集器接口
|
||||
|
|
@ -11,7 +12,13 @@ type CollectorInterface interface {
|
|||
// WaitLogin 等待登录
|
||||
WaitLogin() (bool, string)
|
||||
// AskQuestion 提问并获取答案
|
||||
AskQuestion(question string) (string, error)
|
||||
AskQuestion(question string) (*CollectResult, error)
|
||||
}
|
||||
|
||||
// CollectResult 收集结果
|
||||
type CollectResult struct {
|
||||
Answer string `json:"answer"` // AI回答内容
|
||||
ShareLink string `json:"share_link"` // 分享链接
|
||||
}
|
||||
|
||||
// NewCollector 创建收集器的工厂函数类型
|
||||
|
|
@ -19,22 +26,22 @@ type NewCollector func(
|
|||
ctx context.Context,
|
||||
param *CollectParams,
|
||||
cfg *config.Config,
|
||||
logger *log.Logger) CollectorInterface
|
||||
logger log.AllLogger) CollectorInterface
|
||||
|
||||
// CollectorValue 收集器配置信息
|
||||
type CollectorValue struct {
|
||||
Name string // 平台名称
|
||||
InitMethod NewCollector // 初始化方法
|
||||
Platform string // 平台标识: wenxin, deepseek, doubao, qianwen
|
||||
Name string // 平台名称
|
||||
InitMethod NewCollector // 初始化方法
|
||||
Platform string // 平台标识: wenxin, deepseek, doubao, qianwen
|
||||
Icon string
|
||||
}
|
||||
|
||||
// CollectParams 收集任务参数
|
||||
type CollectParams struct {
|
||||
Headless bool // 是否无头模式
|
||||
UserIndex string // 用户索引
|
||||
PlatIndex string // 平台索引
|
||||
RequestID string // 请求ID
|
||||
Platform string // 平台类型
|
||||
Headless bool // 是否无头模式
|
||||
RequestID string // 请求ID
|
||||
Platform string // 平台类型: wenxin, deepseek, doubao, qianwen
|
||||
|
||||
}
|
||||
|
||||
// CollectorMap 收集器注册表
|
||||
|
|
@ -43,20 +50,24 @@ var CollectorMap = map[string]*CollectorValue{
|
|||
Name: "文心一言",
|
||||
InitMethod: NewWenxinCollector,
|
||||
Platform: "wenxin",
|
||||
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/wenxin.png",
|
||||
},
|
||||
"deepseek": {
|
||||
Name: "DeepSeek",
|
||||
InitMethod: NewDeepseekCollector,
|
||||
Platform: "deepseek",
|
||||
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/deepseek.png",
|
||||
},
|
||||
"doubao": {
|
||||
Name: "豆包",
|
||||
InitMethod: NewDoubaoCollector,
|
||||
Platform: "doubao",
|
||||
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/doubao.png",
|
||||
},
|
||||
"qianwen": {
|
||||
Name: "通义千问",
|
||||
InitMethod: NewQianwenCollector,
|
||||
Platform: "qianwen",
|
||||
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/qianwen.png",
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,18 +4,19 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"geo/internal/config"
|
||||
"log"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// CollectManager 收集管理器
|
||||
type CollectManager struct {
|
||||
ctx context.Context
|
||||
config *config.Config
|
||||
logger *log.Logger
|
||||
logger log.AllLogger
|
||||
}
|
||||
|
||||
// NewCollectManager 创建收集管理器
|
||||
func NewCollectManager(ctx context.Context, cfg *config.Config, logger *log.Logger) *CollectManager {
|
||||
func NewCollectManager(ctx context.Context, cfg *config.Config, logger log.AllLogger) *CollectManager {
|
||||
return &CollectManager{
|
||||
ctx: ctx,
|
||||
config: cfg,
|
||||
|
|
@ -39,10 +40,10 @@ func (m *CollectManager) GetCollector(platform string, params *CollectParams) (C
|
|||
}
|
||||
|
||||
// AskQuestion 向指定AI平台提问
|
||||
func (m *CollectManager) AskQuestion(platform string, params *CollectParams, question string) (string, error) {
|
||||
func (m *CollectManager) AskQuestion(platform string, params *CollectParams, question string) (*CollectResult, error) {
|
||||
collector, err := m.GetCollector(platform, params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return collector.AskQuestion(question)
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"geo/internal/config"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// QianwenCollector 通义千问收集器
|
||||
|
|
@ -18,7 +18,7 @@ type QianwenCollector struct {
|
|||
}
|
||||
|
||||
// NewQianwenCollector 创建通义千问收集器
|
||||
func NewQianwenCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger *log.Logger) CollectorInterface {
|
||||
func NewQianwenCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger log.AllLogger) CollectorInterface {
|
||||
collector := &QianwenCollector{
|
||||
BaseCollector: NewBaseCollector(ctx, params, cfg, logger),
|
||||
}
|
||||
|
|
@ -54,81 +54,65 @@ func (c *QianwenCollector) CheckLoginStatus() bool {
|
|||
|
||||
// WaitLogin 等待登录
|
||||
func (c *QianwenCollector) WaitLogin() (bool, string) {
|
||||
c.LogInfo("开始等待通义千问登录...")
|
||||
|
||||
if err := c.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// 访问通义千问页面
|
||||
c.Page.MustNavigate(c.ChatURL)
|
||||
c.Sleep(3)
|
||||
|
||||
// 检查是否已登录
|
||||
if c.CheckLoginStatus() {
|
||||
c.SaveCookies()
|
||||
c.LogInfo("已有登录状态")
|
||||
return true, "already_logged_in"
|
||||
}
|
||||
|
||||
c.LogInfo("请登录阿里云账号...")
|
||||
|
||||
// 等待用户手动登录,最多300秒
|
||||
for i := 0; i < 300; i++ {
|
||||
if c.CheckLoginStatus() {
|
||||
c.Sleep(2)
|
||||
c.SaveCookies()
|
||||
c.LogInfo("登录成功")
|
||||
return true, "login_success"
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
return false, "登录超时,请检查网络或账号状态"
|
||||
return false, "登录超时"
|
||||
}
|
||||
|
||||
// AskQuestion 提问并获取答案
|
||||
func (c *QianwenCollector) AskQuestion(question string) (string, error) {
|
||||
c.LogInfo(fmt.Sprintf("开始向通义千问提问: %s", question))
|
||||
|
||||
// 初始化浏览器
|
||||
func (c *QianwenCollector) AskQuestion(question string) (*CollectResult, error) {
|
||||
if err := c.SetupDriver(); err != nil {
|
||||
return "", fmt.Errorf("浏览器启动失败: %v", err)
|
||||
return nil, fmt.Errorf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// 初始化页面
|
||||
if err := c.InitPage(); err != nil {
|
||||
return "", fmt.Errorf("页面初始化失败,请先调用WaitLogin登录: %v", err)
|
||||
return nil, fmt.Errorf("页面初始化失败: %v", err)
|
||||
}
|
||||
|
||||
c.Sleep(3)
|
||||
|
||||
// 输入问题
|
||||
if err := c.inputQuestion(question); err != nil {
|
||||
return "", fmt.Errorf("输入问题失败: %v", err)
|
||||
return nil, fmt.Errorf("输入问题失败: %v", err)
|
||||
}
|
||||
|
||||
// 点击发送
|
||||
if err := c.clickSendButton(); err != nil {
|
||||
return "", fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
return nil, fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待并获取答案
|
||||
answer, err := c.waitForAnswer()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取答案失败: %v", err)
|
||||
return nil, fmt.Errorf("获取答案失败: %v", err)
|
||||
}
|
||||
|
||||
c.LogInfo(fmt.Sprintf("成功获取通义千问答案,长度: %d 字符", len(answer)))
|
||||
return answer, nil
|
||||
return &CollectResult{
|
||||
Answer: answer,
|
||||
ShareLink: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// inputQuestion 输入问题
|
||||
func (c *QianwenCollector) inputQuestion(question string) error {
|
||||
c.LogInfo("输入问题到通义千问...")
|
||||
|
||||
// 通义千问的输入框选择器
|
||||
inputSelectors := []string{
|
||||
"textarea[placeholder*='输入']",
|
||||
|
|
@ -147,7 +131,6 @@ func (c *QianwenCollector) inputQuestion(question string) error {
|
|||
for _, selector := range inputSelectors {
|
||||
inputBox, err = c.WaitForElementVisible(selector, 10)
|
||||
if err == nil && inputBox != nil {
|
||||
c.LogInfo(fmt.Sprintf("找到输入框: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -164,7 +147,7 @@ func (c *QianwenCollector) inputQuestion(question string) error {
|
|||
|
||||
// 清空输入框
|
||||
if err := c.ClearInput(inputBox); err != nil {
|
||||
c.LogInfo(fmt.Sprintf("清空输入框失败: %v", err))
|
||||
// Ignore clear error
|
||||
}
|
||||
c.SleepMs(300)
|
||||
|
||||
|
|
@ -173,7 +156,6 @@ func (c *QianwenCollector) inputQuestion(question string) error {
|
|||
inputBox.Input(question)
|
||||
}
|
||||
|
||||
c.LogInfo(fmt.Sprintf("问题已输入"))
|
||||
c.SleepMs(1000)
|
||||
|
||||
return nil
|
||||
|
|
@ -181,8 +163,6 @@ func (c *QianwenCollector) inputQuestion(question string) error {
|
|||
|
||||
// clickSendButton 点击发送按钮
|
||||
func (c *QianwenCollector) clickSendButton() error {
|
||||
c.LogInfo("点击发送按钮...")
|
||||
|
||||
// 发送按钮选择器
|
||||
sendSelectors := []string{
|
||||
"button[class*='send']",
|
||||
|
|
@ -201,7 +181,6 @@ func (c *QianwenCollector) clickSendButton() error {
|
|||
for _, selector := range sendSelectors {
|
||||
sendBtn, err = c.WaitForElementClickable(selector, 5)
|
||||
if err == nil && sendBtn != nil {
|
||||
c.LogInfo(fmt.Sprintf("找到发送按钮: %s", selector))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -221,7 +200,6 @@ func (c *QianwenCollector) clickSendButton() error {
|
|||
return fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
}
|
||||
|
||||
c.LogInfo("已点击发送按钮")
|
||||
c.SleepMs(2000)
|
||||
|
||||
return nil
|
||||
|
|
@ -229,8 +207,6 @@ func (c *QianwenCollector) clickSendButton() error {
|
|||
|
||||
// waitForAnswer 等待并获取答案
|
||||
func (c *QianwenCollector) waitForAnswer() (string, error) {
|
||||
c.LogInfo("等待通义千问回答...")
|
||||
|
||||
timeout := 120 // 最大等待时间(秒)
|
||||
startTime := time.Now()
|
||||
lastAnswerLength := 0
|
||||
|
|
@ -268,7 +244,6 @@ func (c *QianwenCollector) waitForAnswer() (string, error) {
|
|||
currentLength := len(text)
|
||||
if currentLength == lastAnswerLength && currentLength > 10 {
|
||||
// 答案不再增长,认为已完成
|
||||
c.LogInfo("获取到完整答案")
|
||||
return strings.TrimSpace(text), nil
|
||||
}
|
||||
lastAnswerLength = currentLength
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"geo/internal/config"
|
||||
"log"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/atotto/clipboard"
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// Source 文章引用来源结构体
|
||||
|
|
@ -28,7 +29,7 @@ type WenxinCollector struct {
|
|||
}
|
||||
|
||||
// NewWenxinCollector 创建文心一言收集器
|
||||
func NewWenxinCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger *log.Logger) CollectorInterface {
|
||||
func NewWenxinCollector(ctx context.Context, params *CollectParams, cfg *config.Config, logger log.AllLogger) CollectorInterface {
|
||||
collector := &WenxinCollector{
|
||||
BaseCollector: NewBaseCollector(ctx, params, cfg, logger),
|
||||
}
|
||||
|
|
@ -46,37 +47,11 @@ func (c *WenxinCollector) SetupDriver() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// 通过 JavaScript 设置 navigator.language 为中文
|
||||
jsCode := `
|
||||
(function() {
|
||||
Object.defineProperty(navigator, 'language', {
|
||||
get: function() { return 'zh-CN'; },
|
||||
configurable: true
|
||||
});
|
||||
Object.defineProperty(navigator, 'languages', {
|
||||
get: function() { return ['zh-CN', 'zh', 'en']; },
|
||||
configurable: true
|
||||
});
|
||||
})();
|
||||
`
|
||||
|
||||
if _, err := c.Page.Eval(jsCode); err != nil {
|
||||
c.LogInfo(fmt.Sprintf("设置语言失败: %v", err))
|
||||
} else {
|
||||
c.LogInfo("已设置浏览器语言为中文 (zh-CN)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckLoginStatus 检查登录状态
|
||||
func (c *WenxinCollector) CheckLoginStatus() bool {
|
||||
currentURL := c.GetCurrentURL()
|
||||
|
||||
// 如果在登录页面,说明未登录
|
||||
if strings.Contains(currentURL, "passport.baidu.com") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查页面上是否存在内容为"登录"或"Login"的button,如果存在说明未登录
|
||||
loginButtons, err := c.Page.Elements("button")
|
||||
|
|
@ -97,117 +72,69 @@ func (c *WenxinCollector) CheckLoginStatus() bool {
|
|||
|
||||
// WaitLogin 等待登录
|
||||
func (c *WenxinCollector) WaitLogin() (bool, string) {
|
||||
c.LogInfo("开始等待文心一言登录...")
|
||||
|
||||
if err := c.SetupDriver(); err != nil {
|
||||
return false, fmt.Sprintf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// 访问聊天页面
|
||||
c.Page.MustNavigate(c.ChatURL)
|
||||
c.Sleep(3)
|
||||
|
||||
// 检查是否已登录
|
||||
if c.CheckLoginStatus() {
|
||||
c.SaveCookies()
|
||||
c.LogInfo("已有登录状态")
|
||||
return true, "already_logged_in"
|
||||
}
|
||||
|
||||
c.LogInfo("检测到未登录,请在当前页面完成登录(扫码或输入账号密码)...")
|
||||
|
||||
// 不跳转页面,在当前页面循环检查登录按钮是否存在
|
||||
// 最多等待300秒
|
||||
for i := 0; i < 5000; i++ {
|
||||
// 检查页面上是否还存在"登录"或"Login"按钮
|
||||
loginButtonExists := false
|
||||
buttons, err := c.Page.Elements("button")
|
||||
if err == nil {
|
||||
for _, btn := range buttons {
|
||||
text, _ := btn.Text()
|
||||
trimmedText := strings.TrimSpace(text)
|
||||
if trimmedText == "登录" || trimmedText == "Login" {
|
||||
loginButtonExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果登录按钮不存在,说明已登录
|
||||
if !loginButtonExists {
|
||||
c.Sleep(2) // 等待页面稳定
|
||||
for i := 0; i < 300; i++ {
|
||||
if c.CheckLoginStatus() {
|
||||
c.Sleep(2)
|
||||
c.SaveCookies()
|
||||
c.LogInfo("登录成功:登录按钮已消失")
|
||||
return true, "login_success"
|
||||
}
|
||||
|
||||
// 每秒检查一次
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// 每30秒输出一次提示
|
||||
if i > 0 && i%30 == 0 {
|
||||
c.LogInfo(fmt.Sprintf("等待登录中... 已等待 %d 秒", i))
|
||||
}
|
||||
}
|
||||
|
||||
return false, "登录超时,请检查网络或账号状态"
|
||||
return false, "登录超时"
|
||||
}
|
||||
|
||||
// AskQuestion 提问并获取答案
|
||||
func (c *WenxinCollector) AskQuestion(question string) (string, error) {
|
||||
c.LogInfo(fmt.Sprintf("开始提问: %s", question))
|
||||
|
||||
// 初始化浏览器
|
||||
func (c *WenxinCollector) AskQuestion(question string) (*CollectResult, error) {
|
||||
if err := c.SetupDriver(); err != nil {
|
||||
return "", fmt.Errorf("浏览器启动失败: %v", err)
|
||||
return nil, fmt.Errorf("浏览器启动失败: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
//初始化页面(加载cookies和检查登录)
|
||||
if err := c.InitPage(); err != nil {
|
||||
return "", fmt.Errorf("页面初始化失败,请先调用WaitLogin登录: %v", err)
|
||||
return nil, fmt.Errorf("页面初始化失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待页面完全加载
|
||||
c.Sleep(3)
|
||||
|
||||
// 查找输入框并输入问题
|
||||
if err := c.inputQuestion(question); err != nil {
|
||||
return "", fmt.Errorf("输入问题失败: %v", err)
|
||||
return nil, fmt.Errorf("输入问题失败: %v", err)
|
||||
}
|
||||
|
||||
// 点击发送按钮
|
||||
if err := c.clickSendButton(); err != nil {
|
||||
return "", fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
return nil, fmt.Errorf("点击发送按钮失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待并获取答案
|
||||
answer, err := c.waitForAnswer()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取答案失败: %v", err)
|
||||
return nil, fmt.Errorf("获取答案失败: %v", err)
|
||||
}
|
||||
|
||||
c.LogInfo(fmt.Sprintf("成功获取答案,长度: %d 字符", len(answer)))
|
||||
|
||||
// 获取分享链接
|
||||
_, shareErr := c.getShareLink()
|
||||
if shareErr != nil {
|
||||
c.LogInfo(fmt.Sprintf("分享链接获取状态: %v", shareErr))
|
||||
shareLink := ""
|
||||
link, _ := c.getShareLink()
|
||||
if link != "" {
|
||||
shareLink = link
|
||||
}
|
||||
|
||||
// 获取引用来源
|
||||
sources, sourcesErr := c.GetSources()
|
||||
if sourcesErr != nil {
|
||||
c.LogInfo(fmt.Sprintf("引用来源获取失败: %v", sourcesErr))
|
||||
} else if len(sources) > 0 {
|
||||
c.LogInfo(fmt.Sprintf("成功获取 %d 个引用来源", len(sources)))
|
||||
for i, source := range sources {
|
||||
c.LogInfo(fmt.Sprintf(" [%d] 标题: %s, 来源: %s, URL: %s", i+1, source.Title, source.PlatformName, source.Url))
|
||||
}
|
||||
}
|
||||
|
||||
return answer, nil
|
||||
return &CollectResult{
|
||||
Answer: answer,
|
||||
ShareLink: shareLink,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// inputQuestion 输入问题
|
||||
|
|
@ -411,7 +338,7 @@ func (c *WenxinCollector) waitForAnswer() (string, error) {
|
|||
htmlContent, err := answerElem.HTML()
|
||||
if err == nil && len(strings.TrimSpace(htmlContent)) > 30 {
|
||||
// 清理HTML标签,只保留纯文本
|
||||
answerText = CleanHTMLTags(htmlContent)
|
||||
answerText = CleanDivTags(htmlContent)
|
||||
c.LogInfo(fmt.Sprintf("找到答案容器,清理后文本长度: %d", len(answerText)))
|
||||
} else {
|
||||
// 如果HTML获取失败,尝试获取文本
|
||||
|
|
@ -557,61 +484,105 @@ func (c *WenxinCollector) getShareLink() (string, error) {
|
|||
}
|
||||
|
||||
c.LogInfo("✓ 点击成功")
|
||||
c.SleepMs(2000) // 等待弹窗出现
|
||||
c.SleepMs(3000) // 等待弹窗出现
|
||||
c.Screenshot("after_share_icon_click")
|
||||
|
||||
// 步骤3: 在弹窗中查找shareContainer的div
|
||||
// 步骤3: 在弹窗中查找shareContainer的div(带重试机制)
|
||||
c.LogInfo("步骤3: 查找包含'shareContainer'的div元素...")
|
||||
|
||||
var shareContainerDiv *rod.Element
|
||||
maxRetries := 5
|
||||
retryDelay := 1000 // 每次重试间隔1秒
|
||||
|
||||
// 重新获取所有div元素
|
||||
allDivs, err = c.Page.Elements("div")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取页面div元素失败: %v", err)
|
||||
}
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
c.LogInfo(fmt.Sprintf("第 %d/%d 次尝试查找shareContainer...", attempt, maxRetries))
|
||||
|
||||
c.LogInfo(fmt.Sprintf("在 %d 个div元素中查找包含'shareContainer'的class", len(allDivs)))
|
||||
// 重新获取所有div元素
|
||||
allDivs, err = c.Page.Elements("div")
|
||||
if err != nil {
|
||||
c.LogInfo(fmt.Sprintf("获取页面div元素失败: %v", err))
|
||||
if attempt < maxRetries {
|
||||
c.SleepMs(retryDelay)
|
||||
continue
|
||||
}
|
||||
return "", fmt.Errorf("获取页面div元素失败: %v", err)
|
||||
}
|
||||
|
||||
for _, elem := range allDivs {
|
||||
classAttr, _ := elem.Attribute("class")
|
||||
if classAttr != nil && strings.Contains(strings.ToLower(*classAttr), "sharecontainer") {
|
||||
tagName, _ := elem.Property("tagName")
|
||||
c.LogInfo(fmt.Sprintf("✓ 找到shareContainer容器: tag=%s, class=%s", tagName.Str(), *classAttr))
|
||||
shareContainerDiv = elem
|
||||
break
|
||||
c.LogInfo(fmt.Sprintf("在 %d 个div元素中查找包含'shareContainer'的class", len(allDivs)))
|
||||
|
||||
for _, elem := range allDivs {
|
||||
classAttr, _ := elem.Attribute("class")
|
||||
if classAttr != nil && strings.Contains(strings.ToLower(*classAttr), "sharecontainer") {
|
||||
tagName, _ := elem.Property("tagName")
|
||||
c.LogInfo(fmt.Sprintf("✓ 找到shareContainer容器: tag=%s, class=%s", tagName.Str(), *classAttr))
|
||||
shareContainerDiv = elem
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if shareContainerDiv != nil {
|
||||
break // 找到了,退出重试循环
|
||||
}
|
||||
|
||||
// 没找到,等待后重试
|
||||
if attempt < maxRetries {
|
||||
c.LogInfo(fmt.Sprintf("未找到shareContainer,%d毫秒后重试...", retryDelay))
|
||||
c.SleepMs(retryDelay)
|
||||
}
|
||||
}
|
||||
|
||||
if shareContainerDiv == nil {
|
||||
return "", fmt.Errorf("未找到包含'shareContainer' class的div元素")
|
||||
c.Screenshot("share_container_not_found")
|
||||
return "", fmt.Errorf("经过 %d 次重试仍未找到包含'shareContainer' class的div元素", maxRetries)
|
||||
}
|
||||
|
||||
// 步骤4: 在shareContainer内查找genLink的button
|
||||
// 步骤4: 在shareContainer内查找genLink的button(带重试机制)
|
||||
c.LogInfo("步骤4: 在shareContainer容器内查找包含'genLink'的button...")
|
||||
|
||||
var genLinkBtn *rod.Element
|
||||
maxRetries = 3
|
||||
retryDelay = 800
|
||||
|
||||
buttons, err := shareContainerDiv.Elements("button")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取button元素失败: %v", err)
|
||||
}
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
c.LogInfo(fmt.Sprintf("第 %d/%d 次尝试查找genLink按钮...", attempt, maxRetries))
|
||||
|
||||
c.LogInfo(fmt.Sprintf("在 %d 个button元素中查找包含'genLink'的class", len(buttons)))
|
||||
buttons, err := shareContainerDiv.Elements("button")
|
||||
if err != nil {
|
||||
c.LogInfo(fmt.Sprintf("获取button元素失败: %v", err))
|
||||
if attempt < maxRetries {
|
||||
c.SleepMs(retryDelay)
|
||||
continue
|
||||
}
|
||||
return "", fmt.Errorf("获取button元素失败: %v", err)
|
||||
}
|
||||
|
||||
for _, elem := range buttons {
|
||||
classAttr, _ := elem.Attribute("class")
|
||||
if classAttr != nil && strings.Contains(strings.ToLower(*classAttr), "genlink") {
|
||||
tagName, _ := elem.Property("tagName")
|
||||
text, _ := elem.Text()
|
||||
c.LogInfo(fmt.Sprintf("✓ 找到genLink按钮: tag=%s, class=%s, text=%s", tagName.Str(), *classAttr, strings.TrimSpace(text)))
|
||||
genLinkBtn = elem
|
||||
break
|
||||
c.LogInfo(fmt.Sprintf("在 %d 个button元素中查找包含'genLink'的class", len(buttons)))
|
||||
|
||||
for _, elem := range buttons {
|
||||
classAttr, _ := elem.Attribute("class")
|
||||
if classAttr != nil && strings.Contains(strings.ToLower(*classAttr), "genlink") {
|
||||
tagName, _ := elem.Property("tagName")
|
||||
text, _ := elem.Text()
|
||||
c.LogInfo(fmt.Sprintf("✓ 找到genLink按钮: tag=%s, class=%s, text=%s", tagName.Str(), *classAttr, strings.TrimSpace(text)))
|
||||
genLinkBtn = elem
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if genLinkBtn != nil {
|
||||
break // 找到了,退出重试循环
|
||||
}
|
||||
|
||||
// 没找到,等待后重试
|
||||
if attempt < maxRetries {
|
||||
c.LogInfo(fmt.Sprintf("未找到genLink按钮,%d毫秒后重试...", retryDelay))
|
||||
c.SleepMs(retryDelay)
|
||||
}
|
||||
}
|
||||
|
||||
if genLinkBtn == nil {
|
||||
return "", fmt.Errorf("在shareContainer容器内未找到包含'genLink' class的button")
|
||||
c.Screenshot("genlink_button_not_found")
|
||||
return "", fmt.Errorf("经过 %d 次重试仍未在shareContainer容器内找到包含'genLink' class的button", maxRetries)
|
||||
}
|
||||
|
||||
// 滚动到按钮位置
|
||||
|
|
|
|||
|
|
@ -196,10 +196,30 @@ type (
|
|||
}
|
||||
|
||||
ProductCollectRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
Keywords []string `json:"keywords" validate:"required" zh:"关键词"`
|
||||
Platform []int `json:"platform" validate:"required" zh:"平台"`
|
||||
Question string `json:"question" validate:"required" zh:"问题"`
|
||||
ProductId int32 `json:"product_id" validate:"required" zh:"项目Id"`
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
Keywords []string `json:"keywords" validate:"required" zh:"关键词"`
|
||||
PlatformIndex []string `json:"platform_index" validate:"required" zh:"平台"`
|
||||
Question string `json:"question" validate:"required" zh:"问题"`
|
||||
ProductId int32 `json:"product_id" validate:"required" zh:"项目Id"`
|
||||
}
|
||||
|
||||
CollectListRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
ProductId int32 `json:"product_id" validate:"required" zh:"项目Id"`
|
||||
Page int `json:"page" zh:"页码"`
|
||||
Limit int `json:"limit" zh:"每页数量"`
|
||||
}
|
||||
|
||||
// PageRequest 分页请求
|
||||
PageRequest struct {
|
||||
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
|
||||
Page int `json:"page" zh:"页码"`
|
||||
Limit int `json:"limit" zh:"每页数量"`
|
||||
}
|
||||
|
||||
// ReqPageBo 分页参数
|
||||
ReqPageBo struct {
|
||||
Page int `json:"page"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ type AppModule struct {
|
|||
publishService *service.PublishService
|
||||
productService *service.ProductService
|
||||
productSourceService *service.ProductSourceService
|
||||
collectService *service.CollectService
|
||||
}
|
||||
|
||||
func NewAppModule(
|
||||
|
|
@ -24,6 +25,7 @@ func NewAppModule(
|
|||
publishService *service.PublishService,
|
||||
productService *service.ProductService,
|
||||
productSourceService *service.ProductSourceService,
|
||||
collectService *service.CollectService,
|
||||
) *AppModule {
|
||||
return &AppModule{
|
||||
cfg: cfg,
|
||||
|
|
@ -32,6 +34,7 @@ func NewAppModule(
|
|||
publishService: publishService,
|
||||
productService: productService,
|
||||
productSourceService: productSourceService,
|
||||
collectService: collectService,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -59,7 +62,6 @@ func (m *AppModule) Register(router fiber.Router) {
|
|||
router.Post("/product/detail", vali(m.productService.Detail, &entitys.ProductDetailRequest{}))
|
||||
router.Post("/product/update", vali(m.productService.Update, &entitys.ProductUpdateRequest{}))
|
||||
router.Post("/product/del", vali(m.productService.Del, &entitys.ProductDelRequest{}))
|
||||
router.Post("/product/collect", vali(m.productService.Collect, &entitys.ProductCollectRequest{}))
|
||||
router.Post("/img/upload", m.productService.ImgUpload)
|
||||
|
||||
router.Post("/plat/list", vali(m.appService.PlatList, &entitys.PlatListRequest{}))
|
||||
|
|
@ -70,4 +72,8 @@ func (m *AppModule) Register(router fiber.Router) {
|
|||
router.Post("/product/source/update", vali(m.productSourceService.Update, &entitys.ProductSourceUpdateRequest{}))
|
||||
router.Post("/product/source/del", vali(m.productSourceService.Del, &entitys.ProductSourceDelRequest{}))
|
||||
router.Post("/product/source/publish", vali(m.productSourceService.Publish, &entitys.ProductPublishRequest{}))
|
||||
|
||||
router.Post("/collect/create", vali(m.collectService.Collect, &entitys.ProductCollectRequest{}))
|
||||
router.Get("/collect/platforms", m.collectService.GetCollectPlatForms)
|
||||
router.Post("/collect/list", vali(m.collectService.CollectList, &entitys.CollectListRequest{}))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,204 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geo/internal/biz"
|
||||
"geo/internal/collect"
|
||||
"geo/internal/config"
|
||||
"geo/internal/data/impl"
|
||||
"geo/internal/data/model"
|
||||
"geo/internal/entitys"
|
||||
"geo/tmpl/dataTemp"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// CollectService 收集服务
|
||||
type CollectService struct {
|
||||
cfg *config.Config
|
||||
collectBiz *biz.CollectBiz
|
||||
collect *impl.CollectImpl
|
||||
collectTask *impl.CollectTaskImpl
|
||||
authBiz *biz.AuthBiz
|
||||
}
|
||||
|
||||
// NewCollectService 创建收集服务
|
||||
func NewCollectService(
|
||||
cfg *config.Config,
|
||||
collectBiz *biz.CollectBiz,
|
||||
collect *impl.CollectImpl,
|
||||
collectTask *impl.CollectTaskImpl,
|
||||
authBiz *biz.AuthBiz,
|
||||
) *CollectService {
|
||||
return &CollectService{
|
||||
cfg: cfg,
|
||||
collectBiz: collectBiz,
|
||||
collect: collect,
|
||||
collectTask: collectTask,
|
||||
authBiz: authBiz,
|
||||
}
|
||||
}
|
||||
|
||||
// CollectList 获取收集列表及对应的任务详情
|
||||
func (c *CollectService) CollectList(ctx *fiber.Ctx, req *entitys.CollectListRequest) error {
|
||||
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 构建查询条件
|
||||
cond := builder.NewCond()
|
||||
if req.ProductId > 0 {
|
||||
cond = cond.And(builder.Eq{"product_id": req.ProductId})
|
||||
}
|
||||
|
||||
// 查询 collect 列表
|
||||
var collects []model.Collect
|
||||
pageBo, err := c.collect.GetListToStruct(ctx.UserContext(), &cond, &dataTemp.ReqPageBo{Page: req.Page, Limit: req.Limit}, &collects, "created_at DESC")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 提取所有的 collect_code
|
||||
collectCodes := make([]string, 0, len(collects))
|
||||
for _, collect := range collects {
|
||||
collectCodes = append(collectCodes, collect.CollectCode)
|
||||
}
|
||||
|
||||
// 批量查询所有相关的 collect_task
|
||||
var tasks []model.CollectTask
|
||||
if len(collectCodes) > 0 {
|
||||
taskCond := builder.NewCond().And(builder.In("collect_code", collectCodes))
|
||||
_, err := c.collectTask.GetListToStruct(ctx.UserContext(), &taskCond, nil, &tasks, "created_at ASC")
|
||||
if err != nil {
|
||||
log.Printf("查询 collect_task 失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 按 collect_code 分组 tasks
|
||||
tasksMap := make(map[string][]model.CollectTask)
|
||||
for _, task := range tasks {
|
||||
tasksMap[task.CollectCode] = append(tasksMap[task.CollectCode], task)
|
||||
}
|
||||
|
||||
// 组装返回数据
|
||||
result := make([]*CollectWithTasks, 0, len(collects))
|
||||
for _, collect := range collects {
|
||||
cwt := &CollectWithTasks{
|
||||
Collect: collect,
|
||||
Tasks: tasksMap[collect.CollectCode],
|
||||
}
|
||||
result = append(result, cwt)
|
||||
}
|
||||
|
||||
return ctx.JSON(fiber.Map{
|
||||
"list": result,
|
||||
"total": pageBo.Total,
|
||||
"page": req.Page,
|
||||
"pageSize": req.Limit,
|
||||
})
|
||||
}
|
||||
|
||||
// CollectWithTasks 收集记录及其任务列表
|
||||
type CollectWithTasks struct {
|
||||
model.Collect
|
||||
Tasks []model.CollectTask `json:"tasks"`
|
||||
}
|
||||
|
||||
func (c *CollectService) GetCollectPlatForms(ctx *fiber.Ctx) error {
|
||||
var list = make([]map[string]string, 0, len(collect.CollectorMap))
|
||||
for _, v := range collect.CollectorMap {
|
||||
list = append(list, map[string]string{
|
||||
"name": v.Name,
|
||||
"index": v.Platform,
|
||||
"icon": v.Icon,
|
||||
})
|
||||
}
|
||||
return ctx.JSON(list)
|
||||
}
|
||||
|
||||
// Collect 创建收集任务
|
||||
func (c *CollectService) Collect(ctx *fiber.Ctx, req *entitys.ProductCollectRequest) error {
|
||||
_, err := c.authBiz.ValidateAccessToken(ctx.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano())
|
||||
collectData := &model.Collect{
|
||||
CollectCode: collectCode,
|
||||
ProductID: req.ProductId,
|
||||
Keywords: strings.Join(req.Keywords, ","),
|
||||
Platform: strings.Join(req.PlatformIndex, ","),
|
||||
Question: req.Question,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
err = c.collect.Add(ctx.UserContext(), collectData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go c.doCollectAsync(collectCode, req.PlatformIndex, req.Question)
|
||||
|
||||
return ctx.JSON(fiber.Map{"message": "收录生成中"})
|
||||
}
|
||||
|
||||
// doCollectAsync 异步执行收集任务
|
||||
func (c *CollectService) doCollectAsync(collectCode string, platforms []string, question string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
tasks := make([]*model.CollectTask, 0, len(platforms))
|
||||
|
||||
for _, platIndex := range platforms {
|
||||
wg.Add(1)
|
||||
|
||||
go func(platIndex string) {
|
||||
defer wg.Done()
|
||||
|
||||
platformName, exist := collect.CollectorMap[platIndex]
|
||||
if !exist {
|
||||
log.Printf("未知的平台索引: %d", platIndex)
|
||||
return
|
||||
}
|
||||
|
||||
requestID := fmt.Sprintf("%s_%s", collectCode, platIndex)
|
||||
result, err := c.collectBiz.AskAIQuestion(platIndex, requestID, question, true)
|
||||
if err != nil {
|
||||
log.Printf("平台 %s 收集失败: %v", platformName, err)
|
||||
return
|
||||
}
|
||||
|
||||
task := &model.CollectTask{
|
||||
CollectCode: collectCode,
|
||||
AiPlatformIndex: platIndex,
|
||||
ContentHTML: result.Answer,
|
||||
ShareURL: result.ShareLink,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
Status: 1,
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
tasks = append(tasks, task)
|
||||
mu.Unlock()
|
||||
}(platIndex)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(tasks) > 0 {
|
||||
if err := c.collectTask.Add(ctx, tasks); err != nil {
|
||||
log.Printf("保存收集任务失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geo/internal/ai_tool"
|
||||
"geo/internal/biz"
|
||||
"geo/internal/config"
|
||||
|
|
@ -13,14 +11,8 @@ import (
|
|||
"geo/tmpl/dataTemp"
|
||||
"geo/tmpl/errcode"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-viper/mapstructure/v2"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
|
@ -33,8 +25,6 @@ type ProductService struct {
|
|||
authBiz *biz.AuthBiz
|
||||
productBiz *biz.ProductBiz
|
||||
aiBiz *biz.AiBiz
|
||||
collect *impl.CollectImpl
|
||||
collectTask *impl.CollectTaskImpl
|
||||
}
|
||||
|
||||
func NewProductService(
|
||||
|
|
@ -43,8 +33,6 @@ func NewProductService(
|
|||
authBiz *biz.AuthBiz,
|
||||
productBiz *biz.ProductBiz,
|
||||
aiBiz *biz.AiBiz,
|
||||
collect *impl.CollectImpl,
|
||||
collectTask *impl.CollectTaskImpl,
|
||||
) *ProductService {
|
||||
return &ProductService{
|
||||
cfg: cfg,
|
||||
|
|
@ -52,8 +40,6 @@ func NewProductService(
|
|||
authBiz: authBiz,
|
||||
productBiz: productBiz,
|
||||
aiBiz: aiBiz,
|
||||
collect: collect,
|
||||
collectTask: collectTask,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -243,357 +229,3 @@ func (p *ProductService) CreateProductInfoByDocx(c *fiber.Ctx) error {
|
|||
|
||||
return pkg.HandleResponse(c, productInfo)
|
||||
}
|
||||
|
||||
type CollectInfo struct {
|
||||
AIPlatformIndex string `json:"ai_platform_index"`
|
||||
ContentHtml string `json:"content_html"`
|
||||
ShareUrl string `json:"share_url"`
|
||||
Source []Source `json:"source"`
|
||||
}
|
||||
|
||||
type Source struct {
|
||||
Title string `json:"name"`
|
||||
Url string `json:"url"`
|
||||
PlatformName string `json:"platform"`
|
||||
PlatformIcon string `json:"Platform_icon"`
|
||||
}
|
||||
|
||||
func (p *ProductService) Collect(c *fiber.Ctx, req *entitys.ProductCollectRequest) error {
|
||||
log.Printf("[DEBUG] ========== 请求开始 ==========")
|
||||
log.Printf("[DEBUG] 请求时间: %s", time.Now().Format("2006-01-02 15:04:05.000"))
|
||||
log.Printf("[Collect] 开始处理收集请求, ProductID: %d, Platforms: %v, Keywords: %v",
|
||||
req.ProductId, req.Platform, req.Keywords)
|
||||
|
||||
_, err := p.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
|
||||
if err != nil {
|
||||
log.Printf("[Collect] 验证token失败, ProductID: %d, Error: %v", req.ProductId, err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = p.productBiz.GetProduct(c.UserContext(), req.ProductId)
|
||||
if err != nil {
|
||||
log.Printf("[Collect] 获取产品信息失败, ProductID: %d, Error: %v", req.ProductId, err)
|
||||
return err
|
||||
}
|
||||
|
||||
platformStr := make([]string, len(req.Platform))
|
||||
for i, s := range req.Platform {
|
||||
platformStr[i] = strconv.Itoa(s)
|
||||
}
|
||||
|
||||
collectCode := fmt.Sprintf("C%d_%d", req.ProductId, time.Now().UnixNano())
|
||||
collectData := &model.Collect{
|
||||
CollectCode: collectCode,
|
||||
ProductID: req.ProductId,
|
||||
Keywords: strings.Join(req.Keywords, ","),
|
||||
Platform: strings.Join(platformStr, ","),
|
||||
Question: req.Question,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
log.Printf("[Collect] 创建收集记录, CollectCode: %s, ProductID: %d", collectCode, req.ProductId)
|
||||
|
||||
err = p.collect.Add(c.UserContext(), collectData)
|
||||
if err != nil {
|
||||
log.Printf("[Collect] 保存收集记录失败, CollectCode: %s, Error: %v", collectCode, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("[Collect] ✅ 启动异步收集任务, CollectCode: %s, Platforms: %v", collectCode, req.Platform)
|
||||
|
||||
go func() {
|
||||
// 记录 goroutine 启动时间
|
||||
startTime := time.Now()
|
||||
log.Printf("[Goroutine] 异步任务启动, CollectCode: %s, 启动时间: %s", collectCode, startTime.Format("15:04:05.000"))
|
||||
|
||||
// 使用独立 context,避免请求结束后任务被取消
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
|
||||
|
||||
// 监控 context 取消
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
log.Printf("[Goroutine] ❌ Context被取消! CollectCode: %s, 原因: %v, 耗时: %v",
|
||||
collectCode, ctx.Err(), time.Since(startTime))
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[Goroutine] ❌ PANIC: %v\nStack: %s", r, debug.Stack())
|
||||
}
|
||||
log.Printf("[Goroutine] 异步任务结束, CollectCode: %s, 总耗时: %v", collectCode, time.Since(startTime))
|
||||
cancel()
|
||||
log.Printf("[Goroutine] 已调用 cancel(), CollectCode: %s", collectCode)
|
||||
}()
|
||||
|
||||
log.Printf("[Goroutine] 准备调用 doCollect, CollectCode: %s", collectCode)
|
||||
p.doCollect(ctx, collectData, req.Platform)
|
||||
log.Printf("[Goroutine] doCollect 已返回, CollectCode: %s", collectCode)
|
||||
}()
|
||||
|
||||
log.Printf("[DEBUG] ========== 请求返回 ==========")
|
||||
return pkg.HandleResponse(c, "收录生成中")
|
||||
}
|
||||
|
||||
func (p *ProductService) doCollect(ctx context.Context, collectData *model.Collect, platforms []int) {
|
||||
collectCode := collectData.CollectCode
|
||||
startTime := time.Now()
|
||||
|
||||
log.Printf("[doCollect] ========== 开始执行 ==========")
|
||||
log.Printf("[doCollect] CollectCode: %s, Platforms: %v", collectCode, platforms)
|
||||
log.Printf("[doCollect] Context状态: %v, 超时时间: %v", ctx.Err(), time.Second*240)
|
||||
|
||||
// 监控 context
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
log.Printf("[doCollect] ⚠️ 检测到Context取消! CollectCode: %s, 原因: %v, 已执行时间: %v",
|
||||
collectCode, ctx.Err(), time.Since(startTime))
|
||||
}()
|
||||
|
||||
collectClient := ai_tool.NewCollect(p.cfg.Collect.ApiKey)
|
||||
log.Printf("[doCollect] 已创建 collectClient")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
resCh := make(chan *model.CollectTask, len(platforms))
|
||||
log.Printf("[doCollect] 创建 channel, 容量: %d", len(platforms))
|
||||
|
||||
// 启动监控 goroutine
|
||||
monitorStart := time.Now()
|
||||
|
||||
// 启动所有平台的任务
|
||||
log.Printf("[doCollect] 启动 %d 个平台任务", len(platforms))
|
||||
for i, plat := range platforms {
|
||||
log.Printf("[doCollect] 启动任务 #%d, Platform: %d", i+1, plat)
|
||||
wg.Add(1)
|
||||
go p.processPlatform(ctx, &wg, collectClient, collectData, plat, resCh, i+1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("[Monitor] 监控goroutine启动, CollectCode: %s", collectCode)
|
||||
wg.Wait()
|
||||
log.Printf("[Monitor] ✅ 所有任务完成, 准备关闭channel, 等待时间: %v", time.Since(monitorStart))
|
||||
close(resCh)
|
||||
log.Printf("[Monitor] Channel已关闭")
|
||||
}()
|
||||
|
||||
// 收集结果 - 添加超时保护
|
||||
log.Printf("[doCollect] 开始等待结果...")
|
||||
var datas []*model.CollectTask
|
||||
taskCount := 0
|
||||
|
||||
// 设置一个最大等待时间
|
||||
waitTimeout := time.After(250 * time.Second)
|
||||
|
||||
for {
|
||||
select {
|
||||
case task, ok := <-resCh:
|
||||
if !ok {
|
||||
log.Printf("[doCollect] Channel已关闭, 收集到 %d 条结果", len(datas))
|
||||
goto SAVE
|
||||
}
|
||||
datas = append(datas, task)
|
||||
taskCount++
|
||||
log.Printf("[doCollect] ✅ 收到结果 #%d, Platform: %d, RequestID: %s, ScriptTime: %d",
|
||||
taskCount, task.Platform, task.RequestID, task.ScriptTime)
|
||||
|
||||
case <-waitTimeout:
|
||||
log.Printf("[doCollect] ⚠️ 等待超时 250秒, 强制退出, 已收集: %d/%d", taskCount, len(platforms))
|
||||
goto SAVE
|
||||
|
||||
case <-ctx.Done():
|
||||
log.Printf("[doCollect] ❌ Context取消, 强制退出, 已收集: %d/%d, 原因: %v",
|
||||
taskCount, len(platforms), ctx.Err())
|
||||
goto SAVE
|
||||
}
|
||||
}
|
||||
|
||||
SAVE:
|
||||
log.Printf("[doCollect] 收集完成, 共 %d 条结果", len(datas))
|
||||
|
||||
// 保存结果
|
||||
if len(datas) > 0 {
|
||||
log.Printf("[doCollect] 开始保存到数据库, 数量: %d", len(datas))
|
||||
saveStart := time.Now()
|
||||
if err := p.collectTask.Add(ctx, datas); err != nil {
|
||||
log.Printf("[doCollect] ❌ 保存失败: %v", err)
|
||||
} else {
|
||||
log.Printf("[doCollect] ✅ 保存成功, 耗时: %v", time.Since(saveStart))
|
||||
}
|
||||
} else {
|
||||
log.Printf("[doCollect] ⚠️ 没有结果需要保存")
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
log.Printf("[doCollect] ========== 结束执行, 总耗时: %v ==========", elapsed)
|
||||
}
|
||||
|
||||
func (p *ProductService) processPlatform(ctx context.Context, wg *sync.WaitGroup,
|
||||
collectClient *ai_tool.Collect, collectData *model.Collect, plat int,
|
||||
resCh chan<- *model.CollectTask, taskNum int) {
|
||||
|
||||
collectCode := collectData.CollectCode
|
||||
startTime := time.Now()
|
||||
|
||||
log.Printf("[Platform #%d] ========== 开始 ==========", taskNum)
|
||||
log.Printf("[Platform #%d] CollectCode: %s, Platform: %d", taskNum, collectCode, plat)
|
||||
|
||||
// 确保 wg.Done() 一定会被调用
|
||||
defer func() {
|
||||
log.Printf("[Platform #%d] 准备调用 wg.Done(), 已执行时间: %v", taskNum, time.Since(startTime))
|
||||
wg.Done()
|
||||
log.Printf("[Platform #%d] 已调用 wg.Done()", taskNum)
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[Platform #%d] ❌ PANIC: %v\nStack: %s", taskNum, r, debug.Stack())
|
||||
}
|
||||
log.Printf("[Platform #%d] ========== 结束, 耗时: %v ==========", taskNum, time.Since(startTime))
|
||||
}()
|
||||
|
||||
// 检查 context 是否已取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[Platform #%d] ❌ Context已取消, 退出执行, 原因: %v", taskNum, ctx.Err())
|
||||
return
|
||||
default:
|
||||
log.Printf("[Platform #%d] Context正常", taskNum)
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
request := ai_tool.CreateReq{
|
||||
Keywords: collectData.Keywords,
|
||||
Question: collectData.Question,
|
||||
Platform: plat,
|
||||
ThirdID: fmt.Sprintf("%s_%d", collectData.CollectCode, plat),
|
||||
}
|
||||
|
||||
log.Printf("[Platform #%d] 调用 Create API, Request: %+v", taskNum, request)
|
||||
|
||||
createStart := time.Now()
|
||||
res, err := collectClient.Create(ctx, &request)
|
||||
createElapsed := time.Since(createStart)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[Platform #%d] ❌ Create失败, 耗时: %v, Error: %v", taskNum, createElapsed, err)
|
||||
return
|
||||
}
|
||||
if res.Code != 1 {
|
||||
log.Printf("[Platform #%d] ❌ Create返回错误码, 耗时: %v, Code: %d, Message: %s",
|
||||
taskNum, createElapsed, res.Code, res.Msg)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[Platform #%d] ✅ Create成功, 耗时: %v, RequestID: %s",
|
||||
taskNum, createElapsed, res.Data.RequestId)
|
||||
|
||||
// 轮询任务状态
|
||||
log.Printf("[Platform #%d] 开始轮询, RequestID: %s", taskNum, res.Data.RequestId)
|
||||
|
||||
pollStart := time.Now()
|
||||
task := p.pollTaskStatus(ctx, collectClient, res.Data.RequestId, collectData, plat, taskNum)
|
||||
pollElapsed := time.Since(pollStart)
|
||||
|
||||
if task != nil {
|
||||
log.Printf("[Platform #%d] ✅ 轮询成功, 耗时: %v, ScriptTime: %d",
|
||||
taskNum, pollElapsed, task.ScriptTime)
|
||||
|
||||
// 发送结果到 channel
|
||||
log.Printf("[Platform #%d] 准备发送结果到channel", taskNum)
|
||||
select {
|
||||
case resCh <- task:
|
||||
log.Printf("[Platform #%d] ✅ 结果已发送到channel", taskNum)
|
||||
case <-ctx.Done():
|
||||
log.Printf("[Platform #%d] ⚠️ Context取消, 放弃发送结果", taskNum)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Platform #%d] ❌ 轮询失败, 耗时: %v, 未获取到结果", taskNum, pollElapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProductService) pollTaskStatus(ctx context.Context, collectClient *ai_tool.Collect,
|
||||
requestID string, collectData *model.Collect, plat int, taskNum int) *model.CollectTask {
|
||||
|
||||
collectCode := collectData.CollectCode
|
||||
startTime := time.Now()
|
||||
|
||||
log.Printf("[Poll #%d] ========== 开始轮询 ==========", taskNum)
|
||||
log.Printf("[Poll #%d] CollectCode: %s, Platform: %d, RequestID: %s",
|
||||
taskNum, collectCode, plat, requestID)
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
errCount := 0
|
||||
const maxErrors = 5
|
||||
pollCount := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[Poll #%d] ❌ Context取消, 停止轮询, 已轮询%d次, 耗时: %v, 原因: %v",
|
||||
taskNum, pollCount, time.Since(startTime), ctx.Err())
|
||||
return nil
|
||||
|
||||
case <-ticker.C:
|
||||
pollCount++
|
||||
log.Printf("[Poll #%d] 第 %d 次轮询, 已耗时: %v", taskNum, pollCount, time.Since(startTime))
|
||||
|
||||
checkStart := time.Now()
|
||||
checkRes, err := collectClient.CheckTask(ctx, requestID)
|
||||
checkElapsed := time.Since(checkStart)
|
||||
|
||||
if err != nil {
|
||||
errCount++
|
||||
log.Printf("[Poll #%d] ❌ 轮询失败(第%d次错误), 耗时: %v, Error: %v, 累计错误: %d/%d",
|
||||
taskNum, pollCount, checkElapsed, err, errCount, maxErrors)
|
||||
if errCount >= maxErrors {
|
||||
log.Printf("[Poll #%d] 达到最大错误次数, 停止轮询", taskNum)
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("[Poll #%d] ✅ 轮询成功, 耗时: %v, Code: %d, Status: %d, ScriptTime: %d, ShouluDate: %s",
|
||||
taskNum, checkElapsed, checkRes.Code, checkRes.Data.Status,
|
||||
checkRes.Data.ScriptTime, checkRes.Data.ShouluDate)
|
||||
|
||||
if checkRes.Code != 1 {
|
||||
log.Printf("[Poll #%d] ❌ 返回错误码: %d", taskNum, checkRes.Code)
|
||||
return nil
|
||||
}
|
||||
// 判断任务是否完成
|
||||
// 根据你的业务逻辑调整判断条件
|
||||
isCompleted := false
|
||||
completeReason := ""
|
||||
|
||||
if checkRes.Data.Status != 0 { // 假设 2 表示完成
|
||||
isCompleted = true
|
||||
completeReason = fmt.Sprintf("chekcStatus=%d", checkRes.Data.Status)
|
||||
}
|
||||
|
||||
if isCompleted {
|
||||
log.Printf("[Poll #%d] 🎉 任务完成! 原因: %s, 总轮询次数: %d, 总耗时: %v",
|
||||
taskNum, completeReason, pollCount, time.Since(startTime))
|
||||
|
||||
return &model.CollectTask{
|
||||
RequestID: checkRes.Data.RequestId,
|
||||
CollectCode: collectData.CollectCode,
|
||||
ScriptTime: int32(checkRes.Data.ScriptTime),
|
||||
Platform: int32(checkRes.Data.Platform),
|
||||
CollectData: checkRes.Data.ShouluDate,
|
||||
ShareURL: checkRes.Data.ShareUrl,
|
||||
ImgURL: checkRes.Data.ImgUrl,
|
||||
PointKeyword: checkRes.Data.HitWord,
|
||||
Question: checkRes.Data.Question,
|
||||
Res: pkg.JsonStringIgonErr(checkRes),
|
||||
CreatedAt: time.Now(),
|
||||
Status: int32(checkRes.Data.Status),
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[Poll #%d] 任务未完成, 继续轮询, Status=%d, ScriptTime=%d, ShouluDate=%s",
|
||||
taskNum, checkRes.Data.Status, checkRes.Data.ScriptTime, checkRes.Data.ShouluDate)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,4 +10,5 @@ var ProviderSetAppService = wire.NewSet(
|
|||
NewLoginService,
|
||||
NewProductService,
|
||||
NewProductSourceService,
|
||||
NewCollectService,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue