131 lines
2.8 KiB
Go
131 lines
2.8 KiB
Go
package services
|
||
|
||
import (
|
||
"ai_scheduler/internal/biz"
|
||
"log"
|
||
"sync"
|
||
"time"
|
||
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/entitys"
|
||
"context"
|
||
|
||
"gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot"
|
||
)
|
||
|
||
type DingBotService struct {
|
||
config *config.Config
|
||
dingTalkBotBiz *biz.DingTalkBotBiz
|
||
}
|
||
|
||
func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService {
|
||
return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz}
|
||
}
|
||
|
||
func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) {
|
||
return d.dingTalkBotBiz.GetDingTalkBotCfgList()
|
||
}
|
||
|
||
func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) {
|
||
var (
|
||
lastErr error
|
||
chat []string
|
||
streamWG sync.WaitGroup
|
||
resChan = make(chan string, 100) // 缓冲通道防止阻塞
|
||
)
|
||
|
||
// 初始化请求
|
||
requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 创建子上下文用于控制goroutine生命周期
|
||
subCtx, cancel := context.WithCancel(ctx)
|
||
defer cancel()
|
||
|
||
// 启动流式处理goroutine
|
||
streamWG.Add(1)
|
||
go func() {
|
||
defer streamWG.Done()
|
||
err = d.dingTalkBotBiz.HandleStreamRes(subCtx, data, resChan)
|
||
if err != nil {
|
||
return
|
||
}
|
||
}()
|
||
|
||
// 启动业务处理goroutine
|
||
done := make(chan error, 1)
|
||
go func() {
|
||
done <- d.dingTalkBotBiz.Do(subCtx, requireData)
|
||
}()
|
||
|
||
// 主处理循环
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
lastErr = ctx.Err()
|
||
goto cleanup
|
||
|
||
case resp, ok := <-requireData.Ch:
|
||
if !ok {
|
||
goto cleanup
|
||
}
|
||
|
||
// 处理不同类型响应
|
||
switch resp.Type {
|
||
case entitys.ResponseLog:
|
||
// 忽略日志类型
|
||
continue
|
||
|
||
//case entitys.ResponseText, entitys.ResponseJson:
|
||
// chat = append(chat, resp.Content)
|
||
// if err := d.dingTalkBotBiz.ReplyText(ctx, data.SessionWebhook, resp.Content); err != nil {
|
||
// log.Printf("处理非流响应失败: %v", err)
|
||
// lastErr = err
|
||
// }
|
||
|
||
default:
|
||
chat = append(chat, resp.Content)
|
||
select {
|
||
case resChan <- resp.Content:
|
||
case <-ctx.Done():
|
||
lastErr = ctx.Err()
|
||
goto cleanup
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
cleanup:
|
||
streamWG.Wait()
|
||
// 关闭流式通道
|
||
close(resChan)
|
||
|
||
// 保存历史记录
|
||
if saveErr := d.dingTalkBotBiz.SaveHis(ctx, requireData, chat); saveErr != nil {
|
||
log.Printf("保存历史记录失败: %v", saveErr)
|
||
if lastErr == nil {
|
||
lastErr = saveErr
|
||
}
|
||
}
|
||
|
||
// 等待业务处理完成(带超时)
|
||
select {
|
||
case err := <-done:
|
||
if err != nil {
|
||
log.Printf("业务处理失败: %v", err)
|
||
if lastErr == nil {
|
||
lastErr = err
|
||
}
|
||
}
|
||
case <-time.After(3 * time.Second): // 增加超时时间
|
||
log.Println("警告:等待业务处理超时,可能发生goroutine泄漏")
|
||
}
|
||
|
||
if lastErr != nil {
|
||
return nil, lastErr
|
||
}
|
||
return []byte("success"), nil
|
||
}
|