geoGo/internal/manager/publish_manager.go

535 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package manager
import (
"fmt"
"geo/internal/config"
"geo/internal/entitys"
"geo/internal/publisher"
"geo/pkg"
"geo/utils"
"io"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
const (
// 任务状态常量
StatusPending = 1
StatusProcessing = 2
StatusFailed = 3
StatusSuccess = 4
// 默认并发worker数量
DefaultWorkerNum = 2
MaxWorkerNum = 5
)
// PublishManager 发布管理器
type PublishManager struct {
AutoStatus bool
Conf *config.Config
TokenID int
running bool
mu sync.RWMutex
stopCh chan struct{}
db *utils.Db
stopOnce sync.Once
// 并发控制
workerNum int // 并发worker数量
workerWg sync.WaitGroup // 等待所有worker退出
}
var (
publishManager *PublishManager
once sync.Once
)
// GetPublishManager 获取单例实例
func GetPublishManager(config *config.Config, db *utils.Db) (*PublishManager, error) {
if config == nil || db == nil {
return nil, fmt.Errorf("config和db参数不能为空")
}
once.Do(func() {
publishManager = &PublishManager{
AutoStatus: false,
Conf: config,
stopCh: make(chan struct{}),
db: db,
}
})
return publishManager, nil
}
// getTaskLogger 获取任务专属日志记录器
func (pm *PublishManager) getTaskLogger(requestID string) (*log.Logger, *os.File, error) {
if requestID == "" {
return nil, nil, fmt.Errorf("requestID不能为空")
}
logsDir := pm.Conf.Sys.LogsDir
if logsDir == "" {
logsDir = "./logs"
}
dateDir := time.Now().Format("2006-01-02")
taskLogDir := filepath.Join(logsDir, "tasks", dateDir)
if err := os.MkdirAll(taskLogDir, 0755); err != nil {
return nil, nil, fmt.Errorf("创建日志目录失败: %v", err)
}
logPath := filepath.Join(taskLogDir, fmt.Sprintf("%s.log", requestID))
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, nil, fmt.Errorf("创建日志文件失败: %v", err)
}
multiWriter := io.MultiWriter(logFile, os.Stdout)
taskLogger := log.New(multiWriter, "", log.LstdFlags|log.Lmicroseconds)
taskLogger.Printf(strings.Repeat("=", 80))
taskLogger.Printf("任务开始 | RequestID: %s | 时间: %s", requestID, time.Now().Format("2006-01-02 15:04:05.000"))
taskLogger.Printf(strings.Repeat("=", 80))
return taskLogger, logFile, nil
}
// Start 启动自动发布支持并发worker数量
func (pm *PublishManager) Start(tokenID int, workerNum int) bool {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.AutoStatus {
log.Printf("自动发布服务已在运行中tokenID=%d", tokenID)
return false
}
if workerNum <= 0 {
workerNum = DefaultWorkerNum
}
if workerNum > MaxWorkerNum {
workerNum = MaxWorkerNum
}
pm.TokenID = tokenID
pm.AutoStatus = true
pm.workerNum = workerNum
pm.stopCh = make(chan struct{})
for i := 0; i < workerNum; i++ {
pm.workerWg.Add(1)
go pm.workerLoop(i)
}
log.Printf("自动发布服务已启动tokenID=%dworker数量=%d", tokenID, workerNum)
return true
}
// Stop 停止自动发布
func (pm *PublishManager) Stop() bool {
pm.mu.Lock()
if !pm.AutoStatus {
pm.mu.Unlock()
return false
}
pm.AutoStatus = false
pm.mu.Unlock()
pm.stopOnce.Do(func() {
close(pm.stopCh)
})
done := make(chan struct{})
go func() {
pm.workerWg.Wait()
close(done)
}()
select {
case <-done:
log.Println("所有worker已正常退出")
case <-time.After(30 * time.Second):
log.Println("等待worker退出超时强制结束")
}
return true
}
// workerLoop 单个worker循环
func (pm *PublishManager) workerLoop(workerID int) {
defer pm.workerWg.Done()
log.Printf("[Worker-%d] 启动tokenID=%d", workerID, pm.TokenID)
for {
select {
case <-pm.stopCh:
log.Printf("[Worker-%d] 收到停止信号,退出", workerID)
return
default:
pm.executeOneTask(workerID, true)
time.Sleep(2 * time.Second)
}
}
}
// executeOneTask 执行单个任务
func (pm *PublishManager) executeOneTask(workerID int, headless bool) {
task, err := pm.acquireTask()
if err != nil {
log.Printf("[Worker-%d] 获取任务失败: %v", workerID, err)
time.Sleep(5 * time.Second)
return
}
if task == nil {
time.Sleep(10 * time.Second)
return
}
log.Printf("[Worker-%d] 开始处理任务 requestID=%s", workerID, task.RequestID)
result := pm.processTask(task, headless)
if result == nil {
log.Printf("[Worker-%d] 任务失败: %s", workerID, result.Message)
} else {
if result.Success {
log.Printf("[Worker-%d] 任务成功: %s", workerID, result.Message)
} else {
log.Printf("[Worker-%d] 任务失败: %s", workerID, result.Message)
}
}
}
// acquireTask 原子获取一个待发布任务(使用 GORM 事务 + FOR UPDATE SKIP LOCKED
func (pm *PublishManager) acquireTask() (*entitys.PublishTaskDetail, error) {
currentTime := time.Now().Format("2006-01-02 15:04:05")
// 开启事务
tx := pm.db.Client.Begin()
if tx.Error != nil {
return nil, fmt.Errorf("开启事务失败: %v", tx.Error)
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
selectSQL := `
SELECT
p.request_id,
p.plat_index,
p.title,
p.tag,
p.user_index,
p.url,
p.img,
p.publish_time,
p.status,
pl.index as plat_index_value,
pl.status as plat_status,
pl.login_url,
pl.edit_url,
pl.logined_url,
pl.desc
FROM publish p
INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1
WHERE p.token_id = ? AND p.status = ? AND p.publish_time <= ?
ORDER BY p.publish_time ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
`
var task entitys.PublishTaskDetail
err := tx.Raw(selectSQL, pm.TokenID, StatusPending, currentTime).Scan(&task).Error
if err != nil {
tx.Rollback()
return nil, fmt.Errorf("查询任务失败: %v", err)
}
if task.RequestID == "" {
tx.Rollback()
return nil, nil
}
updateSQL := "UPDATE publish SET status = ? WHERE request_id = ? AND status = ?"
result := tx.Exec(updateSQL, StatusProcessing, task.RequestID, StatusPending)
if result.Error != nil {
tx.Rollback()
return nil, fmt.Errorf("更新任务状态失败: %v", result.Error)
}
if result.RowsAffected == 0 {
tx.Rollback()
return nil, nil
}
if err := tx.Commit().Error; err != nil {
return nil, fmt.Errorf("提交事务失败: %v", err)
}
return &task, nil
}
// processTask 处理单个任务
func (pm *PublishManager) processTask(publishData *entitys.PublishTaskDetail, headless bool) *SingleResult {
if publishData == nil || publishData.RequestID == "" {
return &SingleResult{Success: false, Message: "无效的任务数据"}
}
taskLogger, logFile, err := pm.getTaskLogger(publishData.RequestID)
if err != nil {
log.Printf("[任务 %s] 创建日志文件失败: %v使用全局日志", publishData.RequestID, err)
taskLogger = log.Default()
}
if logFile != nil {
defer logFile.Close()
}
defer func() {
if r := recover(); r != nil {
errMsg := fmt.Sprintf("任务执行发生panic: %v", r)
taskLogger.Printf("❌ CRITICAL: %s", errMsg)
pm.updatePublishStatus(publishData.RequestID, StatusFailed, errMsg)
}
}()
taskLogger.Printf("[任务 %s] 开始处理headless=%v", publishData.RequestID, headless)
params, sourceUrl := pm.extractTaskParams(publishData, taskLogger)
if params == nil {
pm.updatePublishStatus(publishData.RequestID, StatusFailed, "提取任务参数失败")
return &SingleResult{Success: false, Message: "提取任务参数失败", RequestId: publishData.RequestID}
}
params.Headless = headless
publisherClass := GetPublisherClass(params.PlatIndex)
if publisherClass == nil {
errMsg := fmt.Sprintf("不支持的平台: %s", params.PlatIndex)
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
pm.updatePublishStatus(publishData.RequestID, StatusFailed, errMsg)
return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID}
}
docPath, imgPath, err := pm.downloadAndPrepareFiles(params.RequestID, sourceUrl, taskLogger, publisherClass)
if err != nil {
errMsg := fmt.Sprintf("准备文件失败: %v", err)
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
pm.updatePublishStatus(publishData.RequestID, StatusFailed, errMsg)
return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID}
}
defer pm.cleanupFiles(docPath, imgPath, taskLogger, publishData.RequestID)
params.Content, err = pm.extractContent(docPath, publisherClass, taskLogger, publishData.RequestID)
if err != nil {
errMsg := fmt.Sprintf("提取文档内容失败: %v", err)
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
pm.updatePublishStatus(publishData.RequestID, StatusFailed, errMsg)
return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID}
}
params.ImagePath = imgPath
params.SourcePath = docPath
pub := publisherClass.InitMethod(params, pm.Conf, taskLogger)
taskLogger.Printf("[任务 %s] 开始执行发布...", publishData.RequestID)
success, message := pub.PublishNote()
if success {
taskLogger.Printf("[任务 %s] ✅ 发布成功: %s", publishData.RequestID, message)
pm.updatePublishStatus(publishData.RequestID, StatusSuccess, message)
} else {
taskLogger.Printf("[任务 %s] ❌ 发布失败: %s", publishData.RequestID, message)
pm.updatePublishStatus(publishData.RequestID, StatusFailed, message)
}
taskLogger.Printf(strings.Repeat("=", 80))
taskLogger.Printf("任务结束 | RequestID: %s | 结果: %v", publishData.RequestID, success)
return &SingleResult{Success: success, Message: message, RequestId: publishData.RequestID}
}
// RetryTask 重试任务(非无头模式)
func (pm *PublishManager) RetryTask(requestID string) *SingleResult {
if requestID == "" {
return &SingleResult{Success: false, Message: "requestID不能为空"}
}
publishData, err := pm.GetTaskByRequestID(requestID)
if err != nil || publishData == nil {
return &SingleResult{Success: false, Message: "任务不存在"}
}
log.Printf("[重试] 开始重试任务 requestID=%s非无头模式", requestID)
result := pm.processTask(publishData, false)
if result == nil {
result = &SingleResult{Success: false, Message: "系统故障", RequestId: requestID}
}
return result
}
// GetTaskByRequestID 根据RequestID获取任务
func (pm *PublishManager) GetTaskByRequestID(requestID string) (*entitys.PublishTaskDetail, error) {
if requestID == "" {
return nil, fmt.Errorf("requestID不能为空")
}
sql := `
SELECT
p.request_id,
p.plat_index,
p.title,
p.tag,
p.user_index,
p.url,
p.img,
p.publish_time,
p.status,
pl.index as plat_index_value,
pl.status as plat_status,
pl.login_url,
pl.edit_url,
pl.logined_url,
pl.desc
FROM publish p
INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1
WHERE p.request_id = ?
`
var task entitys.PublishTaskDetail
err := pm.db.GetOneToStruct(sql, &task, requestID)
if err != nil {
return nil, err
}
if task.RequestID == "" {
return nil, nil
}
return &task, nil
}
// GetStatus 获取状态
func (pm *PublishManager) GetStatus() map[string]interface{} {
pm.mu.RLock()
defer pm.mu.RUnlock()
return map[string]interface{}{
"auto_status": pm.AutoStatus,
"worker_num": pm.workerNum,
"token_id": pm.TokenID,
}
}
// extractTaskParams 提取任务参数
func (pm *PublishManager) extractTaskParams(publishData *entitys.PublishTaskDetail, taskLogger *log.Logger) (*publisher.TaskParams, *fileUrl) {
taskLogger.Printf("[任务 %s] 任务详情 - 平台:%s标题%s用户%s", publishData.RequestID, publishData.PlatIndex, publishData.Title, publishData.UserIndex)
tags := pkg.ParseTags(publishData.Tag)
taskLogger.Printf("[任务 %s] 标签解析完成: %v", publishData.RequestID, tags)
return &publisher.TaskParams{
RequestID: publishData.RequestID,
PlatIndex: publishData.PlatIndex,
Title: publishData.Title,
TagRaw: publishData.Tag,
UserIndex: publishData.UserIndex,
Tags: tags,
PublishData: publishData,
}, &fileUrl{
url: publishData.URL,
imgURL: publishData.Img,
}
}
// downloadAndPrepareFiles 下载文档和图片
func (pm *PublishManager) downloadAndPrepareFiles(requestId string, params *fileUrl, taskLogger *log.Logger, publishClass *publisher.PublisherValue) (docPath, imgPath string, err error) {
taskLogger.Printf("[任务 %s] 开始下载文档...", requestId)
docPath, err = pkg.DownloadFile(params.url, pm.Conf.Sys.DocsDir, requestId)
if err != nil {
return "", "", fmt.Errorf("下载文档失败: %v", err)
}
taskLogger.Printf("[任务 %s] ✅ 文档下载成功: %s", requestId, docPath)
taskLogger.Printf("[任务 %s] 开始下载图片...", requestId)
imgPath, err = pkg.DownloadImage(params.imgURL, requestId, pm.Conf.Sys.UploadDir)
if err != nil {
if docPath != "" {
pkg.DeleteFile(docPath)
}
return "", "", fmt.Errorf("下载图片失败: %v", err)
}
taskLogger.Printf("[任务 %s] ✅ 图片下载成功: %s", requestId, imgPath)
if publishClass.Type == 1 && publishClass.WordContainImg {
if err = pkg.CopyImageToDoc(docPath, imgPath); err != nil {
return docPath, imgPath, fmt.Errorf("复制图片到文档失败: %v", err)
}
}
return docPath, imgPath, nil
}
// cleanupFiles 清理临时文件
func (pm *PublishManager) cleanupFiles(docPath, imgPath string, taskLogger *log.Logger, requestID string) {
if docPath != "" {
pkg.DeleteFile(docPath)
taskLogger.Printf("[任务 %s] 已删除文档文件: %s", requestID, docPath)
}
if imgPath != "" {
pkg.DeleteFile(imgPath)
taskLogger.Printf("[任务 %s] 已删除图片文件: %s", requestID, imgPath)
}
}
// extractContent 提取文档内容
func (pm *PublishManager) extractContent(docPath string, publisherClass *publisher.PublisherValue, taskLogger *log.Logger, requestID string) (string, error) {
if publisherClass.Type != 1 {
return "", nil
}
taskLogger.Printf("[任务 %s] 开始提取文档内容...", requestID)
content, err := pkg.ExtractWordContent(docPath, publisherClass.ContentFormat)
if err != nil {
return "", err
}
taskLogger.Printf("[任务 %s] ✅ 内容提取成功,长度: %d", requestID, len(content))
return content, nil
}
// updatePublishStatus 更新发布状态
func (pm *PublishManager) updatePublishStatus(requestID string, status int, message string) error {
if requestID == "" {
return fmt.Errorf("requestID不能为空")
}
var err error
if message != "" {
_, err = pm.db.Execute("UPDATE publish SET status = ?, msg = ? WHERE request_id = ?", status, message, requestID)
} else {
_, err = pm.db.Execute("UPDATE publish SET status = ? WHERE request_id = ?", status, requestID)
}
if err != nil {
log.Printf("更新发布状态失败: requestID=%s, status=%d, error=%v", requestID, status, err)
}
return err
}
// GetPublisherClass 获取发布器类
func GetPublisherClass(platIndex string) *publisher.PublisherValue {
if platIndex == "" {
return nil
}
if publisherClass, exists := publisher.PublisherMap[platIndex]; exists {
return publisherClass
}
log.Printf("未找到平台 %s 对应的发布器", platIndex)
return nil
}
// SingleResult 单次任务结果
type SingleResult struct {
Success bool `json:"success"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
// fileUrl 内部辅助结构体
type fileUrl struct {
url string
imgURL string
}