This commit is contained in:
renzhiyuan 2026-04-23 02:58:25 +08:00
parent 8e77619a37
commit 5fb8e3d2ee
17 changed files with 496 additions and 31 deletions

View File

@ -40,10 +40,11 @@ func InitializeApp(configConfig *config.Config, allLogger log.AllLogger) (*serve
authBiz := biz.NewAuthBiz(configConfig, tokenImpl, userImpl) authBiz := biz.NewAuthBiz(configConfig, tokenImpl, userImpl)
appService := service.NewAppService(configConfig, tokenImpl, userImpl, platImpl, publishBiz, authBiz, loginRelationImpl, productBiz) appService := service.NewAppService(configConfig, tokenImpl, userImpl, platImpl, publishBiz, authBiz, loginRelationImpl, productBiz)
loginService := service.NewLoginService(configConfig, publishBiz, authBiz) loginService := service.NewLoginService(configConfig, publishBiz, authBiz)
publishService := service.NewPublishService(configConfig, publishBiz, authBiz, db)
productService := service.NewProductService(configConfig, productImpl, authBiz, productBiz)
aiBiz := biz.NewAiBiz(platImpl, articleTypeImpl) aiBiz := biz.NewAiBiz(platImpl, articleTypeImpl)
publishService := service.NewPublishService(configConfig, publishBiz, authBiz, db, aiBiz)
productService := service.NewProductService(configConfig, productImpl, authBiz, productBiz, aiBiz)
productSourceService := service.NewProductSourceService(configConfig, productImpl, authBiz, aiBiz, productBiz, productSourceImpl, publishBiz, articleTypeImpl) productSourceService := service.NewProductSourceService(configConfig, productImpl, authBiz, aiBiz, productBiz, productSourceImpl, publishBiz, articleTypeImpl)
appModule := router.NewAppModule(configConfig, appService, loginService, publishService, productService, productSourceService) appModule := router.NewAppModule(configConfig, appService, loginService, publishService, productService, productSourceService)
routerServer := router.NewRouterServer(appModule) routerServer := router.NewRouterServer(appModule)

120
internal/ai_tool/collect.go Normal file
View File

@ -0,0 +1,120 @@
package ai_tool
import "geo/pkg"
type Collect struct {
apikey string
}
func NewCollect(apikey string) *Collect {
return &Collect{apikey: apikey}
}
type PlatForm struct {
Index int `json:"index"`
PlatFormName string `json:"name"`
Icon string `json:"icon"`
}
var PlatFormList = []PlatForm{
{
Index: 1,
PlatFormName: "deepseek",
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/deepseek.png",
},
{
Index: 2,
PlatFormName: "豆包",
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/doubao.png",
},
{
Index: 3,
PlatFormName: "元宝",
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/yuanbao.png",
},
{
Index: 41,
PlatFormName: "千问",
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/qianwen.png",
},
{
Index: 5,
PlatFormName: "文心一言",
Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/wenxin.png",
},
//{
// Index: 6,
// PlatFormName: "纳米",
//},
//{
// Index: 7,
// PlatFormName: "kimi",
//}, {
// Index: 8,
// PlatFormName: "智普",
//},
}
type CreateReq struct {
// apikey
APIKey string `json:"api_key"`
// 品牌词,多个用英文逗号隔开
Keywords string `json:"keywords"`
// 平台1-deepseek2-豆包3-元宝4-千问5-文心一言6-纳米7-kimi8-智普
Platform int64 `json:"platform"`
// 问题
Question string `json:"question"`
// 建议填第三方的用户id。方便查单
ThirdID *string `json:"third_id,omitempty"`
}
type CreateRes struct {
Code int `json:"code"`
Msg string `json:"msg"`
Time string `json:"time"`
Data struct {
RequestId string `json:"request_id"`
} `json:"data"`
}
func (s *Collect) Create(data *CreateReq) (*CreateRes, error) {
url := "http://8.138.187.158:8082/api/geo/add_shoulu"
data.APIKey = s.apikey
mapInfo, err := pkg.StructToMap(data)
if err != nil {
return nil, err
}
var res CreateRes
err = pkg.PostMultipart(url, mapInfo, &res)
return &res, err
}
type CheckTaskRes struct {
Code int `json:"code"`
Msg string `json:"msg"`
Time string `json:"time"`
Data struct {
ShouluDate string `json:"shoulu_date"`
Platform int `json:"platform"`
ScriptTime int `json:"script_time"`
Question string `json:"question"`
HitWord string `json:"hit_word"`
Status int `json:"status"`
ShareUrl string `json:"share_url"`
KeywordRes string `json:"keyword_res"`
ImgUrl string `json:"img_url"`
RequestId string `json:"request_id"`
} `json:"data"`
}
func (s *Collect) CheckTask(requestId string) (*CheckTaskRes, error) {
url := "http://8.138.187.158:8082/api/geo/check_task"
request := map[string]interface{}{
"request_id": requestId,
"api_key": s.apikey,
}
var res CheckTaskRes
err := pkg.PostMultipart(url, request, &res)
return &res, err
}

View File

@ -2,6 +2,9 @@ package ai_tool
import ( import (
"context" "context"
"encoding/json"
"geo/pkg"
"geo/tmpl/errcode"
"sync" "sync"
"time" "time"
@ -166,3 +169,23 @@ func (h *Hsyq) RequestHsyqBot(ctx context.Context, key string, botId string, mes
log.Info("token用量", resp.Usage.TotalTokens) log.Info("token用量", resp.Usage.TotalTokens)
return resp.Choices[0].Message.Content.StringValue, nil return resp.Choices[0].Message.Content.StringValue, nil
} }
func (h *Hsyq) RequestHsyqBotToJson(ctx context.Context, key string, botId string, message []*model.ChatCompletionMessage, point interface{}) error {
content, err := h.RequestHsyqBot(ctx, key, botId, message)
if err != nil {
return err
}
contentByte := []byte(*content)
if err := json.Unmarshal(contentByte, &point); err != nil {
contentStr, err := pkg.JsonRepair(*content)
if err != nil {
return errcode.SysErr("生成失败,请重试")
}
if err := json.Unmarshal([]byte(contentStr), point); err != nil {
return errcode.SysErr("生成失败,请重试")
}
}
return nil
}

View File

@ -31,7 +31,7 @@ func (a *AiBiz) CreateArticlePrompt(ctx context.Context, data *entitys.BotChat)
}, },
}, },
} }
return mes
//var plats []*model.Plat //var plats []*model.Plat
//cond := builder.NewCond(). //cond := builder.NewCond().
// And(builder.Eq{"plat_type": 1}). // And(builder.Eq{"plat_type": 1}).
@ -85,3 +85,15 @@ func (a *AiBiz) CreateArticlePrompt(ctx context.Context, data *entitys.BotChat)
//} //}
return mes return mes
} }
func (a *AiBiz) CreateProjectInfoPrompt(ctx context.Context, content string) []*volmodle.ChatCompletionMessage {
mes := []*volmodle.ChatCompletionMessage{
{
Role: volmodle.ChatMessageRoleUser,
Content: &volmodle.ChatCompletionMessageContent{
StringValue: volcengine.String(pkg.JsonStringIgonErr(content)),
},
},
}
return mes
}

View File

@ -12,6 +12,17 @@ type Config struct {
Sys Sys `mapstructure:"sys"` Sys Sys `mapstructure:"sys"`
Hsyq Hsyq `mapstructure:"hsyq"` Hsyq Hsyq `mapstructure:"hsyq"`
Oss Oss `mapstructure:"oss"` Oss Oss `mapstructure:"oss"`
AiBot AiBot `mapstructure:"ai_bot"`
collect Collect `mapstructure:"collect"`
}
type Collect struct {
ApiKey string `mapstructure:"api_key"`
}
type AiBot struct {
Article string `mapstructure:"article"`
ProductInfo string `mapstructure:"product_info"`
} }
type Oss struct { type Oss struct {
@ -69,10 +80,10 @@ func LoadConfig() (*Config, error) {
}, //root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai }, //root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai
DB: DB{ DB: DB{
Driver: "mysql", Driver: "mysql",
Source: "root:Renzhiyuan123.@tcp(8.137.19.81:3306)/geo?charset=utf8mb4&parseTime=true&loc=Local", Source: "root:Renzhiyuan123.@tcp(8.137.19.81:3306)/geo?charset=utf8mb4&parseTime=true",
}, },
Sys: Sys{ Sys: Sys{
AutoPublishWorkers: 5, AutoPublishWorkers: 2,
MaxConcurrent: 1, MaxConcurrent: 1,
TaskTimeout: 200, TaskTimeout: 200,
SessionTimeout: 300, SessionTimeout: 300,
@ -98,5 +109,12 @@ func LoadConfig() (*Config, error) {
Endpoint: "https://oss-cn-hangzhou.aliyuncs.com", Endpoint: "https://oss-cn-hangzhou.aliyuncs.com",
FilePath: "geo/", FilePath: "geo/",
}, },
AiBot: AiBot{
Article: "bot-20260413000114-8bw62",
ProductInfo: "bot-20260422010906-hvtbd",
},
Collect: Collect{
ApiKey: "sk_7bac5df901aa8933a238fcfec363f4a0",
},
}, nil }, nil
} }

View File

@ -0,0 +1,16 @@
package entitys
type ProductInfo struct {
Name string `json:"name" validate:"required" zh:"产品名称"`
Industry string `json:"industry" validate:"required" zh:"所属行业"`
Type string `json:"type" validate:"required" zh:"产品类型"`
ProductOrService string `json:"product_or_service" validate:"required" zh:"主营业务"`
Advantages string `json:"advantages" zh:"核心优势"`
Story string `json:"story" zh:"发展故事"`
Problem string `json:"problem" zh:"解决痛点"`
Background string `json:"background" zh:"信任背书"`
Case string `json:"case" zh:"品牌案例"`
Other string `json:"other" zh:"其他信息"`
ServiceScope string `json:"service_scope" zh:"服务范围"`
TargetAudience string `json:"target_audience" zh:"目标客户群体"`
}

View File

@ -194,4 +194,11 @@ type (
Plat []string `json:"plat" validate:"required" zh:"平台"` Plat []string `json:"plat" validate:"required" zh:"平台"`
PublishTime string `json:"publish_time" validate:"required" zh:"发布时间"` PublishTime string `json:"publish_time" validate:"required" zh:"发布时间"`
} }
ProductCollectRequest struct {
AccessToken string `json:"access_token" validate:"required" zh:"access_token"`
Keywords []string `json:"keywords" validate:"required" zh:"关键词"`
Platform []int64 `json:"platform" validate:"required" zh:"平台"`
Question string `json:"question" validate:"required" zh:"问题"`
}
) )

View File

@ -370,9 +370,8 @@ func (pm *PublishManager) processTask(ctx context.Context, publishData *entitys.
taskLogger.Printf("任务结束 | RequestID: %s | 结果: %v", publishData.RequestID, success) taskLogger.Printf("任务结束 | RequestID: %s | 结果: %v", publishData.RequestID, success)
res := &SingleResult{Success: success, Message: message, RequestId: publishData.RequestID} res := &SingleResult{Success: success, Message: message, RequestId: publishData.RequestID}
pm.uploadToOss(ctx, logPath, fmt.Sprintf("%slog/%d_%s.log", pm.Conf.Oss.FilePath, publishData.ID, publishData.RequestID))
url, err := pm.uploadToOss(ctx, logPath, fmt.Sprintf("%s%d_%s.log", pm.Conf.Oss.FilePath, publishData.ID, publishData.RequestID)) url, err := pm.uploadToOss(ctx, logPath, fmt.Sprintf("%slog/%d_%s.log", pm.Conf.Oss.FilePath, publishData.ID, publishData.RequestID))
if err != nil { if err != nil {
taskLogger.Printf("日志上传失败") taskLogger.Printf("日志上传失败")
} }

View File

@ -33,7 +33,7 @@ var PublisherMap = map[string]*PublisherValue{
ContentFormat: "text", ContentFormat: "text",
ImgNeed: 3, ImgNeed: 3,
Type: 1, Type: 1,
WordContainImg: false, WordContainImg: true,
}, },
"bjh": { "bjh": {
Name: "百家号", Name: "百家号",

View File

@ -407,13 +407,13 @@ func (p *XiaohongshuPublisher) clickPublish() error {
func (p *XiaohongshuPublisher) waitForPublishResult() (bool, string) { func (p *XiaohongshuPublisher) waitForPublishResult() (bool, string) {
p.LogInfo("等待发布结果...") p.LogInfo("等待发布结果...")
// 检查URL是否包含success // 检查URL是否包含success
for attempt := 0; attempt < 30; attempt++ {
info, err := p.Page.Info() info, err := p.Page.Info()
if err == nil && strings.Contains(info.URL, "success") { if err == nil && strings.Contains(info.URL, "success") {
p.LogInfo(fmt.Sprintf("发布成功URL包含success: %s", info.URL)) p.LogInfo(fmt.Sprintf("发布成功URL包含success: %s", info.URL))
return true, "发布成功" return true, "发布成功"
} }
for attempt := 0; attempt < 30; attempt++ {
// 检查是否出现失败提示 // 检查是否出现失败提示
exist, toastDiv, err := p.Page.Has(".creator-publish-toast") exist, toastDiv, err := p.Page.Has(".creator-publish-toast")
if err == nil && exist { if err == nil && exist {

View File

@ -320,29 +320,29 @@ func (p *ZhihuPublisher) clickPublish() error {
} }
p.LogInfo("已点击发布按钮") p.LogInfo("已点击发布按钮")
p.Sleep(3)
return nil return nil
} }
func (p *ZhihuPublisher) waitForPublishResult(timeout int) (bool, string) { func (p *ZhihuPublisher) waitForPublishResult(timeout int) (bool, string) {
p.LogInfo("等待发布结果...") p.LogInfo("等待发布结果...")
p.SleepMs(500)
startTime := time.Now() startTime := time.Now()
for time.Since(startTime) < time.Duration(timeout)*time.Second { for time.Since(startTime) < time.Duration(timeout)*time.Second {
currentURL := p.GetCurrentURL() currentURL := p.GetCurrentURL()
if !strings.Contains(currentURL, "/edit") {
p.LogInfo(fmt.Sprintf("发布成功URL: %s", currentURL))
return true, "发布成功"
}
// 检查失败弹窗 // 检查失败弹窗
exist, failedDiv, _ := p.Page.HasX(".Notification-textSection") exist, failedDiv, _ := p.Page.Has(".Notification-textSection")
if exist { if exist {
failedReason, _ := failedDiv.Text() failedReason, _ := failedDiv.Text()
p.LogInfo(fmt.Sprintf("发布失败: %s", failedReason)) p.LogInfo(fmt.Sprintf("发布失败: %s", failedReason))
return false, failedReason return false, failedReason
} }
if !strings.Contains(currentURL, "/edit") {
p.LogInfo(fmt.Sprintf("发布成功URL: %s", currentURL))
return true, "发布成功"
}
p.SleepMs(1000) p.SleepMs(1000)
} }

View File

@ -47,12 +47,14 @@ func (m *AppModule) Register(router fiber.Router) {
router.Post("/publish_off", vali(m.publishService.PublishOff, &entitys.PublishOffRequest{})) router.Post("/publish_off", vali(m.publishService.PublishOff, &entitys.PublishOffRequest{}))
router.Post("/publish_status", vali(m.publishService.PublishStatus, &entitys.PublishStatusRequest{})) router.Post("/publish_status", vali(m.publishService.PublishStatus, &entitys.PublishStatusRequest{}))
router.Post("/publish_execute_retry", vali(m.publishService.PublishExecuteRetry, &entitys.PublishExecuteRetryRequest{})) router.Post("/publish_execute_retry", vali(m.publishService.PublishExecuteRetry, &entitys.PublishExecuteRetryRequest{}))
router.Post("/publish_execute_retry/local", vali(m.publishService.PublishExecuteRetryLocal, &entitys.PublishExecuteRetryRequest{}))
router.Post("/get_publish_list", vali(m.publishService.GetPublishList, &entitys.GetPublishListRequest{})) router.Post("/get_publish_list", vali(m.publishService.GetPublishList, &entitys.GetPublishListRequest{}))
router.Post("/login_platform", vali(m.loginService.LoginPlatform, &entitys.LoginPlatformRequest{})) router.Post("/login_platform", vali(m.loginService.LoginPlatform, &entitys.LoginPlatformRequest{}))
router.Post("/logout_platform", vali(m.loginService.LogoutPlatform, &entitys.LogoutPlatformRequest{})) router.Post("/logout_platform", vali(m.loginService.LogoutPlatform, &entitys.LogoutPlatformRequest{}))
router.Get("/logs/:publish_id/:request_id", m.loginService.Log) router.Get("/logs/:publish_id/:request_id", m.loginService.Log)
router.Post("/product/add", vali(m.productService.Add, &entitys.CreateProductRequest{})) router.Post("/product/add", vali(m.productService.Add, &entitys.CreateProductRequest{}))
router.Post("/product/info/create/docx", m.productService.CreateProductInfoByDocx)
router.Post("/product/list", vali(m.productService.List, &entitys.ProductListRequest{})) router.Post("/product/list", vali(m.productService.List, &entitys.ProductListRequest{}))
router.Post("/product/detail", vali(m.productService.Detail, &entitys.ProductDetailRequest{})) router.Post("/product/detail", vali(m.productService.Detail, &entitys.ProductDetailRequest{}))
router.Post("/product/update", vali(m.productService.Update, &entitys.ProductUpdateRequest{})) router.Post("/product/update", vali(m.productService.Update, &entitys.ProductUpdateRequest{}))

View File

@ -57,10 +57,12 @@ func (a *AppService) LoginApp(c *fiber.Ctx, req *entitys.LoginAppRequest) error
And(builder.Eq{"status": 1}) And(builder.Eq{"status": 1})
tokenInfo := &model.Token{} tokenInfo := &model.Token{}
err := a.tokenImpl.GetOneBySearchStruct(c.UserContext(), &cond, tokenInfo) err := a.tokenImpl.GetOneBySearchStruct(c.UserContext(), &cond, tokenInfo)
if err != nil || tokenInfo == nil { if err != nil {
return errcode.Forbidden("密钥无效") return errcode.Forbidden("密钥无效")
} }
if tokenInfo.ID == 0 {
return errcode.NotFound("密钥无效")
}
accessToken := pkg.GenerateUUID() accessToken := pkg.GenerateUUID()
err = a.tokenImpl.UpdateByKey(c.UserContext(), a.tokenImpl.PrimaryKey(), tokenInfo.ID, &model.Token{ err = a.tokenImpl.UpdateByKey(c.UserContext(), a.tokenImpl.PrimaryKey(), tokenInfo.ID, &model.Token{
AccessToken: accessToken, AccessToken: accessToken,

View File

@ -1,6 +1,7 @@
package service package service
import ( import (
"geo/internal/ai_tool"
"geo/internal/biz" "geo/internal/biz"
"geo/internal/config" "geo/internal/config"
"geo/internal/data/impl" "geo/internal/data/impl"
@ -12,6 +13,8 @@ import (
"github.com/go-viper/mapstructure/v2" "github.com/go-viper/mapstructure/v2"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"io" "io"
"os"
"path/filepath"
"xorm.io/builder" "xorm.io/builder"
) )
@ -20,14 +23,16 @@ type ProductService struct {
productImpl *impl.ProductImpl productImpl *impl.ProductImpl
authBiz *biz.AuthBiz authBiz *biz.AuthBiz
productBiz *biz.ProductBiz productBiz *biz.ProductBiz
aiBiz *biz.AiBiz
} }
func NewProductService(cfg *config.Config, ProductImpl *impl.ProductImpl, authBiz *biz.AuthBiz, productBiz *biz.ProductBiz) *ProductService { func NewProductService(cfg *config.Config, ProductImpl *impl.ProductImpl, authBiz *biz.AuthBiz, productBiz *biz.ProductBiz, aiBiz *biz.AiBiz) *ProductService {
return &ProductService{ return &ProductService{
cfg: cfg, cfg: cfg,
productImpl: ProductImpl, productImpl: ProductImpl,
authBiz: authBiz, authBiz: authBiz,
productBiz: productBiz, productBiz: productBiz,
aiBiz: aiBiz,
} }
} }
@ -152,3 +157,77 @@ func (p *ProductService) ImgUpload(c *fiber.Ctx) error {
} }
return pkg.HandleResponse(c, fiber.Map{"url": url}) return pkg.HandleResponse(c, fiber.Map{"url": url})
} }
func (p *ProductService) CreateProductInfoByDocx(c *fiber.Ctx) error {
access := c.FormValue("access_token", "")
if access == "" {
return errcode.ParamErr("access_token未找到")
}
userIndex := c.FormValue("user_index", "")
if userIndex == "" {
return errcode.ParamErr("user_index未找到")
}
// 验证token
_, err := p.authBiz.ValidateAccessToken(c.UserContext(), access)
if err != nil {
return err
}
fileHeader, err := c.FormFile("file")
if err != nil {
return errcode.ParamErr("未找到上传文件")
}
file, err := fileHeader.Open()
if err != nil {
return errcode.ParamErrf("无法打开文件:%s", err.Error())
}
defer file.Close()
// 创建临时文件
tempFile, err := os.CreateTemp("", "upload_*.docx")
if err != nil {
return errcode.ParamErrf("创建临时文件失败:%s", err.Error())
}
defer tempFile.Close()
// 获取临时文件的绝对路径
absPath, err := filepath.Abs(tempFile.Name())
if err != nil {
return errcode.ParamErrf("获取文件绝对路径失败:%s", err.Error())
}
// 将上传的文件内容复制到临时文件
_, err = io.Copy(tempFile, file)
if err != nil {
return errcode.ParamErrf("保存文件失败:%s", err.Error())
}
// 确保文件内容已写入磁盘
err = tempFile.Sync()
if err != nil {
return errcode.ParamErrf("同步文件失败:%s", err.Error())
}
markContent, err := pkg.ExtractWordContent(absPath, "markdown")
if err != nil {
return err
}
var productInfo entitys.ProductInfo
mes := p.aiBiz.CreateProjectInfoPrompt(c.UserContext(), markContent)
err = ai_tool.NewHsyq().RequestHsyqBotToJson(c.UserContext(), p.cfg.Hsyq.ApiKey, p.cfg.AiBot.ProductInfo, mes, &productInfo)
if err != nil {
return err
}
return pkg.HandleResponse(c, productInfo)
}
func (p *ProductService) Collect(c *fiber.Ctx, req *entitys.ProductCollectRequest) error {
_, err := p.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
return pkg.HandleResponse(c, productInfo)
}

View File

@ -76,7 +76,7 @@ func (p *ProductSourceService) Create(c *fiber.Ctx, req *entitys.ProductSourceCr
BrandInfo: &brandInfo, BrandInfo: &brandInfo,
} }
mes := p.aiBiz.CreateArticlePrompt(c.UserContext(), BotChatMes) mes := p.aiBiz.CreateArticlePrompt(c.UserContext(), BotChatMes)
content, err := ai_tool.NewHsyq().RequestHsyqBot(c.UserContext(), p.cfg.Hsyq.ApiKey, "bot-20260413000114-8bw62", mes) content, err := ai_tool.NewHsyq().RequestHsyqBot(c.UserContext(), p.cfg.Hsyq.ApiKey, p.cfg.AiBot.Article, mes)
if err != nil { if err != nil {
return err return err
} }

View File

@ -25,6 +25,7 @@ type PublishService struct {
publishBiz *biz.PublishBiz publishBiz *biz.PublishBiz
db *utils.Db db *utils.Db
authBiz *biz.AuthBiz authBiz *biz.AuthBiz
aiBiz *biz.AiBiz
} }
func NewPublishService( func NewPublishService(
@ -32,12 +33,14 @@ func NewPublishService(
publishBiz *biz.PublishBiz, publishBiz *biz.PublishBiz,
authBiz *biz.AuthBiz, authBiz *biz.AuthBiz,
db *utils.Db, db *utils.Db,
aiBiz *biz.AiBiz,
) *PublishService { ) *PublishService {
return &PublishService{ return &PublishService{
cfg: cfg, cfg: cfg,
publishBiz: publishBiz, publishBiz: publishBiz,
db: db, db: db,
authBiz: authBiz, authBiz: authBiz,
aiBiz: aiBiz,
} }
} }
@ -150,6 +153,17 @@ func (s *PublishService) PublishStatus(c *fiber.Ctx, req *entitys.PublishStatusR
//} //}
func (s *PublishService) PublishExecuteRetry(c *fiber.Ctx, req *entitys.PublishExecuteRetryRequest) error { func (s *PublishService) PublishExecuteRetry(c *fiber.Ctx, req *entitys.PublishExecuteRetryRequest) error {
_, err := s.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil {
return err
}
err = s.publishBiz.UpdatePublishStatus(c.UserContext(), req.RequestID, 1, "")
return err
}
func (s *PublishService) PublishExecuteRetryLocal(c *fiber.Ctx, req *entitys.PublishExecuteRetryRequest) error {
tokenInfo, err := s.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) tokenInfo, err := s.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken)
if err != nil { if err != nil {
return err return err

172
pkg/Multi.go Normal file
View File

@ -0,0 +1,172 @@
package pkg
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"time"
)
// PostMultipart 发送 multipart/form-data 请求
// url: 请求地址
// data: map[string]interface{} 类型的数据,支持以下类型:
// - 普通字段string, int, float64, bool 等
// - 文件字段:*os.File, []byte, 或实现了 io.Reader 接口的类型,需配合文件名使用
// - 文件路径:使用 "file_path" 字段标记map[string]interface{}{"field_name": map[string]interface{}{"file_path": "/path/to/file"}}
// - 简化文件map[string]interface{}{"field_name": map[string]interface{}{"file": fileObj, "filename": "custom.png"}}
//
// result: 响应 JSON 将解析到此指针对象
func PostMultipart(url string, data map[string]interface{}, result interface{}) error {
// 创建缓冲区和 multipart writer
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// 遍历所有字段
for fieldName, fieldValue := range data {
switch v := fieldValue.(type) {
case string:
// 普通字符串字段
if err := writer.WriteField(fieldName, v); err != nil {
writer.Close()
return fmt.Errorf("写入字段 %s 失败: %w", fieldName, err)
}
case int, int32, int64, float32, float64, bool:
// 基本类型转字符串
if err := writer.WriteField(fieldName, fmt.Sprintf("%v", v)); err != nil {
writer.Close()
return fmt.Errorf("写入字段 %s 失败: %w", fieldName, err)
}
case *os.File:
// *os.File 类型,自动获取文件名
if err := addFilePart(writer, fieldName, v.Name(), v); err != nil {
writer.Close()
return err
}
case []byte:
// 字节数组,需要提供文件名
return fmt.Errorf("字段 %s 为 []byte 类型,请使用 map[string]interface{} 格式提供文件名: {\"data\": 字节内容, \"filename\": \"文件名\"}", fieldName)
case map[string]interface{}:
// 复杂文件描述
if err := handleFileMap(writer, fieldName, v); err != nil {
writer.Close()
return err
}
default:
// 其他类型转为字符串
if err := writer.WriteField(fieldName, fmt.Sprintf("%v", v)); err != nil {
writer.Close()
return fmt.Errorf("写入字段 %s 失败: %w", fieldName, err)
}
}
}
// 关闭 writer 以写入结束边界
if err := writer.Close(); err != nil {
return fmt.Errorf("关闭 multipart writer 失败: %w", err)
}
// 创建请求
req, err := http.NewRequest("POST", url, body)
if err != nil {
return fmt.Errorf("创建请求失败: %w", err)
}
// 设置 Content-Type包含边界分隔符
req.Header.Set("Content-Type", writer.FormDataContentType())
// 发送请求
client := &http.Client{
Timeout: 30 * time.Second,
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("读取响应失败: %w", err)
}
// 检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("HTTP 错误: %d, 响应: %s", resp.StatusCode, string(respBody))
}
// 解析 JSON 到 result 指针
if err := json.Unmarshal(respBody, result); err != nil {
return fmt.Errorf("JSON 解析失败: %w, 原始响应: %s", err, string(respBody))
}
return nil
}
// addFilePart 添加文件部分
func addFilePart(writer *multipart.Writer, fieldName, filename string, reader io.Reader) error {
part, err := writer.CreateFormFile(fieldName, filename)
if err != nil {
return fmt.Errorf("创建文件字段 %s 失败: %w", fieldName, err)
}
if _, err := io.Copy(part, reader); err != nil {
return fmt.Errorf("复制文件 %s 内容失败: %w", fieldName, err)
}
return nil
}
// handleFileMap 处理文件映射
func handleFileMap(writer *multipart.Writer, fieldName string, fileMap map[string]interface{}) error {
// 支持两种格式:
// 1. {"file_path": "/path/to/file"}
// 2. {"file": io.Reader, "filename": "custom_name.ext"}
// 格式1: 文件路径
if filePath, ok := fileMap["file_path"].(string); ok {
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("打开文件 %s 失败: %w", filePath, err)
}
defer file.Close()
return addFilePart(writer, fieldName, filepath.Base(filePath), file)
}
// 格式2: 直接提供文件和文件名
if fileReader, ok := fileMap["file"]; ok {
// 获取文件名
filename := "file"
if fn, ok := fileMap["filename"].(string); ok && fn != "" {
filename = fn
}
var reader io.Reader
switch r := fileReader.(type) {
case io.Reader:
reader = r
case []byte:
reader = bytes.NewReader(r)
case *os.File:
reader = r
if filename == "file" {
filename = filepath.Base(r.Name())
}
default:
return fmt.Errorf("文件字段 %s 不支持的类型: %T", fieldName, fileReader)
}
return addFilePart(writer, fieldName, filename, reader)
}
return fmt.Errorf("文件字段 %s 格式错误,需要 file_path 或 file+filename", fieldName)
}