geoGo/internal/manager/publish_manager.go

594 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package manager
import (
"context"
"fmt"
"geo/internal/biz"
"geo/internal/config"
"geo/internal/entitys"
"geo/internal/publisher"
"geo/pkg"
"geo/utils"
"geo/utils/utils_oss"
"io"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
const (
// 任务状态常量
StatusPending = 1
StatusProcessing = 2
StatusFailed = 3
StatusSuccess = 4
// 默认并发worker数量
DefaultWorkerNum = 4
MaxWorkerNum = 100
)
// PublishManager 发布管理器
type PublishManager struct {
AutoStatus bool
Conf *config.Config
TokenID int
running bool
mu sync.RWMutex
stopCh chan struct{}
db *utils.Db
stopOnce sync.Once
stopOnceMu sync.Mutex
publicBiz *biz.PublishBiz
// 并发控制
workerNum int // 并发worker数量
workerWg sync.WaitGroup // 等待所有worker退出
}
var (
publishManager *PublishManager
once sync.Once
)
// GetPublishManager 获取单例实例
func GetPublishManager(config *config.Config, db *utils.Db, publicBiz *biz.PublishBiz) (*PublishManager, error) {
if config == nil || db == nil {
return nil, fmt.Errorf("config和db参数不能为空")
}
once.Do(func() {
publishManager = &PublishManager{
AutoStatus: false,
Conf: config,
stopCh: make(chan struct{}),
db: db,
publicBiz: publicBiz,
}
})
return publishManager, nil
}
// getTaskLogger 获取任务专属日志记录器
func (pm *PublishManager) getTaskLogger(publishID int, requestID string) (*log.Logger, *os.File, string, error) {
logsDir := pm.Conf.Sys.LogsDir
if err := os.MkdirAll(logsDir, 0755); err != nil {
return nil, nil, "", fmt.Errorf("创建日志目录失败: %v", err)
}
logPath := filepath.Join(logsDir, fmt.Sprintf("%d_%s.log", publishID, requestID))
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, nil, "", fmt.Errorf("创建日志文件失败: %v", err)
}
multiWriter := io.MultiWriter(logFile, os.Stdout)
taskLogger := log.New(multiWriter, "", log.LstdFlags|log.Lmicroseconds)
taskLogger.Printf(strings.Repeat("=", 80))
taskLogger.Printf("任务开始 | RequestID: %s | 时间: %s", requestID, time.Now().Format("2006-01-02 15:04:05.000"))
taskLogger.Printf(strings.Repeat("=", 80))
return taskLogger, logFile, logPath, nil
}
// Start 启动自动发布支持并发worker数量
func (pm *PublishManager) Start(workerNum int) bool {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.AutoStatus {
log.Printf("自动发布服务已在运行中")
return false
}
if workerNum <= 0 {
workerNum = DefaultWorkerNum
}
if workerNum > MaxWorkerNum {
workerNum = MaxWorkerNum
}
pm.AutoStatus = true
pm.workerNum = workerNum
pm.stopCh = make(chan struct{})
// ✅ 修改:重置 sync.Once关键修复
pm.stopOnceMu.Lock()
pm.stopOnce = sync.Once{}
pm.stopOnceMu.Unlock()
for i := 0; i < workerNum; i++ {
pm.workerWg.Add(1)
go pm.workerLoop(i)
}
log.Printf("自动发布服务已启动worker数量=%d", workerNum)
return true
}
func (pm *PublishManager) Stop() bool {
pm.mu.Lock()
if !pm.AutoStatus {
pm.mu.Unlock()
return false
}
pm.AutoStatus = false
pm.mu.Unlock()
// ✅ 修改:使用互斥锁保护 Once 调用
pm.stopOnceMu.Lock()
pm.stopOnce.Do(func() {
close(pm.stopCh)
})
pm.stopOnceMu.Unlock()
done := make(chan struct{})
go func() {
pm.workerWg.Wait()
close(done)
}()
select {
case <-done:
log.Println("所有worker已正常退出")
case <-time.After(30 * time.Second):
log.Println("等待worker退出超时强制结束")
}
return true
}
// workerLoop 单个worker循环
func (pm *PublishManager) workerLoop(workerID int) {
defer pm.workerWg.Done()
log.Printf("[Worker-%d] 启动tokenID=%d", workerID, pm.TokenID)
for {
select {
case <-pm.stopCh:
log.Printf("[Worker-%d] 收到停止信号,退出", workerID)
return
default:
pm.executeOneTask(workerID, true)
time.Sleep(2 * time.Second)
}
}
}
// executeOneTask 执行单个任务
func (pm *PublishManager) executeOneTask(workerID int, headless bool) {
task, err := pm.acquireTask()
if err != nil {
log.Printf("[Worker-%d] 获取任务失败: %v", workerID, err)
time.Sleep(5 * time.Second)
return
}
if task == nil {
time.Sleep(30 * time.Second)
return
}
log.Printf("[Worker-%d] 开始处理任务 requestID=%s", workerID, task.RequestID)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// 使用 channel 接收结果,避免 goroutine 泄漏
resultChan := make(chan *SingleResult, 1)
go func() {
resultChan <- pm.processTask(ctx, task, headless)
}()
// 监听 ctx 超时
var result *SingleResult
select {
case result = <-resultChan:
// 任务正常完成
case <-ctx.Done():
// 超时或取消
log.Printf("[Worker-%d] 任务超时或被取消 requestID=%s, err=%v", workerID, task.RequestID, ctx.Err())
result = &SingleResult{
Success: false,
Message: fmt.Sprintf("任务执行超时: %v", ctx.Err()),
}
// 注意:这里不能等待 goroutine 结束,因为 processTask 可能卡住
// cancel() 会在 defer 中调用,但 processTask 内部需要响应 ctx.Done()
}
if result == nil {
log.Printf("[Worker-%d] 任务返回空结果", workerID)
} else if result.Success {
log.Printf("[Worker-%d] 任务成功: %s", workerID, result.Message)
} else {
log.Printf("[Worker-%d] 任务失败: %s", workerID, result.Message)
}
}
// acquireTask 原子获取一个待发布任务(使用 GORM 事务 + FOR UPDATE SKIP LOCKED
func (pm *PublishManager) acquireTask() (*entitys.PublishTaskDetail, error) {
currentTime := time.Now().Format("2006-01-02 15:04:05")
uid := uuid.NewString()[:16]
// 使用子查询先找出要更新的 request_id
updateSQL := `
UPDATE publish p
SET p.status = ?,uid=?
WHERE p.request_id = (
SELECT request_id FROM (
SELECT p2.request_id
FROM publish p2
INNER JOIN plat pl ON p2.plat_index COLLATE utf8mb4_unicode_ci = pl.index
WHERE p2.status = ?
AND p2.publish_time <= ?
AND pl.status = 1
ORDER BY p2.publish_time ASC
LIMIT 1
) AS tmp
)
AND p.status = ?
`
result := pm.db.Client.Exec(updateSQL, StatusProcessing, uid, StatusPending, currentTime, StatusPending)
if result.Error != nil {
return nil, result.Error
}
if result.RowsAffected == 0 {
return nil, nil
}
// 查询刚更新的任务
var task entitys.PublishTaskDetail
querySQL := `
SELECT
p.id,
p.request_id,
p.plat_index,
p.title,
p.tag,
p.user_index,
p.token_id,
p.url,
p.img,
p.publish_time,
p.status,
pl.index as plat_index_value,
pl.status as plat_status,
pl.login_url,
pl.edit_url,
pl.logined_url,
pl.desc
FROM publish p
INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1
WHERE uid=?
ORDER BY p.publish_time DESC
LIMIT 1
`
err := pm.db.Client.Raw(querySQL, uid).Scan(&task).Error
if err != nil {
return nil, err
}
return &task, nil
}
// processTask 处理单个任务
func (pm *PublishManager) processTask(ctx context.Context, publishData *entitys.PublishTaskDetail, headless bool) *SingleResult {
if publishData == nil || publishData.RequestID == "" {
return &SingleResult{Success: false, Message: "无效的任务数据"}
}
// 检查 context 是否已取消
select {
case <-ctx.Done():
return &SingleResult{Success: false, Message: "任务被取消: " + ctx.Err().Error(), RequestId: publishData.RequestID}
default:
}
taskLogger, logFile, logPath, err := pm.getTaskLogger(publishData.ID, publishData.RequestID)
if err != nil {
log.Printf("[任务 %s] 创建日志文件失败: %v使用全局日志", publishData.RequestID, err)
taskLogger = log.Default()
}
if logFile != nil {
defer logFile.Close()
}
defer func() {
if r := recover(); r != nil {
errMsg := fmt.Sprintf("任务执行发生panic: %v", r)
taskLogger.Printf("❌ CRITICAL: %s", errMsg)
pm.updatePublishStatus(publishData, StatusFailed, errMsg, "")
}
}()
taskLogger.Printf("[任务 %s] 开始处理headless=%v", publishData.RequestID, headless)
params, sourceUrl := pm.extractTaskParams(publishData, taskLogger)
if params == nil {
pm.updatePublishStatus(publishData, StatusFailed, "提取任务参数失败", "")
return &SingleResult{Success: false, Message: "提取任务参数失败", RequestId: publishData.RequestID}
}
params.Headless = headless
publisherClass := GetPublisherClass(params.PlatIndex)
if publisherClass == nil {
errMsg := fmt.Sprintf("不支持的平台: %s", params.PlatIndex)
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
pm.updatePublishStatus(publishData, StatusFailed, errMsg, "")
return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID}
}
docPath, imgPath, err := pm.downloadAndPrepareFiles(params.RequestID, sourceUrl, taskLogger, publisherClass)
if err != nil {
errMsg := fmt.Sprintf("准备文件失败: %v", err)
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
pm.updatePublishStatus(publishData, StatusFailed, errMsg, "")
return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID}
}
defer pm.cleanupFiles(docPath, imgPath, taskLogger, publishData.RequestID)
params.Content, err = pm.extractContent(docPath, publisherClass, taskLogger, publishData.RequestID)
if err != nil {
errMsg := fmt.Sprintf("提取文档内容失败: %v", err)
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
pm.updatePublishStatus(publishData, StatusFailed, errMsg, "")
return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID}
}
params.ImagePath = imgPath
params.SourcePath = docPath
pub := publisherClass.InitMethod(ctx, params, pm.Conf, taskLogger)
taskLogger.Printf("[任务 %s] 开始执行发布...", publishData.RequestID)
success, message := pub.PublishNote()
taskLogger.Printf(strings.Repeat("=", 80))
taskLogger.Printf("任务结束 | RequestID: %s | 结果: %v", publishData.RequestID, success)
res := &SingleResult{Success: success, Message: message, RequestId: publishData.RequestID}
pm.uploadToOss(ctx, logPath, fmt.Sprintf("%slog/%d_%s.log", pm.Conf.Oss.FilePath, publishData.ID, publishData.RequestID))
url, err := pm.uploadToOss(ctx, logPath, fmt.Sprintf("%s%d_%s.log", pm.Conf.Oss.FilePath, publishData.ID, publishData.RequestID))
if err != nil {
taskLogger.Printf("日志上传失败")
}
if success {
taskLogger.Printf("[任务 %s] ✅ 发布成功: %s", publishData.RequestID, message)
pm.updatePublishStatus(publishData, StatusSuccess, message, url)
} else {
taskLogger.Printf("[任务 %s] ❌ 发布失败: %s", publishData.RequestID, message)
pm.updatePublishStatus(publishData, StatusFailed, message, url)
}
return res
}
func (pm *PublishManager) uploadToOss(ctx context.Context, filePath string, path string) (url string, err error) {
client, err := utils_oss.NewClient(pm.Conf)
if err != nil {
return
}
fileBytes, err := os.ReadFile(filePath)
if err != nil {
return
}
url, err = client.UploadBytes(path, fileBytes)
return
}
// RetryTask 重试任务(非无头模式)
func (pm *PublishManager) RetryTask(ctx context.Context, tokenId int, requestID string) *SingleResult {
if requestID == "" {
return &SingleResult{Success: false, Message: "requestID不能为空"}
}
publishData, err := pm.GetTaskByRequestID(requestID, tokenId)
if err != nil || publishData == nil {
return &SingleResult{Success: false, Message: "任务不存在"}
}
log.Printf("[重试] 开始重试任务 requestID=%s非无头模式", requestID)
result := pm.processTask(ctx, publishData, false)
if result == nil {
result = &SingleResult{Success: false, Message: "系统故障", RequestId: requestID}
}
return result
}
// GetTaskByRequestID 根据RequestID获取任务
func (pm *PublishManager) GetTaskByRequestID(requestID string, tokenId int) (*entitys.PublishTaskDetail, error) {
if requestID == "" {
return nil, fmt.Errorf("requestID不能为空")
}
sql := `
SELECT
p.id,
p.request_id,
p.token_id,
p.plat_index,
p.title,
p.tag,
p.user_index,
p.url,
p.img,
p.publish_time,
p.status,
pl.index as plat_index_value,
pl.status as plat_status,
pl.login_url,
pl.edit_url,
pl.logined_url,
pl.desc
FROM publish p
INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1
WHERE p.request_id = ? AND token_id=?
`
var task entitys.PublishTaskDetail
err := pm.db.GetOneToStruct(sql, &task, requestID, tokenId)
if err != nil {
return nil, err
}
if task.RequestID == "" {
return nil, nil
}
return &task, nil
}
// GetStatus 获取状态
func (pm *PublishManager) GetStatus() map[string]interface{} {
pm.mu.RLock()
defer pm.mu.RUnlock()
return map[string]interface{}{
"auto_status": pm.AutoStatus,
"worker_num": pm.workerNum,
"token_id": pm.TokenID,
}
}
// extractTaskParams 提取任务参数
func (pm *PublishManager) extractTaskParams(publishData *entitys.PublishTaskDetail, taskLogger *log.Logger) (*publisher.TaskParams, *fileUrl) {
taskLogger.Printf("[任务 %s] 任务详情 - 平台:%s标题%s用户%s", publishData.RequestID, publishData.PlatIndex, publishData.Title, publishData.UserIndex)
tags := pkg.ParseTags(publishData.Tag)
taskLogger.Printf("[任务 %s] 标签解析完成: %v", publishData.RequestID, tags)
return &publisher.TaskParams{
RequestID: publishData.RequestID,
PlatIndex: publishData.PlatIndex,
Title: publishData.Title,
TagRaw: publishData.Tag,
UserIndex: publishData.UserIndex,
Tags: tags,
PublishData: publishData,
}, &fileUrl{
url: publishData.URL,
imgURL: publishData.Img,
}
}
// downloadAndPrepareFiles 下载文档和图片
func (pm *PublishManager) downloadAndPrepareFiles(requestId string, params *fileUrl, taskLogger *log.Logger, publishClass *publisher.PublisherValue) (docPath, imgPath string, err error) {
taskLogger.Printf("[任务 %s] 开始下载文档...", requestId)
docPath, err = pkg.DownloadFile(params.url, pm.Conf.Sys.DocsDir, requestId)
if err != nil {
return "", "", fmt.Errorf("下载文档失败: %v", err)
}
taskLogger.Printf("[任务 %s] ✅ 文档下载成功: %s", requestId, docPath)
taskLogger.Printf("[任务 %s] 开始下载图片...", requestId)
imgPath, err = pkg.DownloadImage(params.imgURL, requestId, pm.Conf.Sys.UploadDir)
if err != nil {
if docPath != "" {
pkg.DeleteFile(docPath)
}
return "", "", fmt.Errorf("下载图片失败: %v", err)
}
taskLogger.Printf("[任务 %s] ✅ 图片下载成功: %s", requestId, imgPath)
if publishClass.Type == 1 && publishClass.WordContainImg {
if err = pkg.CopyImageToDoc(docPath, imgPath); err != nil {
return docPath, imgPath, fmt.Errorf("复制图片到文档失败: %v", err)
}
}
return docPath, imgPath, nil
}
// cleanupFiles 清理临时文件
func (pm *PublishManager) cleanupFiles(docPath, imgPath string, taskLogger *log.Logger, requestID string) {
if docPath != "" {
pkg.DeleteFile(docPath)
taskLogger.Printf("[任务 %s] 已删除文档文件: %s", requestID, docPath)
}
if imgPath != "" {
pkg.DeleteFile(imgPath)
taskLogger.Printf("[任务 %s] 已删除图片文件: %s", requestID, imgPath)
}
}
// extractContent 提取文档内容
func (pm *PublishManager) extractContent(docPath string, publisherClass *publisher.PublisherValue, taskLogger *log.Logger, requestID string) (string, error) {
if publisherClass.Type != 1 {
return "", nil
}
taskLogger.Printf("[任务 %s] 开始提取文档内容...", requestID)
content, err := pkg.ExtractWordContent(docPath, publisherClass.ContentFormat)
if err != nil {
return "", err
}
taskLogger.Printf("[任务 %s] ✅ 内容提取成功,长度: %d", requestID, len(content))
return content, nil
}
// updatePublishStatus 更新发布状态
func (pm *PublishManager) updatePublishStatus(publishData *entitys.PublishTaskDetail, status int, message string, log_url string) error {
if publishData.ID == 0 {
return fmt.Errorf("id不能为空")
}
var err error
if message != "" {
_, err = pm.db.Execute("UPDATE publish SET status = ?, msg = ?,log_url=? WHERE id=?", status, message, log_url, publishData.ID)
} else {
_, err = pm.db.Execute("UPDATE publish SET status = ?,log_url=? WHERE id=?", status, publishData.ID, log_url, log_url)
}
if err != nil {
log.Printf("更新发布状态失败: id=%s, status=%d, error=%v", publishData.ID, status, err)
}
_ = pm.publicBiz.Notify(&entitys.NotifyData{
TokenId: publishData.TokenID,
Type: entitys.NotifyTypePublish,
Status: status,
Msg: message,
})
return err
}
// GetPublisherClass 获取发布器类
func GetPublisherClass(platIndex string) *publisher.PublisherValue {
if platIndex == "" {
return nil
}
if publisherClass, exists := publisher.PublisherMap[platIndex]; exists {
return publisherClass
}
log.Printf("未找到平台 %s 对应的发布器", platIndex)
return nil
}
// SingleResult 单次任务结果
type SingleResult struct {
Success bool `json:"success"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
// fileUrl 内部辅助结构体
type fileUrl struct {
url string
imgURL string
}