582 lines
17 KiB
Go
582 lines
17 KiB
Go
package manager
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"geo/internal/biz"
|
||
"geo/internal/config"
|
||
"geo/internal/entitys"
|
||
"geo/internal/publisher"
|
||
"geo/pkg"
|
||
"geo/utils"
|
||
"io"
|
||
"log"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
const (
|
||
// 任务状态常量
|
||
StatusPending = 1
|
||
StatusProcessing = 2
|
||
StatusFailed = 3
|
||
StatusSuccess = 4
|
||
|
||
// 默认并发worker数量
|
||
DefaultWorkerNum = 4
|
||
MaxWorkerNum = 100
|
||
)
|
||
|
||
// PublishManager 发布管理器
|
||
type PublishManager struct {
|
||
AutoStatus bool
|
||
Conf *config.Config
|
||
TokenID int
|
||
running bool
|
||
mu sync.RWMutex
|
||
stopCh chan struct{}
|
||
db *utils.Db
|
||
stopOnce sync.Once
|
||
stopOnceMu sync.Mutex
|
||
publicBiz *biz.PublishBiz
|
||
// 并发控制
|
||
workerNum int // 并发worker数量
|
||
workerWg sync.WaitGroup // 等待所有worker退出
|
||
}
|
||
|
||
var (
|
||
publishManager *PublishManager
|
||
once sync.Once
|
||
)
|
||
|
||
// GetPublishManager 获取单例实例
|
||
func GetPublishManager(config *config.Config, db *utils.Db, publicBiz *biz.PublishBiz) (*PublishManager, error) {
|
||
if config == nil || db == nil {
|
||
return nil, fmt.Errorf("config和db参数不能为空")
|
||
}
|
||
|
||
once.Do(func() {
|
||
publishManager = &PublishManager{
|
||
AutoStatus: false,
|
||
Conf: config,
|
||
stopCh: make(chan struct{}),
|
||
db: db,
|
||
publicBiz: publicBiz,
|
||
}
|
||
})
|
||
return publishManager, nil
|
||
}
|
||
|
||
// getTaskLogger 获取任务专属日志记录器
|
||
func (pm *PublishManager) getTaskLogger(requestID string) (*log.Logger, *os.File, error) {
|
||
if requestID == "" {
|
||
return nil, nil, fmt.Errorf("requestID不能为空")
|
||
}
|
||
|
||
logsDir := pm.Conf.Sys.LogsDir
|
||
if logsDir == "" {
|
||
logsDir = "./logs"
|
||
}
|
||
|
||
if err := os.MkdirAll(logsDir, 0755); err != nil {
|
||
return nil, nil, fmt.Errorf("创建日志目录失败: %v", err)
|
||
}
|
||
|
||
logPath := filepath.Join(logsDir, fmt.Sprintf("%s.log", requestID))
|
||
|
||
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("创建日志文件失败: %v", err)
|
||
}
|
||
|
||
multiWriter := io.MultiWriter(logFile, os.Stdout)
|
||
taskLogger := log.New(multiWriter, "", log.LstdFlags|log.Lmicroseconds)
|
||
|
||
taskLogger.Printf(strings.Repeat("=", 80))
|
||
taskLogger.Printf("任务开始 | RequestID: %s | 时间: %s", requestID, time.Now().Format("2006-01-02 15:04:05.000"))
|
||
taskLogger.Printf(strings.Repeat("=", 80))
|
||
|
||
return taskLogger, logFile, nil
|
||
}
|
||
|
||
// Start 启动自动发布(支持并发worker数量)
|
||
func (pm *PublishManager) Start(tokenID int, workerNum int) bool {
|
||
pm.mu.Lock()
|
||
defer pm.mu.Unlock()
|
||
|
||
if pm.AutoStatus {
|
||
log.Printf("自动发布服务已在运行中,tokenID=%d", tokenID)
|
||
return false
|
||
}
|
||
|
||
if workerNum <= 0 {
|
||
workerNum = DefaultWorkerNum
|
||
}
|
||
if workerNum > MaxWorkerNum {
|
||
workerNum = MaxWorkerNum
|
||
}
|
||
|
||
pm.TokenID = tokenID
|
||
pm.AutoStatus = true
|
||
pm.workerNum = workerNum
|
||
pm.stopCh = make(chan struct{})
|
||
|
||
// ✅ 修改:重置 sync.Once(关键修复)
|
||
pm.stopOnceMu.Lock()
|
||
pm.stopOnce = sync.Once{}
|
||
pm.stopOnceMu.Unlock()
|
||
|
||
for i := 0; i < workerNum; i++ {
|
||
pm.workerWg.Add(1)
|
||
go pm.workerLoop(i)
|
||
}
|
||
|
||
log.Printf("自动发布服务已启动,tokenID=%d,worker数量=%d", tokenID, workerNum)
|
||
return true
|
||
}
|
||
|
||
func (pm *PublishManager) Stop() bool {
|
||
pm.mu.Lock()
|
||
if !pm.AutoStatus {
|
||
pm.mu.Unlock()
|
||
return false
|
||
}
|
||
pm.AutoStatus = false
|
||
pm.mu.Unlock()
|
||
|
||
// ✅ 修改:使用互斥锁保护 Once 调用
|
||
pm.stopOnceMu.Lock()
|
||
pm.stopOnce.Do(func() {
|
||
close(pm.stopCh)
|
||
})
|
||
pm.stopOnceMu.Unlock()
|
||
|
||
done := make(chan struct{})
|
||
go func() {
|
||
pm.workerWg.Wait()
|
||
close(done)
|
||
}()
|
||
|
||
select {
|
||
case <-done:
|
||
log.Println("所有worker已正常退出")
|
||
case <-time.After(30 * time.Second):
|
||
log.Println("等待worker退出超时,强制结束")
|
||
}
|
||
return true
|
||
}
|
||
|
||
// workerLoop 单个worker循环
|
||
func (pm *PublishManager) workerLoop(workerID int) {
|
||
defer pm.workerWg.Done()
|
||
log.Printf("[Worker-%d] 启动,tokenID=%d", workerID, pm.TokenID)
|
||
|
||
for {
|
||
select {
|
||
case <-pm.stopCh:
|
||
log.Printf("[Worker-%d] 收到停止信号,退出", workerID)
|
||
return
|
||
default:
|
||
pm.executeOneTask(workerID, true)
|
||
time.Sleep(2 * time.Second)
|
||
}
|
||
}
|
||
}
|
||
|
||
// executeOneTask 执行单个任务
|
||
func (pm *PublishManager) executeOneTask(workerID int, headless bool) {
|
||
task, err := pm.acquireTask()
|
||
if err != nil {
|
||
log.Printf("[Worker-%d] 获取任务失败: %v", workerID, err)
|
||
time.Sleep(5 * time.Second)
|
||
return
|
||
}
|
||
if task == nil {
|
||
time.Sleep(30 * time.Second)
|
||
return
|
||
}
|
||
|
||
log.Printf("[Worker-%d] 开始处理任务 requestID=%s", workerID, task.RequestID)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||
defer cancel()
|
||
// 使用 channel 接收结果,避免 goroutine 泄漏
|
||
resultChan := make(chan *SingleResult, 1)
|
||
|
||
go func() {
|
||
resultChan <- pm.processTask(ctx, task, headless)
|
||
}()
|
||
|
||
// 监听 ctx 超时
|
||
var result *SingleResult
|
||
select {
|
||
case result = <-resultChan:
|
||
// 任务正常完成
|
||
case <-ctx.Done():
|
||
// 超时或取消
|
||
log.Printf("[Worker-%d] 任务超时或被取消 requestID=%s, err=%v", workerID, task.RequestID, ctx.Err())
|
||
result = &SingleResult{
|
||
Success: false,
|
||
Message: fmt.Sprintf("任务执行超时: %v", ctx.Err()),
|
||
}
|
||
// 注意:这里不能等待 goroutine 结束,因为 processTask 可能卡住
|
||
// cancel() 会在 defer 中调用,但 processTask 内部需要响应 ctx.Done()
|
||
}
|
||
|
||
if result == nil {
|
||
log.Printf("[Worker-%d] 任务返回空结果", workerID)
|
||
} else if result.Success {
|
||
log.Printf("[Worker-%d] 任务成功: %s", workerID, result.Message)
|
||
} else {
|
||
log.Printf("[Worker-%d] 任务失败: %s", workerID, result.Message)
|
||
}
|
||
}
|
||
|
||
// acquireTask 原子获取一个待发布任务(使用 GORM 事务 + FOR UPDATE SKIP LOCKED)
|
||
func (pm *PublishManager) acquireTask() (*entitys.PublishTaskDetail, error) {
|
||
currentTime := time.Now().Format("2006-01-02 15:04:05")
|
||
uid := uuid.NewString()[:16]
|
||
// 使用子查询先找出要更新的 request_id
|
||
updateSQL := `
|
||
UPDATE publish p
|
||
SET p.status = ?,uid=?
|
||
WHERE p.request_id = (
|
||
SELECT request_id FROM (
|
||
SELECT p2.request_id
|
||
FROM publish p2
|
||
INNER JOIN plat pl ON p2.plat_index COLLATE utf8mb4_unicode_ci = pl.index
|
||
WHERE p2.token_id = ?
|
||
AND p2.status = ?
|
||
AND p2.publish_time <= ?
|
||
AND pl.status = 1
|
||
ORDER BY p2.publish_time ASC
|
||
LIMIT 1
|
||
) AS tmp
|
||
)
|
||
AND p.status = ?
|
||
`
|
||
|
||
result := pm.db.Client.Exec(updateSQL, StatusProcessing, uid, pm.TokenID, StatusPending, currentTime, StatusPending)
|
||
if result.Error != nil {
|
||
return nil, result.Error
|
||
}
|
||
|
||
if result.RowsAffected == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
// 查询刚更新的任务
|
||
var task entitys.PublishTaskDetail
|
||
querySQL := `
|
||
SELECT
|
||
p.id,
|
||
p.request_id,
|
||
p.plat_index,
|
||
p.title,
|
||
p.tag,
|
||
p.user_index,
|
||
p.url,
|
||
p.img,
|
||
p.publish_time,
|
||
p.status,
|
||
pl.index as plat_index_value,
|
||
pl.status as plat_status,
|
||
pl.login_url,
|
||
pl.edit_url,
|
||
pl.logined_url,
|
||
pl.desc
|
||
FROM publish p
|
||
INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1
|
||
WHERE uid=? AND p.token_id = ?
|
||
ORDER BY p.publish_time DESC
|
||
LIMIT 1
|
||
`
|
||
|
||
err := pm.db.Client.Raw(querySQL, uid, pm.TokenID).Scan(&task).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &task, nil
|
||
}
|
||
|
||
// processTask 处理单个任务
|
||
func (pm *PublishManager) processTask(ctx context.Context, publishData *entitys.PublishTaskDetail, headless bool) *SingleResult {
|
||
if publishData == nil || publishData.RequestID == "" {
|
||
return &SingleResult{Success: false, Message: "无效的任务数据"}
|
||
}
|
||
|
||
// 检查 context 是否已取消
|
||
select {
|
||
case <-ctx.Done():
|
||
return &SingleResult{Success: false, Message: "任务被取消: " + ctx.Err().Error(), RequestId: publishData.RequestID}
|
||
default:
|
||
}
|
||
|
||
taskLogger, logFile, err := pm.getTaskLogger(publishData.RequestID)
|
||
if err != nil {
|
||
log.Printf("[任务 %s] 创建日志文件失败: %v,使用全局日志", publishData.RequestID, err)
|
||
taskLogger = log.Default()
|
||
}
|
||
if logFile != nil {
|
||
defer logFile.Close()
|
||
}
|
||
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
errMsg := fmt.Sprintf("任务执行发生panic: %v", r)
|
||
taskLogger.Printf("❌ CRITICAL: %s", errMsg)
|
||
pm.updatePublishStatus(publishData, StatusFailed, errMsg)
|
||
}
|
||
}()
|
||
|
||
taskLogger.Printf("[任务 %s] 开始处理(headless=%v)", publishData.RequestID, headless)
|
||
|
||
params, sourceUrl := pm.extractTaskParams(publishData, taskLogger)
|
||
if params == nil {
|
||
pm.updatePublishStatus(publishData, StatusFailed, "提取任务参数失败")
|
||
return &SingleResult{Success: false, Message: "提取任务参数失败", RequestId: publishData.RequestID}
|
||
}
|
||
params.Headless = headless
|
||
|
||
publisherClass := GetPublisherClass(params.PlatIndex)
|
||
if publisherClass == nil {
|
||
errMsg := fmt.Sprintf("不支持的平台: %s", params.PlatIndex)
|
||
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
|
||
pm.updatePublishStatus(publishData, StatusFailed, errMsg)
|
||
return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID}
|
||
}
|
||
|
||
docPath, imgPath, err := pm.downloadAndPrepareFiles(params.RequestID, sourceUrl, taskLogger, publisherClass)
|
||
if err != nil {
|
||
errMsg := fmt.Sprintf("准备文件失败: %v", err)
|
||
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
|
||
pm.updatePublishStatus(publishData, StatusFailed, errMsg)
|
||
return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID}
|
||
}
|
||
defer pm.cleanupFiles(docPath, imgPath, taskLogger, publishData.RequestID)
|
||
|
||
params.Content, err = pm.extractContent(docPath, publisherClass, taskLogger, publishData.RequestID)
|
||
if err != nil {
|
||
errMsg := fmt.Sprintf("提取文档内容失败: %v", err)
|
||
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
|
||
pm.updatePublishStatus(publishData, StatusFailed, errMsg)
|
||
return &SingleResult{Success: false, Message: errMsg, RequestId: publishData.RequestID}
|
||
}
|
||
params.ImagePath = imgPath
|
||
params.SourcePath = docPath
|
||
|
||
pub := publisherClass.InitMethod(ctx, params, pm.Conf, taskLogger)
|
||
taskLogger.Printf("[任务 %s] 开始执行发布...", publishData.RequestID)
|
||
success, message := pub.PublishNote()
|
||
|
||
if success {
|
||
taskLogger.Printf("[任务 %s] ✅ 发布成功: %s", publishData.RequestID, message)
|
||
pm.updatePublishStatus(publishData, StatusSuccess, message)
|
||
} else {
|
||
taskLogger.Printf("[任务 %s] ❌ 发布失败: %s", publishData.RequestID, message)
|
||
pm.updatePublishStatus(publishData, StatusFailed, message)
|
||
}
|
||
|
||
taskLogger.Printf(strings.Repeat("=", 80))
|
||
taskLogger.Printf("任务结束 | RequestID: %s | 结果: %v", publishData.RequestID, success)
|
||
return &SingleResult{Success: success, Message: message, RequestId: publishData.RequestID}
|
||
}
|
||
|
||
// RetryTask 重试任务(非无头模式)
|
||
func (pm *PublishManager) RetryTask(ctx context.Context, tokenId int, requestID string) *SingleResult {
|
||
if requestID == "" {
|
||
return &SingleResult{Success: false, Message: "requestID不能为空"}
|
||
}
|
||
|
||
publishData, err := pm.GetTaskByRequestID(requestID, tokenId)
|
||
if err != nil || publishData == nil {
|
||
return &SingleResult{Success: false, Message: "任务不存在"}
|
||
}
|
||
|
||
log.Printf("[重试] 开始重试任务 requestID=%s(非无头模式)", requestID)
|
||
result := pm.processTask(ctx, publishData, false)
|
||
if result == nil {
|
||
result = &SingleResult{Success: false, Message: "系统故障", RequestId: requestID}
|
||
}
|
||
return result
|
||
}
|
||
|
||
// GetTaskByRequestID 根据RequestID获取任务
|
||
func (pm *PublishManager) GetTaskByRequestID(requestID string, tokenId int) (*entitys.PublishTaskDetail, error) {
|
||
if requestID == "" {
|
||
return nil, fmt.Errorf("requestID不能为空")
|
||
}
|
||
|
||
sql := `
|
||
SELECT
|
||
p.id,
|
||
p.request_id,
|
||
p.token_id,
|
||
p.plat_index,
|
||
p.title,
|
||
p.tag,
|
||
p.user_index,
|
||
p.url,
|
||
p.img,
|
||
p.publish_time,
|
||
p.status,
|
||
pl.index as plat_index_value,
|
||
pl.status as plat_status,
|
||
pl.login_url,
|
||
pl.edit_url,
|
||
pl.logined_url,
|
||
pl.desc
|
||
FROM publish p
|
||
INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1
|
||
WHERE p.request_id = ? AND token_id=?
|
||
`
|
||
var task entitys.PublishTaskDetail
|
||
err := pm.db.GetOneToStruct(sql, &task, requestID, tokenId)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if task.RequestID == "" {
|
||
return nil, nil
|
||
}
|
||
return &task, nil
|
||
}
|
||
|
||
// GetStatus 获取状态
|
||
func (pm *PublishManager) GetStatus() map[string]interface{} {
|
||
pm.mu.RLock()
|
||
defer pm.mu.RUnlock()
|
||
|
||
return map[string]interface{}{
|
||
"auto_status": pm.AutoStatus,
|
||
"worker_num": pm.workerNum,
|
||
"token_id": pm.TokenID,
|
||
}
|
||
}
|
||
|
||
// extractTaskParams 提取任务参数
|
||
func (pm *PublishManager) extractTaskParams(publishData *entitys.PublishTaskDetail, taskLogger *log.Logger) (*publisher.TaskParams, *fileUrl) {
|
||
taskLogger.Printf("[任务 %s] 任务详情 - 平台:%s,标题:%s,用户:%s", publishData.RequestID, publishData.PlatIndex, publishData.Title, publishData.UserIndex)
|
||
|
||
tags := pkg.ParseTags(publishData.Tag)
|
||
taskLogger.Printf("[任务 %s] 标签解析完成: %v", publishData.RequestID, tags)
|
||
|
||
return &publisher.TaskParams{
|
||
RequestID: publishData.RequestID,
|
||
PlatIndex: publishData.PlatIndex,
|
||
Title: publishData.Title,
|
||
TagRaw: publishData.Tag,
|
||
UserIndex: publishData.UserIndex,
|
||
Tags: tags,
|
||
PublishData: publishData,
|
||
}, &fileUrl{
|
||
url: publishData.URL,
|
||
imgURL: publishData.Img,
|
||
}
|
||
}
|
||
|
||
// downloadAndPrepareFiles 下载文档和图片
|
||
func (pm *PublishManager) downloadAndPrepareFiles(requestId string, params *fileUrl, taskLogger *log.Logger, publishClass *publisher.PublisherValue) (docPath, imgPath string, err error) {
|
||
taskLogger.Printf("[任务 %s] 开始下载文档...", requestId)
|
||
docPath, err = pkg.DownloadFile(params.url, pm.Conf.Sys.DocsDir, requestId)
|
||
if err != nil {
|
||
return "", "", fmt.Errorf("下载文档失败: %v", err)
|
||
}
|
||
taskLogger.Printf("[任务 %s] ✅ 文档下载成功: %s", requestId, docPath)
|
||
|
||
taskLogger.Printf("[任务 %s] 开始下载图片...", requestId)
|
||
imgPath, err = pkg.DownloadImage(params.imgURL, requestId, pm.Conf.Sys.UploadDir)
|
||
if err != nil {
|
||
if docPath != "" {
|
||
pkg.DeleteFile(docPath)
|
||
}
|
||
return "", "", fmt.Errorf("下载图片失败: %v", err)
|
||
}
|
||
taskLogger.Printf("[任务 %s] ✅ 图片下载成功: %s", requestId, imgPath)
|
||
|
||
if publishClass.Type == 1 && publishClass.WordContainImg {
|
||
if err = pkg.CopyImageToDoc(docPath, imgPath); err != nil {
|
||
return docPath, imgPath, fmt.Errorf("复制图片到文档失败: %v", err)
|
||
}
|
||
}
|
||
return docPath, imgPath, nil
|
||
}
|
||
|
||
// cleanupFiles 清理临时文件
|
||
func (pm *PublishManager) cleanupFiles(docPath, imgPath string, taskLogger *log.Logger, requestID string) {
|
||
if docPath != "" {
|
||
pkg.DeleteFile(docPath)
|
||
taskLogger.Printf("[任务 %s] 已删除文档文件: %s", requestID, docPath)
|
||
}
|
||
if imgPath != "" {
|
||
pkg.DeleteFile(imgPath)
|
||
taskLogger.Printf("[任务 %s] 已删除图片文件: %s", requestID, imgPath)
|
||
}
|
||
}
|
||
|
||
// extractContent 提取文档内容
|
||
func (pm *PublishManager) extractContent(docPath string, publisherClass *publisher.PublisherValue, taskLogger *log.Logger, requestID string) (string, error) {
|
||
if publisherClass.Type != 1 {
|
||
return "", nil
|
||
}
|
||
taskLogger.Printf("[任务 %s] 开始提取文档内容...", requestID)
|
||
content, err := pkg.ExtractWordContent(docPath, publisherClass.ContentFormat)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
taskLogger.Printf("[任务 %s] ✅ 内容提取成功,长度: %d", requestID, len(content))
|
||
return content, nil
|
||
}
|
||
|
||
// updatePublishStatus 更新发布状态
|
||
func (pm *PublishManager) updatePublishStatus(publishData *entitys.PublishTaskDetail, status int, message string) error {
|
||
if publishData.ID == 0 {
|
||
return fmt.Errorf("id不能为空")
|
||
}
|
||
var err error
|
||
if message != "" {
|
||
_, err = pm.db.Execute("UPDATE publish SET status = ?, msg = ? WHERE id=?", status, message, publishData.ID)
|
||
} else {
|
||
_, err = pm.db.Execute("UPDATE publish SET status = ? WHERE id=?", status, publishData.ID)
|
||
}
|
||
if err != nil {
|
||
log.Printf("更新发布状态失败: id=%s, status=%d, error=%v", publishData.ID, status, err)
|
||
}
|
||
_ = pm.publicBiz.Notify(&entitys.NotifyData{
|
||
TokenId: publishData.TokenID,
|
||
Type: entitys.NotifyTypePublish,
|
||
Status: status,
|
||
Msg: message,
|
||
})
|
||
return err
|
||
}
|
||
|
||
// GetPublisherClass 获取发布器类
|
||
func GetPublisherClass(platIndex string) *publisher.PublisherValue {
|
||
if platIndex == "" {
|
||
return nil
|
||
}
|
||
if publisherClass, exists := publisher.PublisherMap[platIndex]; exists {
|
||
return publisherClass
|
||
}
|
||
log.Printf("未找到平台 %s 对应的发布器", platIndex)
|
||
return nil
|
||
}
|
||
|
||
// SingleResult 单次任务结果
|
||
type SingleResult struct {
|
||
Success bool `json:"success"`
|
||
Message string `json:"message"`
|
||
RequestId string `json:"request_id"`
|
||
}
|
||
|
||
// fileUrl 内部辅助结构体
|
||
type fileUrl struct {
|
||
url string
|
||
imgURL string
|
||
}
|