diff --git a/cmd/server/main.go b/cmd/server/main.go index d765735..baca30d 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -13,6 +13,7 @@ func main() { configPath := flag.String("config", "./config/config_test.yaml", "Path to configuration file") onBot := flag.String("bot", "", "bot start") flag.Parse() + ctx := context.Background() bc, err := config.LoadConfig(*configPath) if err != nil { log.Fatalf("加载配置失败: %v", err) @@ -25,7 +26,9 @@ func main() { defer func() { cleanup() }() - app.DingBotServer.Run(context.Background(), *onBot) - + //钉钉机器人 + app.DingBotServer.Run(ctx, *onBot) + //定时任务 + app.Cron.Run(ctx) log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) } diff --git a/go.mod b/go.mod index 694ee17..fa8498d 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 github.com/gabriel-vasile/mimetype v1.4.11 - github.com/go-kratos/kratos/v2 v2.9.1 + github.com/go-kratos/kratos/v2 v2.9.2 github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.20.0 @@ -39,6 +39,7 @@ require ( ) require ( + dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.5 // indirect github.com/alibabacloud-go/debug v1.0.1 // indirect @@ -90,6 +91,7 @@ require ( github.com/richardlehane/mscfb v1.0.4 // indirect github.com/richardlehane/msoleps v1.0.4 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sagikazarmark/locafero v0.3.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect diff --git a/go.sum b/go.sum index b753e6a..b0d2c51 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= +dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= +dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= @@ -193,6 +195,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-kratos/kratos/v2 v2.9.1 h1:EGif6/S/aK/RCR5clIbyhioTNyoSrii3FC118jG40Z0= github.com/go-kratos/kratos/v2 v2.9.1/go.mod h1:a1MQLjMhIh7R0kcJS9SzJYR43BRI7EPzzN0J1Ksu2bA= +github.com/go-kratos/kratos/v2 v2.9.2 h1:px8GJQBeLpquDKQWQ9zohEWiLA8n4D/pv7aH3asvUvo= +github.com/go-kratos/kratos/v2 v2.9.2/go.mod h1:Jc7jaeYd4RAPjetun2C+oFAOO7HNMHTT/Z4LxpuEDJM= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -375,6 +379,8 @@ github.com/richardlehane/msoleps v1.0.4 h1:WuESlvhX3gH2IHcd8UqyCuFY5yiq/GR/yqaSM github.com/richardlehane/msoleps v1.0.4/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index b8976df..0cc4090 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -470,6 +470,24 @@ func (d *DingTalkBotBiz) HandleStreamRes(ctx context.Context, data *chatbot.BotC SenderStaffId: data.SenderStaffId, Title: data.Text.Content, }) + return +} + +func (d *DingTalkBotBiz) GetReportLists(ctx context.Context) (contentChan chan string, err error) { + contentChan = make(chan string, 10) + defer close(contentChan) + contentChan <- "截止今日23点利润亏损合计:127917.0866元,亏损500元以上的分销商和产品金额如下图:" + contentChan <- "![图片](https://lsxdmgoss.oss-cn-chengdu.aliyuncs.com/MarketingSaaS/image/V2/other/shanghu.png)" + + return +} + +func (d *DingTalkBotBiz) GetGroupInfo(ctx context.Context, groupId int) (group model.AiBotGroup, err error) { + + cond := builder.NewCond() + cond = cond.And(builder.Eq{"group_id": groupId}) + cond = cond.And(builder.Eq{"status": constants.Enable}) + err = d.botGroupImpl.GetOneBySearchToStrut(&cond, &group) return } diff --git a/internal/biz/handle/dingtalk/send_card.go.bak1 b/internal/biz/handle/dingtalk/send_card.go.bak1 deleted file mode 100644 index 9fb1e8d..0000000 --- a/internal/biz/handle/dingtalk/send_card.go.bak1 +++ /dev/null @@ -1,280 +0,0 @@ -package dingtalk - -import ( - "ai_scheduler/internal/data/constants" - "context" - "encoding/json" - "errors" - "fmt" - "strings" - "sync" - "time" - - openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client" - dingtalkcard_1_0 "github.com/alibabacloud-go/dingtalk/card_1_0" - dingtalkim_1_0 "github.com/alibabacloud-go/dingtalk/im_1_0" - util "github.com/alibabacloud-go/tea-utils/v2/service" - "github.com/alibabacloud-go/tea/tea" - "github.com/gofiber/fiber/v2/log" - "github.com/google/uuid" -) - -const DefaultInterval = 100 * time.Millisecond -const HeardBeatX = 50 - -type SendCardClient struct { - Auth *Auth - CardClient *sync.Map - mu sync.RWMutex // 保护 CardClient 的并发访问 - logger log.AllLogger // 日志记录 - botOption *Bot -} - -func NewSendCardClient(auth *Auth, logger log.AllLogger) *SendCardClient { - return &SendCardClient{ - Auth: auth, - CardClient: &sync.Map{}, - logger: logger, - botOption: &Bot{}, - } -} - -// initClient 初始化或复用 DingTalk 客户端 -func (s *SendCardClient) initClient(robotCode string) (*dingtalkcard_1_0.Client, error) { - if client, ok := s.CardClient.Load(robotCode); ok { - return client.(*dingtalkcard_1_0.Client), nil - } - s.botOption.BotCode = robotCode - config := &openapi.Config{ - Protocol: tea.String("https"), - RegionId: tea.String("central"), - } - client, err := dingtalkcard_1_0.NewClient(config) - if err != nil { - s.logger.Error("failed to init DingTalk client") - return nil, fmt.Errorf("init client failed: %w", err) - } - - s.CardClient.Store(robotCode, client) - return client, nil -} - -func (s *SendCardClient) NewCard(ctx context.Context, cardSend *CardSend) error { - // 参数校验 - if (len(cardSend.ContentSlice) == 0 || cardSend.ContentSlice == nil) && cardSend.ContentChannel == nil { - return errors.New("卡片内容不能为空") - } - if cardSend.UpdateInterval == 0 { - cardSend.UpdateInterval = DefaultInterval // 默认更新间隔 - } - if cardSend.Title == "" { - cardSend.Title = "钉钉卡片" - } - //替换标题 - cardSend.Template = constants.CardTemp(strings.Replace(string(cardSend.Template), "${title}", cardSend.Title, 1)) - // 初始化客户端 - client, err := s.initClient(cardSend.RobotCode) - if err != nil { - return fmt.Errorf("初始化client失败: %w", err) - } - - // 生成卡片实例ID - cardInstanceId, err := uuid.NewUUID() - if err != nil { - return fmt.Errorf("创建uuid失败: %w", err) - } - - // 构建初始请求 - request, err := s.buildBaseRequest(cardSend, cardInstanceId.String()) - if err != nil { - return fmt.Errorf("请求失败: %w", err) - } - - // 发送初始卡片 - if _, err := s.SendInteractiveCard(ctx, request, cardSend.RobotCode, client); err != nil { - return fmt.Errorf("发送初始卡片失败: %w", err) - } - - // 处理切片内容(同步) - if len(cardSend.ContentSlice) > 0 { - if err := s.processContentSlice(ctx, cardSend, cardInstanceId.String(), client); err != nil { - return fmt.Errorf("内容同步失败: %w", err) - } - } - - // 处理通道内容(异步) - if cardSend.ContentChannel != nil { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - s.processContentChannel(ctx, cardSend, cardInstanceId.String(), client) - }() - wg.Wait() - } - - return nil -} - -// buildBaseRequest 构建基础请求 -func (s *SendCardClient) buildBaseRequest(cardSend *CardSend, cardInstanceId string) (*dingtalkcard_1_0.StreamingUpdateRequest, error) { - cardData := fmt.Sprintf(string(cardSend.Template), "") // 初始空内容 - request := &dingtalkcard_1_0.StreamingUpdateRequest{ - OutTrackId: tea.String("your-out-track-id"), - Guid: tea.String("0F714542-0AFC-2B0E-CF14-E2D39F5BFFE8"), - Key: tea.String("your-ai-param"), - Content: tea.String("test"), - IsFull: tea.Bool(false), - IsFinalize: tea.Bool(false), - IsError: tea.Bool(false), - } - - switch cardSend.ConversationType { - case constants.ConversationTypeGroup: - request.SetOpenConversationId(cardSend.ConversationId) - case constants.ConversationTypeSingle: - receiver, err := json.Marshal(map[string]string{"userId": cardSend.SenderStaffId}) - if err != nil { - return nil, fmt.Errorf("数据整理失败: %w", err) - } - request.SetSingleChatReceiver(string(receiver)) - default: - return nil, errors.New("未知的聊天场景") - } - - return request, nil -} - -// processContentChannel 处理通道内容(异步更新) -func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) { - defer func() { - if r := recover(); r != nil { - s.logger.Error("panic in processContentChannel") - } - }() - - ticker := time.NewTicker(cardSend.UpdateInterval) - defer ticker.Stop() - heartbeatTicker := time.NewTicker(time.Duration(HeardBeatX) * DefaultInterval) - defer heartbeatTicker.Stop() - - var ( - contentBuilder strings.Builder - lastUpdate time.Time - ) - for { - - select { - case content, ok := <-cardSend.ContentChannel: - if !ok { - // 通道关闭,发送最终内容 - if contentBuilder.Len() > 0 { - if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { - s.logger.Errorf("更新卡片失败1:%s", err.Error()) - } - } - return - } - contentBuilder.WriteString(content) - if contentBuilder.Len() > 0 { - if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { - s.logger.Errorf("更新卡片失败2:%s", err.Error()) - } - } - lastUpdate = time.Now() - - case <-heartbeatTicker.C: - if time.Now().Unix()-lastUpdate.Unix() >= HeardBeatX { - return - } - - case <-ctx.Done(): - s.logger.Info("context canceled, stop channel processing") - return - } - } - -} - -// processContentSlice 处理切片内容(同步更新) -func (s *SendCardClient) processContentSlice(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) error { - var contentBuilder strings.Builder - for _, content := range cardSend.ContentSlice { - contentBuilder.WriteString(content) - err := s.updateCardRequest(ctx, &UpdateCardRequest{ - Template: string(cardSend.Template), - Content: contentBuilder.String(), - Client: client, - RobotCode: cardSend.RobotCode, - CardInstanceId: cardInstanceId, - }) - if err != nil { - return fmt.Errorf("更新卡片失败: %w", err) - } - time.Sleep(cardSend.UpdateInterval) // 控制更新频率 - } - return nil -} - -// updateCardContent 封装卡片更新逻辑 -func (s *SendCardClient) updateCardContent(ctx context.Context, cardSend *CardSend, cardInstanceId, content string, client *dingtalkim_1_0.Client) error { - err := s.updateCardRequest(ctx, &UpdateCardRequest{ - Template: string(cardSend.Template), - Content: content, - Client: client, - RobotCode: cardSend.RobotCode, - CardInstanceId: cardInstanceId, - }) - - return err -} - -func (s *SendCardClient) updateCardRequest(ctx context.Context, updateCardRequest *UpdateCardRequest) error { - - updateRequest := &dingtalkim_1_0.UpdateRobotInteractiveCardRequest{ - CardBizId: tea.String(updateCardRequest.CardInstanceId), - CardData: tea.String(fmt.Sprintf(updateCardRequest.Template, updateCardRequest.Content)), - } - _, err := s.UpdateInteractiveCard(ctx, updateRequest, updateCardRequest.RobotCode, updateCardRequest.Client) - return err -} - -// UpdateInteractiveCard 更新交互卡片(封装错误处理) -func (s *SendCardClient) UpdateInteractiveCard(ctx context.Context, request *dingtalkim_1_0.UpdateRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (*dingtalkim_1_0.UpdateRobotInteractiveCardResponse, error) { - authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) - if err != nil { - return nil, fmt.Errorf("get token failed: %w", err) - } - - headers := &dingtalkim_1_0.UpdateRobotInteractiveCardHeaders{ - XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), - } - - response, err := client.UpdateRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) - if err != nil { - return nil, fmt.Errorf("API call failed: %w,request:%v", err, request.String()) - } - return response, nil -} - -// SendInteractiveCard 发送交互卡片(封装错误处理) -func (s *SendCardClient) SendInteractiveCard(ctx context.Context, request *dingtalkim_1_0.SendRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (res *dingtalkim_1_0.SendRobotInteractiveCardResponse, err error) { - err = s.Auth.GetBotConfigFromModel(s.botOption) - if err != nil { - return nil, fmt.Errorf("初始化bot失败: %w", err) - } - authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) - if err != nil { - return nil, fmt.Errorf("get token failed: %w", err) - } - - headers := &dingtalkim_1_0.SendRobotInteractiveCardHeaders{ - XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), - } - - response, err := client.SendRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) - if err != nil { - return nil, fmt.Errorf("API call failed: %w", err) - } - return response, nil -} diff --git a/internal/config/config.go b/internal/config/config.go index 441d09d..35f2115 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,10 @@ package config import ( + "ai_scheduler/pkg" "fmt" - "time" - "github.com/spf13/viper" + "time" ) // Config 应用配置 @@ -222,10 +222,32 @@ func LoadConfig(configPath string) (*Config, error) { } // 解析配置 - var config Config - if err := viper.Unmarshal(&config); err != nil { + var bc Config + if err := viper.Unmarshal(&bc); err != nil { return nil, fmt.Errorf("failed to unmarshal config: %w", err) } - return &config, nil + return &bc, nil +} + +func LoadConfigWithTest() (*Config, error) { + var bc Config + modularDir, err := pkg.GetModuleDir() + if err != nil { + return nil, err + } + viper.SetConfigFile(modularDir + "/config/config_test.yaml") + viper.SetConfigType("yaml") + // 读取配置文件 + if err := viper.ReadInConfig(); err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + // 解析配置 + + if err := viper.Unmarshal(&bc); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %w", err) + } + + return &bc, nil + } diff --git a/internal/pkg/channel_pool.go b/internal/pkg/channel_pool.go deleted file mode 100644 index eda85fa..0000000 --- a/internal/pkg/channel_pool.go +++ /dev/null @@ -1,75 +0,0 @@ -package pkg - -import ( - "ai_scheduler/internal/config" - "ai_scheduler/internal/entitys" - "sync" -) - -type SafeChannelPool struct { - pool chan chan entitys.ResponseData // 存储空闲 channel 的队列 - bufSize int // channel 缓冲大小 - mu sync.Mutex - closed bool -} - -func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) { - pool := &SafeChannelPool{ - pool: make(chan chan entitys.ResponseData, c.Sys.ChannelPoolLen), - bufSize: c.Sys.ChannelPoolSize, - } - - cleanup := pool.Close - return pool, cleanup -} - -// 从池中获取 channel(若无空闲则创建新 channel) -func (p *SafeChannelPool) Get() chan entitys.ResponseData { - p.mu.Lock() - defer p.mu.Unlock() - - if p.closed { - return make(chan entitys.ResponseData, p.bufSize) - } - - select { - case ch := <-p.pool: // 从池中取 - return ch - default: // 池为空,创建新 channel - return make(chan entitys.ResponseData, p.bufSize) - } -} - -// 将 channel 放回池中(必须确保 channel 已清空!) -func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) { - p.mu.Lock() - defer p.mu.Unlock() - - if p.closed { - return - } - - // 清空 channel(防止复用时读取旧数据) - go func() { - for range ch { - // 丢弃所有数据(或根据业务需求处理) - } - }() - - select { - case p.pool <- ch: // 尝试放回池中 - default: // 池已满,直接关闭 channel(避免泄漏) - close(ch) - } - return -} - -// 关闭池(释放所有资源) -func (p *SafeChannelPool) Close() { - p.mu.Lock() - defer p.mu.Unlock() - - p.closed = true - close(p.pool) // 关闭池队列 - // 需额外逻辑关闭所有内部 channel(此处简化) -} diff --git a/internal/pkg/func.go b/internal/pkg/func.go index d07b202..4d9232b 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -6,8 +6,6 @@ import ( "errors" "fmt" "net/url" - "os" - "path/filepath" "reflect" "strconv" "strings" @@ -197,64 +195,6 @@ func StructToMapUsingJsoniter(obj interface{}) (map[string]string, error) { return result, nil } -func GetModuleDir() (string, error) { - dir, err := os.Getwd() - if err != nil { - return "", err - } - - for { - modPath := filepath.Join(dir, "go.mod") - if _, err := os.Stat(modPath); err == nil { - return dir, nil // 找到 go.mod - } - - // 向上查找父目录 - parent := filepath.Dir(dir) - if parent == dir { - break // 到达根目录,未找到 - } - dir = parent - } - - return "", fmt.Errorf("go.mod not found in current directory or parents") -} - -// GetCacheDir 用于获取缓存目录路径 -// 如果缓存目录不存在,则会自动创建 -// 返回值: -// - string: 缓存目录的路径 -// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息 -func GetCacheDir() (string, error) { - // 获取模块目录 - modDir, err := GetModuleDir() - if err != nil { - return "", err - } - // 拼接缓存目录路径 - path := fmt.Sprintf("%s/cache", modDir) - // 创建目录(包括所有必要的父目录),权限设置为0755 - err = os.MkdirAll(path, 0755) - if err != nil { - return "", fmt.Errorf("创建目录失败: %w", err) - } - // 返回成功创建的缓存目录路径 - return path, nil -} - -func GetTmplDir() (string, error) { - modDir, err := GetModuleDir() - if err != nil { - return "", err - } - path := fmt.Sprintf("%s/tmpl", modDir) - err = os.MkdirAll(path, 0755) - if err != nil { - return "", fmt.Errorf("创建目录失败: %w", err) - } - return path, nil -} - // 通用结构体转 Query 参数 func StructToQuery(obj interface{}) (url.Values, error) { values := url.Values{} diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go index f8fadac..603dedc 100644 --- a/internal/pkg/provider_set.go +++ b/internal/pkg/provider_set.go @@ -15,7 +15,7 @@ var ProviderSetClient = wire.NewSet( utils_langchain.NewUtilLangChain, utils_ollama.NewClient, utils_vllm.NewClient, - NewSafeChannelPool, + dingtalk.NewOldClient, dingtalk.NewContactClient, dingtalk.NewNotableClient, diff --git a/internal/server/cron.go b/internal/server/cron.go new file mode 100644 index 0000000..76c5739 --- /dev/null +++ b/internal/server/cron.go @@ -0,0 +1,93 @@ +package server + +import ( + "ai_scheduler/internal/services" + "context" + + "github.com/gofiber/fiber/v2/log" + "github.com/robfig/cron/v3" +) + +type CronServer struct { + Cron *cron.Cron + jobs []*cronJob + log log.AllLogger + cronService *services.CronService + ctx context.Context +} + +type cronJob struct { + EntryId int32 + Func func(context.Context) error + Name string + Schedule string +} + +func NewCronServer( + log log.AllLogger, + cronService *services.CronService, +) *CronServer { + return &CronServer{ + Cron: cron.New(), + log: log, + cronService: cronService, + ctx: context.Background(), + } +} + +func (c *CronServer) InitJobs(ctx context.Context) { + // 创建一个可用于所有定时任务的上下文(可以取消的上下文) + c.ctx = ctx + c.jobs = []*cronJob{ + { + Func: c.cronService.CronReportSend, + Name: "直连天下报表推送", + Schedule: "@every 60s", + }, + } +} + +func (c *CronServer) Run(ctx context.Context) { + // 先初始化任务 + if c.jobs == nil { + c.InitJobs(ctx) + } + + for i, job := range c.jobs { + // 复制变量到闭包内,避免闭包变量捕获问题 + job := job + jobID := i + 1 + _, err := c.Cron.AddFunc(job.Schedule, func() { + c.log.Infof("任务[%d]:%s开始执行", jobID, job.Name) + + defer func() { + if r := recover(); r != nil { + c.log.Errorf("任务[%d]:%s执行时发生panic: %v", jobID, job.Name, r) + } + c.log.Infof("任务[%d]:%s执行结束", jobID, job.Name) + }() + + // 为每次执行创建新的上下文 + ctx := context.Background() + err := job.Func(ctx) + if err != nil { + c.log.Errorf("任务[%d]:%s执行失败: %s", jobID, job.Name, err.Error()) + } + }) + if err != nil { + c.log.Errorf("添加任务失败:%s", err.Error()) + } + } + + // 启动cron调度器 + c.Cron.Start() + c.log.Info("Cron调度器已启动") +} + +// Stop 停止cron调度器 +func (c *CronServer) Stop() { + if c.Cron != nil { + c.Cron.Stop() + c.log.Info("Cron调度器已停止") + } +} diff --git a/internal/server/provider_set.go b/internal/server/provider_set.go index d5cef3d..08bc37b 100644 --- a/internal/server/provider_set.go +++ b/internal/server/provider_set.go @@ -9,4 +9,5 @@ var ProviderSetServer = wire.NewSet( NewHTTPServer, ProvideAllDingBotServices, NewDingTalkBotServer, + NewCronServer, ) diff --git a/internal/server/server.go b/internal/server/server.go index 02c8f84..b455eb4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,18 +10,14 @@ type Servers struct { cfg *config.Config HttpServer *fiber.App DingBotServer *DingTalkBotServer + Cron *CronServer } -func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer) *Servers { +func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer, cron *CronServer) *Servers { return &Servers{ HttpServer: fiber, cfg: cfg, DingBotServer: DingBotServer, + Cron: cron, } } - -//func DingBotServerInit(clientId string, clientSecret string, cfg *config.Config, handler *do.Handle, do *do.Do) (cli *client.StreamClient) { -// cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret))) -// cli.RegisterChatBotCallbackRouter(services.NewDingBotService(cfg, handler, do).OnChatBotMessageReceived) -// return -//} diff --git a/internal/services/cron.go b/internal/services/cron.go new file mode 100644 index 0000000..5f624d1 --- /dev/null +++ b/internal/services/cron.go @@ -0,0 +1,43 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + "context" + + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" +) + +type CronService struct { + config *config.Config + dingTalkBotBiz *biz.DingTalkBotBiz +} + +func NewCronService(config *config.Config, dingTalkBotBiz *biz.DingTalkBotBiz) *CronService { + return &CronService{ + config: config, + dingTalkBotBiz: dingTalkBotBiz, + } +} + +func (d *CronService) CronReportSend(ctx context.Context) error { + reportChan, err := d.dingTalkBotBiz.GetReportLists(ctx) + if err != nil { + return err + } + groupId := 23 + groupInfo, err := d.dingTalkBotBiz.GetGroupInfo(ctx, groupId) + if err != nil { + return err + } + err = d.dingTalkBotBiz.HandleStreamRes(ctx, &chatbot.BotCallbackDataModel{ + RobotCode: groupInfo.RobotCode, + ConversationType: constants.ConversationTypeGroup, + ConversationId: groupInfo.ConversationID, + Text: chatbot.BotCallbackDataTextModel{ + Content: "报表", + }, + }, reportChan) + return nil +} diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index b71e40b..4635b63 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -3,6 +3,7 @@ package services import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "context" "log" @@ -135,3 +136,24 @@ func (d *DingBotService) runBackgroundTasks(ctx context.Context, data *chatbot.B return nil } + +func (d *DingBotService) CronReportSend(ctx context.Context) error { + reportChan, err := d.dingTalkBotBiz.GetReportLists(ctx) + if err != nil { + return err + } + groupId := 23 + groupInfo, err := d.dingTalkBotBiz.GetGroupInfo(ctx, groupId) + if err != nil { + return err + } + err = d.dingTalkBotBiz.HandleStreamRes(ctx, &chatbot.BotCallbackDataModel{ + RobotCode: groupInfo.RobotCode, + ConversationType: constants.ConversationTypeGroup, + ConversationId: groupInfo.ConversationID, + Text: chatbot.BotCallbackDataTextModel{ + Content: "报表", + }, + }, reportChan) + return nil +} diff --git a/internal/services/dtalk_bot.go.bak b/internal/services/dtalk_bot.go.bak deleted file mode 100644 index 75c2c7f..0000000 --- a/internal/services/dtalk_bot.go.bak +++ /dev/null @@ -1,130 +0,0 @@ -package services - -import ( - "ai_scheduler/internal/biz" - "log" - "sync" - "time" - - "ai_scheduler/internal/config" - "ai_scheduler/internal/entitys" - "context" - - "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" -) - -type DingBotService struct { - config *config.Config - dingTalkBotBiz *biz.DingTalkBotBiz -} - -func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { - return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz} -} - -func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) { - return d.dingTalkBotBiz.GetDingTalkBotCfgList() -} - -func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { - var ( - lastErr error - chat []string - streamWG sync.WaitGroup - resChan = make(chan string, 100) // 缓冲通道防止阻塞 - ) - - // 初始化请求 - requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data) - if err != nil { - return nil, err - } - - // 创建子上下文用于控制goroutine生命周期 - subCtx, cancel := context.WithCancel(ctx) - defer cancel() - - // 启动流式处理goroutine - streamWG.Add(1) - go func() { - defer streamWG.Done() - err = d.dingTalkBotBiz.HandleStreamRes(subCtx, data, resChan) - if err != nil { - return - } - }() - - // 启动业务处理goroutine - done := make(chan error, 1) - go func() { - done <- d.dingTalkBotBiz.Do(subCtx, requireData) - }() - - // 主处理循环 - for { - select { - case <-ctx.Done(): - lastErr = ctx.Err() - goto cleanup - - case resp, ok := <-requireData.Ch: - if !ok { - goto cleanup - } - - // 处理不同类型响应 - switch resp.Type { - case entitys.ResponseLog: - // 忽略日志类型 - continue - - //case entitys.ResponseText, entitys.ResponseJson: - // chat = append(chat, resp.Content) - // if err := d.dingTalkBotBiz.ReplyText(ctx, data.SessionWebhook, resp.Content); err != nil { - // log.Printf("处理非流响应失败: %v", err) - // lastErr = err - // } - - default: - chat = append(chat, resp.Content) - select { - case resChan <- resp.Content: - case <-ctx.Done(): - lastErr = ctx.Err() - goto cleanup - } - } - } - } - -cleanup: - streamWG.Wait() - // 关闭流式通道 - close(resChan) - - // 保存历史记录 - if saveErr := d.dingTalkBotBiz.SaveHis(ctx, requireData, chat); saveErr != nil { - log.Printf("保存历史记录失败: %v", saveErr) - if lastErr == nil { - lastErr = saveErr - } - } - - // 等待业务处理完成(带超时) - select { - case err := <-done: - if err != nil { - log.Printf("业务处理失败: %v", err) - if lastErr == nil { - lastErr = err - } - } - case <-time.After(3 * time.Second): // 增加超时时间 - log.Println("警告:等待业务处理超时,可能发生goroutine泄漏") - } - - if lastErr != nil { - return nil, lastErr - } - return []byte("success"), nil -} diff --git a/internal/services/dtalk_bot_test.go b/internal/services/dtalk_bot_test.go new file mode 100644 index 0000000..2b8223f --- /dev/null +++ b/internal/services/dtalk_bot_test.go @@ -0,0 +1,105 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/biz/do" + dingtalk2 "ai_scheduler/internal/biz/handle/dingtalk" + "ai_scheduler/internal/biz/llm_service" + "ai_scheduler/internal/biz/tools_regis" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/component/callback" + "ai_scheduler/internal/domain/repo" + "ai_scheduler/internal/domain/workflow" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/dingtalk" + "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/pkg/utils_vllm" + + "ai_scheduler/internal/tools" + "ai_scheduler/utils" + "context" + "testing" + + "github.com/gofiber/fiber/v2/log" +) + +func Test_Report(t *testing.T) { + run() + a := dingBotService.CronReportSend(context.Background()) + t.Log(a) +} + +var ( + configConfig *config.Config + err error + dingBotService *DingBotService +) + +// run 函数是程序的入口函数,负责初始化和配置各个组件 +func run() { + // 加载测试配置 + configConfig, err = config.LoadConfigWithTest() + // 初始化数据库连接 + db, _ := utils.NewGormDb(configConfig) + // 初始化各种实现层组件 + sysImpl := impl.NewSysImpl(db) + taskImpl := impl.NewTaskImpl(db) + chatHisImpl := impl.NewChatHisImpl(db) + sessionImpl := impl.NewSessionImpl(db) + botConfigImpl := impl.NewBotConfigImpl(db) + botGroupImpl := impl.NewBotGroupImpl(db) + botUserImpl := impl.NewBotUserImpl(db) + // 初始化Do业务对象 + doDo := do.NewDo(sysImpl, taskImpl, chatHisImpl, configConfig) + // 初始化Ollama客户端 + client, _, _ := utils_ollama.NewClient(configConfig) + // 初始化vLLM客户端 + utils_vllmClient, _, _ := utils_vllm.NewClient(configConfig) + // 初始化Redis数据库连接 + rdb := utils.NewRdb(configConfig) + // 初始化仓库层 + repos := repo.NewRepos(sessionImpl, rdb) + // 初始化包级别的Redis连接 + pkgRdb := pkg.NewRdb(configConfig) + + // 初始化机器人工具实现层 + botToolsImpl := impl.NewBotToolsImpl(db) + // 初始化机器人部门实现层 + botDeptImpl := impl.NewBotDeptImpl(db) + // 初始化Redis管理器 + redisManager := callback.NewRedisManager(pkgRdb) + // 初始化组件 + components := component.NewComponents(redisManager) + // 初始化工作流注册表 + registry := workflow.NewRegistry(configConfig, client, repos, components) + // 初始化钉钉旧版客户端 + oldClient := dingtalk.NewOldClient(configConfig) + // 初始化Ollama服务 + ollamaService := llm_service.NewOllamaGenerate(client, utils_vllmClient, configConfig, chatHisImpl) + // 初始化工具管理器 + manager := tools.NewManager(configConfig, client) + // 初始化钉钉联系人客户端 + contactClient, _ := dingtalk.NewContactClient(configConfig) + // 初始化钉钉记事本客户端 + notableClient, _ := dingtalk.NewNotableClient(configConfig) + // 初始化工具注册 + toolRegis := tools_regis.NewToolsRegis(botToolsImpl) + // 初始化机器人聊天历史实现层 + botChatHisImpl := impl.NewBotChatHisImpl(db) + // 初始化钉钉认证 + auth := dingtalk2.NewAuth(configConfig, rdb, botConfigImpl) + // 初始化部门服务 + dept := dingtalk2.NewDept(botDeptImpl, auth) + // 初始化用户服务 + user := dingtalk2.NewUser(botUserImpl, auth, dept) + // 初始化发送卡片客户端 + sendCardClient := dingtalk2.NewSendCardClient(auth, log.DefaultLogger()) + // 初始化处理器 + handle := do.NewHandle(ollamaService, manager, configConfig, sessionImpl, registry, oldClient, contactClient, notableClient) + // 初始化钉钉机器人业务逻辑 + dingTalkBotBiz := biz.NewDingTalkBotBiz(doDo, handle, botConfigImpl, botGroupImpl, user, toolRegis, botChatHisImpl, manager, configConfig, sendCardClient) + // 初始化钉钉机器人服务 + dingBotService = NewDingBotService(configConfig, dingTalkBotBiz) +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 375a886..55eed7a 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -14,4 +14,5 @@ var ProviderSetServices = wire.NewSet( NewDingBotService, NewHistoryService, NewCapabilityService, + NewCronService, ) diff --git a/internal/tools/bbxt/bbxt.go b/internal/tools/bbxt/bbxt.go index d7d00ef..77a5f29 100644 --- a/internal/tools/bbxt/bbxt.go +++ b/internal/tools/bbxt/bbxt.go @@ -1,14 +1,10 @@ package bbxt import ( - "ai_scheduler/internal/pkg" + "ai_scheduler/pkg" "fmt" - "reflect" "sort" "time" - - "github.com/gofiber/fiber/v2/log" - "github.com/xuri/excelize/v2" ) type BbxtTools struct { @@ -91,7 +87,18 @@ func (b *BbxtTools) StatisOursProductLossSumTotal(ct []string) (err error) { reseller.ProductLoss[info.OursProductId] = productLoss } } + + // 按经销商总亏损排序 + resellers := make([]*ResellerLoss, 0, len(resellerMap)) for _, v := range resellerMap { + resellers = append(resellers, v) + } + sort.Slice(resellers, func(i, j int) bool { + return resellers[i].Total < resellers[j].Total + }) + + // 构建分组 + for _, v := range resellers { if v.Total <= -100 { total = append(total, []string{ fmt.Sprintf("%s", v.ResellerName), @@ -110,172 +117,8 @@ func (b *BbxtTools) StatisOursProductLossSumTotal(ct []string) (err error) { if len(gt) > 0 { filePath := b.cacheDir + "/kshj_gt" + fmt.Sprintf("%d", time.Now().Unix()) + ".xlsx" - err = b.resellerDetailFillExcel(b.excelTempDir+"/"+"kshj_gt.xlsx", filePath, gt) + // err = b.resellerDetailFillExcel(b.excelTempDir+"/"+"kshj_gt.xlsx", filePath, gt) + err = b.resellerDetailFillExcelV2(b.excelTempDir+"/"+"kshj_gt.xlsx", filePath, gt) } return err } - -// 最简单的通用函数 -func (b *BbxtTools) SimpleFillExcel(templatePath, outputPath string, dataSlice interface{}) error { - // 1. 打开模板 - f, err := excelize.OpenFile(templatePath) - if err != nil { - return err - } - defer f.Close() - - sheet := f.GetSheetName(0) - - // 1.1 获取第二行模板样式 - resellerTplRow := 2 - styleIDReseller, err := f.GetCellStyle(sheet, fmt.Sprintf("A%d", resellerTplRow)) - if err != nil { - log.Errorf("获取分销商总计样式失败: %v", err) - styleIDReseller = 0 - } - // 1.2 获取分销商总计行高 - rowHeightReseller, err := f.GetRowHeight(sheet, resellerTplRow) - if err != nil { - log.Errorf("获取分销商总计行高失败: %v", err) - rowHeightReseller = 31 // 默认高度 - } - - // 2. 反射获取切片数据 - v := reflect.ValueOf(dataSlice) - if v.Kind() != reflect.Slice { - return fmt.Errorf("dataSlice must be a slice") - } - - // 3. 从第2行开始填充 - row := 2 - for i := 0; i < v.Len(); i++ { - item := v.Index(i).Interface() - currentRow := row + i - - // 4. 将item转换为一行数据 - var rowData []interface{} - - // 如果是切片 - if reflect.TypeOf(item).Kind() == reflect.Slice { - itemV := reflect.ValueOf(item) - for j := 0; j < itemV.Len(); j++ { - rowData = append(rowData, itemV.Index(j).Interface()) - } - } else if reflect.TypeOf(item).Kind() == reflect.Struct { - itemV := reflect.ValueOf(item) - for j := 0; j < itemV.NumField(); j++ { - if itemV.Field(j).CanInterface() { - rowData = append(rowData, itemV.Field(j).Interface()) - } - } - } else { - rowData = []interface{}{item} - } - // 4.1 设置行高 - f.SetRowHeight(sheet, currentRow, rowHeightReseller) - - // 5. 填充到Excel - for col, value := range rowData { - cell := fmt.Sprintf("%c%d", 'A'+col, currentRow) - f.SetCellValue(sheet, cell, value) - } - - // 5.1 使用第二行模板样式 - if styleIDReseller != 0 { - f.SetCellStyle(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("B%d", currentRow), styleIDReseller) - } - } - - // 6. 保存 - return f.SaveAs(outputPath) -} - -// 分销商负利润详情填充excel -// 1.使用模板文件作为输出文件 -// 2.分销商总计使用第二行样式(宽高、背景、颜色等) -// 3.商品详情使用第三行样式(宽高、背景、颜色等) -// 4.保存为新文件 -func (b *BbxtTools) resellerDetailFillExcel(templatePath, outputPath string, dataSlice []*ResellerLoss) error { - // 1. 读取模板 - f, err := excelize.OpenFile(templatePath) - if err != nil { - return err - } - defer f.Close() - - sheet := f.GetSheetName(0) - - // 获取模板样式1:第二行-分销商总计 - resellerTplRow := 2 - styleIDReseller, err := f.GetCellStyle(sheet, fmt.Sprintf("A%d", resellerTplRow)) - if err != nil { - log.Errorf("获取分销商总计样式失败: %v", err) - styleIDReseller = 0 - } - rowHeightReseller, err := f.GetRowHeight(sheet, resellerTplRow) - if err != nil { - log.Errorf("获取分销商总计行高失败: %v", err) - rowHeightReseller = 31 // 默认高度 - } - // 获取模板样式2:第三行-产品亏损明细 - productTplRow := 3 - styleIDProduct, err := f.GetCellStyle(sheet, fmt.Sprintf("A%d", productTplRow)) - if err != nil { - log.Errorf("获取商品详情样式失败: %v", err) - styleIDProduct = 0 - } - rowHeightProduct, err := f.GetRowHeight(sheet, productTplRow) - if err != nil { - log.Errorf("获取商品详情行高失败: %v", err) - rowHeightProduct = 25 // 默认高度 - } - - currentRow := 2 - - for _, reseller := range dataSlice { - // 3. 填充经销商数据 (ResellerName, Total) - // 设置行高 - f.SetRowHeight(sheet, currentRow, rowHeightReseller) - - // 设置单元格值 - f.SetCellValue(sheet, fmt.Sprintf("A%d", currentRow), reseller.ResellerName) - f.SetCellValue(sheet, fmt.Sprintf("B%d", currentRow), reseller.Total) - - // 应用样式 - if styleIDReseller != 0 { - f.SetCellStyle(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("B%d", currentRow), styleIDReseller) - } - - currentRow++ - - // 4. 填充产品亏损明细 - // 先对 ProductLoss 进行排序 - var products []ProductLoss - for _, p := range reseller.ProductLoss { - products = append(products, p) - } - // 按 Loss 升序排序 (亏损越多越靠前,负数越小) - sort.Slice(products, func(i, j int) bool { - return products[i].Loss < products[j].Loss - }) - - for _, p := range products { - // 设置行高 - f.SetRowHeight(sheet, currentRow, rowHeightProduct) - - // 设置单元格值 - f.SetCellValue(sheet, fmt.Sprintf("A%d", currentRow), p.ProductName) - f.SetCellValue(sheet, fmt.Sprintf("B%d", currentRow), p.Loss) - - // 应用样式 - if styleIDProduct != 0 { - f.SetCellStyle(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("B%d", currentRow), styleIDProduct) - } - - currentRow++ - } - } - - // 6. 保存 - return f.SaveAs(outputPath) -} diff --git a/internal/tools/bbxt/bbxt_test.go b/internal/tools/bbxt/bbxt_test.go index 628c675..b6b2275 100644 --- a/internal/tools/bbxt/bbxt_test.go +++ b/internal/tools/bbxt/bbxt_test.go @@ -1,13 +1,16 @@ package bbxt -import "testing" +import ( + "testing" + "time" +) func Test_StatisOursProductLossSumApiTotal(t *testing.T) { o, err := NewBbxtTools() if err != nil { panic(err) } - err = o.StatisOursProductLossSumTotal([]string{"2025-12-28+00:00:00", "2025-12-28+23:59:59.999"}) + err = o.DailyReport(time.Date(2025, 12, 30, 0, 0, 0, 0, time.Local)) t.Log(err) diff --git a/internal/tools/bbxt/excel.go b/internal/tools/bbxt/excel.go new file mode 100644 index 0000000..4ba932d --- /dev/null +++ b/internal/tools/bbxt/excel.go @@ -0,0 +1,415 @@ +package bbxt + +import ( + "bytes" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "reflect" + "sort" + + "github.com/go-kratos/kratos/v2/log" + "github.com/xuri/excelize/v2" +) + +// 最简单的通用函数 +func (b *BbxtTools) SimpleFillExcel(templatePath, outputPath string, dataSlice interface{}) error { + // 1. 打开模板 + f, err := excelize.OpenFile(templatePath) + if err != nil { + return err + } + defer f.Close() + + sheet := f.GetSheetName(0) + + // 1.1 获取第二行模板样式 + resellerTplRow := 2 + styleIDReseller, err := f.GetCellStyle(sheet, fmt.Sprintf("A%d", resellerTplRow)) + if err != nil { + log.Errorf("获取分销商总计样式失败: %v", err) + styleIDReseller = 0 + } + // 1.2 获取分销商总计行高 + rowHeightReseller, err := f.GetRowHeight(sheet, resellerTplRow) + if err != nil { + log.Errorf("获取分销商总计行高失败: %v", err) + rowHeightReseller = 31 // 默认高度 + } + + // 2. 反射获取切片数据 + v := reflect.ValueOf(dataSlice) + if v.Kind() != reflect.Slice { + return fmt.Errorf("dataSlice must be a slice") + } + + // 3. 从第2行开始填充 + row := 2 + for i := 0; i < v.Len(); i++ { + item := v.Index(i).Interface() + currentRow := row + i + + // 4. 将item转换为一行数据 + var rowData []interface{} + + // 如果是切片 + if reflect.TypeOf(item).Kind() == reflect.Slice { + itemV := reflect.ValueOf(item) + for j := 0; j < itemV.Len(); j++ { + rowData = append(rowData, itemV.Index(j).Interface()) + } + } else if reflect.TypeOf(item).Kind() == reflect.Struct { + itemV := reflect.ValueOf(item) + for j := 0; j < itemV.NumField(); j++ { + if itemV.Field(j).CanInterface() { + rowData = append(rowData, itemV.Field(j).Interface()) + } + } + } else { + rowData = []interface{}{item} + } + // 4.1 设置行高 + f.SetRowHeight(sheet, currentRow, rowHeightReseller) + + // 5. 填充到Excel + for col, value := range rowData { + cell := fmt.Sprintf("%c%d", 'A'+col, currentRow) + f.SetCellValue(sheet, cell, value) + } + + // 5.1 使用第二行模板样式 + if styleIDReseller != 0 { + f.SetCellStyle(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("B%d", currentRow), styleIDReseller) + } + } + + // 6. 保存 + return f.SaveAs(outputPath) +} + +// 分销商负利润详情填充excel +// 1.使用模板文件作为输出文件 +// 2.分销商总计使用第二行样式(宽高、背景、颜色等) +// 3.商品详情使用第三行样式(宽高、背景、颜色等) +// 4.保存为新文件 +func (b *BbxtTools) resellerDetailFillExcel(templatePath, outputPath string, dataSlice []*ResellerLoss) error { + // 1. 读取模板 + f, err := excelize.OpenFile(templatePath) + if err != nil { + return err + } + defer f.Close() + + sheet := f.GetSheetName(0) + + // 获取模板样式1:第二行-分销商总计 + resellerTplRow := 2 + styleIDReseller, err := f.GetCellStyle(sheet, fmt.Sprintf("A%d", resellerTplRow)) + if err != nil { + log.Errorf("获取分销商总计样式失败: %v", err) + styleIDReseller = 0 + } + rowHeightReseller, err := f.GetRowHeight(sheet, resellerTplRow) + if err != nil { + log.Errorf("获取分销商总计行高失败: %v", err) + rowHeightReseller = 31 // 默认高度 + } + // 获取模板样式2:第三行-产品亏损明细 + productTplRow := 3 + styleIDProduct, err := f.GetCellStyle(sheet, fmt.Sprintf("A%d", productTplRow)) + if err != nil { + log.Errorf("获取商品详情样式失败: %v", err) + styleIDProduct = 0 + } + rowHeightProduct, err := f.GetRowHeight(sheet, productTplRow) + if err != nil { + log.Errorf("获取商品详情行高失败: %v", err) + rowHeightProduct = 25 // 默认高度 + } + + currentRow := 2 + + for _, reseller := range dataSlice { + // 3. 填充经销商数据 (ResellerName, Total) + // 设置行高 + f.SetRowHeight(sheet, currentRow, rowHeightReseller) + + // 设置单元格值 + f.SetCellValue(sheet, fmt.Sprintf("A%d", currentRow), reseller.ResellerName) + f.SetCellValue(sheet, fmt.Sprintf("B%d", currentRow), reseller.Total) + + // 应用样式 + if styleIDReseller != 0 { + f.SetCellStyle(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("B%d", currentRow), styleIDReseller) + } + + currentRow++ + + // 4. 填充产品亏损明细 + // 先对 ProductLoss 进行排序 + var products []ProductLoss + for _, p := range reseller.ProductLoss { + products = append(products, p) + } + // 按 Loss 升序排序 (亏损越多越靠前,负数越小) + sort.Slice(products, func(i, j int) bool { + return products[i].Loss < products[j].Loss + }) + + for _, p := range products { + // 设置行高 + f.SetRowHeight(sheet, currentRow, rowHeightProduct) + + // 设置单元格值 + f.SetCellValue(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("·%s", p.ProductName)) + f.SetCellValue(sheet, fmt.Sprintf("B%d", currentRow), p.Loss) + + // 应用样式 + if styleIDProduct != 0 { + f.SetCellStyle(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("B%d", currentRow), styleIDProduct) + } + + currentRow++ + } + } + + // buffer, err := f.WriteToBuffer() + // if err != nil { + // return err + // } + + // return buffer.Bytes(), nil + + // 6. 保存 + return f.SaveAs(outputPath) +} + +// 分销商负利润详情填充excel-V2 +// 1.使用模板文件作为输出文件,从第二行开始填充 +// 2.整体为3列:1.分销商名称(以ResellerName为分组,分销商名称列使用的样式为) 2.商品名称(p.ProductName) 3.亏损金额(p.Loss) +// 3.分销商名称列使用的样式为 A2;商品名称、亏损金额使用的样式为 B2、C2;样式包括宽高、背景、颜色等 +// 4.以ResellerName分组,合并单元格 +// 5.在文件末尾使用“合计”,合计行样式为模板第四行 +// 6.保存为新文件 +func (b *BbxtTools) resellerDetailFillExcelV2(templatePath, outputPath string, dataSlice []*ResellerLoss) error { + // 1. 读取模板 + f, err := excelize.OpenFile(templatePath) + if err != nil { + return err + } + defer f.Close() + + sheet := f.GetSheetName(0) + + // ---------------- 样式获取 ---------------- + // 模板第2行:数据行样式 + tplRowData := 2 + styleA2, err := f.GetCellStyle(sheet, fmt.Sprintf("A%d", tplRowData)) + if err != nil { + styleA2 = 0 + } + // B2和C2通常样式一致,这里取B2作为明细列样式 + styleB2, err := f.GetCellStyle(sheet, fmt.Sprintf("B%d", tplRowData)) + if err != nil { + styleB2 = 0 + } + styleC2, err := f.GetCellStyle(sheet, fmt.Sprintf("C%d", tplRowData)) + if err != nil { + styleC2 = 0 + } + + rowHeightData, err := f.GetRowHeight(sheet, tplRowData) + if err != nil { + rowHeightData = 20 + } + + // 模板第4行:合计行样式 + tplRowTotal := 4 + styleTotalA, err := f.GetCellStyle(sheet, fmt.Sprintf("A%d", tplRowTotal)) + if err != nil { + styleTotalA = 0 + } + styleTotalB, err := f.GetCellStyle(sheet, fmt.Sprintf("B%d", tplRowTotal)) + if err != nil { + styleTotalB = 0 + } + styleTotalC, err := f.GetCellStyle(sheet, fmt.Sprintf("C%d", tplRowTotal)) + if err != nil { + styleTotalC = 0 + } + rowHeightTotal, err := f.GetRowHeight(sheet, tplRowTotal) + if err != nil { + rowHeightTotal = 30 + } + // ---------------------------------------- + + currentRow := 2 + totalLoss := 0.0 + + for _, reseller := range dataSlice { + // 排序 ProductLoss + var products []ProductLoss + for _, p := range reseller.ProductLoss { + products = append(products, p) + } + sort.Slice(products, func(i, j int) bool { + return products[i].Loss < products[j].Loss + }) + + startRow := currentRow + + // 填充该经销商的所有产品 + for _, p := range products { + // 设置行高 + f.SetRowHeight(sheet, currentRow, rowHeightData) + + // 设置值 + f.SetCellValue(sheet, fmt.Sprintf("A%d", currentRow), reseller.ResellerName) + f.SetCellValue(sheet, fmt.Sprintf("B%d", currentRow), p.ProductName) + f.SetCellValue(sheet, fmt.Sprintf("C%d", currentRow), p.Loss) + + // 设置样式 + if styleA2 != 0 { + f.SetCellStyle(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("A%d", currentRow), styleA2) + } + if styleB2 != 0 { + f.SetCellStyle(sheet, fmt.Sprintf("B%d", currentRow), fmt.Sprintf("B%d", currentRow), styleB2) + } + if styleC2 != 0 { + f.SetCellStyle(sheet, fmt.Sprintf("C%d", currentRow), fmt.Sprintf("C%d", currentRow), styleC2) + } + + totalLoss += p.Loss + currentRow++ + } + + endRow := currentRow - 1 + // 合并单元格 (如果多于1行) + if endRow > startRow { + f.MergeCell(sheet, fmt.Sprintf("A%d", startRow), fmt.Sprintf("A%d", endRow)) + } + } + + // ---------------- 填充合计行 ---------------- + // 设置行高 + f.SetRowHeight(sheet, currentRow, rowHeightTotal) + + f.SetCellValue(sheet, fmt.Sprintf("A%d", currentRow), "合计") + // B列留空,C列填充总亏损 + f.SetCellValue(sheet, fmt.Sprintf("C%d", currentRow), totalLoss) + + // 设置合计行样式 + if styleTotalA != 0 { + f.SetCellStyle(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("A%d", currentRow), styleTotalA) + } + if styleTotalB != 0 { + f.SetCellStyle(sheet, fmt.Sprintf("B%d", currentRow), fmt.Sprintf("B%d", currentRow), styleTotalB) + } + if styleTotalC != 0 { + f.SetCellStyle(sheet, fmt.Sprintf("C%d", currentRow), fmt.Sprintf("C%d", currentRow), styleTotalC) + } + // 取消合并合计行的A、B列 + // f.MergeCell(sheet, fmt.Sprintf("A%d", currentRow), fmt.Sprintf("B%d", currentRow)) + + excelBytes, err := f.WriteToBuffer() + if err != nil { + return fmt.Errorf("write to bytes failed: %v", err) + } + + picBytes, err := b.excel2picPy(templatePath, excelBytes.Bytes()) + if err != nil { + return fmt.Errorf("excel2picPy failed: %v", err) + } + b.SavePic("temp.png", picBytes) + + // 6. 保存 + return f.SaveAs(outputPath) +} + +// excel2picPy 将excel转换为图片python +// python 接口如下: +// curl --location --request POST 'http://192.168.6.109:8010/api/v1/convert' \ +// --header 'Content-Type: multipart/form-data; boundary=--------------------------952147881043913664015069' \ +// --form 'file=@"C:\\Users\\Administrator\\Downloads\\销售同比分析2025-12-29 0-12点.xlsx"' \ +// --form 'sheet_name="销售同比分析"' +func (b *BbxtTools) excel2picPy(templatePath string, excelBytes []byte) ([]byte, error) { + // 1. 获取 Sheet Name + // 尝试从 excelBytes 解析,如果失败则使用默认值 "Sheet1" + sheetName := "Sheet1" + f, err := excelize.OpenReader(bytes.NewReader(excelBytes)) + if err == nil { + sheetName = f.GetSheetName(0) + if sheetName == "" { + sheetName = "Sheet1" + } + f.Close() + } + + // 2. 构造 Multipart 请求 + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // 添加文件字段 + // 使用 templatePath 的文件名作为上传文件名,如果没有则用 default.xlsx + filename := "default.xlsx" + if templatePath != "" { + filename = filepath.Base(templatePath) + } + + part, err := writer.CreateFormFile("file", filename) + if err != nil { + return nil, fmt.Errorf("create form file failed: %v", err) + } + if _, err = part.Write(excelBytes); err != nil { + return nil, fmt.Errorf("write file part failed: %v", err) + } + + // 添加 sheet_name 字段 + if err = writer.WriteField("sheet_name", sheetName); err != nil { + return nil, fmt.Errorf("write field sheet_name failed: %v", err) + } + + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("close writer failed: %v", err) + } + + // 3. 发送 HTTP POST 请求 + url := "http://192.168.6.109:8010/api/v1/convert" + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, fmt.Errorf("create request failed: %v", err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("send request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("api request failed with status: %d, body: %s", resp.StatusCode, string(respBody)) + } + + // 4. 读取响应 Body (图片内容) + picBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body failed: %v", err) + } + + return picBytes, nil +} + +// SavePic 保存图片到本地 +func (b *BbxtTools) SavePic(outputPath string, picBytes []byte) error { + dir := filepath.Dir(outputPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("create directory failed: %v", err) + } + return os.WriteFile(outputPath, picBytes, 0644) +} diff --git a/pkg/func.go b/pkg/func.go new file mode 100644 index 0000000..006f16f --- /dev/null +++ b/pkg/func.go @@ -0,0 +1,65 @@ +package pkg + +import ( + "fmt" + "os" + "path/filepath" +) + +func GetModuleDir() (string, error) { + dir, err := os.Getwd() + if err != nil { + return "", err + } + + for { + modPath := filepath.Join(dir, "go.mod") + if _, err := os.Stat(modPath); err == nil { + return dir, nil // 找到 go.mod + } + + // 向上查找父目录 + parent := filepath.Dir(dir) + if parent == dir { + break // 到达根目录,未找到 + } + dir = parent + } + + return "", fmt.Errorf("go.mod not found in current directory or parents") +} + +// GetCacheDir 用于获取缓存目录路径 +// 如果缓存目录不存在,则会自动创建 +// 返回值: +// - string: 缓存目录的路径 +// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息 +func GetCacheDir() (string, error) { + // 获取模块目录 + modDir, err := GetModuleDir() + if err != nil { + return "", err + } + // 拼接缓存目录路径 + path := fmt.Sprintf("%s/cache", modDir) + // 创建目录(包括所有必要的父目录),权限设置为0755 + err = os.MkdirAll(path, 0755) + if err != nil { + return "", fmt.Errorf("创建目录失败: %w", err) + } + // 返回成功创建的缓存目录路径 + return path, nil +} + +func GetTmplDir() (string, error) { + modDir, err := GetModuleDir() + if err != nil { + return "", err + } + path := fmt.Sprintf("%s/tmpl", modDir) + err = os.MkdirAll(path, 0755) + if err != nil { + return "", fmt.Errorf("创建目录失败: %w", err) + } + return path, nil +} diff --git a/tmpl/excel_temp/kshj_gt.xlsx b/tmpl/excel_temp/kshj_gt.xlsx index c826635..e270a67 100755 Binary files a/tmpl/excel_temp/kshj_gt.xlsx and b/tmpl/excel_temp/kshj_gt.xlsx differ diff --git a/tmpl/excel_temp/kshj_gt.xlsx.v1 b/tmpl/excel_temp/kshj_gt.xlsx.v1 new file mode 100755 index 0000000..6933f63 Binary files /dev/null and b/tmpl/excel_temp/kshj_gt.xlsx.v1 differ