583 lines
17 KiB
Go
583 lines
17 KiB
Go
package manager
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"geo/internal/config"
|
||
"geo/internal/entitys"
|
||
"geo/internal/publisher"
|
||
"geo/pkg"
|
||
"geo/utils"
|
||
"io"
|
||
"log"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
// 任务状态常量
|
||
StatusPending = 1
|
||
StatusProcessing = 2
|
||
StatusFailed = 3
|
||
StatusSuccess = 4
|
||
|
||
// 批处理间隔
|
||
BatchInterval = 30 * time.Second
|
||
|
||
// 日志格式
|
||
LogSeparator = "================================================================"
|
||
)
|
||
|
||
// PublishManager 发布管理器
|
||
type PublishManager struct {
|
||
AutoStatus bool
|
||
Conf *config.Config
|
||
TokenID int
|
||
running bool
|
||
mu sync.RWMutex // 使用读写锁优化并发读取
|
||
stopCh chan struct{}
|
||
db *utils.Db
|
||
stopOnce sync.Once // 确保stopCh只关闭一次
|
||
}
|
||
|
||
var (
|
||
publishManager *PublishManager
|
||
once sync.Once
|
||
)
|
||
|
||
// GetPublishManager 获取单例实例(优化:添加nil检查)
|
||
func GetPublishManager(config *config.Config, db *utils.Db) (*PublishManager, error) {
|
||
if config == nil || db == nil {
|
||
return nil, fmt.Errorf("config和db参数不能为空")
|
||
}
|
||
|
||
once.Do(func() {
|
||
publishManager = &PublishManager{
|
||
AutoStatus: false,
|
||
Conf: config,
|
||
stopCh: make(chan struct{}),
|
||
db: db,
|
||
}
|
||
})
|
||
return publishManager, nil
|
||
}
|
||
|
||
// getTaskLogger 获取任务专属日志记录器(优化:添加参数验证和错误恢复)
|
||
func (pm *PublishManager) getTaskLogger(requestID string) (*log.Logger, *os.File, error) {
|
||
if requestID == "" {
|
||
return nil, nil, fmt.Errorf("requestID不能为空")
|
||
}
|
||
|
||
// 确定日志目录,提供默认值
|
||
logsDir := pm.Conf.Sys.LogsDir
|
||
if logsDir == "" {
|
||
logsDir = "./logs"
|
||
}
|
||
|
||
// 按日期创建子目录
|
||
dateDir := time.Now().Format("2006-01-02")
|
||
taskLogDir := filepath.Join(logsDir, "tasks", dateDir)
|
||
if err := os.MkdirAll(taskLogDir, 0755); err != nil {
|
||
return nil, nil, fmt.Errorf("创建日志目录失败: %v", err)
|
||
}
|
||
|
||
// 使用requestID作为文件名,添加随机数避免重复(可选)
|
||
logPath := filepath.Join(taskLogDir, fmt.Sprintf("%s.log", requestID))
|
||
|
||
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("创建日志文件失败: %v", err)
|
||
}
|
||
|
||
// 创建写入器:同时写入文件和标准输出
|
||
multiWriter := io.MultiWriter(logFile, os.Stdout)
|
||
taskLogger := log.New(multiWriter, "", log.LstdFlags|log.Lmicroseconds)
|
||
|
||
// 写入任务开始分隔线
|
||
taskLogger.Printf(strings.Repeat("=", 80))
|
||
taskLogger.Printf("任务开始 | RequestID: %s | 时间: %s", requestID, time.Now().Format("2006-01-02 15:04:05.000"))
|
||
taskLogger.Printf(strings.Repeat("=", 80))
|
||
|
||
return taskLogger, logFile, nil
|
||
}
|
||
|
||
// Start 启动自动发布(优化:使用读锁检查状态,写锁修改状态)
|
||
func (pm *PublishManager) Start(tokenID int) bool {
|
||
pm.mu.Lock()
|
||
defer pm.mu.Unlock()
|
||
|
||
if pm.AutoStatus {
|
||
log.Printf("自动发布服务已在运行中,tokenID=%d", tokenID)
|
||
return false
|
||
}
|
||
|
||
pm.TokenID = tokenID
|
||
pm.AutoStatus = true
|
||
pm.stopCh = make(chan struct{}) // 重新创建stopCh
|
||
|
||
go pm.autoPublishLoop()
|
||
log.Printf("自动发布服务已启动,tokenID=%d", tokenID)
|
||
return true
|
||
}
|
||
|
||
// Stop 停止自动发布(优化:使用sync.Once确保stopCh只关闭一次)
|
||
func (pm *PublishManager) Stop() bool {
|
||
pm.mu.Lock()
|
||
defer pm.mu.Unlock()
|
||
|
||
if !pm.AutoStatus {
|
||
return false
|
||
}
|
||
|
||
pm.AutoStatus = false
|
||
pm.stopOnce.Do(func() {
|
||
close(pm.stopCh)
|
||
})
|
||
return true
|
||
}
|
||
|
||
// autoPublishLoop 自动发布循环(优化:添加退出日志)
|
||
func (pm *PublishManager) autoPublishLoop() {
|
||
log.Println("自动发布服务已启动,开始循环执行")
|
||
|
||
for {
|
||
select {
|
||
case <-pm.stopCh:
|
||
log.Println("自动发布服务已停止")
|
||
return
|
||
default:
|
||
pm.batchPublish()
|
||
time.Sleep(BatchInterval)
|
||
}
|
||
}
|
||
}
|
||
|
||
// batchPublish 批量发布
|
||
func (pm *PublishManager) batchPublish() {
|
||
if !pm.isAutoStatus() {
|
||
return
|
||
}
|
||
|
||
publishData, err := pm.getPendingPublish()
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 使用context实现超时控制
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(pm.Conf.Sys.TaskTimeout)*time.Second)
|
||
defer cancel()
|
||
|
||
done := make(chan struct{})
|
||
go func() {
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
log.Printf("批处理发布发生 panic: %v", r)
|
||
}
|
||
}()
|
||
pm.processSingleTask(publishData)
|
||
close(done)
|
||
}()
|
||
|
||
select {
|
||
case <-done:
|
||
// 正常完成
|
||
case <-ctx.Done():
|
||
log.Printf("任务执行超时: %v", ctx.Err())
|
||
}
|
||
}
|
||
|
||
// getPendingPublish 获取待发布任务(返回结构体)
|
||
func (pm *PublishManager) getPendingPublish() (*entitys.PublishTaskDetail, error) {
|
||
currentTime := time.Now().Format("2006-01-02 15:04:05")
|
||
|
||
// SQL查询,明确指定字段
|
||
sql := `
|
||
SELECT
|
||
p.request_id,
|
||
p.plat_index,
|
||
p.title,
|
||
p.tag,
|
||
p.user_index,
|
||
p.url,
|
||
p.img,
|
||
p.publish_time,
|
||
p.status,
|
||
pl.index as plat_index_value,
|
||
pl.status as plat_status
|
||
pl.login_url,
|
||
pl.edit_url,
|
||
pl.logined_url,
|
||
pl.desc
|
||
FROM publish p
|
||
INNER JOIN plat pl ON p.plat_index = pl.index AND pl.status = 1
|
||
WHERE p.token_id = ? AND p.status = ? AND p.publish_time <= ?
|
||
ORDER BY p.publish_time ASC
|
||
LIMIT 1
|
||
`
|
||
|
||
var task entitys.PublishTaskDetail
|
||
err := pm.db.GetOneToStruct(sql, &task, pm.TokenID, StatusPending, currentTime)
|
||
if err != nil {
|
||
log.Printf("查询待发布任务失败: token_id=%d, error=%v", pm.TokenID, err)
|
||
return nil, err
|
||
}
|
||
|
||
// 检查是否为空记录(根据你的db实现,可能需要判断task.RequestID是否为空)
|
||
if task.RequestID == "" {
|
||
log.Printf("没有待发布任务: token_id=%d, current_time=%s", pm.TokenID, currentTime)
|
||
return nil, nil
|
||
}
|
||
|
||
log.Printf("获取到待发布任务: token_id=%d, request_id=%s", pm.TokenID, task.RequestID)
|
||
return &task, nil
|
||
}
|
||
|
||
// isAutoStatus 获取自动状态(优化:使用读锁)
|
||
func (pm *PublishManager) isAutoStatus() bool {
|
||
pm.mu.RLock()
|
||
defer pm.mu.RUnlock()
|
||
return pm.AutoStatus
|
||
}
|
||
|
||
// GetTaskByRequestID 根据RequestID获取任务(优化:添加缓存机制可选项)
|
||
func (pm *PublishManager) GetTaskByRequestID(requestID string) (*entitys.PublishTaskDetail, error) {
|
||
if requestID == "" {
|
||
return nil, fmt.Errorf("requestID不能为空")
|
||
}
|
||
|
||
sql := `
|
||
SELECT
|
||
p.request_id,
|
||
p.plat_index,
|
||
p.title,
|
||
p.tag,
|
||
p.user_index,
|
||
p.url,
|
||
p.img,
|
||
p.publish_time,
|
||
p.status,
|
||
pl.index as plat_index_value,
|
||
pl.status as plat_status,
|
||
pl.login_url,
|
||
pl.edit_url,
|
||
pl.logined_url,
|
||
pl.desc
|
||
FROM publish p
|
||
INNER JOIN plat pl ON p.plat_index COLLATE utf8mb4_unicode_ci = pl.index AND pl.status = 1
|
||
WHERE p.request_id = ?
|
||
`
|
||
var task entitys.PublishTaskDetail
|
||
err := pm.db.GetOneToStruct(sql, &task, requestID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 检查是否为空记录(根据你的db实现,可能需要判断task.RequestID是否为空)
|
||
if task.RequestID == "" {
|
||
return nil, nil
|
||
}
|
||
|
||
return &task, nil
|
||
}
|
||
|
||
type SingleResult struct {
|
||
Success bool `json:"success"`
|
||
Message string `json:"message"`
|
||
RequestId string `json:"request_id"`
|
||
}
|
||
|
||
func (pm *PublishManager) processSingleTask(publishData *entitys.PublishTaskDetail) (result *SingleResult) {
|
||
if publishData == nil {
|
||
return &SingleResult{
|
||
Success: false,
|
||
Message: "publishData不能为空",
|
||
}
|
||
}
|
||
if publishData.RequestID == "" {
|
||
|
||
return &SingleResult{
|
||
Success: false,
|
||
Message: "requestID不能为空",
|
||
}
|
||
}
|
||
|
||
// 获取任务专属日志
|
||
taskLogger, logFile, err := pm.getTaskLogger(publishData.RequestID)
|
||
if err != nil {
|
||
log.Printf("[任务 %s] 创建日志文件失败: %v,使用全局日志", publishData.RequestID, err)
|
||
taskLogger = log.Default()
|
||
}
|
||
if logFile != nil {
|
||
defer logFile.Close()
|
||
}
|
||
|
||
// 全局defer用于捕获panic并记录
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
errMsg := fmt.Sprintf("任务执行发生panic: %v", r)
|
||
taskLogger.Printf("❌ CRITICAL: %s", errMsg)
|
||
taskLogger.Printf(strings.Repeat("=", 80))
|
||
taskLogger.Printf("任务异常结束 | RequestID: %s | 时间: %s", publishData.RequestID, time.Now().Format("2006-01-02 15:04:05.000"))
|
||
taskLogger.Printf(strings.Repeat("=", 80))
|
||
|
||
result = &SingleResult{
|
||
Success: false,
|
||
Message: errMsg,
|
||
RequestId: publishData.RequestID,
|
||
}
|
||
return
|
||
}
|
||
}()
|
||
|
||
taskLogger.Printf("[任务 %s] 开始处理", publishData.RequestID)
|
||
|
||
// 提取任务参数
|
||
params, sourceUrl := pm.extractTaskParams(publishData, taskLogger)
|
||
if params == nil {
|
||
return &SingleResult{
|
||
Success: false,
|
||
Message: "提取任务参数失败",
|
||
RequestId: publishData.RequestID,
|
||
}
|
||
}
|
||
|
||
// 获取发布器
|
||
publisherClass := GetPublisherClass(params.PlatIndex)
|
||
if publisherClass == nil {
|
||
errMsg := fmt.Sprintf("不支持的平台: %s", params.PlatIndex)
|
||
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
|
||
pm.updatePublishStatus(publishData.RequestID, StatusFailed, errMsg)
|
||
return &SingleResult{
|
||
Success: false,
|
||
Message: errMsg,
|
||
RequestId: publishData.RequestID,
|
||
}
|
||
}
|
||
|
||
// 更新状态为发布中
|
||
if err := pm.updatePublishStatus(publishData.RequestID, StatusProcessing, ""); err != nil {
|
||
taskLogger.Printf("[任务 %s] ❌ 更新状态失败: %v", publishData.RequestID, err)
|
||
}
|
||
|
||
// 下载并处理文档
|
||
params.SourcePath, params.ImagePath, err = pm.downloadAndPrepareFiles(params.RequestID, sourceUrl, taskLogger, publisherClass)
|
||
|
||
if err != nil {
|
||
errMsg := fmt.Sprintf("准备文件失败: %v", err)
|
||
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
|
||
pm.updatePublishStatus(publishData.RequestID, StatusFailed, errMsg)
|
||
return &SingleResult{
|
||
Success: false,
|
||
Message: errMsg,
|
||
RequestId: publishData.RequestID,
|
||
}
|
||
}
|
||
|
||
// 确保文件清理
|
||
defer pm.cleanupFiles(params.SourcePath, params.ImagePath, taskLogger, publishData.RequestID)
|
||
|
||
// 提取内容
|
||
params.Content, err = pm.extractContent(params.SourcePath, publisherClass, taskLogger, publishData.RequestID)
|
||
if err != nil {
|
||
errMsg := fmt.Sprintf("提取文档内容失败: %v", err)
|
||
taskLogger.Printf("[任务 %s] ❌ %s", publishData.RequestID, errMsg)
|
||
pm.updatePublishStatus(publishData.RequestID, StatusFailed, errMsg)
|
||
return &SingleResult{
|
||
Success: false,
|
||
Message: errMsg,
|
||
RequestId: publishData.RequestID,
|
||
}
|
||
}
|
||
|
||
// 执行发布
|
||
taskLogger.Printf("[任务 %s] ✅ 内容提取成功,长度: %d", publishData.RequestID, len(params.Content))
|
||
taskLogger.Printf("[任务 %s] 创建发布器...", publishData.RequestID)
|
||
pub := publisherClass.InitMethod(params, pm.Conf, taskLogger)
|
||
taskLogger.Printf("[任务 %s] 创建%s发布器", publisherClass.Name, publishData.RequestID)
|
||
taskLogger.Printf("[任务 %s] 开始执行发布...", publishData.RequestID)
|
||
success, message := pub.PublishNote()
|
||
|
||
// 更新最终状态
|
||
if success {
|
||
taskLogger.Printf("[任务 %s] ✅ 发布成功: %s", publishData.RequestID, message)
|
||
pm.updatePublishStatus(publishData.RequestID, StatusSuccess, message)
|
||
} else {
|
||
taskLogger.Printf("[任务 %s] ❌ 发布失败: %s", publishData.RequestID, message)
|
||
pm.updatePublishStatus(publishData.RequestID, StatusFailed, message)
|
||
}
|
||
|
||
// 记录任务结束
|
||
taskLogger.Printf(strings.Repeat("=", 80))
|
||
taskLogger.Printf("任务结束 | RequestID: %s | 结果: %v | 时间: %s", publishData.RequestID, success, time.Now().Format("2006-01-02 15:04:05.000"))
|
||
taskLogger.Printf(strings.Repeat("=", 80))
|
||
|
||
return &SingleResult{
|
||
Success: success,
|
||
Message: message,
|
||
RequestId: publishData.RequestID,
|
||
}
|
||
}
|
||
|
||
type fileUrl struct {
|
||
url string
|
||
imgURL string
|
||
}
|
||
|
||
func (pm *PublishManager) extractTaskParams(publishData *entitys.PublishTaskDetail, taskLogger *log.Logger) (*publisher.TaskParams, *fileUrl) {
|
||
|
||
taskLogger.Printf("[任务 %s] 任务详情 - 平台:%s,标题:%s,用户:%s", publishData.RequestID, publishData.PlatIndex, publishData.Title, publishData.UserIndex)
|
||
|
||
// 解析标签
|
||
tags := pkg.ParseTags(publishData.Tag)
|
||
taskLogger.Printf("[任务 %s] 标签解析完成: %v", publishData.RequestID, tags)
|
||
|
||
return &publisher.TaskParams{
|
||
RequestID: publishData.RequestID,
|
||
PlatIndex: publishData.PlatIndex,
|
||
Title: publishData.Title,
|
||
TagRaw: publishData.Tag,
|
||
UserIndex: publishData.UserIndex,
|
||
Tags: tags,
|
||
PublishData: publishData,
|
||
}, &fileUrl{
|
||
url: publishData.URL,
|
||
imgURL: publishData.Img,
|
||
}
|
||
}
|
||
|
||
func (pm *PublishManager) downloadAndPrepareFiles(requestId string, params *fileUrl, taskLogger *log.Logger, publishClass *publisher.PublisherValue) (docPath, imgPath string, err error) {
|
||
// 下载文档
|
||
taskLogger.Printf("[任务 %s] 开始下载文档...", requestId)
|
||
docPath, err = pkg.DownloadFile(params.url, pm.Conf.Sys.DocsDir, requestId)
|
||
if err != nil {
|
||
return "", "", fmt.Errorf("下载文档失败: %v", err)
|
||
}
|
||
taskLogger.Printf("[任务 %s] ✅ 文档下载成功: %s", requestId, docPath)
|
||
|
||
// 下载图片
|
||
taskLogger.Printf("[任务 %s] 开始下载图片...", requestId)
|
||
imgPath, err = pkg.DownloadImage(params.imgURL, requestId, pm.Conf.Sys.UploadDir)
|
||
if err != nil {
|
||
// 如果图片下载失败,清理已下载的文档
|
||
if docPath != "" {
|
||
pkg.DeleteFile(docPath)
|
||
}
|
||
return "", "", fmt.Errorf("下载图片失败: %v", err)
|
||
}
|
||
taskLogger.Printf("[任务 %s] ✅ 图片下载成功: %s", requestId, imgPath)
|
||
|
||
if publishClass.Type == 1 && publishClass.WordContainImg {
|
||
// 如果文档中包含图片,则将图片复制到文档目录
|
||
if err = pkg.CopyImageToDoc(docPath, imgPath); err != nil {
|
||
return docPath, imgPath, fmt.Errorf("复制图片到文档失败: %v", err)
|
||
}
|
||
}
|
||
|
||
return docPath, imgPath, nil
|
||
}
|
||
|
||
// cleanupFiles 清理临时文件
|
||
func (pm *PublishManager) cleanupFiles(docPath, imgPath string, taskLogger *log.Logger, requestID string) {
|
||
if docPath != "" {
|
||
pkg.DeleteFile(docPath)
|
||
taskLogger.Printf("[任务 %s] 已删除文档文件: %s", requestID, docPath)
|
||
}
|
||
if imgPath != "" {
|
||
pkg.DeleteFile(imgPath)
|
||
taskLogger.Printf("[任务 %s] 已删除图片文件: %s", requestID, imgPath)
|
||
}
|
||
}
|
||
|
||
// extractContent 提取文档内容
|
||
func (pm *PublishManager) extractContent(docPath string, publisherClass *publisher.PublisherValue, taskLogger *log.Logger, requestID string) (string, error) {
|
||
if publisherClass.Type != 1 {
|
||
return "", nil // 不需要提取内容
|
||
}
|
||
|
||
taskLogger.Printf("[任务 %s] 开始提取文档内容...", requestID)
|
||
content, err := pkg.ExtractWordContent(docPath, publisherClass.ContentFormat)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
taskLogger.Printf("[任务 %s] ✅ 内容提取成功,长度: %d", requestID, len(content))
|
||
return content, nil
|
||
}
|
||
|
||
// updatePublishStatus 更新发布状态(优化:添加错误处理)
|
||
func (pm *PublishManager) updatePublishStatus(requestID string, status int, message string) error {
|
||
if requestID == "" {
|
||
return fmt.Errorf("requestID不能为空")
|
||
}
|
||
|
||
var err error
|
||
if message != "" {
|
||
_, err = pm.db.Execute("UPDATE publish SET status = ?, msg = ? WHERE request_id = ?", status, message, requestID)
|
||
} else {
|
||
_, err = pm.db.Execute("UPDATE publish SET status = ? WHERE request_id = ?", status, requestID)
|
||
}
|
||
|
||
if err != nil {
|
||
log.Printf("更新发布状态失败: requestID=%s, status=%d, error=%v", requestID, status, err)
|
||
}
|
||
return err
|
||
}
|
||
|
||
// RetryTask 重试任务(优化:添加任务状态检查)
|
||
func (pm *PublishManager) RetryTask(requestID string) *SingleResult {
|
||
if requestID == "" {
|
||
return &SingleResult{
|
||
Success: false,
|
||
Message: "requestID不能为空",
|
||
}
|
||
}
|
||
|
||
publishData, err := pm.GetTaskByRequestID(requestID)
|
||
if err != nil || publishData == nil {
|
||
|
||
return &SingleResult{
|
||
Success: false,
|
||
Message: "任务不存在",
|
||
}
|
||
}
|
||
|
||
// 只允许重试失败的任务
|
||
//
|
||
//if publishData.Status != StatusFailed {
|
||
// return &SingleResult{
|
||
// Success: false,
|
||
// Message: fmt.Sprintf("只能重试失败的任务,当前状态: %d", publishData.Status),
|
||
// }
|
||
//}
|
||
|
||
return pm.processSingleTask(publishData)
|
||
}
|
||
|
||
// GetStatus 获取状态(优化:使用读锁)
|
||
func (pm *PublishManager) GetStatus() map[string]interface{} {
|
||
pm.mu.RLock()
|
||
defer pm.mu.RUnlock()
|
||
|
||
return map[string]interface{}{
|
||
"auto_status": pm.AutoStatus,
|
||
"max_concurrent": pm.Conf.Sys.MaxConcurrent,
|
||
"task_timeout": pm.Conf.Sys.TaskTimeout,
|
||
"token_id": pm.TokenID,
|
||
}
|
||
}
|
||
|
||
// GetPublisherClass 获取发布器类(优化:添加默认值处理)
|
||
func GetPublisherClass(platIndex string) *publisher.PublisherValue {
|
||
if platIndex == "" {
|
||
return nil
|
||
}
|
||
|
||
if publisherClass, exists := publisher.PublisherMap[platIndex]; exists {
|
||
return publisherClass
|
||
}
|
||
|
||
log.Printf("未找到平台 %s 对应的发布器", platIndex)
|
||
return nil
|
||
}
|