133 lines
3.1 KiB
Go
133 lines
3.1 KiB
Go
package services
|
||
|
||
import (
|
||
"ai_scheduler/internal/biz"
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/entitys"
|
||
"context"
|
||
"log"
|
||
"sync"
|
||
"time"
|
||
|
||
"gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot"
|
||
"golang.org/x/sync/errgroup"
|
||
)
|
||
|
||
type DingBotService struct {
|
||
config *config.Config
|
||
dingTalkBotBiz *biz.DingTalkBotBiz
|
||
}
|
||
|
||
func NewDingBotService(config *config.Config, dingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService {
|
||
return &DingBotService{
|
||
config: config,
|
||
dingTalkBotBiz: dingTalkBotBiz,
|
||
}
|
||
}
|
||
|
||
func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) {
|
||
return d.dingTalkBotBiz.GetDingTalkBotCfgList()
|
||
}
|
||
|
||
func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) {
|
||
requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 启动后台任务(独立生命周期,带超时控制)
|
||
go func() {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||
defer cancel()
|
||
if err := d.runBackgroundTasks(ctx, data, requireData); err != nil {
|
||
log.Printf("后台任务执行失败: %v", err)
|
||
}
|
||
}()
|
||
|
||
return []byte("success"), nil
|
||
}
|
||
|
||
func (d *DingBotService) runBackgroundTasks(ctx context.Context, data *chatbot.BotCallbackDataModel, requireData *entitys.RequireDataDingTalkBot) error {
|
||
g, ctx := errgroup.WithContext(ctx)
|
||
var (
|
||
chat []string
|
||
chatMu sync.Mutex
|
||
resChan = make(chan string, 10)
|
||
)
|
||
|
||
// 1. 流式处理协程
|
||
g.Go(func() error {
|
||
defer func() {
|
||
// 确保通道最终关闭
|
||
close(resChan)
|
||
}()
|
||
return d.dingTalkBotBiz.HandleStreamRes(ctx, data, resChan)
|
||
})
|
||
|
||
// 2. 业务处理协程(负责关闭requireData.Ch)
|
||
g.Go(func() error {
|
||
// 在完成时关闭通道
|
||
defer close(requireData.Ch)
|
||
return d.dingTalkBotBiz.Do(ctx, requireData)
|
||
})
|
||
|
||
// 3. 结果收集协程(修改后的版本)
|
||
resultDone := make(chan struct{})
|
||
g.Go(func() error {
|
||
// 使用defer确保通道关闭
|
||
defer close(resultDone)
|
||
|
||
// 处理通道中的数据
|
||
for {
|
||
select {
|
||
case resp, ok := <-requireData.Ch:
|
||
if !ok {
|
||
return nil // 通道已关闭,正常退出
|
||
}
|
||
if resp.Type != entitys.ResponseLog {
|
||
chatMu.Lock()
|
||
chat = append(chat, resp.Content)
|
||
chatMu.Unlock()
|
||
|
||
select {
|
||
case resChan <- resp.Content:
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
}
|
||
}
|
||
case <-ctx.Done():
|
||
return ctx.Err() // 上下文取消,提前退出
|
||
}
|
||
}
|
||
})
|
||
|
||
// 4. 统一关闭通道的协程(只关闭resChan)
|
||
g.Go(func() error {
|
||
<-resultDone
|
||
// resChan已在流式处理协程关闭
|
||
return nil
|
||
})
|
||
|
||
// 5. 历史记录保存协程
|
||
g.Go(func() error {
|
||
<-resultDone
|
||
chatMu.Lock()
|
||
savedChat := make([]string, len(chat))
|
||
copy(savedChat, chat)
|
||
chatMu.Unlock()
|
||
|
||
if err := d.dingTalkBotBiz.SaveHis(ctx, requireData, savedChat); err != nil {
|
||
log.Printf("保存历史记录失败: %v", err)
|
||
return err
|
||
}
|
||
return nil
|
||
})
|
||
|
||
// 阻塞直到所有协程完成或出错
|
||
if err := g.Wait(); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|