diff --git a/cmd/server/main.go b/cmd/server/main.go index d765735..baca30d 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -13,6 +13,7 @@ func main() { configPath := flag.String("config", "./config/config_test.yaml", "Path to configuration file") onBot := flag.String("bot", "", "bot start") flag.Parse() + ctx := context.Background() bc, err := config.LoadConfig(*configPath) if err != nil { log.Fatalf("加载配置失败: %v", err) @@ -25,7 +26,9 @@ func main() { defer func() { cleanup() }() - app.DingBotServer.Run(context.Background(), *onBot) - + //钉钉机器人 + app.DingBotServer.Run(ctx, *onBot) + //定时任务 + app.Cron.Run(ctx) log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) } diff --git a/go.mod b/go.mod index 694ee17..fa8498d 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 github.com/gabriel-vasile/mimetype v1.4.11 - github.com/go-kratos/kratos/v2 v2.9.1 + github.com/go-kratos/kratos/v2 v2.9.2 github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.20.0 @@ -39,6 +39,7 @@ require ( ) require ( + dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.5 // indirect github.com/alibabacloud-go/debug v1.0.1 // indirect @@ -90,6 +91,7 @@ require ( github.com/richardlehane/mscfb v1.0.4 // indirect github.com/richardlehane/msoleps v1.0.4 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sagikazarmark/locafero v0.3.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect diff --git a/go.sum b/go.sum index b753e6a..b0d2c51 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= +dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= +dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= @@ -193,6 +195,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-kratos/kratos/v2 v2.9.1 h1:EGif6/S/aK/RCR5clIbyhioTNyoSrii3FC118jG40Z0= github.com/go-kratos/kratos/v2 v2.9.1/go.mod h1:a1MQLjMhIh7R0kcJS9SzJYR43BRI7EPzzN0J1Ksu2bA= +github.com/go-kratos/kratos/v2 v2.9.2 h1:px8GJQBeLpquDKQWQ9zohEWiLA8n4D/pv7aH3asvUvo= +github.com/go-kratos/kratos/v2 v2.9.2/go.mod h1:Jc7jaeYd4RAPjetun2C+oFAOO7HNMHTT/Z4LxpuEDJM= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -375,6 +379,8 @@ github.com/richardlehane/msoleps v1.0.4 h1:WuESlvhX3gH2IHcd8UqyCuFY5yiq/GR/yqaSM github.com/richardlehane/msoleps v1.0.4/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index b8976df..0cc4090 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -470,6 +470,24 @@ func (d *DingTalkBotBiz) HandleStreamRes(ctx context.Context, data *chatbot.BotC SenderStaffId: data.SenderStaffId, Title: data.Text.Content, }) + return +} + +func (d *DingTalkBotBiz) GetReportLists(ctx context.Context) (contentChan chan string, err error) { + contentChan = make(chan string, 10) + defer close(contentChan) + contentChan <- "截止今日23点利润亏损合计:127917.0866元,亏损500元以上的分销商和产品金额如下图:" + contentChan <- "![图片](https://lsxdmgoss.oss-cn-chengdu.aliyuncs.com/MarketingSaaS/image/V2/other/shanghu.png)" + + return +} + +func (d *DingTalkBotBiz) GetGroupInfo(ctx context.Context, groupId int) (group model.AiBotGroup, err error) { + + cond := builder.NewCond() + cond = cond.And(builder.Eq{"group_id": groupId}) + cond = cond.And(builder.Eq{"status": constants.Enable}) + err = d.botGroupImpl.GetOneBySearchToStrut(&cond, &group) return } diff --git a/internal/biz/handle/dingtalk/send_card.go.bak1 b/internal/biz/handle/dingtalk/send_card.go.bak1 deleted file mode 100644 index 9fb1e8d..0000000 --- a/internal/biz/handle/dingtalk/send_card.go.bak1 +++ /dev/null @@ -1,280 +0,0 @@ -package dingtalk - -import ( - "ai_scheduler/internal/data/constants" - "context" - "encoding/json" - "errors" - "fmt" - "strings" - "sync" - "time" - - openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client" - dingtalkcard_1_0 "github.com/alibabacloud-go/dingtalk/card_1_0" - dingtalkim_1_0 "github.com/alibabacloud-go/dingtalk/im_1_0" - util "github.com/alibabacloud-go/tea-utils/v2/service" - "github.com/alibabacloud-go/tea/tea" - "github.com/gofiber/fiber/v2/log" - "github.com/google/uuid" -) - -const DefaultInterval = 100 * time.Millisecond -const HeardBeatX = 50 - -type SendCardClient struct { - Auth *Auth - CardClient *sync.Map - mu sync.RWMutex // 保护 CardClient 的并发访问 - logger log.AllLogger // 日志记录 - botOption *Bot -} - -func NewSendCardClient(auth *Auth, logger log.AllLogger) *SendCardClient { - return &SendCardClient{ - Auth: auth, - CardClient: &sync.Map{}, - logger: logger, - botOption: &Bot{}, - } -} - -// initClient 初始化或复用 DingTalk 客户端 -func (s *SendCardClient) initClient(robotCode string) (*dingtalkcard_1_0.Client, error) { - if client, ok := s.CardClient.Load(robotCode); ok { - return client.(*dingtalkcard_1_0.Client), nil - } - s.botOption.BotCode = robotCode - config := &openapi.Config{ - Protocol: tea.String("https"), - RegionId: tea.String("central"), - } - client, err := dingtalkcard_1_0.NewClient(config) - if err != nil { - s.logger.Error("failed to init DingTalk client") - return nil, fmt.Errorf("init client failed: %w", err) - } - - s.CardClient.Store(robotCode, client) - return client, nil -} - -func (s *SendCardClient) NewCard(ctx context.Context, cardSend *CardSend) error { - // 参数校验 - if (len(cardSend.ContentSlice) == 0 || cardSend.ContentSlice == nil) && cardSend.ContentChannel == nil { - return errors.New("卡片内容不能为空") - } - if cardSend.UpdateInterval == 0 { - cardSend.UpdateInterval = DefaultInterval // 默认更新间隔 - } - if cardSend.Title == "" { - cardSend.Title = "钉钉卡片" - } - //替换标题 - cardSend.Template = constants.CardTemp(strings.Replace(string(cardSend.Template), "${title}", cardSend.Title, 1)) - // 初始化客户端 - client, err := s.initClient(cardSend.RobotCode) - if err != nil { - return fmt.Errorf("初始化client失败: %w", err) - } - - // 生成卡片实例ID - cardInstanceId, err := uuid.NewUUID() - if err != nil { - return fmt.Errorf("创建uuid失败: %w", err) - } - - // 构建初始请求 - request, err := s.buildBaseRequest(cardSend, cardInstanceId.String()) - if err != nil { - return fmt.Errorf("请求失败: %w", err) - } - - // 发送初始卡片 - if _, err := s.SendInteractiveCard(ctx, request, cardSend.RobotCode, client); err != nil { - return fmt.Errorf("发送初始卡片失败: %w", err) - } - - // 处理切片内容(同步) - if len(cardSend.ContentSlice) > 0 { - if err := s.processContentSlice(ctx, cardSend, cardInstanceId.String(), client); err != nil { - return fmt.Errorf("内容同步失败: %w", err) - } - } - - // 处理通道内容(异步) - if cardSend.ContentChannel != nil { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - s.processContentChannel(ctx, cardSend, cardInstanceId.String(), client) - }() - wg.Wait() - } - - return nil -} - -// buildBaseRequest 构建基础请求 -func (s *SendCardClient) buildBaseRequest(cardSend *CardSend, cardInstanceId string) (*dingtalkcard_1_0.StreamingUpdateRequest, error) { - cardData := fmt.Sprintf(string(cardSend.Template), "") // 初始空内容 - request := &dingtalkcard_1_0.StreamingUpdateRequest{ - OutTrackId: tea.String("your-out-track-id"), - Guid: tea.String("0F714542-0AFC-2B0E-CF14-E2D39F5BFFE8"), - Key: tea.String("your-ai-param"), - Content: tea.String("test"), - IsFull: tea.Bool(false), - IsFinalize: tea.Bool(false), - IsError: tea.Bool(false), - } - - switch cardSend.ConversationType { - case constants.ConversationTypeGroup: - request.SetOpenConversationId(cardSend.ConversationId) - case constants.ConversationTypeSingle: - receiver, err := json.Marshal(map[string]string{"userId": cardSend.SenderStaffId}) - if err != nil { - return nil, fmt.Errorf("数据整理失败: %w", err) - } - request.SetSingleChatReceiver(string(receiver)) - default: - return nil, errors.New("未知的聊天场景") - } - - return request, nil -} - -// processContentChannel 处理通道内容(异步更新) -func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) { - defer func() { - if r := recover(); r != nil { - s.logger.Error("panic in processContentChannel") - } - }() - - ticker := time.NewTicker(cardSend.UpdateInterval) - defer ticker.Stop() - heartbeatTicker := time.NewTicker(time.Duration(HeardBeatX) * DefaultInterval) - defer heartbeatTicker.Stop() - - var ( - contentBuilder strings.Builder - lastUpdate time.Time - ) - for { - - select { - case content, ok := <-cardSend.ContentChannel: - if !ok { - // 通道关闭,发送最终内容 - if contentBuilder.Len() > 0 { - if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { - s.logger.Errorf("更新卡片失败1:%s", err.Error()) - } - } - return - } - contentBuilder.WriteString(content) - if contentBuilder.Len() > 0 { - if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil { - s.logger.Errorf("更新卡片失败2:%s", err.Error()) - } - } - lastUpdate = time.Now() - - case <-heartbeatTicker.C: - if time.Now().Unix()-lastUpdate.Unix() >= HeardBeatX { - return - } - - case <-ctx.Done(): - s.logger.Info("context canceled, stop channel processing") - return - } - } - -} - -// processContentSlice 处理切片内容(同步更新) -func (s *SendCardClient) processContentSlice(ctx context.Context, cardSend *CardSend, cardInstanceId string, client *dingtalkim_1_0.Client) error { - var contentBuilder strings.Builder - for _, content := range cardSend.ContentSlice { - contentBuilder.WriteString(content) - err := s.updateCardRequest(ctx, &UpdateCardRequest{ - Template: string(cardSend.Template), - Content: contentBuilder.String(), - Client: client, - RobotCode: cardSend.RobotCode, - CardInstanceId: cardInstanceId, - }) - if err != nil { - return fmt.Errorf("更新卡片失败: %w", err) - } - time.Sleep(cardSend.UpdateInterval) // 控制更新频率 - } - return nil -} - -// updateCardContent 封装卡片更新逻辑 -func (s *SendCardClient) updateCardContent(ctx context.Context, cardSend *CardSend, cardInstanceId, content string, client *dingtalkim_1_0.Client) error { - err := s.updateCardRequest(ctx, &UpdateCardRequest{ - Template: string(cardSend.Template), - Content: content, - Client: client, - RobotCode: cardSend.RobotCode, - CardInstanceId: cardInstanceId, - }) - - return err -} - -func (s *SendCardClient) updateCardRequest(ctx context.Context, updateCardRequest *UpdateCardRequest) error { - - updateRequest := &dingtalkim_1_0.UpdateRobotInteractiveCardRequest{ - CardBizId: tea.String(updateCardRequest.CardInstanceId), - CardData: tea.String(fmt.Sprintf(updateCardRequest.Template, updateCardRequest.Content)), - } - _, err := s.UpdateInteractiveCard(ctx, updateRequest, updateCardRequest.RobotCode, updateCardRequest.Client) - return err -} - -// UpdateInteractiveCard 更新交互卡片(封装错误处理) -func (s *SendCardClient) UpdateInteractiveCard(ctx context.Context, request *dingtalkim_1_0.UpdateRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (*dingtalkim_1_0.UpdateRobotInteractiveCardResponse, error) { - authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) - if err != nil { - return nil, fmt.Errorf("get token failed: %w", err) - } - - headers := &dingtalkim_1_0.UpdateRobotInteractiveCardHeaders{ - XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), - } - - response, err := client.UpdateRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) - if err != nil { - return nil, fmt.Errorf("API call failed: %w,request:%v", err, request.String()) - } - return response, nil -} - -// SendInteractiveCard 发送交互卡片(封装错误处理) -func (s *SendCardClient) SendInteractiveCard(ctx context.Context, request *dingtalkim_1_0.SendRobotInteractiveCardRequest, robotCode string, client *dingtalkim_1_0.Client) (res *dingtalkim_1_0.SendRobotInteractiveCardResponse, err error) { - err = s.Auth.GetBotConfigFromModel(s.botOption) - if err != nil { - return nil, fmt.Errorf("初始化bot失败: %w", err) - } - authInfo, err := s.Auth.GetTokenFromBotOption(ctx, WithBot(s.botOption)) - if err != nil { - return nil, fmt.Errorf("get token failed: %w", err) - } - - headers := &dingtalkim_1_0.SendRobotInteractiveCardHeaders{ - XAcsDingtalkAccessToken: tea.String(authInfo.AccessToken), - } - - response, err := client.SendRobotInteractiveCardWithOptions(request, headers, &util.RuntimeOptions{}) - if err != nil { - return nil, fmt.Errorf("API call failed: %w", err) - } - return response, nil -} diff --git a/internal/config/config.go b/internal/config/config.go index 441d09d..35f2115 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,10 @@ package config import ( + "ai_scheduler/pkg" "fmt" - "time" - "github.com/spf13/viper" + "time" ) // Config 应用配置 @@ -222,10 +222,32 @@ func LoadConfig(configPath string) (*Config, error) { } // 解析配置 - var config Config - if err := viper.Unmarshal(&config); err != nil { + var bc Config + if err := viper.Unmarshal(&bc); err != nil { return nil, fmt.Errorf("failed to unmarshal config: %w", err) } - return &config, nil + return &bc, nil +} + +func LoadConfigWithTest() (*Config, error) { + var bc Config + modularDir, err := pkg.GetModuleDir() + if err != nil { + return nil, err + } + viper.SetConfigFile(modularDir + "/config/config_test.yaml") + viper.SetConfigType("yaml") + // 读取配置文件 + if err := viper.ReadInConfig(); err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + // 解析配置 + + if err := viper.Unmarshal(&bc); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %w", err) + } + + return &bc, nil + } diff --git a/internal/pkg/channel_pool.go b/internal/pkg/channel_pool.go deleted file mode 100644 index eda85fa..0000000 --- a/internal/pkg/channel_pool.go +++ /dev/null @@ -1,75 +0,0 @@ -package pkg - -import ( - "ai_scheduler/internal/config" - "ai_scheduler/internal/entitys" - "sync" -) - -type SafeChannelPool struct { - pool chan chan entitys.ResponseData // 存储空闲 channel 的队列 - bufSize int // channel 缓冲大小 - mu sync.Mutex - closed bool -} - -func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) { - pool := &SafeChannelPool{ - pool: make(chan chan entitys.ResponseData, c.Sys.ChannelPoolLen), - bufSize: c.Sys.ChannelPoolSize, - } - - cleanup := pool.Close - return pool, cleanup -} - -// 从池中获取 channel(若无空闲则创建新 channel) -func (p *SafeChannelPool) Get() chan entitys.ResponseData { - p.mu.Lock() - defer p.mu.Unlock() - - if p.closed { - return make(chan entitys.ResponseData, p.bufSize) - } - - select { - case ch := <-p.pool: // 从池中取 - return ch - default: // 池为空,创建新 channel - return make(chan entitys.ResponseData, p.bufSize) - } -} - -// 将 channel 放回池中(必须确保 channel 已清空!) -func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) { - p.mu.Lock() - defer p.mu.Unlock() - - if p.closed { - return - } - - // 清空 channel(防止复用时读取旧数据) - go func() { - for range ch { - // 丢弃所有数据(或根据业务需求处理) - } - }() - - select { - case p.pool <- ch: // 尝试放回池中 - default: // 池已满,直接关闭 channel(避免泄漏) - close(ch) - } - return -} - -// 关闭池(释放所有资源) -func (p *SafeChannelPool) Close() { - p.mu.Lock() - defer p.mu.Unlock() - - p.closed = true - close(p.pool) // 关闭池队列 - // 需额外逻辑关闭所有内部 channel(此处简化) -} diff --git a/internal/pkg/func.go b/internal/pkg/func.go index d07b202..4d9232b 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -6,8 +6,6 @@ import ( "errors" "fmt" "net/url" - "os" - "path/filepath" "reflect" "strconv" "strings" @@ -197,64 +195,6 @@ func StructToMapUsingJsoniter(obj interface{}) (map[string]string, error) { return result, nil } -func GetModuleDir() (string, error) { - dir, err := os.Getwd() - if err != nil { - return "", err - } - - for { - modPath := filepath.Join(dir, "go.mod") - if _, err := os.Stat(modPath); err == nil { - return dir, nil // 找到 go.mod - } - - // 向上查找父目录 - parent := filepath.Dir(dir) - if parent == dir { - break // 到达根目录,未找到 - } - dir = parent - } - - return "", fmt.Errorf("go.mod not found in current directory or parents") -} - -// GetCacheDir 用于获取缓存目录路径 -// 如果缓存目录不存在,则会自动创建 -// 返回值: -// - string: 缓存目录的路径 -// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息 -func GetCacheDir() (string, error) { - // 获取模块目录 - modDir, err := GetModuleDir() - if err != nil { - return "", err - } - // 拼接缓存目录路径 - path := fmt.Sprintf("%s/cache", modDir) - // 创建目录(包括所有必要的父目录),权限设置为0755 - err = os.MkdirAll(path, 0755) - if err != nil { - return "", fmt.Errorf("创建目录失败: %w", err) - } - // 返回成功创建的缓存目录路径 - return path, nil -} - -func GetTmplDir() (string, error) { - modDir, err := GetModuleDir() - if err != nil { - return "", err - } - path := fmt.Sprintf("%s/tmpl", modDir) - err = os.MkdirAll(path, 0755) - if err != nil { - return "", fmt.Errorf("创建目录失败: %w", err) - } - return path, nil -} - // 通用结构体转 Query 参数 func StructToQuery(obj interface{}) (url.Values, error) { values := url.Values{} diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go index f8fadac..603dedc 100644 --- a/internal/pkg/provider_set.go +++ b/internal/pkg/provider_set.go @@ -15,7 +15,7 @@ var ProviderSetClient = wire.NewSet( utils_langchain.NewUtilLangChain, utils_ollama.NewClient, utils_vllm.NewClient, - NewSafeChannelPool, + dingtalk.NewOldClient, dingtalk.NewContactClient, dingtalk.NewNotableClient, diff --git a/internal/server/cron.go b/internal/server/cron.go new file mode 100644 index 0000000..76c5739 --- /dev/null +++ b/internal/server/cron.go @@ -0,0 +1,93 @@ +package server + +import ( + "ai_scheduler/internal/services" + "context" + + "github.com/gofiber/fiber/v2/log" + "github.com/robfig/cron/v3" +) + +type CronServer struct { + Cron *cron.Cron + jobs []*cronJob + log log.AllLogger + cronService *services.CronService + ctx context.Context +} + +type cronJob struct { + EntryId int32 + Func func(context.Context) error + Name string + Schedule string +} + +func NewCronServer( + log log.AllLogger, + cronService *services.CronService, +) *CronServer { + return &CronServer{ + Cron: cron.New(), + log: log, + cronService: cronService, + ctx: context.Background(), + } +} + +func (c *CronServer) InitJobs(ctx context.Context) { + // 创建一个可用于所有定时任务的上下文(可以取消的上下文) + c.ctx = ctx + c.jobs = []*cronJob{ + { + Func: c.cronService.CronReportSend, + Name: "直连天下报表推送", + Schedule: "@every 60s", + }, + } +} + +func (c *CronServer) Run(ctx context.Context) { + // 先初始化任务 + if c.jobs == nil { + c.InitJobs(ctx) + } + + for i, job := range c.jobs { + // 复制变量到闭包内,避免闭包变量捕获问题 + job := job + jobID := i + 1 + _, err := c.Cron.AddFunc(job.Schedule, func() { + c.log.Infof("任务[%d]:%s开始执行", jobID, job.Name) + + defer func() { + if r := recover(); r != nil { + c.log.Errorf("任务[%d]:%s执行时发生panic: %v", jobID, job.Name, r) + } + c.log.Infof("任务[%d]:%s执行结束", jobID, job.Name) + }() + + // 为每次执行创建新的上下文 + ctx := context.Background() + err := job.Func(ctx) + if err != nil { + c.log.Errorf("任务[%d]:%s执行失败: %s", jobID, job.Name, err.Error()) + } + }) + if err != nil { + c.log.Errorf("添加任务失败:%s", err.Error()) + } + } + + // 启动cron调度器 + c.Cron.Start() + c.log.Info("Cron调度器已启动") +} + +// Stop 停止cron调度器 +func (c *CronServer) Stop() { + if c.Cron != nil { + c.Cron.Stop() + c.log.Info("Cron调度器已停止") + } +} diff --git a/internal/server/provider_set.go b/internal/server/provider_set.go index d5cef3d..08bc37b 100644 --- a/internal/server/provider_set.go +++ b/internal/server/provider_set.go @@ -9,4 +9,5 @@ var ProviderSetServer = wire.NewSet( NewHTTPServer, ProvideAllDingBotServices, NewDingTalkBotServer, + NewCronServer, ) diff --git a/internal/server/server.go b/internal/server/server.go index 02c8f84..b455eb4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,18 +10,14 @@ type Servers struct { cfg *config.Config HttpServer *fiber.App DingBotServer *DingTalkBotServer + Cron *CronServer } -func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer) *Servers { +func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer, cron *CronServer) *Servers { return &Servers{ HttpServer: fiber, cfg: cfg, DingBotServer: DingBotServer, + Cron: cron, } } - -//func DingBotServerInit(clientId string, clientSecret string, cfg *config.Config, handler *do.Handle, do *do.Do) (cli *client.StreamClient) { -// cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret))) -// cli.RegisterChatBotCallbackRouter(services.NewDingBotService(cfg, handler, do).OnChatBotMessageReceived) -// return -//} diff --git a/internal/services/cron.go b/internal/services/cron.go new file mode 100644 index 0000000..5f624d1 --- /dev/null +++ b/internal/services/cron.go @@ -0,0 +1,43 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + "context" + + "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" +) + +type CronService struct { + config *config.Config + dingTalkBotBiz *biz.DingTalkBotBiz +} + +func NewCronService(config *config.Config, dingTalkBotBiz *biz.DingTalkBotBiz) *CronService { + return &CronService{ + config: config, + dingTalkBotBiz: dingTalkBotBiz, + } +} + +func (d *CronService) CronReportSend(ctx context.Context) error { + reportChan, err := d.dingTalkBotBiz.GetReportLists(ctx) + if err != nil { + return err + } + groupId := 23 + groupInfo, err := d.dingTalkBotBiz.GetGroupInfo(ctx, groupId) + if err != nil { + return err + } + err = d.dingTalkBotBiz.HandleStreamRes(ctx, &chatbot.BotCallbackDataModel{ + RobotCode: groupInfo.RobotCode, + ConversationType: constants.ConversationTypeGroup, + ConversationId: groupInfo.ConversationID, + Text: chatbot.BotCallbackDataTextModel{ + Content: "报表", + }, + }, reportChan) + return nil +} diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index b71e40b..4635b63 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -3,6 +3,7 @@ package services import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" "context" "log" @@ -135,3 +136,24 @@ func (d *DingBotService) runBackgroundTasks(ctx context.Context, data *chatbot.B return nil } + +func (d *DingBotService) CronReportSend(ctx context.Context) error { + reportChan, err := d.dingTalkBotBiz.GetReportLists(ctx) + if err != nil { + return err + } + groupId := 23 + groupInfo, err := d.dingTalkBotBiz.GetGroupInfo(ctx, groupId) + if err != nil { + return err + } + err = d.dingTalkBotBiz.HandleStreamRes(ctx, &chatbot.BotCallbackDataModel{ + RobotCode: groupInfo.RobotCode, + ConversationType: constants.ConversationTypeGroup, + ConversationId: groupInfo.ConversationID, + Text: chatbot.BotCallbackDataTextModel{ + Content: "报表", + }, + }, reportChan) + return nil +} diff --git a/internal/services/dtalk_bot.go.bak b/internal/services/dtalk_bot.go.bak deleted file mode 100644 index 75c2c7f..0000000 --- a/internal/services/dtalk_bot.go.bak +++ /dev/null @@ -1,130 +0,0 @@ -package services - -import ( - "ai_scheduler/internal/biz" - "log" - "sync" - "time" - - "ai_scheduler/internal/config" - "ai_scheduler/internal/entitys" - "context" - - "gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot" -) - -type DingBotService struct { - config *config.Config - dingTalkBotBiz *biz.DingTalkBotBiz -} - -func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { - return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz} -} - -func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) { - return d.dingTalkBotBiz.GetDingTalkBotCfgList() -} - -func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { - var ( - lastErr error - chat []string - streamWG sync.WaitGroup - resChan = make(chan string, 100) // 缓冲通道防止阻塞 - ) - - // 初始化请求 - requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data) - if err != nil { - return nil, err - } - - // 创建子上下文用于控制goroutine生命周期 - subCtx, cancel := context.WithCancel(ctx) - defer cancel() - - // 启动流式处理goroutine - streamWG.Add(1) - go func() { - defer streamWG.Done() - err = d.dingTalkBotBiz.HandleStreamRes(subCtx, data, resChan) - if err != nil { - return - } - }() - - // 启动业务处理goroutine - done := make(chan error, 1) - go func() { - done <- d.dingTalkBotBiz.Do(subCtx, requireData) - }() - - // 主处理循环 - for { - select { - case <-ctx.Done(): - lastErr = ctx.Err() - goto cleanup - - case resp, ok := <-requireData.Ch: - if !ok { - goto cleanup - } - - // 处理不同类型响应 - switch resp.Type { - case entitys.ResponseLog: - // 忽略日志类型 - continue - - //case entitys.ResponseText, entitys.ResponseJson: - // chat = append(chat, resp.Content) - // if err := d.dingTalkBotBiz.ReplyText(ctx, data.SessionWebhook, resp.Content); err != nil { - // log.Printf("处理非流响应失败: %v", err) - // lastErr = err - // } - - default: - chat = append(chat, resp.Content) - select { - case resChan <- resp.Content: - case <-ctx.Done(): - lastErr = ctx.Err() - goto cleanup - } - } - } - } - -cleanup: - streamWG.Wait() - // 关闭流式通道 - close(resChan) - - // 保存历史记录 - if saveErr := d.dingTalkBotBiz.SaveHis(ctx, requireData, chat); saveErr != nil { - log.Printf("保存历史记录失败: %v", saveErr) - if lastErr == nil { - lastErr = saveErr - } - } - - // 等待业务处理完成(带超时) - select { - case err := <-done: - if err != nil { - log.Printf("业务处理失败: %v", err) - if lastErr == nil { - lastErr = err - } - } - case <-time.After(3 * time.Second): // 增加超时时间 - log.Println("警告:等待业务处理超时,可能发生goroutine泄漏") - } - - if lastErr != nil { - return nil, lastErr - } - return []byte("success"), nil -} diff --git a/internal/services/dtalk_bot_test.go b/internal/services/dtalk_bot_test.go new file mode 100644 index 0000000..2b8223f --- /dev/null +++ b/internal/services/dtalk_bot_test.go @@ -0,0 +1,105 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/biz/do" + dingtalk2 "ai_scheduler/internal/biz/handle/dingtalk" + "ai_scheduler/internal/biz/llm_service" + "ai_scheduler/internal/biz/tools_regis" + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/domain/component" + "ai_scheduler/internal/domain/component/callback" + "ai_scheduler/internal/domain/repo" + "ai_scheduler/internal/domain/workflow" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/dingtalk" + "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/pkg/utils_vllm" + + "ai_scheduler/internal/tools" + "ai_scheduler/utils" + "context" + "testing" + + "github.com/gofiber/fiber/v2/log" +) + +func Test_Report(t *testing.T) { + run() + a := dingBotService.CronReportSend(context.Background()) + t.Log(a) +} + +var ( + configConfig *config.Config + err error + dingBotService *DingBotService +) + +// run 函数是程序的入口函数,负责初始化和配置各个组件 +func run() { + // 加载测试配置 + configConfig, err = config.LoadConfigWithTest() + // 初始化数据库连接 + db, _ := utils.NewGormDb(configConfig) + // 初始化各种实现层组件 + sysImpl := impl.NewSysImpl(db) + taskImpl := impl.NewTaskImpl(db) + chatHisImpl := impl.NewChatHisImpl(db) + sessionImpl := impl.NewSessionImpl(db) + botConfigImpl := impl.NewBotConfigImpl(db) + botGroupImpl := impl.NewBotGroupImpl(db) + botUserImpl := impl.NewBotUserImpl(db) + // 初始化Do业务对象 + doDo := do.NewDo(sysImpl, taskImpl, chatHisImpl, configConfig) + // 初始化Ollama客户端 + client, _, _ := utils_ollama.NewClient(configConfig) + // 初始化vLLM客户端 + utils_vllmClient, _, _ := utils_vllm.NewClient(configConfig) + // 初始化Redis数据库连接 + rdb := utils.NewRdb(configConfig) + // 初始化仓库层 + repos := repo.NewRepos(sessionImpl, rdb) + // 初始化包级别的Redis连接 + pkgRdb := pkg.NewRdb(configConfig) + + // 初始化机器人工具实现层 + botToolsImpl := impl.NewBotToolsImpl(db) + // 初始化机器人部门实现层 + botDeptImpl := impl.NewBotDeptImpl(db) + // 初始化Redis管理器 + redisManager := callback.NewRedisManager(pkgRdb) + // 初始化组件 + components := component.NewComponents(redisManager) + // 初始化工作流注册表 + registry := workflow.NewRegistry(configConfig, client, repos, components) + // 初始化钉钉旧版客户端 + oldClient := dingtalk.NewOldClient(configConfig) + // 初始化Ollama服务 + ollamaService := llm_service.NewOllamaGenerate(client, utils_vllmClient, configConfig, chatHisImpl) + // 初始化工具管理器 + manager := tools.NewManager(configConfig, client) + // 初始化钉钉联系人客户端 + contactClient, _ := dingtalk.NewContactClient(configConfig) + // 初始化钉钉记事本客户端 + notableClient, _ := dingtalk.NewNotableClient(configConfig) + // 初始化工具注册 + toolRegis := tools_regis.NewToolsRegis(botToolsImpl) + // 初始化机器人聊天历史实现层 + botChatHisImpl := impl.NewBotChatHisImpl(db) + // 初始化钉钉认证 + auth := dingtalk2.NewAuth(configConfig, rdb, botConfigImpl) + // 初始化部门服务 + dept := dingtalk2.NewDept(botDeptImpl, auth) + // 初始化用户服务 + user := dingtalk2.NewUser(botUserImpl, auth, dept) + // 初始化发送卡片客户端 + sendCardClient := dingtalk2.NewSendCardClient(auth, log.DefaultLogger()) + // 初始化处理器 + handle := do.NewHandle(ollamaService, manager, configConfig, sessionImpl, registry, oldClient, contactClient, notableClient) + // 初始化钉钉机器人业务逻辑 + dingTalkBotBiz := biz.NewDingTalkBotBiz(doDo, handle, botConfigImpl, botGroupImpl, user, toolRegis, botChatHisImpl, manager, configConfig, sendCardClient) + // 初始化钉钉机器人服务 + dingBotService = NewDingBotService(configConfig, dingTalkBotBiz) +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 375a886..55eed7a 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -14,4 +14,5 @@ var ProviderSetServices = wire.NewSet( NewDingBotService, NewHistoryService, NewCapabilityService, + NewCronService, ) diff --git a/internal/tools/bbxt/bbxt.go b/internal/tools/bbxt/bbxt.go index 6d593db..f7d1e68 100644 --- a/internal/tools/bbxt/bbxt.go +++ b/internal/tools/bbxt/bbxt.go @@ -1,7 +1,7 @@ package bbxt import ( - "ai_scheduler/internal/pkg" + "ai_scheduler/pkg" "fmt" "sort" "time" diff --git a/pkg/func.go b/pkg/func.go new file mode 100644 index 0000000..006f16f --- /dev/null +++ b/pkg/func.go @@ -0,0 +1,65 @@ +package pkg + +import ( + "fmt" + "os" + "path/filepath" +) + +func GetModuleDir() (string, error) { + dir, err := os.Getwd() + if err != nil { + return "", err + } + + for { + modPath := filepath.Join(dir, "go.mod") + if _, err := os.Stat(modPath); err == nil { + return dir, nil // 找到 go.mod + } + + // 向上查找父目录 + parent := filepath.Dir(dir) + if parent == dir { + break // 到达根目录,未找到 + } + dir = parent + } + + return "", fmt.Errorf("go.mod not found in current directory or parents") +} + +// GetCacheDir 用于获取缓存目录路径 +// 如果缓存目录不存在,则会自动创建 +// 返回值: +// - string: 缓存目录的路径 +// - error: 如果获取模块目录失败或创建缓存目录失败,则返回错误信息 +func GetCacheDir() (string, error) { + // 获取模块目录 + modDir, err := GetModuleDir() + if err != nil { + return "", err + } + // 拼接缓存目录路径 + path := fmt.Sprintf("%s/cache", modDir) + // 创建目录(包括所有必要的父目录),权限设置为0755 + err = os.MkdirAll(path, 0755) + if err != nil { + return "", fmt.Errorf("创建目录失败: %w", err) + } + // 返回成功创建的缓存目录路径 + return path, nil +} + +func GetTmplDir() (string, error) { + modDir, err := GetModuleDir() + if err != nil { + return "", err + } + path := fmt.Sprintf("%s/tmpl", modDir) + err = os.MkdirAll(path, 0755) + if err != nil { + return "", fmt.Errorf("创建目录失败: %w", err) + } + return path, nil +}