From 5fb8e3d2ee4e877bc8a4f42e631743eb43e3ea19 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 23 Apr 2026 02:58:25 +0800 Subject: [PATCH] 1 --- cmd/server/wire_gen.go | 7 +- internal/ai_tool/collect.go | 120 +++++++++++++++++++ internal/ai_tool/hsyq.go | 23 ++++ internal/biz/ai.go | 14 ++- internal/config/config.go | 32 ++++-- internal/entitys/product.go | 16 +++ internal/entitys/request.go | 7 ++ internal/manager/publish_manager.go | 3 +- internal/publisher/interface.go | 2 +- internal/publisher/xhs.go | 10 +- internal/publisher/zh.go | 16 +-- internal/server/router/app.go | 2 + internal/service/app.go | 6 +- internal/service/product.go | 81 ++++++++++++- internal/service/product_source.go | 2 +- internal/service/publish.go | 14 +++ pkg/Multi.go | 172 ++++++++++++++++++++++++++++ 17 files changed, 496 insertions(+), 31 deletions(-) create mode 100644 internal/ai_tool/collect.go create mode 100644 internal/entitys/product.go create mode 100644 pkg/Multi.go diff --git a/cmd/server/wire_gen.go b/cmd/server/wire_gen.go index 87350ec..a53134c 100644 --- a/cmd/server/wire_gen.go +++ b/cmd/server/wire_gen.go @@ -40,10 +40,11 @@ func InitializeApp(configConfig *config.Config, allLogger log.AllLogger) (*serve authBiz := biz.NewAuthBiz(configConfig, tokenImpl, userImpl) appService := service.NewAppService(configConfig, tokenImpl, userImpl, platImpl, publishBiz, authBiz, loginRelationImpl, productBiz) loginService := service.NewLoginService(configConfig, publishBiz, authBiz) - publishService := service.NewPublishService(configConfig, publishBiz, authBiz, db) - - productService := service.NewProductService(configConfig, productImpl, authBiz, productBiz) aiBiz := biz.NewAiBiz(platImpl, articleTypeImpl) + publishService := service.NewPublishService(configConfig, publishBiz, authBiz, db, aiBiz) + + productService := service.NewProductService(configConfig, productImpl, authBiz, productBiz, aiBiz) + productSourceService := service.NewProductSourceService(configConfig, productImpl, authBiz, aiBiz, productBiz, productSourceImpl, publishBiz, articleTypeImpl) appModule := router.NewAppModule(configConfig, appService, loginService, publishService, productService, productSourceService) routerServer := router.NewRouterServer(appModule) diff --git a/internal/ai_tool/collect.go b/internal/ai_tool/collect.go new file mode 100644 index 0000000..4116cd8 --- /dev/null +++ b/internal/ai_tool/collect.go @@ -0,0 +1,120 @@ +package ai_tool + +import "geo/pkg" + +type Collect struct { + apikey string +} + +func NewCollect(apikey string) *Collect { + return &Collect{apikey: apikey} +} + +type PlatForm struct { + Index int `json:"index"` + PlatFormName string `json:"name"` + Icon string `json:"icon"` +} + +var PlatFormList = []PlatForm{ + { + Index: 1, + PlatFormName: "deepseek", + Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/deepseek.png", + }, + { + Index: 2, + PlatFormName: "豆包", + Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/doubao.png", + }, + { + Index: 3, + PlatFormName: "元宝", + Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/yuanbao.png", + }, + { + Index: 41, + PlatFormName: "千问", + Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/qianwen.png", + }, + { + Index: 5, + PlatFormName: "文心一言", + Icon: "https://attachment-public.oss-cn-hangzhou.aliyuncs.com/geo/platform/wenxin.png", + }, + //{ + // Index: 6, + // PlatFormName: "纳米", + //}, + //{ + // Index: 7, + // PlatFormName: "kimi", + //}, { + // Index: 8, + // PlatFormName: "智普", + //}, +} + +type CreateReq struct { + // apikey + APIKey string `json:"api_key"` + // 品牌词,多个用英文逗号隔开 + Keywords string `json:"keywords"` + // 平台,1-deepseek,2-豆包,3-元宝,4-千问,5-文心一言,6-纳米,7-kimi,8-智普 + Platform int64 `json:"platform"` + // 问题 + Question string `json:"question"` + // 建议填第三方的用户id。方便查单 + ThirdID *string `json:"third_id,omitempty"` +} + +type CreateRes struct { + Code int `json:"code"` + Msg string `json:"msg"` + Time string `json:"time"` + Data struct { + RequestId string `json:"request_id"` + } `json:"data"` +} + +func (s *Collect) Create(data *CreateReq) (*CreateRes, error) { + url := "http://8.138.187.158:8082/api/geo/add_shoulu" + data.APIKey = s.apikey + mapInfo, err := pkg.StructToMap(data) + if err != nil { + return nil, err + } + var res CreateRes + err = pkg.PostMultipart(url, mapInfo, &res) + return &res, err +} + +type CheckTaskRes struct { + Code int `json:"code"` + Msg string `json:"msg"` + Time string `json:"time"` + Data struct { + ShouluDate string `json:"shoulu_date"` + Platform int `json:"platform"` + ScriptTime int `json:"script_time"` + Question string `json:"question"` + HitWord string `json:"hit_word"` + Status int `json:"status"` + ShareUrl string `json:"share_url"` + KeywordRes string `json:"keyword_res"` + ImgUrl string `json:"img_url"` + RequestId string `json:"request_id"` + } `json:"data"` +} + +func (s *Collect) CheckTask(requestId string) (*CheckTaskRes, error) { + url := "http://8.138.187.158:8082/api/geo/check_task" + request := map[string]interface{}{ + "request_id": requestId, + "api_key": s.apikey, + } + + var res CheckTaskRes + err := pkg.PostMultipart(url, request, &res) + return &res, err +} diff --git a/internal/ai_tool/hsyq.go b/internal/ai_tool/hsyq.go index fa2ed15..457b8e1 100644 --- a/internal/ai_tool/hsyq.go +++ b/internal/ai_tool/hsyq.go @@ -2,6 +2,9 @@ package ai_tool import ( "context" + "encoding/json" + "geo/pkg" + "geo/tmpl/errcode" "sync" "time" @@ -166,3 +169,23 @@ func (h *Hsyq) RequestHsyqBot(ctx context.Context, key string, botId string, mes log.Info("token用量:", resp.Usage.TotalTokens) return resp.Choices[0].Message.Content.StringValue, nil } + +func (h *Hsyq) RequestHsyqBotToJson(ctx context.Context, key string, botId string, message []*model.ChatCompletionMessage, point interface{}) error { + + content, err := h.RequestHsyqBot(ctx, key, botId, message) + if err != nil { + return err + } + contentByte := []byte(*content) + + if err := json.Unmarshal(contentByte, &point); err != nil { + contentStr, err := pkg.JsonRepair(*content) + if err != nil { + return errcode.SysErr("生成失败,请重试") + } + if err := json.Unmarshal([]byte(contentStr), point); err != nil { + return errcode.SysErr("生成失败,请重试") + } + } + return nil +} diff --git a/internal/biz/ai.go b/internal/biz/ai.go index 2c34dda..2e425e7 100644 --- a/internal/biz/ai.go +++ b/internal/biz/ai.go @@ -31,7 +31,7 @@ func (a *AiBiz) CreateArticlePrompt(ctx context.Context, data *entitys.BotChat) }, }, } - return mes + //var plats []*model.Plat //cond := builder.NewCond(). // And(builder.Eq{"plat_type": 1}). @@ -85,3 +85,15 @@ func (a *AiBiz) CreateArticlePrompt(ctx context.Context, data *entitys.BotChat) //} return mes } + +func (a *AiBiz) CreateProjectInfoPrompt(ctx context.Context, content string) []*volmodle.ChatCompletionMessage { + mes := []*volmodle.ChatCompletionMessage{ + { + Role: volmodle.ChatMessageRoleUser, + Content: &volmodle.ChatCompletionMessageContent{ + StringValue: volcengine.String(pkg.JsonStringIgonErr(content)), + }, + }, + } + return mes +} diff --git a/internal/config/config.go b/internal/config/config.go index 59bf931..4102335 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,11 +7,22 @@ import ( // Config 应用配置 type Config struct { - Server ServerConfig `mapstructure:"server"` - DB DB `mapstructure:"db"` - Sys Sys `mapstructure:"sys"` - Hsyq Hsyq `mapstructure:"hsyq"` - Oss Oss `mapstructure:"oss"` + Server ServerConfig `mapstructure:"server"` + DB DB `mapstructure:"db"` + Sys Sys `mapstructure:"sys"` + Hsyq Hsyq `mapstructure:"hsyq"` + Oss Oss `mapstructure:"oss"` + AiBot AiBot `mapstructure:"ai_bot"` + collect Collect `mapstructure:"collect"` +} + +type Collect struct { + ApiKey string `mapstructure:"api_key"` +} + +type AiBot struct { + Article string `mapstructure:"article"` + ProductInfo string `mapstructure:"product_info"` } type Oss struct { @@ -69,10 +80,10 @@ func LoadConfig() (*Config, error) { }, //root:lansexiongdi6,@tcp(47.97.27.195:3306)/geo?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai DB: DB{ Driver: "mysql", - Source: "root:Renzhiyuan123.@tcp(8.137.19.81:3306)/geo?charset=utf8mb4&parseTime=true&loc=Local", + Source: "root:Renzhiyuan123.@tcp(8.137.19.81:3306)/geo?charset=utf8mb4&parseTime=true", }, Sys: Sys{ - AutoPublishWorkers: 5, + AutoPublishWorkers: 2, MaxConcurrent: 1, TaskTimeout: 200, SessionTimeout: 300, @@ -98,5 +109,12 @@ func LoadConfig() (*Config, error) { Endpoint: "https://oss-cn-hangzhou.aliyuncs.com", FilePath: "geo/", }, + AiBot: AiBot{ + Article: "bot-20260413000114-8bw62", + ProductInfo: "bot-20260422010906-hvtbd", + }, + Collect: Collect{ + ApiKey: "sk_7bac5df901aa8933a238fcfec363f4a0", + }, }, nil } diff --git a/internal/entitys/product.go b/internal/entitys/product.go new file mode 100644 index 0000000..c5b3247 --- /dev/null +++ b/internal/entitys/product.go @@ -0,0 +1,16 @@ +package entitys + +type ProductInfo struct { + Name string `json:"name" validate:"required" zh:"产品名称"` + Industry string `json:"industry" validate:"required" zh:"所属行业"` + Type string `json:"type" validate:"required" zh:"产品类型"` + ProductOrService string `json:"product_or_service" validate:"required" zh:"主营业务"` + Advantages string `json:"advantages" zh:"核心优势"` + Story string `json:"story" zh:"发展故事"` + Problem string `json:"problem" zh:"解决痛点"` + Background string `json:"background" zh:"信任背书"` + Case string `json:"case" zh:"品牌案例"` + Other string `json:"other" zh:"其他信息"` + ServiceScope string `json:"service_scope" zh:"服务范围"` + TargetAudience string `json:"target_audience" zh:"目标客户群体"` +} diff --git a/internal/entitys/request.go b/internal/entitys/request.go index 376cb91..fd38f87 100644 --- a/internal/entitys/request.go +++ b/internal/entitys/request.go @@ -194,4 +194,11 @@ type ( Plat []string `json:"plat" validate:"required" zh:"平台"` PublishTime string `json:"publish_time" validate:"required" zh:"发布时间"` } + + ProductCollectRequest struct { + AccessToken string `json:"access_token" validate:"required" zh:"access_token"` + Keywords []string `json:"keywords" validate:"required" zh:"关键词"` + Platform []int64 `json:"platform" validate:"required" zh:"平台"` + Question string `json:"question" validate:"required" zh:"问题"` + } ) diff --git a/internal/manager/publish_manager.go b/internal/manager/publish_manager.go index 500e59f..780009d 100644 --- a/internal/manager/publish_manager.go +++ b/internal/manager/publish_manager.go @@ -370,9 +370,8 @@ func (pm *PublishManager) processTask(ctx context.Context, publishData *entitys. taskLogger.Printf("任务结束 | RequestID: %s | 结果: %v", publishData.RequestID, success) res := &SingleResult{Success: success, Message: message, RequestId: publishData.RequestID} - pm.uploadToOss(ctx, logPath, fmt.Sprintf("%slog/%d_%s.log", pm.Conf.Oss.FilePath, publishData.ID, publishData.RequestID)) - url, err := pm.uploadToOss(ctx, logPath, fmt.Sprintf("%s%d_%s.log", pm.Conf.Oss.FilePath, publishData.ID, publishData.RequestID)) + url, err := pm.uploadToOss(ctx, logPath, fmt.Sprintf("%slog/%d_%s.log", pm.Conf.Oss.FilePath, publishData.ID, publishData.RequestID)) if err != nil { taskLogger.Printf("日志上传失败") } diff --git a/internal/publisher/interface.go b/internal/publisher/interface.go index cfb3932..7458a5a 100644 --- a/internal/publisher/interface.go +++ b/internal/publisher/interface.go @@ -33,7 +33,7 @@ var PublisherMap = map[string]*PublisherValue{ ContentFormat: "text", ImgNeed: 3, Type: 1, - WordContainImg: false, + WordContainImg: true, }, "bjh": { Name: "百家号", diff --git a/internal/publisher/xhs.go b/internal/publisher/xhs.go index 9240674..a944d00 100644 --- a/internal/publisher/xhs.go +++ b/internal/publisher/xhs.go @@ -407,13 +407,13 @@ func (p *XiaohongshuPublisher) clickPublish() error { func (p *XiaohongshuPublisher) waitForPublishResult() (bool, string) { p.LogInfo("等待发布结果...") // 检查URL是否包含success - info, err := p.Page.Info() - if err == nil && strings.Contains(info.URL, "success") { - p.LogInfo(fmt.Sprintf("发布成功,URL包含success: %s", info.URL)) - return true, "发布成功" - } for attempt := 0; attempt < 30; attempt++ { + info, err := p.Page.Info() + if err == nil && strings.Contains(info.URL, "success") { + p.LogInfo(fmt.Sprintf("发布成功,URL包含success: %s", info.URL)) + return true, "发布成功" + } // 检查是否出现失败提示 exist, toastDiv, err := p.Page.Has(".creator-publish-toast") if err == nil && exist { diff --git a/internal/publisher/zh.go b/internal/publisher/zh.go index 6dbe57b..0399dca 100644 --- a/internal/publisher/zh.go +++ b/internal/publisher/zh.go @@ -320,29 +320,29 @@ func (p *ZhihuPublisher) clickPublish() error { } p.LogInfo("已点击发布按钮") - p.Sleep(3) + return nil } func (p *ZhihuPublisher) waitForPublishResult(timeout int) (bool, string) { p.LogInfo("等待发布结果...") - + p.SleepMs(500) startTime := time.Now() for time.Since(startTime) < time.Duration(timeout)*time.Second { currentURL := p.GetCurrentURL() - - if !strings.Contains(currentURL, "/edit") { - p.LogInfo(fmt.Sprintf("发布成功!URL: %s", currentURL)) - return true, "发布成功" - } // 检查失败弹窗 - exist, failedDiv, _ := p.Page.HasX(".Notification-textSection") + exist, failedDiv, _ := p.Page.Has(".Notification-textSection") if exist { failedReason, _ := failedDiv.Text() p.LogInfo(fmt.Sprintf("发布失败: %s", failedReason)) return false, failedReason } + if !strings.Contains(currentURL, "/edit") { + p.LogInfo(fmt.Sprintf("发布成功!URL: %s", currentURL)) + return true, "发布成功" + } + p.SleepMs(1000) } diff --git a/internal/server/router/app.go b/internal/server/router/app.go index 763232c..ea53612 100644 --- a/internal/server/router/app.go +++ b/internal/server/router/app.go @@ -47,12 +47,14 @@ func (m *AppModule) Register(router fiber.Router) { router.Post("/publish_off", vali(m.publishService.PublishOff, &entitys.PublishOffRequest{})) router.Post("/publish_status", vali(m.publishService.PublishStatus, &entitys.PublishStatusRequest{})) router.Post("/publish_execute_retry", vali(m.publishService.PublishExecuteRetry, &entitys.PublishExecuteRetryRequest{})) + router.Post("/publish_execute_retry/local", vali(m.publishService.PublishExecuteRetryLocal, &entitys.PublishExecuteRetryRequest{})) router.Post("/get_publish_list", vali(m.publishService.GetPublishList, &entitys.GetPublishListRequest{})) router.Post("/login_platform", vali(m.loginService.LoginPlatform, &entitys.LoginPlatformRequest{})) router.Post("/logout_platform", vali(m.loginService.LogoutPlatform, &entitys.LogoutPlatformRequest{})) router.Get("/logs/:publish_id/:request_id", m.loginService.Log) router.Post("/product/add", vali(m.productService.Add, &entitys.CreateProductRequest{})) + router.Post("/product/info/create/docx", m.productService.CreateProductInfoByDocx) router.Post("/product/list", vali(m.productService.List, &entitys.ProductListRequest{})) router.Post("/product/detail", vali(m.productService.Detail, &entitys.ProductDetailRequest{})) router.Post("/product/update", vali(m.productService.Update, &entitys.ProductUpdateRequest{})) diff --git a/internal/service/app.go b/internal/service/app.go index ba9fb70..dc1ad4d 100644 --- a/internal/service/app.go +++ b/internal/service/app.go @@ -57,10 +57,12 @@ func (a *AppService) LoginApp(c *fiber.Ctx, req *entitys.LoginAppRequest) error And(builder.Eq{"status": 1}) tokenInfo := &model.Token{} err := a.tokenImpl.GetOneBySearchStruct(c.UserContext(), &cond, tokenInfo) - if err != nil || tokenInfo == nil { + if err != nil { return errcode.Forbidden("密钥无效") } - + if tokenInfo.ID == 0 { + return errcode.NotFound("密钥无效") + } accessToken := pkg.GenerateUUID() err = a.tokenImpl.UpdateByKey(c.UserContext(), a.tokenImpl.PrimaryKey(), tokenInfo.ID, &model.Token{ AccessToken: accessToken, diff --git a/internal/service/product.go b/internal/service/product.go index 8f9e4a4..c4c4a66 100644 --- a/internal/service/product.go +++ b/internal/service/product.go @@ -1,6 +1,7 @@ package service import ( + "geo/internal/ai_tool" "geo/internal/biz" "geo/internal/config" "geo/internal/data/impl" @@ -12,6 +13,8 @@ import ( "github.com/go-viper/mapstructure/v2" "github.com/gofiber/fiber/v2" "io" + "os" + "path/filepath" "xorm.io/builder" ) @@ -20,14 +23,16 @@ type ProductService struct { productImpl *impl.ProductImpl authBiz *biz.AuthBiz productBiz *biz.ProductBiz + aiBiz *biz.AiBiz } -func NewProductService(cfg *config.Config, ProductImpl *impl.ProductImpl, authBiz *biz.AuthBiz, productBiz *biz.ProductBiz) *ProductService { +func NewProductService(cfg *config.Config, ProductImpl *impl.ProductImpl, authBiz *biz.AuthBiz, productBiz *biz.ProductBiz, aiBiz *biz.AiBiz) *ProductService { return &ProductService{ cfg: cfg, productImpl: ProductImpl, authBiz: authBiz, productBiz: productBiz, + aiBiz: aiBiz, } } @@ -152,3 +157,77 @@ func (p *ProductService) ImgUpload(c *fiber.Ctx) error { } return pkg.HandleResponse(c, fiber.Map{"url": url}) } + +func (p *ProductService) CreateProductInfoByDocx(c *fiber.Ctx) error { + access := c.FormValue("access_token", "") + if access == "" { + return errcode.ParamErr("access_token未找到") + } + userIndex := c.FormValue("user_index", "") + if userIndex == "" { + return errcode.ParamErr("user_index未找到") + } + // 验证token + _, err := p.authBiz.ValidateAccessToken(c.UserContext(), access) + if err != nil { + return err + } + + fileHeader, err := c.FormFile("file") + if err != nil { + return errcode.ParamErr("未找到上传文件") + } + + file, err := fileHeader.Open() + if err != nil { + return errcode.ParamErrf("无法打开文件:%s", err.Error()) + } + defer file.Close() + + // 创建临时文件 + tempFile, err := os.CreateTemp("", "upload_*.docx") + if err != nil { + return errcode.ParamErrf("创建临时文件失败:%s", err.Error()) + } + defer tempFile.Close() + + // 获取临时文件的绝对路径 + absPath, err := filepath.Abs(tempFile.Name()) + if err != nil { + return errcode.ParamErrf("获取文件绝对路径失败:%s", err.Error()) + } + + // 将上传的文件内容复制到临时文件 + _, err = io.Copy(tempFile, file) + if err != nil { + return errcode.ParamErrf("保存文件失败:%s", err.Error()) + } + + // 确保文件内容已写入磁盘 + err = tempFile.Sync() + if err != nil { + return errcode.ParamErrf("同步文件失败:%s", err.Error()) + } + markContent, err := pkg.ExtractWordContent(absPath, "markdown") + if err != nil { + return err + } + + var productInfo entitys.ProductInfo + mes := p.aiBiz.CreateProjectInfoPrompt(c.UserContext(), markContent) + err = ai_tool.NewHsyq().RequestHsyqBotToJson(c.UserContext(), p.cfg.Hsyq.ApiKey, p.cfg.AiBot.ProductInfo, mes, &productInfo) + if err != nil { + return err + } + + return pkg.HandleResponse(c, productInfo) +} + +func (p *ProductService) Collect(c *fiber.Ctx, req *entitys.ProductCollectRequest) error { + _, err := p.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + return pkg.HandleResponse(c, productInfo) +} diff --git a/internal/service/product_source.go b/internal/service/product_source.go index 2d377c9..de68bd3 100644 --- a/internal/service/product_source.go +++ b/internal/service/product_source.go @@ -76,7 +76,7 @@ func (p *ProductSourceService) Create(c *fiber.Ctx, req *entitys.ProductSourceCr BrandInfo: &brandInfo, } mes := p.aiBiz.CreateArticlePrompt(c.UserContext(), BotChatMes) - content, err := ai_tool.NewHsyq().RequestHsyqBot(c.UserContext(), p.cfg.Hsyq.ApiKey, "bot-20260413000114-8bw62", mes) + content, err := ai_tool.NewHsyq().RequestHsyqBot(c.UserContext(), p.cfg.Hsyq.ApiKey, p.cfg.AiBot.Article, mes) if err != nil { return err } diff --git a/internal/service/publish.go b/internal/service/publish.go index d371177..aa8b4fa 100644 --- a/internal/service/publish.go +++ b/internal/service/publish.go @@ -25,6 +25,7 @@ type PublishService struct { publishBiz *biz.PublishBiz db *utils.Db authBiz *biz.AuthBiz + aiBiz *biz.AiBiz } func NewPublishService( @@ -32,12 +33,14 @@ func NewPublishService( publishBiz *biz.PublishBiz, authBiz *biz.AuthBiz, db *utils.Db, + aiBiz *biz.AiBiz, ) *PublishService { return &PublishService{ cfg: cfg, publishBiz: publishBiz, db: db, authBiz: authBiz, + aiBiz: aiBiz, } } @@ -150,6 +153,17 @@ func (s *PublishService) PublishStatus(c *fiber.Ctx, req *entitys.PublishStatusR //} func (s *PublishService) PublishExecuteRetry(c *fiber.Ctx, req *entitys.PublishExecuteRetryRequest) error { + _, err := s.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) + if err != nil { + return err + } + + err = s.publishBiz.UpdatePublishStatus(c.UserContext(), req.RequestID, 1, "") + + return err +} + +func (s *PublishService) PublishExecuteRetryLocal(c *fiber.Ctx, req *entitys.PublishExecuteRetryRequest) error { tokenInfo, err := s.authBiz.ValidateAccessToken(c.UserContext(), req.AccessToken) if err != nil { return err diff --git a/pkg/Multi.go b/pkg/Multi.go new file mode 100644 index 0000000..5178466 --- /dev/null +++ b/pkg/Multi.go @@ -0,0 +1,172 @@ +package pkg + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "time" +) + +// PostMultipart 发送 multipart/form-data 请求 +// url: 请求地址 +// data: map[string]interface{} 类型的数据,支持以下类型: +// - 普通字段:string, int, float64, bool 等 +// - 文件字段:*os.File, []byte, 或实现了 io.Reader 接口的类型,需配合文件名使用 +// - 文件路径:使用 "file_path" 字段标记,如:map[string]interface{}{"field_name": map[string]interface{}{"file_path": "/path/to/file"}} +// - 简化文件:map[string]interface{}{"field_name": map[string]interface{}{"file": fileObj, "filename": "custom.png"}} +// +// result: 响应 JSON 将解析到此指针对象 +func PostMultipart(url string, data map[string]interface{}, result interface{}) error { + // 创建缓冲区和 multipart writer + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // 遍历所有字段 + for fieldName, fieldValue := range data { + switch v := fieldValue.(type) { + case string: + // 普通字符串字段 + if err := writer.WriteField(fieldName, v); err != nil { + writer.Close() + return fmt.Errorf("写入字段 %s 失败: %w", fieldName, err) + } + + case int, int32, int64, float32, float64, bool: + // 基本类型转字符串 + if err := writer.WriteField(fieldName, fmt.Sprintf("%v", v)); err != nil { + writer.Close() + return fmt.Errorf("写入字段 %s 失败: %w", fieldName, err) + } + + case *os.File: + // *os.File 类型,自动获取文件名 + if err := addFilePart(writer, fieldName, v.Name(), v); err != nil { + writer.Close() + return err + } + + case []byte: + // 字节数组,需要提供文件名 + return fmt.Errorf("字段 %s 为 []byte 类型,请使用 map[string]interface{} 格式提供文件名: {\"data\": 字节内容, \"filename\": \"文件名\"}", fieldName) + + case map[string]interface{}: + // 复杂文件描述 + if err := handleFileMap(writer, fieldName, v); err != nil { + writer.Close() + return err + } + + default: + // 其他类型转为字符串 + if err := writer.WriteField(fieldName, fmt.Sprintf("%v", v)); err != nil { + writer.Close() + return fmt.Errorf("写入字段 %s 失败: %w", fieldName, err) + } + } + } + + // 关闭 writer 以写入结束边界 + if err := writer.Close(); err != nil { + return fmt.Errorf("关闭 multipart writer 失败: %w", err) + } + + // 创建请求 + req, err := http.NewRequest("POST", url, body) + if err != nil { + return fmt.Errorf("创建请求失败: %w", err) + } + + // 设置 Content-Type,包含边界分隔符 + req.Header.Set("Content-Type", writer.FormDataContentType()) + + // 发送请求 + client := &http.Client{ + Timeout: 30 * time.Second, + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("请求失败: %w", err) + } + defer resp.Body.Close() + + // 读取响应体 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("读取响应失败: %w", err) + } + + // 检查 HTTP 状态码 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("HTTP 错误: %d, 响应: %s", resp.StatusCode, string(respBody)) + } + + // 解析 JSON 到 result 指针 + if err := json.Unmarshal(respBody, result); err != nil { + return fmt.Errorf("JSON 解析失败: %w, 原始响应: %s", err, string(respBody)) + } + + return nil +} + +// addFilePart 添加文件部分 +func addFilePart(writer *multipart.Writer, fieldName, filename string, reader io.Reader) error { + part, err := writer.CreateFormFile(fieldName, filename) + if err != nil { + return fmt.Errorf("创建文件字段 %s 失败: %w", fieldName, err) + } + + if _, err := io.Copy(part, reader); err != nil { + return fmt.Errorf("复制文件 %s 内容失败: %w", fieldName, err) + } + return nil +} + +// handleFileMap 处理文件映射 +func handleFileMap(writer *multipart.Writer, fieldName string, fileMap map[string]interface{}) error { + // 支持两种格式: + // 1. {"file_path": "/path/to/file"} + // 2. {"file": io.Reader, "filename": "custom_name.ext"} + + // 格式1: 文件路径 + if filePath, ok := fileMap["file_path"].(string); ok { + file, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("打开文件 %s 失败: %w", filePath, err) + } + defer file.Close() + return addFilePart(writer, fieldName, filepath.Base(filePath), file) + } + + // 格式2: 直接提供文件和文件名 + if fileReader, ok := fileMap["file"]; ok { + // 获取文件名 + filename := "file" + if fn, ok := fileMap["filename"].(string); ok && fn != "" { + filename = fn + } + + var reader io.Reader + switch r := fileReader.(type) { + case io.Reader: + reader = r + case []byte: + reader = bytes.NewReader(r) + case *os.File: + reader = r + if filename == "file" { + filename = filepath.Base(r.Name()) + } + default: + return fmt.Errorf("文件字段 %s 不支持的类型: %T", fieldName, fileReader) + } + + return addFilePart(writer, fieldName, filename, reader) + } + + return fmt.Errorf("文件字段 %s 格式错误,需要 file_path 或 file+filename", fieldName) +}