package manager import ( "context" "fmt" "geo/internal/biz" "geo/internal/config" "geo/internal/entitys" "geo/internal/publisher" "geo/pkg" "geo/utils" "io" "log" "os" "path/filepath" "strings" "sync" "time" "github.com/google/uuid" ) const ( // 任务状态常量 StatusPending = 1 StatusProcessing = 2 StatusFailed = 3 StatusSuccess = 4 // 默认并发worker数量 DefaultWorkerNum = 4 MaxWorkerNum = 100 ) // PublishManager 发布管理器 type PublishManager struct { AutoStatus bool Conf *config.Config TokenID int running bool mu sync.RWMutex stopCh chan struct{} db *utils.Db stopOnce sync.Once stopOnceMu sync.Mutex publicBiz *biz.PublishBiz // 并发控制 workerNum int // 并发worker数量 workerWg sync.WaitGroup // 等待所有worker退出 } var ( publishManager *PublishManager once sync.Once ) // GetPublishManager 获取单例实例 func GetPublishManager(config *config.Config, db *utils.Db, publicBiz *biz.PublishBiz) (*PublishManager, error) { if config == nil || db == nil { return nil, fmt.Errorf("config和db参数不能为空") } once.Do(func() { publishManager = &PublishManager{ AutoStatus: false, Conf: config, stopCh: make(chan struct{}), db: db, publicBiz: publicBiz, } }) return publishManager, nil } // getTaskLogger 获取任务专属日志记录器 func (pm *PublishManager) getTaskLogger(requestID string) (*log.Logger, *os.File, error) { if requestID == "" { return nil, nil, fmt.Errorf("requestID不能为空") } logsDir := pm.Conf.Sys.LogsDir if logsDir == "" { logsDir = "./logs" } if err := os.MkdirAll(logsDir, 0755); err != nil { return nil, nil, fmt.Errorf("创建日志目录失败: %v", err) } logPath := filepath.Join(logsDir, fmt.Sprintf("%s.log", requestID)) logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { return nil, nil, fmt.Errorf("创建日志文件失败: %v", err) } multiWriter := io.MultiWriter(logFile, os.Stdout) taskLogger := log.New(multiWriter, "", log.LstdFlags|log.Lmicroseconds) taskLogger.Printf(strings.Repeat("=", 80)) taskLogger.Printf("任务开始 | RequestID: %s | 时间: %s", requestID, time.Now().Format("2006-01-02 15:04:05.000")) taskLogger.Printf(strings.Repeat("=", 80)) return taskLogger, logFile, nil } // Start 启动自动发布(支持并发worker数量) func (pm *PublishManager) Start(tokenID int, workerNum int) bool { pm.mu.Lock() defer pm.mu.Unlock() if pm.AutoStatus { log.Printf("自动发布服务已在运行中,tokenID=%d", tokenID) return false } if workerNum <= 0 { workerNum = DefaultWorkerNum } if workerNum > MaxWorkerNum { workerNum = MaxWorkerNum } pm.TokenID = tokenID pm.AutoStatus = true pm.workerNum = workerNum pm.stopCh = make(chan struct{}) // ✅ 修改:重置 sync.Once(关键修复) pm.stopOnceMu.Lock() pm.stopOnce = sync.Once{} pm.stopOnceMu.Unlock() for i := 0; i < workerNum; i++ { pm.workerWg.Add(1) go pm.workerLoop(i) } log.Printf("自动发布服务已启动,tokenID=%d,worker数量=%d", tokenID, workerNum) return true } func (pm *PublishManager) Stop() bool { pm.mu.Lock() if !pm.AutoStatus { pm.mu.Unlock() return false } pm.AutoStatus = false pm.mu.Unlock() // ✅ 修改:使用互斥锁保护 Once 调用 pm.stopOnceMu.Lock() pm.stopOnce.Do(func() { close(pm.stopCh) }) pm.stopOnceMu.Unlock() done := make(chan struct{}) go func() { pm.workerWg.Wait() close(done) }() select { case <-done: log.Println("所有worker已正常退出") case <-time.After(30 * time.Second): log.Println("等待worker退出超时,强制结束") } return true } // workerLoop 单个worker循环 func (pm *PublishManager) workerLoop(workerID int) { defer pm.workerWg.Done() log.Printf("[Worker-%d] 启动,tokenID=%d", workerID, pm.TokenID) for { select { case <-pm.stopCh: log.Printf("[Worker-%d] 收到停止信号,退出", workerID) return default: pm.executeOneTask(workerID, true) time.Sleep(2 * time.Second) } } } // executeOneTask 执行单个任务 func (pm *PublishManager) executeOneTask(workerID int, headless bool) { task, err := pm.acquireTask() if err != nil { log.Printf("[Worker-%d] 获取任务失败: %v", workerID, err) time.Sleep(5 * time.Second) return } if task == nil { time.Sleep(30 * time.Second) return } log.Printf("[Worker-%d] 开始处理任务 requestID=%s", workerID, task.RequestID) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // 使用 channel 接收结果,避免 goroutine 泄漏 resultChan := make(chan *SingleResult, 1) go func() { resultChan <- pm.processTask(ctx, task, headless) }() // 监听 ctx 超时 var result *SingleResult select { case result = <-resultChan: // 任务正常完成 case <-ctx.Done(): // 超时或取消 log.Printf("[Worker-%d] 任务超时或被取消 requestID=%s, err=%v", workerID, task.RequestID, ctx.Err()) result = &SingleResult{ Success: false, Message: fmt.Sprintf("任务执行超时: %v", ctx.Err()), } // 注意:这里不能等待 goroutine 结束,因为 processTask 可能卡住 // cancel() 会在 defer 中调用,但 processTask 内部需要响应 ctx.Done() } if result == nil { log.Printf("[Worker-%d] 任务返回空结果", workerID) } else if result.Success { log.Printf("[Worker-%d] 任务成功: %s", workerID, result.Message) } else { log.Printf("[Worker-%d] 任务失败: %s", workerID, result.Message) } } // acquireTask 原子获取一个待发布任务(使用 GORM 事务 + FOR UPDATE SKIP LOCKED) func (pm *PublishManager) acquireTask() (*entitys.PublishTaskDetail, error) { currentTime := time.Now().Format("2006-01-02 15:04:05") uid := uuid.NewString()[:16] // 使用子查询先找出要更新的 request_id updateSQL := ` UPDATE publish p SET p.status = ?,uid=? WHERE p.request_id = ( SELECT request_id FROM ( SELECT p2.request_id FROM publish p2 INNER JOIN plat pl ON p2.plat_index COLLATE utf8mb4_unicode_ci = pl.index WHERE p2.token_id = ? AND p2.status = ? AND p2.publish_time <= ? AND pl.status = 1 ORDER BY p2.publish_time ASC LIMIT 1 ) AS tmp ) AND p.status = ? ` result := pm.db.Client.Exec(updateSQL, StatusProcessing, uid, pm.TokenID, StatusPending, currentTime, StatusPending) if result.Error != nil { return nil, result.Error } if result.RowsAffected == 0 { return nil, nil } // 查询刚更新的任务 var task entitys.PublishTaskDetail querySQL := ` SELECT p.id, p.request_id, p.plat_index, p.title, p.tag, p.user_index, p.url, p.img, p.publish_time, p.status, pl.index as plat_index_value, pl.status as plat_status, pl.login_url, pl.edit_url, pl.logined_url, pl.desc FROM publish p INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1 WHERE uid=? AND p.token_id = ? ORDER BY p.publish_time DESC LIMIT 1 ` err := pm.db.Client.Raw(querySQL, uid, pm.TokenID).Scan(&task).Error if err != nil { return nil, err } return &task, nil } // processTask 处理单个任务 func (pm *PublishManager) processTask(ctx context.Context, publishData *entitys.PublishTaskDetail, headless bool) *SingleResult { if publishData == nil || publishData.RequestID == "" { return &SingleResult{Success: false, Message: "无效的任务数据"} } // 检查 context 是否已取消 select { case <-ctx.Done(): return &SingleResult{Success: false, Message: "任务被取消: " + ctx.Err().Error(), RequestId: publishData.RequestID} default: } taskLogger, logFile, err := pm.getTaskLogger(publishData.RequestID) if err != nil { log.Printf("[任务 %s] 创建日志文件失败: %v,使用全局日志", publishData.RequestID, err) taskLogger = log.Default() } if logFile != nil { defer logFile.Close() } defer func() { if r := recover(); r != nil { errMsg := fmt.Sprintf("任务执行发生panic: %v", r) taskLogger.Printf("❌ CRITICAL: %s", errMsg) pm.updatePublishStatus(publishData, StatusFailed, errMsg) } }() taskLogger.Printf("[任务 %s] 开始处理(headless=%v)", publishData.RequestID, headless) params, sourceUrl := pm.extractTaskParams(publishData, taskLogger) if params == nil { pm.updatePublishStatus(publishData, StatusFailed, "提取任务参数失败") return &SingleResult{Success: false, Message: "提取任务参数失败", RequestId: publishData.RequestID} } params.Headless = headless publisherClass := GetPublisherClass(params.PlatIndex) if publisherClass == nil { errMsg := fmt.Sprintf("不支持的平台: %s", params.PlatIndex) taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg) pm.updatePublishStatus(publishData, StatusFailed, errMsg) return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID} } docPath, imgPath, err := pm.downloadAndPrepareFiles(params.RequestID, sourceUrl, taskLogger, publisherClass) if err != nil { errMsg := fmt.Sprintf("准备文件失败: %v", err) taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg) pm.updatePublishStatus(publishData, StatusFailed, errMsg) return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID} } defer pm.cleanupFiles(docPath, imgPath, taskLogger, publishData.RequestID) params.Content, err = pm.extractContent(docPath, publisherClass, taskLogger, publishData.RequestID) if err != nil { errMsg := fmt.Sprintf("提取文档内容失败: %v", err) taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg) pm.updatePublishStatus(publishData, StatusFailed, errMsg) return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID} } params.ImagePath = imgPath params.SourcePath = docPath pub := publisherClass.InitMethod(ctx, params, pm.Conf, taskLogger) taskLogger.Printf("[任务 %s] 开始执行发布...", publishData.RequestID) success, message := pub.PublishNote() if success { taskLogger.Printf("[任务 %s] ✅ 发布成功: %s", publishData.RequestID, message) pm.updatePublishStatus(publishData, StatusSuccess, message) } else { taskLogger.Printf("[任务 %s] ❌ 发布失败: %s", publishData.RequestID, message) pm.updatePublishStatus(publishData, StatusFailed, message) } taskLogger.Printf(strings.Repeat("=", 80)) taskLogger.Printf("任务结束 | RequestID: %s | 结果: %v", publishData.RequestID, success) return &SingleResult{Success: success, Message: message, RequestId: publishData.RequestID} } // RetryTask 重试任务(非无头模式) func (pm *PublishManager) RetryTask(ctx context.Context, tokenId int, requestID string) *SingleResult { if requestID == "" { return &SingleResult{Success: false, Message: "requestID不能为空"} } publishData, err := pm.GetTaskByRequestID(requestID, tokenId) if err != nil || publishData == nil { return &SingleResult{Success: false, Message: "任务不存在"} } log.Printf("[重试] 开始重试任务 requestID=%s(非无头模式)", requestID) result := pm.processTask(ctx, publishData, false) if result == nil { result = &SingleResult{Success: false, Message: "系统故障", RequestId: requestID} } return result } // GetTaskByRequestID 根据RequestID获取任务 func (pm *PublishManager) GetTaskByRequestID(requestID string, tokenId int) (*entitys.PublishTaskDetail, error) { if requestID == "" { return nil, fmt.Errorf("requestID不能为空") } sql := ` SELECT p.id, p.request_id, p.token_id, p.plat_index, p.title, p.tag, p.user_index, p.url, p.img, p.publish_time, p.status, pl.index as plat_index_value, pl.status as plat_status, pl.login_url, pl.edit_url, pl.logined_url, pl.desc FROM publish p INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1 WHERE p.request_id = ? AND token_id=? ` var task entitys.PublishTaskDetail err := pm.db.GetOneToStruct(sql, &task, requestID, tokenId) if err != nil { return nil, err } if task.RequestID == "" { return nil, nil } return &task, nil } // GetStatus 获取状态 func (pm *PublishManager) GetStatus() map[string]interface{} { pm.mu.RLock() defer pm.mu.RUnlock() return map[string]interface{}{ "auto_status": pm.AutoStatus, "worker_num": pm.workerNum, "token_id": pm.TokenID, } } // extractTaskParams 提取任务参数 func (pm *PublishManager) extractTaskParams(publishData *entitys.PublishTaskDetail, taskLogger *log.Logger) (*publisher.TaskParams, *fileUrl) { taskLogger.Printf("[任务 %s] 任务详情 - 平台:%s,标题:%s,用户:%s", publishData.RequestID, publishData.PlatIndex, publishData.Title, publishData.UserIndex) tags := pkg.ParseTags(publishData.Tag) taskLogger.Printf("[任务 %s] 标签解析完成: %v", publishData.RequestID, tags) return &publisher.TaskParams{ RequestID: publishData.RequestID, PlatIndex: publishData.PlatIndex, Title: publishData.Title, TagRaw: publishData.Tag, UserIndex: publishData.UserIndex, Tags: tags, PublishData: publishData, }, &fileUrl{ url: publishData.URL, imgURL: publishData.Img, } } // downloadAndPrepareFiles 下载文档和图片 func (pm *PublishManager) downloadAndPrepareFiles(requestId string, params *fileUrl, taskLogger *log.Logger, publishClass *publisher.PublisherValue) (docPath, imgPath string, err error) { taskLogger.Printf("[任务 %s] 开始下载文档...", requestId) docPath, err = pkg.DownloadFile(params.url, pm.Conf.Sys.DocsDir, requestId) if err != nil { return "", "", fmt.Errorf("下载文档失败: %v", err) } taskLogger.Printf("[任务 %s] ✅ 文档下载成功: %s", requestId, docPath) taskLogger.Printf("[任务 %s] 开始下载图片...", requestId) imgPath, err = pkg.DownloadImage(params.imgURL, requestId, pm.Conf.Sys.UploadDir) if err != nil { if docPath != "" { pkg.DeleteFile(docPath) } return "", "", fmt.Errorf("下载图片失败: %v", err) } taskLogger.Printf("[任务 %s] ✅ 图片下载成功: %s", requestId, imgPath) if publishClass.Type == 1 && publishClass.WordContainImg { if err = pkg.CopyImageToDoc(docPath, imgPath); err != nil { return docPath, imgPath, fmt.Errorf("复制图片到文档失败: %v", err) } } return docPath, imgPath, nil } // cleanupFiles 清理临时文件 func (pm *PublishManager) cleanupFiles(docPath, imgPath string, taskLogger *log.Logger, requestID string) { if docPath != "" { pkg.DeleteFile(docPath) taskLogger.Printf("[任务 %s] 已删除文档文件: %s", requestID, docPath) } if imgPath != "" { pkg.DeleteFile(imgPath) taskLogger.Printf("[任务 %s] 已删除图片文件: %s", requestID, imgPath) } } // extractContent 提取文档内容 func (pm *PublishManager) extractContent(docPath string, publisherClass *publisher.PublisherValue, taskLogger *log.Logger, requestID string) (string, error) { if publisherClass.Type != 1 { return "", nil } taskLogger.Printf("[任务 %s] 开始提取文档内容...", requestID) content, err := pkg.ExtractWordContent(docPath, publisherClass.ContentFormat) if err != nil { return "", err } taskLogger.Printf("[任务 %s] ✅ 内容提取成功,长度: %d", requestID, len(content)) return content, nil } // updatePublishStatus 更新发布状态 func (pm *PublishManager) updatePublishStatus(publishData *entitys.PublishTaskDetail, status int, message string) error { if publishData.ID == 0 { return fmt.Errorf("id不能为空") } var err error if message != "" { _, err = pm.db.Execute("UPDATE publish SET status = ?, msg = ? WHERE id=?", status, message, publishData.ID) } else { _, err = pm.db.Execute("UPDATE publish SET status = ? WHERE id=?", status, publishData.ID) } if err != nil { log.Printf("更新发布状态失败: id=%s, status=%d, error=%v", publishData.ID, status, err) } _ = pm.publicBiz.Notify(&entitys.NotifyData{ TokenId: publishData.TokenID, Type: entitys.NotifyTypePublish, Status: status, Msg: message, }) return err } // GetPublisherClass 获取发布器类 func GetPublisherClass(platIndex string) *publisher.PublisherValue { if platIndex == "" { return nil } if publisherClass, exists := publisher.PublisherMap[platIndex]; exists { return publisherClass } log.Printf("未找到平台 %s 对应的发布器", platIndex) return nil } // SingleResult 单次任务结果 type SingleResult struct { Success bool `json:"success"` Message string `json:"message"` RequestId string `json:"request_id"` } // fileUrl 内部辅助结构体 type fileUrl struct { url string imgURL string }