From 945f3ff4fc6d1b18c6a01e2d17706c366b5ec44c Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Tue, 9 Dec 2025 14:25:13 +0800 Subject: [PATCH 1/7] =?UTF-8?q?feat:=E4=BC=9A=E8=AF=9D=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/chat_history.go | 10 +++++++++- internal/entitys/chat_history.go | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/internal/biz/chat_history.go b/internal/biz/chat_history.go index 58266e8..c3912a4 100644 --- a/internal/biz/chat_history.go +++ b/internal/biz/chat_history.go @@ -26,10 +26,18 @@ func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *C // 查询会话历史 func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) { - chats, err := s.chatHiRepo.FindAll( + + con := []impl.CondFunc{ s.chatHiRepo.WithSessionId(query.SessionID), s.chatHiRepo.PaginateScope(query.Page, query.PageSize), s.chatHiRepo.OrderByDesc("his_id"), + } + if query.HisID > 0 { + con = append(con, s.chatHiRepo.WithHisId(query.HisID)) + } + + chats, err := s.chatHiRepo.FindAll( + con..., ) if err != nil { return nil, err diff --git a/internal/entitys/chat_history.go b/internal/entitys/chat_history.go index 019a8c6..b50148e 100644 --- a/internal/entitys/chat_history.go +++ b/internal/entitys/chat_history.go @@ -19,6 +19,7 @@ type ChatHisLog struct { } type ChatHistQuery struct { + HisID int64 `json:"his_id"` SessionID string `json:"session_id"` Page int `json:"page"` PageSize int `json:"page_size"` From d509a18d442b3bd8346605458ca833aceb820dde Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 9 Dec 2025 16:32:13 +0800 Subject: [PATCH 2/7] feat: add DingTalk bot integration --- cmd/server/main.go | 5 +- config/config_test.yaml | 6 + go.mod | 11 +- go.sum | 6 + internal/biz/ding_talk_bot.go | 42 ++++++ internal/biz/do/ctx.go | 31 +++- internal/biz/do/handle.go | 20 +-- internal/biz/do/prompt.go | 70 +++++++++ internal/biz/handle/file.go | 14 ++ internal/biz/handle/handle.go | 22 --- internal/biz/llm_service/common.go | 47 ------ internal/biz/llm_service/langchain.go | 159 ++++++++++----------- internal/biz/llm_service/ollama.go | 79 ++-------- internal/biz/provider_set.go | 6 +- internal/biz/router.go | 16 ++- internal/config/config.go | 25 ++-- internal/data/constants/bot.go | 22 ++- internal/data/constants/file.go | 75 ++++++++++ internal/data/impl/bot_chat_history.go | 15 ++ internal/data/model/ai_bot_chat_his.gen.go | 27 ++++ internal/entitys/bot.go | 23 ++- internal/entitys/ollama.go | 10 ++ internal/entitys/recognize.go | 31 ++++ internal/entitys/response.go | 35 +++-- internal/entitys/types.go | 3 +- internal/server/ding_talk_bot.go | 65 +++++++++ internal/server/provider_set.go | 11 +- internal/server/server.go | 22 ++- internal/services/callback.go | 6 +- internal/services/dtalk_bot.go | 135 +++++++++++++++++ internal/services/provider_set.go | 9 +- internal/tools_bot/dtalk_bot.go | 15 +- 32 files changed, 765 insertions(+), 298 deletions(-) create mode 100644 internal/biz/ding_talk_bot.go create mode 100644 internal/biz/do/prompt.go create mode 100644 internal/biz/handle/file.go delete mode 100644 internal/biz/handle/handle.go create mode 100644 internal/data/constants/file.go create mode 100644 internal/data/impl/bot_chat_history.go create mode 100644 internal/data/model/ai_bot_chat_his.gen.go create mode 100644 internal/entitys/ollama.go create mode 100644 internal/entitys/recognize.go create mode 100644 internal/server/ding_talk_bot.go create mode 100644 internal/services/dtalk_bot.go diff --git a/cmd/server/main.go b/cmd/server/main.go index f0b0214..a6607f5 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,7 +2,6 @@ package main import ( "ai_scheduler/internal/config" - "flag" "fmt" @@ -23,8 +22,8 @@ func main() { } defer func() { cleanup() - }() - + //app.DingBotServer.Run(context.Background()) + //app.DingBotServer.RunBots(app.DingBotServer.BotServices) log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) } diff --git a/config/config_test.yaml b/config/config_test.yaml index 2929dd7..eed1fa5 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -81,3 +81,9 @@ default_prompt: # 权限配置 permissionConfig: permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId=" + + +ding_talk_bots: + public: + client_id: "dingchg59zwwvmuuvldx", + client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz", \ No newline at end of file diff --git a/go.mod b/go.mod index 6b08e4a..c78edda 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module ai_scheduler -go 1.24.0 - -toolchain go1.24.7 +go 1.24.1 require ( gitea.cdlsxd.cn/self-tools/l_request v1.0.8 @@ -13,18 +11,18 @@ require ( github.com/emirpasic/gods v1.18.1 github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 + github.com/gabriel-vasile/mimetype v1.4.11 github.com/go-kratos/kratos/v2 v2.9.1 github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/websocket/v2 v2.2.1 github.com/google/uuid v1.6.0 github.com/google/wire v0.7.0 github.com/ollama/ollama v0.12.7 + github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/redis/go-redis/v9 v9.16.0 github.com/spf13/viper v1.17.0 github.com/tmc/langchaingo v0.1.13 google.golang.org/grpc v1.64.0 - google.golang.org/protobuf v1.34.1 - gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.31.0 xorm.io/builder v0.3.13 @@ -46,6 +44,7 @@ require ( github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect @@ -82,5 +81,7 @@ require ( golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 29b4ace..e1306fe 100644 --- a/go.sum +++ b/go.sum @@ -142,6 +142,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik= +github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -217,6 +219,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= @@ -273,6 +277,8 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108 github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= +github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go new file mode 100644 index 0000000..09de5d4 --- /dev/null +++ b/internal/biz/ding_talk_bot.go @@ -0,0 +1,42 @@ +package biz + +import ( + "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/entitys" + "context" + + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" +) + +// AiRouterBiz 智能路由服务 +type DingTalkBotBiz struct { + do *do.Do + handle *do.Handle +} + +// NewDingTalkBotBiz +func NewDingTalkBotBiz( + do *do.Do, + handle *do.Handle, +) *DingTalkBotBiz { + return &DingTalkBotBiz{ + do: do, + handle: handle, + } +} + +func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallbackDataModel) (requireData *entitys.RequireDataDingTalkBot, err error) { + requireData = &entitys.RequireDataDingTalkBot{ + Req: data, + Ch: make(chan entitys.Response, 2), + } + entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等") + + requireData.Sys, err = d.do.GetSysInfoForDingTalkBot(requireData) + requireData.Tasks, err = d.do.GetTasks(requireData.Sys.SysID) + return +} + +func (d *DingTalkBotBiz) Recognize(ctx context.Context, rec *entitys.Recognize, ch chan entitys.Response) (match *entitys.Match, err error) { + return d.handle.RecognizeBot(ctx, rec, ch) +} diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index ed8749d..eeefbeb 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -81,6 +81,20 @@ func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData * return nil } +func (d *Do) DataAuthForBot(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { + + // 2. 加载系统信息 + if err = d.loadSystemInfo(ctx, client, requireData); err != nil { + return fmt.Errorf("获取系统信息失败: %w", err) + } + + // 3. 加载任务列表 + if err = d.loadTaskList(ctx, client, requireData); err != nil { + return fmt.Errorf("获取任务列表失败: %w", err) + } + return nil +} + // 提取数据验证为单独函数 func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error { requireData.Session = client.GetSession() @@ -104,7 +118,7 @@ func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.Req // 获取系统信息的辅助函数 func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if sysInfo := client.GetSysInfo(); sysInfo == nil { - sys, err := d.getSysInfo(requireData) + sys, err := d.GetSysInfo(requireData) if err != nil { return err } @@ -119,7 +133,7 @@ func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, require // 获取任务列表的辅助函数 func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if taskInfo := client.GetTasks(); len(taskInfo) == 0 { - tasks, err := d.getTasks(requireData.Sys.SysID) + tasks, err := d.GetTasks(requireData.Sys.SysID) if err != nil { return err } @@ -202,7 +216,16 @@ func (d *Do) getRequireData() (err error) { return } -func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) { +func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) { + cond := builder.NewCond() + cond = cond.And(builder.Eq{"app_key": requireData.Key}) + cond = cond.And(builder.IsNull{"delete_at"}) + cond = cond.And(builder.Eq{"status": 1}) + err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) + return +} + +func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"app_key": requireData.Key}) cond = cond.And(builder.IsNull{"delete_at"}) @@ -221,7 +244,7 @@ func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.Ai return } -func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) { +func (d *Do) GetTasks(sysId int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"sys_id": sysId}) diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index e014d25..162005f 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -17,8 +17,9 @@ import ( "context" "encoding/json" "fmt" - "gorm.io/gorm/utils" "strings" + + "gorm.io/gorm/utils" ) type Handle struct { @@ -45,23 +46,24 @@ func NewHandle( } } -func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) { - entitys.ResLog(requireData.Ch, "recognize_start", "准备意图识别") - +func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) { + entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别") + prompt, err := promptProcessor.CreatePrompt(ctx, rec) //意图识别 - recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData) + recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{ + Prompt: prompt, + Tools: rec.Tasks, + }) if err != nil { return } - entitys.ResLog(requireData.Ch, "recognize", recognizeMsg) - entitys.ResLog(requireData.Ch, "recognize_end", "意图识别结束") + entitys.ResLog(rec.Ch, "recognize", recognizeMsg) + entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束") - var match entitys.Match if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil { err = errors.SysErr("数据结构错误:%v", err.Error()) return } - requireData.Match = &match return } diff --git a/internal/biz/do/prompt.go b/internal/biz/do/prompt.go new file mode 100644 index 0000000..cb1ca18 --- /dev/null +++ b/internal/biz/do/prompt.go @@ -0,0 +1,70 @@ +package do + +import ( + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "context" + "strings" + + "github.com/ollama/ollama/api" +) + +type PromptOption interface { + CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) +} + +type WithSys struct { +} + +func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) { + var ( + prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片 + ) + // 获取用户内容,如果出错则直接返回错误 + content, err := f.getUserContent(ctx, rec) + if err != nil { + return nil, err + } + // 构建提示消息列表,包含系统提示、助手回复和用户内容 + mes = append(prompt, api.Message{ + Role: "system", // 系统角色 + Content: rec.SystemPrompt, // 系统提示内容 + }, api.Message{ + Role: "assistant", // 助手角色 + Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容 + }, api.Message{ + Role: "user", // 用户角色 + Content: content.String(), // 用户输入内容 + }) + + return +} + +func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { + var hasFile bool + if len(rec.UserContent.FileUrl) > 0 || rec.UserContent.File != nil { + hasFile = true + } + content.WriteString(rec.UserContent.Text) + if hasFile { + content.WriteString("\n") + } + + if len(rec.UserContent.Tag) > 0 { + content.WriteString("\n") + content.WriteString("### 工具必须使用:") + content.WriteString(rec.UserContent.Tag) + } + + if len(rec.ChatHis.Messages) > 0 { + content.WriteString("### 引用历史聊天记录:\n") + content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis)) + } + + if hasFile { + content.WriteString("\n") + content.WriteString("### 文件内容:\n") + hand.WriteString(rec.UserContent.FileUrl, rec.UserContent.FileUrl) + } + return +} diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go new file mode 100644 index 0000000..3c9d79b --- /dev/null +++ b/internal/biz/handle/file.go @@ -0,0 +1,14 @@ +package handle + +import ( + "ai_scheduler/internal/entitys" +) + +// HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式 +// 判断文件大小 +// 判断文件类型 +// 判断文件是否合法 +func HandleRecognizeFile(file *entitys.RecognizeFile) { + //Todo 仲云 + return +} diff --git a/internal/biz/handle/handle.go b/internal/biz/handle/handle.go deleted file mode 100644 index 391b6eb..0000000 --- a/internal/biz/handle/handle.go +++ /dev/null @@ -1,22 +0,0 @@ -package handle - -import ( - "ai_scheduler/internal/config" - "ai_scheduler/internal/tools" -) - -type Handle struct { - toolManager *tools.Manager - conf *config.Config -} - -func NewHandle( - toolManager *tools.Manager, - conf *config.Config, - -) *Handle { - return &Handle{ - toolManager: toolManager, - conf: conf, - } -} diff --git a/internal/biz/llm_service/common.go b/internal/biz/llm_service/common.go index 1d62ed7..c3ffe46 100644 --- a/internal/biz/llm_service/common.go +++ b/internal/biz/llm_service/common.go @@ -1,9 +1,7 @@ package llm_service import ( - "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/model" - "ai_scheduler/internal/entitys" "context" "time" ) @@ -20,48 +18,3 @@ func buildSystemPrompt(prompt string) string { return prompt } - -func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { - for _, item := range his { - if len(chatHis.SessionId) == 0 { - chatHis.SessionId = item.SessionID - } - chatHis.Messages = append(chatHis.Messages, []entitys.HisMessage{ - { - Role: constants.RoleUser, - Content: item.Ques, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - { - Role: constants.RoleAssistant, - Content: item.Ans, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - }...) - } - chatHis.Context = entitys.HisContext{ - UserLanguage: "zh-CN", - SystemMode: "technical_support", - } - return -} - -func BuildChatHisMessage(his []model.AiChatHi) (chatHis []entitys.HisMessage) { - for _, item := range his { - - chatHis = append(chatHis, []entitys.HisMessage{ - { - Role: constants.RoleUser, - Content: item.Ques, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - { - Role: constants.RoleAssistant, - Content: item.Ans, - Timestamp: item.CreateAt.Format(time.DateTime), - }, - }...) - } - - return -} diff --git a/internal/biz/llm_service/langchain.go b/internal/biz/llm_service/langchain.go index 63f8815..a3216e8 100644 --- a/internal/biz/llm_service/langchain.go +++ b/internal/biz/llm_service/langchain.go @@ -1,87 +1,76 @@ package llm_service -import ( - "ai_scheduler/internal/data/model" - "ai_scheduler/internal/entitys" - "ai_scheduler/internal/pkg" - "ai_scheduler/internal/pkg/utils_langchain" - "context" - "encoding/json" - - "github.com/tmc/langchaingo/llms" -) - -type LangChainService struct { - client *utils_langchain.UtilLangChain -} - -func NewLangChainGenerate( - client *utils_langchain.UtilLangChain, -) *LangChainService { - - return &LangChainService{ - client: client, - } -} - -func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) { - prompt := r.getPrompt(sysInfo, history, userInput, tasks) - AgentClient := r.client.Get() - defer r.client.Put(AgentClient) - match, err := AgentClient.Llm.GenerateContent( - ctx, // 使用可取消的上下文 - prompt, - llms.WithJSONMode(), - ) - msg = match.Choices[0].Content - return -} - -func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { - var ( - prompt = make([]llms.MessageContent, 0) - ) - prompt = append(prompt, llms.MessageContent{ - Role: llms.ChatMessageTypeSystem, - Parts: []llms.ContentPart{ - llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), - }, - }, llms.MessageContent{ - Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{ - llms.TextPart(reqInput), - }, - }) - return prompt -} - -func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool { - taskPrompt := make([]llms.Tool, 0) - for _, task := range tasks { - var taskConfig entitys.TaskConfig - err := json.Unmarshal([]byte(task.Config), &taskConfig) - if err != nil { - continue - } - taskPrompt = append(taskPrompt, llms.Tool{ - Type: "function", - Function: &llms.FunctionDefinition{ - Name: task.Index, - Description: task.Desc, - Parameters: taskConfig.Param, - }, - }) - - } - return taskPrompt -} +//type LangChainService struct { +// client *utils_langchain.UtilLangChain +//} +// +//func NewLangChainGenerate( +// client *utils_langchain.UtilLangChain, +//) *LangChainService { +// +// return &LangChainService{ +// client: client, +// } +//} +// +//func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) { +// prompt := r.getPrompt(sysInfo, history, userInput, tasks) +// AgentClient := r.client.Get() +// defer r.client.Put(AgentClient) +// match, err := AgentClient.Llm.GenerateContent( +// ctx, // 使用可取消的上下文 +// prompt, +// llms.WithJSONMode(), +// ) +// msg = match.Choices[0].Content +// return +//} +// +//func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { +// var ( +// prompt = make([]llms.MessageContent, 0) +// ) +// prompt = append(prompt, llms.MessageContent{ +// Role: llms.ChatMessageTypeSystem, +// Parts: []llms.ContentPart{ +// llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)), +// }, +// }, llms.MessageContent{ +// Role: llms.ChatMessageTypeTool, +// Parts: []llms.ContentPart{ +// llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))), +// }, +// }, llms.MessageContent{ +// Role: llms.ChatMessageTypeTool, +// Parts: []llms.ContentPart{ +// llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), +// }, +// }, llms.MessageContent{ +// Role: llms.ChatMessageTypeHuman, +// Parts: []llms.ContentPart{ +// llms.TextPart(reqInput), +// }, +// }) +// return prompt +//} +// +//func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool { +// taskPrompt := make([]llms.Tool, 0) +// for _, task := range tasks { +// var taskConfig entitys.TaskConfig +// err := json.Unmarshal([]byte(task.Config), &taskConfig) +// if err != nil { +// continue +// } +// taskPrompt = append(taskPrompt, llms.Tool{ +// Type: "function", +// Function: &llms.FunctionDefinition{ +// Name: task.Index, +// Description: task.Desc, +// Parameters: taskConfig.Param, +// }, +// }) +// +// } +// return taskPrompt +//} diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 6527a3b..238e002 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -35,14 +35,14 @@ func NewOllamaGenerate( } } -func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) { - prompt, err := r.getPrompt(ctx, requireData) +func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) { + if err != nil { return } - toolDefinitions := r.registerToolsOllama(requireData.Tasks) + toolDefinitions := r.registerToolsOllama(req.Tools) - match, err := r.client.ToolSelect(ctx, prompt, toolDefinitions) + match, err := r.client.ToolSelect(ctx, rec.Prompt, toolDefinitions) if err != nil { return } @@ -64,87 +64,28 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity msg = match.Message.Content return + } -func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) { - - var ( - prompt = make([]api.Message, 0) - ) - content, err := r.getUserContent(ctx, requireData) - if err != nil { - return nil, err - } - prompt = append(prompt, api.Message{ - Role: "system", - Content: buildSystemPrompt(requireData.Sys.SysPrompt), - }, api.Message{ - Role: "assistant", - Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)), - }, api.Message{ - Role: "user", - Content: content, - }) - - return prompt, nil -} - -func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) { - var content strings.Builder - content.WriteString(requireData.Req.Text) - if len(requireData.ImgByte) > 0 { - content.WriteString("\n") - } - - if len(requireData.Req.Tags) > 0 { - content.WriteString("\n") - content.WriteString("### 工具必须使用:") - content.WriteString(requireData.Req.Tags) - } - - if len(requireData.ImgByte) > 0 { - desc, err := r.RecognizeWithImg(ctx, requireData) - if err != nil { - return "", err - } - content.WriteString("### 上传图片解析内容:\n") - content.WriteString(requireData.Req.Tags) - content.WriteString(desc.Response) - } - - if requireData.Req.MarkHis > 0 { - var his model.AiChatHi - cond := builder.NewCond() - cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis}) - err := r.chatHis.GetOneBySearchToStrut(&cond, &his) - if err != nil { - return "", err - } - content.WriteString("### 引用历史聊天记录:\n") - content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his}))) - } - return content.String(), nil -} - -func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { - if requireData.ImgByte == nil { +func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) { + if imgByte == nil { return } - entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") + entitys.ResLog(ch, "recognize_img_start", "图片识别中...") desc, err = r.client.Generation(ctx, &api.GenerateRequest{ Model: r.config.Ollama.VlModel, Stream: new(bool), System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, - Images: requireData.ImgByte, + Images: imgByte, KeepAlive: &api.Duration{Duration: 3600 * time.Second}, //Think: &api.ThinkValue{Value: false}, }) if err != nil { return } - entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) + entitys.ResLog(ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) return } diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go index cc3de0a..aefa3ce 100644 --- a/internal/biz/provider_set.go +++ b/internal/biz/provider_set.go @@ -2,7 +2,6 @@ package biz import ( "ai_scheduler/internal/biz/do" - "ai_scheduler/internal/biz/handle" "ai_scheduler/internal/biz/llm_service" "github.com/google/wire" @@ -12,10 +11,11 @@ var ProviderSetBiz = wire.NewSet( NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz, - llm_service.NewLangChainGenerate, + //llm_service.NewLangChainGenerate, llm_service.NewOllamaGenerate, - handle.NewHandle, + //handle.NewHandle, do.NewDo, do.NewHandle, NewTaskBiz, + NewDingTalkBotBiz, ) diff --git a/internal/biz/router.go b/internal/biz/router.go index 6dcc233..335ddcc 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -3,6 +3,7 @@ package biz import ( "ai_scheduler/internal/biz/do" "ai_scheduler/internal/gateway" + "context" "ai_scheduler/internal/entitys" @@ -56,9 +57,15 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS log.Errorf("数据验证和收集失败: %s", err.Error()) return } - + //组装意图识别 + rec, err := r.SetRec(ctx, requireData) + if err != nil { + log.Errorf("组装意图识别失败: %s", err.Error()) + return + } //意图识别 - if err = r.handle.Recognize(ctx, requireData); err != nil { + requireData.Match, err = r.handle.Recognize(ctx, &rec, &do.WithSys{}) + if err != nil { log.Errorf("意图识别失败: %s", err.Error()) return } @@ -70,3 +77,8 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS } return } + +func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, err error) { + //TODO 叙平 + return +} diff --git a/internal/config/config.go b/internal/config/config.go index da25c39..815dca0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,16 +9,21 @@ import ( // Config 应用配置 type Config struct { - Server ServerConfig `mapstructure:"server"` - Ollama OllamaConfig `mapstructure:"ollama"` - Sys SysConfig `mapstructure:"sys"` - Tools ToolsConfig `mapstructure:"tools"` - Logging LoggingConfig `mapstructure:"logging"` - Redis Redis `mapstructure:"redis"` - DB DB `mapstructure:"db"` - DefaultPrompt SysPrompt `mapstructure:"default_prompt"` - PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` - // LLM *LLM `mapstructure:"llm"` + Server ServerConfig `mapstructure:"server"` + Ollama OllamaConfig `mapstructure:"ollama"` + Sys SysConfig `mapstructure:"sys"` + Tools ToolsConfig `mapstructure:"tools"` + Logging LoggingConfig `mapstructure:"logging"` + Redis Redis `mapstructure:"redis"` + DB DB `mapstructure:"db"` + DefaultPrompt SysPrompt `mapstructure:"default_prompt"` + PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` + DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"` +} + +type DingTalkBot struct { + ClientId string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` } type SysPrompt struct { diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index 6dc6680..bfe004c 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -3,5 +3,25 @@ package constants type BotTools string const ( - BotToolsBugOptimizationSubmit = "bug_optimization_submit" // 系统的bug/优化建议 + BotToolsBugOptimizationSubmit BotTools = "bug_optimization_submit" // 系统的bug/优化建议 ) + +type ChatStyle int + +const ( + ChatStyleNormal ChatStyle = 1 //正常 + ChatStyleSerious ChatStyle = 2 //严肃 + ChatStyleGentle ChatStyle = 3 //温柔 + ChatStyleArrogance ChatStyle = 4 //傲慢 + ChatStyleCute ChatStyle = 5 //可爱 + ChatStyleAngry ChatStyle = 6 //愤怒 +) + +var ChatStyleMap = map[ChatStyle]string{ + ChatStyleNormal: "正常", + ChatStyleSerious: "严肃", + ChatStyleGentle: "温柔", + ChatStyleArrogance: "傲慢", + ChatStyleCute: "可爱", + ChatStyleAngry: "愤怒", +} diff --git a/internal/data/constants/file.go b/internal/data/constants/file.go new file mode 100644 index 0000000..a418959 --- /dev/null +++ b/internal/data/constants/file.go @@ -0,0 +1,75 @@ +package constants + +import ( + "github.com/gabriel-vasile/mimetype" + "io" + "strings" +) + +type FileType string + +const ( + FileTypeUnknown FileType = "unknown" + FileTypeImage FileType = "image" + //FileTypeVideo FileType = "video" + FileTypeExcel FileType = "excel" + FileTypeWord FileType = "word" + FileTypeTxt FileType = "txt" + FileTypePDF FileType = "pdf" +) + +var FileTypeMappings = map[FileType][]string{ + FileTypeImage: { + "image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml", + ".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg", + }, + FileTypeExcel: { + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xls", ".xlsx", + }, + FileTypeWord: { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".doc", ".docx", + }, + FileTypePDF: { + "application/pdf", + ".pdf", + }, + FileTypeTxt: { + "text/plain", + ".txt", + }, +} + +func DetectFileType(file io.ReadSeeker) (FileType, error) { + // 读取文件头(512字节足够检测大多数类型) + buffer := make([]byte, 512) + n, err := file.Read(buffer) + if err != nil && err != io.EOF { + return FileTypeUnknown, err + } + + // 重置读取位置 + if _, err := file.Seek(0, io.SeekStart); err != nil { + return FileTypeUnknown, err + } + + // 检测 MIME 类型 + detectedMIME := mimetype.Detect(buffer[:n]).String() + + // 遍历映射表,匹配 MIME 或扩展名 + for fileType, mimesOrExts := range FileTypeMappings { + for _, item := range mimesOrExts { + if strings.HasPrefix(item, ".") { + continue // 跳过扩展名(仅 MIME 检测) + } + if item == detectedMIME { + return fileType, nil + } + } + } + + return FileTypeUnknown, nil +} diff --git a/internal/data/impl/bot_chat_history.go b/internal/data/impl/bot_chat_history.go new file mode 100644 index 0000000..3da5184 --- /dev/null +++ b/internal/data/impl/bot_chat_history.go @@ -0,0 +1,15 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotChatHisImpl struct { + dataTemp.DataTemp +} + +func NewBotChatHisImpl(db *utils.Db) *TaskImpl { + return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))} +} diff --git a/internal/data/model/ai_bot_chat_his.gen.go b/internal/data/model/ai_bot_chat_his.gen.go new file mode 100644 index 0000000..1e4bfff --- /dev/null +++ b/internal/data/model/ai_bot_chat_his.gen.go @@ -0,0 +1,27 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotChatHi = "ai_bot_chat_his" + +// AiBotChatHi mapped from table +type AiBotChatHi struct { + HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"` + SessionID string `gorm:"column:session_id;not null" json:"session_id"` + Role string `gorm:"column:role;not null;comment:system系统输出,assistant助手输出,user用户输入" json:"role"` // system系统输出,assistant助手输出,user用户输入 + Content string `gorm:"column:content;not null" json:"content"` + Files string `gorm:"column:files;not null" json:"files"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` +} + +// TableName AiBotChatHi's table name +func (*AiBotChatHi) TableName() string { + return TableNameAiBotChatHi +} diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 226fbc8..0a99113 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -1,7 +1,24 @@ package entitys -type BotType int +import ( + "ai_scheduler/internal/data/model" -const ( - BugAndQuesDingTalk BotType = iota + 1 + "github.com/ollama/ollama/api" + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" ) + +type RequireDataDingTalkBot struct { + Session string + Key string + Sys model.AiSy + Histories []model.AiChatHi + SessionInfo model.AiSession + Tasks []model.AiTask + Match *Match + Req *chatbot.BotCallbackDataModel + Auth string + Ch chan Response + KnowledgeConf KnowledgeBaseRequest + ImgByte []api.ImageData + ImgUrls []string +} diff --git a/internal/entitys/ollama.go b/internal/entitys/ollama.go new file mode 100644 index 0000000..7297dab --- /dev/null +++ b/internal/entitys/ollama.go @@ -0,0 +1,10 @@ +package entitys + +import ( + "github.com/ollama/ollama/api" +) + +type ToolSelect struct { + Prompt []api.Message + Tools []RegistrationTask +} diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go new file mode 100644 index 0000000..ae32b89 --- /dev/null +++ b/internal/entitys/recognize.go @@ -0,0 +1,31 @@ +package entitys + +import ( + "ai_scheduler/internal/data/constants" +) + +type Recognize struct { + SystemPrompt string + UserContent *RecognizeUserContent + ChatHis ChatHis + Tasks []RegistrationTask + Ch chan Response +} + +type RegistrationTask struct { +} + +type RecognizeUserContent struct { + Text string + File *RecognizeFile + ActionCardUrl string + Tag string +} + +type FileData []byte + +type RecognizeFile struct { + File []FileData // 文件数据(二进制格式) + FileUrl string // 文件下载链接 + FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) +} diff --git a/internal/entitys/response.go b/internal/entitys/response.go index cdadc98..d99127f 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -2,22 +2,25 @@ package entitys import ( "encoding/json" + "github.com/gofiber/websocket/v2" ) type ResponseType string const ( - ResponseJson ResponseType = "json" - ResponseLoading ResponseType = "loading" - ResponseEnd ResponseType = "end" - ResponseStream ResponseType = "stream" - ResponseText ResponseType = "txt" - ResponseImg ResponseType = "img" - ResponseFile ResponseType = "file" - ResponseErr ResponseType = "error" - ResponseLog ResponseType = "log" - ResponseAuth ResponseType = "auth" + ResponseJson ResponseType = "json" + ResponseLoading ResponseType = "loading" + ResponseEnd ResponseType = "end" + ResponseStream ResponseType = "stream" + ResponseText ResponseType = "txt" + ResponseImg ResponseType = "img" + ResponseFile ResponseType = "file" + ResponseErr ResponseType = "error" + ResponseLog ResponseType = "log" + ResponseAuth ResponseType = "auth" + ResponseMarkdown ResponseType = "markdown" + ResponseActionCard ResponseType = "actionCard" ) func ResLog(ch chan Response, index string, content string) { @@ -45,6 +48,9 @@ func ResJson(ch chan Response, index string, content string) { } func ResEnd(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -53,6 +59,9 @@ func ResEnd(ch chan Response, index string, content string) { } func ResText(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -61,6 +70,9 @@ func ResText(ch chan Response, index string, content string) { } func ResLoading(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, @@ -68,6 +80,9 @@ func ResLoading(ch chan Response, index string, content string) { } } func ResError(ch chan Response, index string, content string) { + if ch == nil { + return + } ch <- Response{ Index: index, Content: content, diff --git a/internal/entitys/types.go b/internal/entitys/types.go index ece5ceb..cdc7986 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -3,7 +3,6 @@ package entitys import ( "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/model" - "context" "encoding/json" @@ -150,7 +149,7 @@ type RequireData struct { Histories []model.AiChatHi SessionInfo model.AiSession Tasks []model.AiTask - Match *Match + Match Match Req *ChatSockRequest Auth string Ch chan Response diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go new file mode 100644 index 0000000..cda0e1c --- /dev/null +++ b/internal/server/ding_talk_bot.go @@ -0,0 +1,65 @@ +package server + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/services" + "context" + + "github.com/go-kratos/kratos/v2/log" + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" +) + +type DingBotServiceInterface interface { + GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) + OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) +} + +type DingTalkBotServer struct { + Clients []*client.StreamClient +} + +func NewDingTalkBotServer( + cfg *config.Config, + services []DingBotServiceInterface, +) *DingTalkBotServer { + clients := make([]*client.StreamClient, 0) + for _, service := range services { + serviceConf, index := service.GetServiceCfg(cfg.DingTalkBots) + if serviceConf == nil { + log.Info("未找到%s配置", index) + continue + } + cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service) + if cli == nil { + log.Info("%s客户端初始失败", index) + continue + } + clients = append(clients, cli) + } + return &DingTalkBotServer{ + Clients: clients, + } +} + +func ProvideAllDingBotServices( + dingBotSvc *services.DingBotService, +) []DingBotServiceInterface { + return []DingBotServiceInterface{dingBotSvc} +} + +func (d *DingTalkBotServer) Run(ctx context.Context) { + for name, cli := range d.Clients { + err := cli.Start(ctx) + if err != nil { + log.Info("%s启动失败", name) + continue + } + log.Info("%s启动成功", name) + } +} +func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) { + cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret))) + cli.RegisterChatBotCallbackRouter(service.OnChatBotMessageReceived) + return +} diff --git a/internal/server/provider_set.go b/internal/server/provider_set.go index dd6b4b4..d5cef3d 100644 --- a/internal/server/provider_set.go +++ b/internal/server/provider_set.go @@ -1,5 +1,12 @@ package server -import "github.com/google/wire" +import ( + "github.com/google/wire" +) -var ProviderSetServer = wire.NewSet(NewServers, NewHTTPServer) +var ProviderSetServer = wire.NewSet( + NewServers, + NewHTTPServer, + ProvideAllDingBotServices, + NewDingTalkBotServer, +) diff --git a/internal/server/server.go b/internal/server/server.go index 488ef38..02c8f84 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,27 @@ package server -import "github.com/gofiber/fiber/v2" +import ( + "ai_scheduler/internal/config" + + "github.com/gofiber/fiber/v2" +) type Servers struct { - HttpServer *fiber.App + cfg *config.Config + HttpServer *fiber.App + DingBotServer *DingTalkBotServer } -func NewServers(fiber *fiber.App) *Servers { +func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer) *Servers { return &Servers{ - HttpServer: fiber, + HttpServer: fiber, + cfg: cfg, + DingBotServer: DingBotServer, } } + +//func DingBotServerInit(clientId string, clientSecret string, cfg *config.Config, handler *do.Handle, do *do.Do) (cli *client.StreamClient) { +// cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret))) +// cli.RegisterChatBotCallbackRouter(services.NewDingBotService(cfg, handler, do).OnChatBotMessageReceived) +// return +//} diff --git a/internal/services/callback.go b/internal/services/callback.go index 415f7cb..b0697c0 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -212,7 +212,7 @@ func (s *CallbackService) sendStreamLog(sessionID string, content string) { } streamLog := entitys.Response{ - Index: constants.BotToolsBugOptimizationSubmit, + Index: string(constants.BotToolsBugOptimizationSubmit), Content: content, Type: entitys.ResponseLog, } @@ -227,7 +227,7 @@ func (s *CallbackService) sendStreamTxt(sessionID string, content string) { } streamLog := entitys.Response{ - Index: constants.BotToolsBugOptimizationSubmit, + Index: string(constants.BotToolsBugOptimizationSubmit), Content: content, Type: entitys.ResponseText, } @@ -242,7 +242,7 @@ func (s *CallbackService) sendStreamLoading(sessionID string, content string) { } streamLog := entitys.Response{ - Index: constants.BotToolsBugOptimizationSubmit, + Index: string(constants.BotToolsBugOptimizationSubmit), Content: content, Type: entitys.ResponseLoading, } diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go new file mode 100644 index 0000000..2855471 --- /dev/null +++ b/internal/services/dtalk_bot.go @@ -0,0 +1,135 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "context" + "fmt" + + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" +) + +type DingBotService struct { + config *config.Config + replier *chatbot.ChatbotReplier + env string + DingTalkBotBiz *biz.DingTalkBotBiz +} + +func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { + return &DingBotService{config: config, replier: chatbot.NewChatbotReplier(), env: "public", DingTalkBotBiz: DingTalkBotBiz} +} + +func (d *DingBotService) GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) { + return cfg[d.env], d.env +} + +func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) { + requireData, err := d.DingTalkBotBiz.InitRequire(ctx, data) + if err != nil { + return + } + go func() { + defer close(requireData.Ch) + + if recognizeErr := d.DingTalkBotBiz.Recognize(ctx, requireData); recognizeErr != nil { + requireData.Ch <- entitys.Response{ + Type: entitys.ResponseEnd, + Content: fmt.Sprintf("处理消息时出错: %v", recognizeErr), + } + } + //向下传递 + if err = d.handle.HandleMatch(ctx, nil, requireData); err != nil { + requireData.Ch <- entitys.Response{ + Type: entitys.ResponseEnd, + Content: fmt.Sprintf("匹配失败: %v", err), + } + } + }() + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case resp, ok := <-requireData.Ch: + if !ok { + return []byte("success"), nil // 通道关闭,处理完成 + } + + if resp.Type == entitys.ResponseLog { + return + } + if err := d.handleRes(ctx, data, resp); err != nil { + return nil, fmt.Errorf("回复失败: %w", err) + } + } + } + return +} + +func (d *DingBotService) handleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { + switch resp.Type { + case entitys.ResponseText: + return d.replyText(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseStream: + return d.replySteam(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseImg: + return d.replyImg(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseFile: + return d.replyFile(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseMarkdown: + return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseActionCard: + return d.replyActionCard(ctx, data.SessionWebhook, resp.Content) + default: + return nil + } +} + +func (d *DingBotService) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingBotService) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 0e1284b..9610ed6 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -6,4 +6,11 @@ import ( "github.com/google/wire" ) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService) +var ProviderSetServices = wire.NewSet( + NewChatService, + NewSessionService, + gateway.NewGateway, + NewTaskService, + NewCallbackService, + NewDingBotService, +) diff --git a/internal/tools_bot/dtalk_bot.go b/internal/tools_bot/dtalk_bot.go index 0e9525a..2eae1d5 100644 --- a/internal/tools_bot/dtalk_bot.go +++ b/internal/tools_bot/dtalk_bot.go @@ -2,22 +2,18 @@ package tools_bot import ( "ai_scheduler/internal/config" - "ai_scheduler/internal/data/constants" - errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" - "context" - "github.com/gofiber/fiber/v2/log" + "context" ) type BotTool struct { config *config.Config llm *utils_ollama.Client sessionImpl *impl.SessionImpl - taskMap map[string]string // task_id -> session_id - // zltxOrderAfterSaleTool tools.ZltxOrderAfterSaleTool + taskMap map[string]string } // NewBotTool 创建直连天下订单详情工具 @@ -27,12 +23,5 @@ func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *im // Execute 执行直连天下订单详情查询 func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) { - switch toolName { - case constants.BotToolsBugOptimizationSubmit: - err = w.BugOptimizationSubmit(ctx, requireData) - default: - log.Errorf("未知的工具类型:%s", toolName) - err = errors.ParamErr("未知的工具类型:%s", toolName) - } return } From f4980f1f22277baa18e2307bf4e94bd22b1d8d2b Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 9 Dec 2025 16:45:38 +0800 Subject: [PATCH 3/7] feat: add DingTalk bot integration --- internal/biz/handle/file.go | 70 +++++++++++++++++++++++++++++++++ internal/data/constants/file.go | 37 ----------------- 2 files changed, 70 insertions(+), 37 deletions(-) diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go index 3c9d79b..9fd4fae 100644 --- a/internal/biz/handle/file.go +++ b/internal/biz/handle/file.go @@ -1,7 +1,17 @@ package handle import ( + "ai_scheduler/internal/data/constants" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" + "errors" + "fmt" + "io" + "net/http" + "path/filepath" + "strings" + + "github.com/gabriel-vasile/mimetype" ) // HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式 @@ -12,3 +22,63 @@ func HandleRecognizeFile(file *entitys.RecognizeFile) { //Todo 仲云 return } + +// 下载文件并返回二进制数据、MIME 类型 +func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err error) { + if len(fileUrl) == 0 { + return + } + req := l_request.Request{ + Method: "GET", + Url: fileUrl, + Headers: map[string]string{ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + "Accept": "image/webp,image/apng,image/*,*/*;q=0.8", + }, + } + res, err := req.Send() + if err != nil { + return + } + var ex bool + if contentType, ex = res.Headers["Content-Type"]; !ex { + err = errors.New("Content-Type不存在") + return + } + + if res.StatusCode != http.StatusOK { + err = fmt.Errorf("server returned non-200 status: %d", res.StatusCode) + } + fileBytes = res.Content + + return fileBytes, contentType, nil +} + +// detectFileType 判断文件类型 +func detectFileType(file io.ReadSeeker, filename string) constants.FileType { + // 1. 读取文件头检测 MIME + buffer := make([]byte, 512) + n, _ := file.Read(buffer) + file.Seek(0, io.SeekStart) // 重置读取位置 + + detectedMIME := mimetype.Detect(buffer[:n]).String() + for fileType, items := range constants.FileTypeMappings { + for _, item := range items { + if !strings.HasPrefix(item, ".") && item == detectedMIME { + return fileType + } + } + } + + // 2. 备用:通过扩展名检测 + ext := strings.ToLower(filepath.Ext(filename)) + for fileType, items := range constants.FileTypeMappings { + for _, item := range items { + if strings.HasPrefix(item, ".") && item == ext { + return fileType + } + } + } + + return constants.FileTypeUnknown +} diff --git a/internal/data/constants/file.go b/internal/data/constants/file.go index a418959..132ad0e 100644 --- a/internal/data/constants/file.go +++ b/internal/data/constants/file.go @@ -1,11 +1,5 @@ package constants -import ( - "github.com/gabriel-vasile/mimetype" - "io" - "strings" -) - type FileType string const ( @@ -42,34 +36,3 @@ var FileTypeMappings = map[FileType][]string{ ".txt", }, } - -func DetectFileType(file io.ReadSeeker) (FileType, error) { - // 读取文件头(512字节足够检测大多数类型) - buffer := make([]byte, 512) - n, err := file.Read(buffer) - if err != nil && err != io.EOF { - return FileTypeUnknown, err - } - - // 重置读取位置 - if _, err := file.Seek(0, io.SeekStart); err != nil { - return FileTypeUnknown, err - } - - // 检测 MIME 类型 - detectedMIME := mimetype.Detect(buffer[:n]).String() - - // 遍历映射表,匹配 MIME 或扩展名 - for fileType, mimesOrExts := range FileTypeMappings { - for _, item := range mimesOrExts { - if strings.HasPrefix(item, ".") { - continue // 跳过扩展名(仅 MIME 检测) - } - if item == detectedMIME { - return fileType, nil - } - } - } - - return FileTypeUnknown, nil -} From 586ad6a124417ade86ba02608b41f1a4b49a8c31 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Tue, 9 Dec 2025 17:24:41 +0800 Subject: [PATCH 4/7] refactor: optimize ollama service and update dependencies --- go.mod | 1 - go.sum | 2 - internal/biz/llm_service/ollama.go | 177 ++++++++--------------------- internal/entitys/recognize.go | 3 + 4 files changed, 53 insertions(+), 130 deletions(-) diff --git a/go.mod b/go.mod index 48d0741..073c4eb 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,6 @@ require ( github.com/evanphx/json-patch v0.5.2 // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/goph/emperror v0.17.2 // indirect github.com/gorilla/websocket v1.5.0 // indirect diff --git a/go.sum b/go.sum index bbb5c7a..c1b06f6 100644 --- a/go.sum +++ b/go.sum @@ -174,8 +174,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= -github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= -github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik= github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 8c127cd..7333de5 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -3,19 +3,15 @@ package llm_service import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" - "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_vllm" "context" - "encoding/json" + "errors" - "strings" - "time" "github.com/ollama/ollama/api" - "xorm.io/builder" ) type OllamaService struct { @@ -41,12 +37,8 @@ func NewOllamaGenerate( func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) { - if err != nil { - return - } toolDefinitions := r.registerToolsOllama(req.Tools) - - match, err := r.client.ToolSelect(ctx, rec.Prompt, toolDefinitions) + match, err := r.client.ToolSelect(ctx, req.Prompt, toolDefinitions) if err != nil { return } @@ -70,132 +62,63 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSe return } -func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) { +//func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) { +// if imgByte == nil { +// return +// } +// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") +// +// desc, err = r.client.Generation(ctx, &api.GenerateRequest{ +// Model: r.config.Ollama.VlModel, +// Stream: new(bool), +// System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, +// Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, +// Images: requireData.ImgByte, +// KeepAlive: &api.Duration{Duration: 3600 * time.Second}, +// //Think: &api.ThinkValue{Value: false}, +// }) +// if err != nil { +// return +// } +// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) +// return +//} - var ( - prompt = make([]api.Message, 0) - ) - content, err := r.getUserContent(ctx, requireData) - if err != nil { - return nil, err - } - prompt = append(prompt, api.Message{ - Role: "system", - Content: buildSystemPrompt(requireData.Sys.SysPrompt), - }, api.Message{ - Role: "assistant", - Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)), - }, api.Message{ - Role: "user", - Content: content, - }) +//func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { +// if requireData.ImgByte == nil { +// return +// } +// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") +// +// outMsg, err := r.vllmClient.RecognizeWithImg(ctx, +// r.config.DefaultPrompt.ImgRecognize.SystemPrompt, +// r.config.DefaultPrompt.ImgRecognize.UserPrompt, +// requireData.ImgUrls, +// ) +// if err != nil { +// return api.GenerateResponse{}, err +// } +// +// desc = api.GenerateResponse{ +// Response: outMsg.Content, +// } +// +// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) +// return +//} - return prompt, nil -} - -func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) { - var content strings.Builder - content.WriteString(requireData.Req.Text) - if len(requireData.ImgByte) > 0 { - content.WriteString("\n") - } - - if len(requireData.Req.Tags) > 0 { - content.WriteString("\n") - content.WriteString("### 工具必须使用:") - content.WriteString(requireData.Req.Tags) - } - - if len(requireData.ImgByte) > 0 { - // desc, err := r.RecognizeWithImg(ctx, requireData) - desc, err := r.RecognizeWithImgVllm(ctx, requireData) - if err != nil { - return "", err - } - content.WriteString("### 上传图片解析内容:\n") - content.WriteString(requireData.Req.Tags) - content.WriteString(desc.Response) - } - - if requireData.Req.MarkHis > 0 { - var his model.AiChatHi - cond := builder.NewCond() - cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis}) - err := r.chatHis.GetOneBySearchToStrut(&cond, &his) - if err != nil { - return "", err - } - content.WriteString("### 引用历史聊天记录:\n") - content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his}))) - } - return content.String(), nil -} - -func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { - if requireData.ImgByte == nil { -func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) { - if imgByte == nil { - return - } - entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") - - desc, err = r.client.Generation(ctx, &api.GenerateRequest{ - Model: r.config.Ollama.VlModel, - Stream: new(bool), - System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt, - Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, - Images: requireData.ImgByte, - KeepAlive: &api.Duration{Duration: 3600 * time.Second}, - //Think: &api.ThinkValue{Value: false}, - }) - if err != nil { - return - } - entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) - return -} - -func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { - if requireData.ImgByte == nil { - return - } - entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...") - - outMsg, err := r.vllmClient.RecognizeWithImg(ctx, - r.config.DefaultPrompt.ImgRecognize.SystemPrompt, - r.config.DefaultPrompt.ImgRecognize.UserPrompt, - requireData.ImgUrls, - ) - if err != nil { - return api.GenerateResponse{}, err - } - - desc = api.GenerateResponse{ - Response: outMsg.Content, - } - - entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response) - return -} - -func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { +func (r *OllamaService) registerToolsOllama(tasks []entitys.RegistrationTask) []api.Tool { taskPrompt := make([]api.Tool, 0) for _, task := range tasks { - var taskConfig entitys.TaskConfigDetail - err := json.Unmarshal([]byte(task.Config), &taskConfig) - if err != nil { - continue - } - taskPrompt = append(taskPrompt, api.Tool{ Type: "function", Function: api.ToolFunction{ - Name: task.Index, + Name: task.Name, Description: task.Desc, Parameters: api.ToolFunctionParameters{ - Type: taskConfig.Param.Type, - Required: taskConfig.Param.Required, - Properties: taskConfig.Param.Properties, + Type: task.TaskConfigDetail.Param.Type, + Required: task.TaskConfigDetail.Param.Required, + Properties: task.TaskConfigDetail.Param.Properties, }, }, }) diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index ae32b89..6d692eb 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -13,6 +13,9 @@ type Recognize struct { } type RegistrationTask struct { + Name string + Desc string + TaskConfigDetail TaskConfigDetail } type RecognizeUserContent struct { From e6c142f3a125f5619de592486f8026c2a0f39bff Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Wed, 10 Dec 2025 11:29:54 +0800 Subject: [PATCH 5/7] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 8 +- internal/biz/ding_talk_bot.go | 112 +++++++++++++++++- internal/biz/do/prompt.go | 63 +++++++++- internal/config/config.go | 45 ++++--- internal/data/constants/bot.go | 6 + internal/data/impl/base.go | 2 +- internal/data/impl/bot_chat_history.go | 4 +- internal/data/impl/bot_config.go | 17 +++ internal/data/impl/bot_impl.go | 28 ----- internal/data/impl/provider_set.go | 8 +- .../{ai_bot.gen.go => ai_bot_config.gen.go} | 14 +-- internal/entitys/bot.go | 5 + internal/server/ding_talk_bot.go | 31 +++-- internal/services/dtalk_bot.go | 112 ++++-------------- internal/services/provider_set.go | 8 +- 15 files changed, 283 insertions(+), 180 deletions(-) create mode 100644 internal/data/impl/bot_config.go delete mode 100644 internal/data/impl/bot_impl.go rename internal/data/model/{ai_bot.gen.go => ai_bot_config.gen.go} (72%) diff --git a/config/config_test.yaml b/config/config_test.yaml index baedada..4ebccb8 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -88,7 +88,7 @@ permissionConfig: permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId=" -ding_talk_bots: - public: - client_id: "dingchg59zwwvmuuvldx", - client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz", \ No newline at end of file +#ding_talk_bots: +# public: +# client_id: "dingchg59zwwvmuuvldx", +# client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz", \ No newline at end of file diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index 09de5d4..f2abea1 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -2,29 +2,61 @@ package biz import ( "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/mapstructure" "context" + "fmt" + "github.com/gofiber/fiber/v2/log" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "xorm.io/builder" ) // AiRouterBiz 智能路由服务 type DingTalkBotBiz struct { - do *do.Do - handle *do.Handle + do *do.Do + handle *do.Handle + botConfigImpl *impl.BotConfigImpl + replier *chatbot.ChatbotReplier + log log.Logger } // NewDingTalkBotBiz func NewDingTalkBotBiz( do *do.Do, handle *do.Handle, + botConfigImpl *impl.BotConfigImpl, ) *DingTalkBotBiz { return &DingTalkBotBiz{ - do: do, - handle: handle, + do: do, + handle: handle, + botConfigImpl: botConfigImpl, + replier: chatbot.NewChatbotReplier(), } } +func (d *DingTalkBotBiz) GetDingTalkBotCfgList() (dingBotList []entitys.DingTalkBot, err error) { + botConfig := make([]model.AiBotConfig, 0) + + cond := builder.NewCond() + cond = cond.And(builder.Eq{"status": constants.Enable}) + cond = cond.And(builder.Eq{"bot_type": constants.BotTypeDingTalk}) + err = d.botConfigImpl.GetRangeToMapStruct(&cond, &botConfig) + for _, v := range botConfig { + var config entitys.DingTalkBot + err = mapstructure.Decode(v, &config) + if err != nil { + d.log.Info("初始化“%s”失败:%s", v.BotName, err.Error()) + } + dingBotList = append(dingBotList, config) + } + return + +} + func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallbackDataModel) (requireData *entitys.RequireDataDingTalkBot, err error) { requireData = &entitys.RequireDataDingTalkBot{ Req: data, @@ -37,6 +69,74 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb return } -func (d *DingTalkBotBiz) Recognize(ctx context.Context, rec *entitys.Recognize, ch chan entitys.Response) (match *entitys.Match, err error) { - return d.handle.RecognizeBot(ctx, rec, ch) +func (d *DingTalkBotBiz) Recognize(ctx context.Context, bot *chatbot.BotCallbackDataModel) (match entitys.Match, err error) { + + return d.handle.Recognize(ctx, nil, &do.WithDingTalkBot{}) +} + +func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { + switch resp.Type { + case entitys.ResponseText: + return d.replyText(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseStream: + return d.replySteam(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseImg: + return d.replyImg(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseFile: + return d.replyFile(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseMarkdown: + return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content) + case entitys.ResponseActionCard: + return d.replyActionCard(ctx, data.SessionWebhook, resp.Content) + default: + return nil + } +} + +func (d *DingTalkBotBiz) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) +} + +func (d *DingTalkBotBiz) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error { + msg := content + if len(arg) > 0 { + msg = fmt.Sprintf(content, arg) + } + return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) } diff --git a/internal/biz/do/prompt.go b/internal/biz/do/prompt.go index cb1ca18..f25cf95 100644 --- a/internal/biz/do/prompt.go +++ b/internal/biz/do/prompt.go @@ -1,6 +1,7 @@ package do import ( + "ai_scheduler/internal/biz/handle" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "context" @@ -42,7 +43,7 @@ func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { var hasFile bool - if len(rec.UserContent.FileUrl) > 0 || rec.UserContent.File != nil { + if rec.UserContent.File != nil && (len(rec.UserContent.File.FileUrl) > 0 || rec.UserContent.File.File != nil) { hasFile = true } content.WriteString(rec.UserContent.Text) @@ -64,7 +65,65 @@ func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (c if hasFile { content.WriteString("\n") content.WriteString("### 文件内容:\n") - hand.WriteString(rec.UserContent.FileUrl, rec.UserContent.FileUrl) + handle.HandleRecognizeFile(rec.UserContent.File) + //...do something with file + } + return +} + +type WithDingTalkBot struct { +} + +func (f *WithDingTalkBot) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) { + var ( + prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片 + ) + // 获取用户内容,如果出错则直接返回错误 + content, err := f.getUserContent(ctx, rec) + if err != nil { + return nil, err + } + // 构建提示消息列表,包含系统提示、助手回复和用户内容 + mes = append(prompt, api.Message{ + Role: "system", // 系统角色 + Content: rec.SystemPrompt, // 系统提示内容 + }, api.Message{ + Role: "assistant", // 助手角色 + Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容 + }, api.Message{ + Role: "user", // 用户角色 + Content: content.String(), // 用户输入内容 + }) + + return +} + +func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { + var hasFile bool + if rec.UserContent.File != nil && (len(rec.UserContent.File.FileUrl) > 0 || rec.UserContent.File.File != nil) { + hasFile = true + } + content.WriteString(rec.UserContent.Text) + if hasFile { + content.WriteString("\n") + } + + if len(rec.UserContent.Tag) > 0 { + content.WriteString("\n") + content.WriteString("### 工具必须使用:") + content.WriteString(rec.UserContent.Tag) + } + + if len(rec.ChatHis.Messages) > 0 { + content.WriteString("### 引用历史聊天记录:\n") + content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis)) + } + + if hasFile { + content.WriteString("\n") + content.WriteString("### 文件内容:\n") + handle.HandleRecognizeFile(rec.UserContent.File) + //...do something with file } return } diff --git a/internal/config/config.go b/internal/config/config.go index 85b8be8..160275f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,22 +9,17 @@ import ( // Config 应用配置 type Config struct { - Server ServerConfig `mapstructure:"server"` - Ollama OllamaConfig `mapstructure:"ollama"` - Vllm VllmConfig `mapstructure:"vllm"` - Sys SysConfig `mapstructure:"sys"` - Tools ToolsConfig `mapstructure:"tools"` - Logging LoggingConfig `mapstructure:"logging"` - Redis Redis `mapstructure:"redis"` - DB DB `mapstructure:"db"` - DefaultPrompt SysPrompt `mapstructure:"default_prompt"` - PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` - DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"` -} - -type DingTalkBot struct { - ClientId string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` + Server ServerConfig `mapstructure:"server"` + Ollama OllamaConfig `mapstructure:"ollama"` + Vllm VllmConfig `mapstructure:"vllm"` + Sys SysConfig `mapstructure:"sys"` + Tools ToolsConfig `mapstructure:"tools"` + Logging LoggingConfig `mapstructure:"logging"` + Redis Redis `mapstructure:"redis"` + DB DB `mapstructure:"db"` + DefaultPrompt SysPrompt `mapstructure:"default_prompt"` + PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` + // DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"` } type SysPrompt struct { @@ -57,18 +52,18 @@ type ServerConfig struct { // OllamaConfig Ollama配置 type OllamaConfig struct { - BaseURL string `mapstructure:"base_url"` - Model string `mapstructure:"model"` - GenerateModel string `mapstructure:"generate_model"` - VlModel string `mapstructure:"vl_model"` - Timeout time.Duration `mapstructure:"timeout"` + BaseURL string `mapstructure:"base_url"` + Model string `mapstructure:"model"` + GenerateModel string `mapstructure:"generate_model"` + VlModel string `mapstructure:"vl_model"` + Timeout time.Duration `mapstructure:"timeout"` } type VllmConfig struct { - BaseURL string `mapstructure:"base_url"` - VlModel string `mapstructure:"vl_model"` - Timeout time.Duration `mapstructure:"timeout"` - Level string `mapstructure:"level"` + BaseURL string `mapstructure:"base_url"` + VlModel string `mapstructure:"vl_model"` + Timeout time.Duration `mapstructure:"timeout"` + Level string `mapstructure:"level"` } type Redis struct { diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index bfe004c..2b24cfe 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -25,3 +25,9 @@ var ChatStyleMap = map[ChatStyle]string{ ChatStyleCute: "可爱", ChatStyleAngry: "愤怒", } + +type BotType int + +const ( + BotTypeDingTalk BotType = 1 // 系统的bug/优化建议 +) diff --git a/internal/data/impl/base.go b/internal/data/impl/base.go index 7aee0b7..b7abf1a 100644 --- a/internal/data/impl/base.go +++ b/internal/data/impl/base.go @@ -22,7 +22,7 @@ BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。 // 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题 type PO interface { model.AiChatHi | - model.AiSy | model.AiSession | model.AiTask | model.AiBot + model.AiSy | model.AiSession | model.AiTask | model.AiBotConfig } type BaseModel[P PO] struct { diff --git a/internal/data/impl/bot_chat_history.go b/internal/data/impl/bot_chat_history.go index 3da5184..7e4de3c 100644 --- a/internal/data/impl/bot_chat_history.go +++ b/internal/data/impl/bot_chat_history.go @@ -10,6 +10,6 @@ type BotChatHisImpl struct { dataTemp.DataTemp } -func NewBotChatHisImpl(db *utils.Db) *TaskImpl { - return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))} +func NewBotChatHisImpl(db *utils.Db) *BotChatHisImpl { + return &BotChatHisImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))} } diff --git a/internal/data/impl/bot_config.go b/internal/data/impl/bot_config.go new file mode 100644 index 0000000..2c98ffb --- /dev/null +++ b/internal/data/impl/bot_config.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotConfigImpl struct { + dataTemp.DataTemp +} + +func NewBotConfigImpl(db *utils.Db) *BotConfigImpl { + return &BotConfigImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotConfig)), + } +} diff --git a/internal/data/impl/bot_impl.go b/internal/data/impl/bot_impl.go deleted file mode 100644 index e76e8de..0000000 --- a/internal/data/impl/bot_impl.go +++ /dev/null @@ -1,28 +0,0 @@ -package impl - -import ( - "ai_scheduler/internal/data/model" - "ai_scheduler/tmpl/dataTemp" - "ai_scheduler/utils" - - "gorm.io/gorm" -) - -type BotImpl struct { - dataTemp.DataTemp - BaseRepository[model.AiBot] -} - -func NewBotImpl(db *utils.Db) *BotImpl { - return &BotImpl{ - DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBot)), - BaseRepository: NewBaseModel[model.AiBot](db.Client), - } -} - -// WithSysId 系统id -func (s *BotImpl) WithSysId(sysId interface{}) CondFunc { - return func(db *gorm.DB) *gorm.DB { - return db.Where("sys_id = ?", sysId) - } -} diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index 1284f11..53d389c 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -4,4 +4,10 @@ import ( "github.com/google/wire" ) -var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl) +var ProviderImpl = wire.NewSet( + NewSessionImpl, + NewSysImpl, + NewTaskImpl, + NewChatHisImpl, + NewBotConfigImpl, +) diff --git a/internal/data/model/ai_bot.gen.go b/internal/data/model/ai_bot_config.gen.go similarity index 72% rename from internal/data/model/ai_bot.gen.go rename to internal/data/model/ai_bot_config.gen.go index ffe73c2..5c4f27a 100644 --- a/internal/data/model/ai_bot.gen.go +++ b/internal/data/model/ai_bot_config.gen.go @@ -8,13 +8,13 @@ import ( "time" ) -const TableNameAiBot = "ai_bot" +const TableNameAiBotConfig = "ai_bot_config" -// AiBot mapped from table -type AiBot struct { +// AiBotConfig mapped from table +type AiBotConfig struct { BotID int32 `gorm:"column:bot_id;primaryKey;autoIncrement:true" json:"bot_id"` SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"` - BotType int32 `gorm:"column:bot_type" json:"bot_type"` + BotType int32 `gorm:"column:bot_type;not null;default:1" json:"bot_type"` BotName string `gorm:"column:bot_name;not null" json:"bot_name"` BotConfig string `gorm:"column:bot_config;not null" json:"bot_config"` CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` @@ -23,7 +23,7 @@ type AiBot struct { DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` } -// TableName AiBot's table name -func (*AiBot) TableName() string { - return TableNameAiBot +// TableName AiBotConfig's table name +func (*AiBotConfig) TableName() string { + return TableNameAiBotConfig } diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 0a99113..fc24764 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -22,3 +22,8 @@ type RequireDataDingTalkBot struct { ImgByte []api.ImageData ImgUrls []string } + +type DingTalkBot struct { + ClientId string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` +} diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go index cda0e1c..1d6c4a4 100644 --- a/internal/server/ding_talk_bot.go +++ b/internal/server/ding_talk_bot.go @@ -1,7 +1,7 @@ package server import ( - "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" "ai_scheduler/internal/services" "context" @@ -11,7 +11,7 @@ import ( ) type DingBotServiceInterface interface { - GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) + GetServiceCfg() ([]entitys.DingTalkBot, error) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) } @@ -19,23 +19,28 @@ type DingTalkBotServer struct { Clients []*client.StreamClient } +// NewDingTalkBotServer 批量注册钉钉客户端cli +// 这里支持两种方式,一种是完全独立service,一种是直接用现成的service +// 独立的service,在本页的ProvideAllDingBotServices方法进行注册 +// 现成的service参考services->dtalk_bot.go +// 具体使用请根据实际业务需求 func NewDingTalkBotServer( - cfg *config.Config, services []DingBotServiceInterface, ) *DingTalkBotServer { clients := make([]*client.StreamClient, 0) for _, service := range services { - serviceConf, index := service.GetServiceCfg(cfg.DingTalkBots) - if serviceConf == nil { - log.Info("未找到%s配置", index) - continue + serviceConfigs, index := service.GetServiceCfg() + for _, serviceConf := range serviceConfigs { + if serviceConf.ClientId == "" || serviceConf.ClientSecret == "" { + continue + } + cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service) + if cli == nil { + log.Info("%s客户端初始失败", index) + continue + } + clients = append(clients, cli) } - cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service) - if cli == nil { - log.Info("%s客户端初始失败", index) - continue - } - clients = append(clients, cli) } return &DingTalkBotServer{ Clients: clients, diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index 2855471..864f5d7 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -3,6 +3,7 @@ package services import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" "context" "fmt" @@ -11,41 +12,39 @@ import ( ) type DingBotService struct { - config *config.Config - replier *chatbot.ChatbotReplier - env string - DingTalkBotBiz *biz.DingTalkBotBiz + config *config.Config + + dingTalkBotBiz *biz.DingTalkBotBiz } func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService { - return &DingBotService{config: config, replier: chatbot.NewChatbotReplier(), env: "public", DingTalkBotBiz: DingTalkBotBiz} + return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz} } -func (d *DingBotService) GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) { - return cfg[d.env], d.env +func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) { + return d.dingTalkBotBiz.GetDingTalkBotCfgList() } func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) { - requireData, err := d.DingTalkBotBiz.InitRequire(ctx, data) + requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data) if err != nil { return } go func() { defer close(requireData.Ch) - - if recognizeErr := d.DingTalkBotBiz.Recognize(ctx, requireData); recognizeErr != nil { - requireData.Ch <- entitys.Response{ - Type: entitys.ResponseEnd, - Content: fmt.Sprintf("处理消息时出错: %v", recognizeErr), - } - } - //向下传递 - if err = d.handle.HandleMatch(ctx, nil, requireData); err != nil { - requireData.Ch <- entitys.Response{ - Type: entitys.ResponseEnd, - Content: fmt.Sprintf("匹配失败: %v", err), - } - } + //if match, _err := d.dingTalkBotBiz.Recognize(ctx, data); _err != nil { + // requireData.Ch <- entitys.Response{ + // Type: entitys.ResponseEnd, + // Content: fmt.Sprintf("处理消息时出错: %s", _err.Error()), + // } + //} + ////向下传递 + //if err = d.dingTalkBotBiz.HandleMatch(ctx, nil, requireData); err != nil { + // requireData.Ch <- entitys.Response{ + // Type: entitys.ResponseEnd, + // Content: fmt.Sprintf("匹配失败: %v", err), + // } + //} }() for { select { @@ -59,77 +58,10 @@ func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *cha if resp.Type == entitys.ResponseLog { return } - if err := d.handleRes(ctx, data, resp); err != nil { + if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil { return nil, fmt.Errorf("回复失败: %w", err) } } } return } - -func (d *DingBotService) handleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error { - switch resp.Type { - case entitys.ResponseText: - return d.replyText(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseStream: - return d.replySteam(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseImg: - return d.replyImg(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseFile: - return d.replyFile(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseMarkdown: - return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content) - case entitys.ResponseActionCard: - return d.replyActionCard(ctx, data.SessionWebhook, resp.Content) - default: - return nil - } -} - -func (d *DingBotService) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} - -func (d *DingBotService) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error { - msg := content - if len(arg) > 0 { - msg = fmt.Sprintf(content, arg) - } - return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg)) -} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 668c7fe..867eb11 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -6,4 +6,10 @@ import ( "github.com/google/wire" ) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService) +var ProviderSetServices = wire.NewSet( + NewChatService, + NewSessionService, gateway.NewGateway, + NewTaskService, + NewCallbackService, + NewDingBotService, + NewHistoryService) From 567874848bb256106d9c590b13c9fbe5370b1640 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Wed, 10 Dec 2025 11:47:39 +0800 Subject: [PATCH 6/7] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/entitys/recognize.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 6d692eb..fc8b55f 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -20,7 +20,7 @@ type RegistrationTask struct { type RecognizeUserContent struct { Text string - File *RecognizeFile + File []*RecognizeFile ActionCardUrl string Tag string } From d7b39f4d60a8db83c8d20c4fcc0371bc8effda5e Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Wed, 10 Dec 2025 14:19:35 +0800 Subject: [PATCH 7/7] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/main.go | 6 ++++-- deploy.sh | 4 +++- internal/biz/do/prompt.go | 15 ++++++--------- internal/biz/handle/file.go | 2 +- internal/entitys/bot.go | 5 +++-- internal/server/ding_talk_bot.go | 17 +++++++++++------ 6 files changed, 28 insertions(+), 21 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index a6607f5..d765735 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,6 +2,7 @@ package main import ( "ai_scheduler/internal/config" + "context" "flag" "fmt" @@ -10,6 +11,7 @@ import ( func main() { configPath := flag.String("config", "./config/config_test.yaml", "Path to configuration file") + onBot := flag.String("bot", "", "bot start") flag.Parse() bc, err := config.LoadConfig(*configPath) if err != nil { @@ -23,7 +25,7 @@ func main() { defer func() { cleanup() }() - //app.DingBotServer.Run(context.Background()) - //app.DingBotServer.RunBots(app.DingBotServer.BotServices) + app.DingBotServer.Run(context.Background(), *onBot) + log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port))) } diff --git a/deploy.sh b/deploy.sh index f02e162..2766253 100644 --- a/deploy.sh +++ b/deploy.sh @@ -14,8 +14,10 @@ fi CONFIG_FILE="config/config.yaml" BRANCH="master" +BOT="ALL" if [ "$MODE" = "dev" ]; then CONFIG_FILE="config/config_test.yaml" + BOT="zltx" BRANCH="test" fi @@ -33,6 +35,6 @@ docker run -itd \ -e "OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}" \ -e "MODE=${MODE}" \ -p 8090:8090 \ - "${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}" + "${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}" --bot "./${BOT}" docker logs -f ${CONTAINER_NAME} \ No newline at end of file diff --git a/internal/biz/do/prompt.go b/internal/biz/do/prompt.go index f25cf95..ea82f84 100644 --- a/internal/biz/do/prompt.go +++ b/internal/biz/do/prompt.go @@ -43,7 +43,7 @@ func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { var hasFile bool - if rec.UserContent.File != nil && (len(rec.UserContent.File.FileUrl) > 0 || rec.UserContent.File.File != nil) { + if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 { hasFile = true } content.WriteString(rec.UserContent.Text) @@ -65,7 +65,10 @@ func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (c if hasFile { content.WriteString("\n") content.WriteString("### 文件内容:\n") - handle.HandleRecognizeFile(rec.UserContent.File) + for _, file := range rec.UserContent.File { + handle.HandleRecognizeFile(file) + } + //...do something with file } return @@ -100,7 +103,7 @@ func (f *WithDingTalkBot) CreatePrompt(ctx context.Context, rec *entitys.Recogni func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) { var hasFile bool - if rec.UserContent.File != nil && (len(rec.UserContent.File.FileUrl) > 0 || rec.UserContent.File.File != nil) { + if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 { hasFile = true } content.WriteString(rec.UserContent.Text) @@ -119,11 +122,5 @@ func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recog content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis)) } - if hasFile { - content.WriteString("\n") - content.WriteString("### 文件内容:\n") - handle.HandleRecognizeFile(rec.UserContent.File) - //...do something with file - } return } diff --git a/internal/biz/handle/file.go b/internal/biz/handle/file.go index 9fd4fae..e9332aa 100644 --- a/internal/biz/handle/file.go +++ b/internal/biz/handle/file.go @@ -18,7 +18,7 @@ import ( // 判断文件大小 // 判断文件类型 // 判断文件是否合法 -func HandleRecognizeFile(file *entitys.RecognizeFile) { +func HandleRecognizeFile(files *entitys.RecognizeFile) { //Todo 仲云 return } diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index fc24764..568822c 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -24,6 +24,7 @@ type RequireDataDingTalkBot struct { } type DingTalkBot struct { - ClientId string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` + BotIndex string + ClientId string + ClientSecret string } diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go index 1d6c4a4..9a63812 100644 --- a/internal/server/ding_talk_bot.go +++ b/internal/server/ding_talk_bot.go @@ -16,7 +16,7 @@ type DingBotServiceInterface interface { } type DingTalkBotServer struct { - Clients []*client.StreamClient + Clients map[string]*client.StreamClient } // NewDingTalkBotServer 批量注册钉钉客户端cli @@ -27,19 +27,19 @@ type DingTalkBotServer struct { func NewDingTalkBotServer( services []DingBotServiceInterface, ) *DingTalkBotServer { - clients := make([]*client.StreamClient, 0) + clients := make(map[string]*client.StreamClient) for _, service := range services { - serviceConfigs, index := service.GetServiceCfg() + serviceConfigs, err := service.GetServiceCfg() for _, serviceConf := range serviceConfigs { if serviceConf.ClientId == "" || serviceConf.ClientSecret == "" { continue } cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service) if cli == nil { - log.Info("%s客户端初始失败", index) + log.Info("%s客户端初始失败:%s", serviceConf.BotIndex, err.Error()) continue } - clients = append(clients, cli) + clients[serviceConf.BotIndex] = cli } } return &DingTalkBotServer{ @@ -53,8 +53,13 @@ func ProvideAllDingBotServices( return []DingBotServiceInterface{dingBotSvc} } -func (d *DingTalkBotServer) Run(ctx context.Context) { +func (d *DingTalkBotServer) Run(ctx context.Context, botIndex string) { for name, cli := range d.Clients { + if botIndex != "All" { + if name != botIndex { + continue + } + } err := cli.Start(ctx) if err != nil { log.Info("%s启动失败", name)