Merge branch 'ryz' into v3
# Conflicts: # go.mod # go.sum # internal/biz/llm_service/ollama.go # internal/config/config.go # internal/entitys/types.go # internal/services/provider_set.go
This commit is contained in:
commit
782006b8d3
|
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
|
|
@ -23,8 +22,8 @@ func main() {
|
|||
}
|
||||
defer func() {
|
||||
cleanup()
|
||||
|
||||
}()
|
||||
|
||||
//app.DingBotServer.Run(context.Background())
|
||||
//app.DingBotServer.RunBots(app.DingBotServer.BotServices)
|
||||
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,3 +86,9 @@ default_prompt:
|
|||
# 权限配置
|
||||
permissionConfig:
|
||||
permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId="
|
||||
|
||||
|
||||
ding_talk_bots:
|
||||
public:
|
||||
client_id: "dingchg59zwwvmuuvldx",
|
||||
client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz",
|
||||
3
go.mod
3
go.mod
|
|
@ -13,6 +13,7 @@ require (
|
|||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/faabiosr/cachego v0.26.0
|
||||
github.com/fastwego/dingding v1.0.0-beta.4
|
||||
github.com/gabriel-vasile/mimetype v1.4.11
|
||||
github.com/go-kratos/kratos/v2 v2.9.1
|
||||
github.com/go-playground/locales v0.14.1
|
||||
github.com/go-playground/universal-translator v0.18.1
|
||||
|
|
@ -22,6 +23,7 @@ require (
|
|||
github.com/google/uuid v1.6.0
|
||||
github.com/google/wire v0.7.0
|
||||
github.com/ollama/ollama v0.12.7
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/redis/go-redis/v9 v9.16.0
|
||||
github.com/spf13/viper v1.17.0
|
||||
github.com/tmc/langchaingo v0.1.13
|
||||
|
|
@ -59,6 +61,7 @@ require (
|
|||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
|
|
|
|||
6
go.sum
6
go.sum
|
|
@ -176,6 +176,8 @@ github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4
|
|||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
|
||||
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
||||
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||
|
|
@ -267,6 +269,8 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR
|
|||
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
|
||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
|
|
@ -342,6 +346,8 @@ github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1ls
|
|||
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz/do"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"context"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
)
|
||||
|
||||
// AiRouterBiz 智能路由服务
|
||||
type DingTalkBotBiz struct {
|
||||
do *do.Do
|
||||
handle *do.Handle
|
||||
}
|
||||
|
||||
// NewDingTalkBotBiz
|
||||
func NewDingTalkBotBiz(
|
||||
do *do.Do,
|
||||
handle *do.Handle,
|
||||
) *DingTalkBotBiz {
|
||||
return &DingTalkBotBiz{
|
||||
do: do,
|
||||
handle: handle,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallbackDataModel) (requireData *entitys.RequireDataDingTalkBot, err error) {
|
||||
requireData = &entitys.RequireDataDingTalkBot{
|
||||
Req: data,
|
||||
Ch: make(chan entitys.Response, 2),
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等")
|
||||
|
||||
requireData.Sys, err = d.do.GetSysInfoForDingTalkBot(requireData)
|
||||
requireData.Tasks, err = d.do.GetTasks(requireData.Sys.SysID)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) Recognize(ctx context.Context, rec *entitys.Recognize, ch chan entitys.Response) (match *entitys.Match, err error) {
|
||||
return d.handle.RecognizeBot(ctx, rec, ch)
|
||||
}
|
||||
|
|
@ -79,6 +79,20 @@ func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData *
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *Do) DataAuthForBot(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) {
|
||||
|
||||
// 2. 加载系统信息
|
||||
if err = d.loadSystemInfo(ctx, client, requireData); err != nil {
|
||||
return fmt.Errorf("获取系统信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. 加载任务列表
|
||||
if err = d.loadTaskList(ctx, client, requireData); err != nil {
|
||||
return fmt.Errorf("获取任务列表失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 提取数据验证为单独函数
|
||||
func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error {
|
||||
requireData.Session = client.GetSession()
|
||||
|
|
@ -102,7 +116,7 @@ func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.Req
|
|||
// 获取系统信息的辅助函数
|
||||
func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
||||
if sysInfo := client.GetSysInfo(); sysInfo == nil {
|
||||
sys, err := d.getSysInfo(requireData)
|
||||
sys, err := d.GetSysInfo(requireData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -117,7 +131,7 @@ func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, require
|
|||
// 获取任务列表的辅助函数
|
||||
func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
||||
if taskInfo := client.GetTasks(); len(taskInfo) == 0 {
|
||||
tasks, err := d.getTasks(requireData.Sys.SysID)
|
||||
tasks, err := d.GetTasks(requireData.Sys.SysID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -200,7 +214,16 @@ func (d *Do) getRequireData() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
|
||||
func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"app_key": requireData.Key})
|
||||
cond = cond.And(builder.IsNull{"delete_at"})
|
||||
cond = cond.And(builder.Eq{"status": 1})
|
||||
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"app_key": requireData.Key})
|
||||
cond = cond.And(builder.IsNull{"delete_at"})
|
||||
|
|
@ -219,7 +242,7 @@ func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.Ai
|
|||
return
|
||||
}
|
||||
|
||||
func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||
func (d *Do) GetTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||||
|
|
|
|||
|
|
@ -46,23 +46,24 @@ func NewHandle(
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||
entitys.ResLog(requireData.Ch, "recognize_start", "准备意图识别")
|
||||
|
||||
func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) {
|
||||
entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别")
|
||||
prompt, err := promptProcessor.CreatePrompt(ctx, rec)
|
||||
//意图识别
|
||||
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
|
||||
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{
|
||||
Prompt: prompt,
|
||||
Tools: rec.Tasks,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize", recognizeMsg)
|
||||
entitys.ResLog(requireData.Ch, "recognize_end", "意图识别结束")
|
||||
entitys.ResLog(rec.Ch, "recognize", recognizeMsg)
|
||||
entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束")
|
||||
|
||||
var match entitys.Match
|
||||
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
|
||||
err = errors.SysErr("数据结构错误:%v", err.Error())
|
||||
return
|
||||
}
|
||||
requireData.Match = &match
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
package do
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type PromptOption interface {
|
||||
CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error)
|
||||
}
|
||||
|
||||
type WithSys struct {
|
||||
}
|
||||
|
||||
func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
|
||||
var (
|
||||
prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片
|
||||
)
|
||||
// 获取用户内容,如果出错则直接返回错误
|
||||
content, err := f.getUserContent(ctx, rec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 构建提示消息列表,包含系统提示、助手回复和用户内容
|
||||
mes = append(prompt, api.Message{
|
||||
Role: "system", // 系统角色
|
||||
Content: rec.SystemPrompt, // 系统提示内容
|
||||
}, api.Message{
|
||||
Role: "assistant", // 助手角色
|
||||
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容
|
||||
}, api.Message{
|
||||
Role: "user", // 用户角色
|
||||
Content: content.String(), // 用户输入内容
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
|
||||
var hasFile bool
|
||||
if len(rec.UserContent.FileUrl) > 0 || rec.UserContent.File != nil {
|
||||
hasFile = true
|
||||
}
|
||||
content.WriteString(rec.UserContent.Text)
|
||||
if hasFile {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(rec.UserContent.Tag) > 0 {
|
||||
content.WriteString("\n")
|
||||
content.WriteString("### 工具必须使用:")
|
||||
content.WriteString(rec.UserContent.Tag)
|
||||
}
|
||||
|
||||
if len(rec.ChatHis.Messages) > 0 {
|
||||
content.WriteString("### 引用历史聊天记录:\n")
|
||||
content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis))
|
||||
}
|
||||
|
||||
if hasFile {
|
||||
content.WriteString("\n")
|
||||
content.WriteString("### 文件内容:\n")
|
||||
hand.WriteString(rec.UserContent.FileUrl, rec.UserContent.FileUrl)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
package handle
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/l_request"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
)
|
||||
|
||||
// HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式
|
||||
// 判断文件大小
|
||||
// 判断文件类型
|
||||
// 判断文件是否合法
|
||||
func HandleRecognizeFile(file *entitys.RecognizeFile) {
|
||||
//Todo 仲云
|
||||
return
|
||||
}
|
||||
|
||||
// 下载文件并返回二进制数据、MIME 类型
|
||||
func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err error) {
|
||||
if len(fileUrl) == 0 {
|
||||
return
|
||||
}
|
||||
req := l_request.Request{
|
||||
Method: "GET",
|
||||
Url: fileUrl,
|
||||
Headers: map[string]string{
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
|
||||
},
|
||||
}
|
||||
res, err := req.Send()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var ex bool
|
||||
if contentType, ex = res.Headers["Content-Type"]; !ex {
|
||||
err = errors.New("Content-Type不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
err = fmt.Errorf("server returned non-200 status: %d", res.StatusCode)
|
||||
}
|
||||
fileBytes = res.Content
|
||||
|
||||
return fileBytes, contentType, nil
|
||||
}
|
||||
|
||||
// detectFileType 判断文件类型
|
||||
func detectFileType(file io.ReadSeeker, filename string) constants.FileType {
|
||||
// 1. 读取文件头检测 MIME
|
||||
buffer := make([]byte, 512)
|
||||
n, _ := file.Read(buffer)
|
||||
file.Seek(0, io.SeekStart) // 重置读取位置
|
||||
|
||||
detectedMIME := mimetype.Detect(buffer[:n]).String()
|
||||
for fileType, items := range constants.FileTypeMappings {
|
||||
for _, item := range items {
|
||||
if !strings.HasPrefix(item, ".") && item == detectedMIME {
|
||||
return fileType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 备用:通过扩展名检测
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
for fileType, items := range constants.FileTypeMappings {
|
||||
for _, item := range items {
|
||||
if strings.HasPrefix(item, ".") && item == ext {
|
||||
return fileType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return constants.FileTypeUnknown
|
||||
}
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
package handle
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/tools"
|
||||
)
|
||||
|
||||
type Handle struct {
|
||||
toolManager *tools.Manager
|
||||
conf *config.Config
|
||||
}
|
||||
|
||||
func NewHandle(
|
||||
toolManager *tools.Manager,
|
||||
conf *config.Config,
|
||||
|
||||
) *Handle {
|
||||
return &Handle{
|
||||
toolManager: toolManager,
|
||||
conf: conf,
|
||||
}
|
||||
}
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
package llm_service
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -20,48 +18,3 @@ func buildSystemPrompt(prompt string) string {
|
|||
|
||||
return prompt
|
||||
}
|
||||
|
||||
func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
||||
for _, item := range his {
|
||||
if len(chatHis.SessionId) == 0 {
|
||||
chatHis.SessionId = item.SessionID
|
||||
}
|
||||
chatHis.Messages = append(chatHis.Messages, []entitys.HisMessage{
|
||||
{
|
||||
Role: constants.RoleUser,
|
||||
Content: item.Ques,
|
||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
||||
},
|
||||
{
|
||||
Role: constants.RoleAssistant,
|
||||
Content: item.Ans,
|
||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
chatHis.Context = entitys.HisContext{
|
||||
UserLanguage: "zh-CN",
|
||||
SystemMode: "technical_support",
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func BuildChatHisMessage(his []model.AiChatHi) (chatHis []entitys.HisMessage) {
|
||||
for _, item := range his {
|
||||
|
||||
chatHis = append(chatHis, []entitys.HisMessage{
|
||||
{
|
||||
Role: constants.RoleUser,
|
||||
Content: item.Ques,
|
||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
||||
},
|
||||
{
|
||||
Role: constants.RoleAssistant,
|
||||
Content: item.Ans,
|
||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,87 +1,76 @@
|
|||
package llm_service
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/pkg/utils_langchain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
)
|
||||
|
||||
type LangChainService struct {
|
||||
client *utils_langchain.UtilLangChain
|
||||
}
|
||||
|
||||
func NewLangChainGenerate(
|
||||
client *utils_langchain.UtilLangChain,
|
||||
) *LangChainService {
|
||||
|
||||
return &LangChainService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
|
||||
prompt := r.getPrompt(sysInfo, history, userInput, tasks)
|
||||
AgentClient := r.client.Get()
|
||||
defer r.client.Put(AgentClient)
|
||||
match, err := AgentClient.Llm.GenerateContent(
|
||||
ctx, // 使用可取消的上下文
|
||||
prompt,
|
||||
llms.WithJSONMode(),
|
||||
)
|
||||
msg = match.Choices[0].Content
|
||||
return
|
||||
}
|
||||
|
||||
func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||
var (
|
||||
prompt = make([]llms.MessageContent, 0)
|
||||
)
|
||||
prompt = append(prompt, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeSystem,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeTool,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeTool,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(reqInput),
|
||||
},
|
||||
})
|
||||
return prompt
|
||||
}
|
||||
|
||||
func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||
taskPrompt := make([]llms.Tool, 0)
|
||||
for _, task := range tasks {
|
||||
var taskConfig entitys.TaskConfig
|
||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
taskPrompt = append(taskPrompt, llms.Tool{
|
||||
Type: "function",
|
||||
Function: &llms.FunctionDefinition{
|
||||
Name: task.Index,
|
||||
Description: task.Desc,
|
||||
Parameters: taskConfig.Param,
|
||||
},
|
||||
})
|
||||
|
||||
}
|
||||
return taskPrompt
|
||||
}
|
||||
//type LangChainService struct {
|
||||
// client *utils_langchain.UtilLangChain
|
||||
//}
|
||||
//
|
||||
//func NewLangChainGenerate(
|
||||
// client *utils_langchain.UtilLangChain,
|
||||
//) *LangChainService {
|
||||
//
|
||||
// return &LangChainService{
|
||||
// client: client,
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
|
||||
// prompt := r.getPrompt(sysInfo, history, userInput, tasks)
|
||||
// AgentClient := r.client.Get()
|
||||
// defer r.client.Put(AgentClient)
|
||||
// match, err := AgentClient.Llm.GenerateContent(
|
||||
// ctx, // 使用可取消的上下文
|
||||
// prompt,
|
||||
// llms.WithJSONMode(),
|
||||
// )
|
||||
// msg = match.Choices[0].Content
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||
// var (
|
||||
// prompt = make([]llms.MessageContent, 0)
|
||||
// )
|
||||
// prompt = append(prompt, llms.MessageContent{
|
||||
// Role: llms.ChatMessageTypeSystem,
|
||||
// Parts: []llms.ContentPart{
|
||||
// llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
|
||||
// },
|
||||
// }, llms.MessageContent{
|
||||
// Role: llms.ChatMessageTypeTool,
|
||||
// Parts: []llms.ContentPart{
|
||||
// llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
|
||||
// },
|
||||
// }, llms.MessageContent{
|
||||
// Role: llms.ChatMessageTypeTool,
|
||||
// Parts: []llms.ContentPart{
|
||||
// llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||
// },
|
||||
// }, llms.MessageContent{
|
||||
// Role: llms.ChatMessageTypeHuman,
|
||||
// Parts: []llms.ContentPart{
|
||||
// llms.TextPart(reqInput),
|
||||
// },
|
||||
// })
|
||||
// return prompt
|
||||
//}
|
||||
//
|
||||
//func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||
// taskPrompt := make([]llms.Tool, 0)
|
||||
// for _, task := range tasks {
|
||||
// var taskConfig entitys.TaskConfig
|
||||
// err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// taskPrompt = append(taskPrompt, llms.Tool{
|
||||
// Type: "function",
|
||||
// Function: &llms.FunctionDefinition{
|
||||
// Name: task.Index,
|
||||
// Description: task.Desc,
|
||||
// Parameters: taskConfig.Param,
|
||||
// },
|
||||
// })
|
||||
//
|
||||
// }
|
||||
// return taskPrompt
|
||||
//}
|
||||
|
|
|
|||
|
|
@ -39,14 +39,14 @@ func NewOllamaGenerate(
|
|||
}
|
||||
}
|
||||
|
||||
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
|
||||
prompt, err := r.getPrompt(ctx, requireData)
|
||||
func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
|
||||
toolDefinitions := r.registerToolsOllama(req.Tools)
|
||||
|
||||
match, err := r.client.ToolSelect(ctx, prompt, toolDefinitions)
|
||||
match, err := r.client.ToolSelect(ctx, rec.Prompt, toolDefinitions)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -133,6 +133,8 @@ func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys
|
|||
|
||||
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||
if requireData.ImgByte == nil {
|
||||
func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
|
||||
if imgByte == nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package biz
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/biz/do"
|
||||
"ai_scheduler/internal/biz/handle"
|
||||
"ai_scheduler/internal/biz/llm_service"
|
||||
|
||||
"github.com/google/wire"
|
||||
|
|
@ -12,10 +11,11 @@ var ProviderSetBiz = wire.NewSet(
|
|||
NewAiRouterBiz,
|
||||
NewSessionBiz,
|
||||
NewChatHistoryBiz,
|
||||
llm_service.NewLangChainGenerate,
|
||||
//llm_service.NewLangChainGenerate,
|
||||
llm_service.NewOllamaGenerate,
|
||||
handle.NewHandle,
|
||||
//handle.NewHandle,
|
||||
do.NewDo,
|
||||
do.NewHandle,
|
||||
NewTaskBiz,
|
||||
NewDingTalkBotBiz,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package biz
|
|||
import (
|
||||
"ai_scheduler/internal/biz/do"
|
||||
"ai_scheduler/internal/gateway"
|
||||
"context"
|
||||
|
||||
"ai_scheduler/internal/entitys"
|
||||
|
||||
|
|
@ -54,9 +55,15 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
|
|||
log.Errorf("数据验证和收集失败: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
//组装意图识别
|
||||
rec, err := r.SetRec(ctx, requireData)
|
||||
if err != nil {
|
||||
log.Errorf("组装意图识别失败: %s", err.Error())
|
||||
return
|
||||
}
|
||||
//意图识别
|
||||
if err = r.handle.Recognize(ctx, requireData); err != nil {
|
||||
requireData.Match, err = r.handle.Recognize(ctx, &rec, &do.WithSys{})
|
||||
if err != nil {
|
||||
log.Errorf("意图识别失败: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
|
@ -68,3 +75,8 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, err error) {
|
||||
//TODO 叙平
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,12 @@ type Config struct {
|
|||
DB DB `mapstructure:"db"`
|
||||
DefaultPrompt SysPrompt `mapstructure:"default_prompt"`
|
||||
PermissionConfig PermissionConfig `mapstructure:"permissionConfig"`
|
||||
// LLM *LLM `mapstructure:"llm"`
|
||||
DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"`
|
||||
}
|
||||
|
||||
type DingTalkBot struct {
|
||||
ClientId string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
}
|
||||
|
||||
type SysPrompt struct {
|
||||
|
|
|
|||
|
|
@ -3,5 +3,25 @@ package constants
|
|||
type BotTools string
|
||||
|
||||
const (
|
||||
BotToolsBugOptimizationSubmit = "bug_optimization_submit" // 系统的bug/优化建议
|
||||
BotToolsBugOptimizationSubmit BotTools = "bug_optimization_submit" // 系统的bug/优化建议
|
||||
)
|
||||
|
||||
type ChatStyle int
|
||||
|
||||
const (
|
||||
ChatStyleNormal ChatStyle = 1 //正常
|
||||
ChatStyleSerious ChatStyle = 2 //严肃
|
||||
ChatStyleGentle ChatStyle = 3 //温柔
|
||||
ChatStyleArrogance ChatStyle = 4 //傲慢
|
||||
ChatStyleCute ChatStyle = 5 //可爱
|
||||
ChatStyleAngry ChatStyle = 6 //愤怒
|
||||
)
|
||||
|
||||
var ChatStyleMap = map[ChatStyle]string{
|
||||
ChatStyleNormal: "正常",
|
||||
ChatStyleSerious: "严肃",
|
||||
ChatStyleGentle: "温柔",
|
||||
ChatStyleArrogance: "傲慢",
|
||||
ChatStyleCute: "可爱",
|
||||
ChatStyleAngry: "愤怒",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
package constants
|
||||
|
||||
type FileType string
|
||||
|
||||
const (
|
||||
FileTypeUnknown FileType = "unknown"
|
||||
FileTypeImage FileType = "image"
|
||||
//FileTypeVideo FileType = "video"
|
||||
FileTypeExcel FileType = "excel"
|
||||
FileTypeWord FileType = "word"
|
||||
FileTypeTxt FileType = "txt"
|
||||
FileTypePDF FileType = "pdf"
|
||||
)
|
||||
|
||||
var FileTypeMappings = map[FileType][]string{
|
||||
FileTypeImage: {
|
||||
"image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml",
|
||||
".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg",
|
||||
},
|
||||
FileTypeExcel: {
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".xls", ".xlsx",
|
||||
},
|
||||
FileTypeWord: {
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc", ".docx",
|
||||
},
|
||||
FileTypePDF: {
|
||||
"application/pdf",
|
||||
".pdf",
|
||||
},
|
||||
FileTypeTxt: {
|
||||
"text/plain",
|
||||
".txt",
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/tmpl/dataTemp"
|
||||
"ai_scheduler/utils"
|
||||
)
|
||||
|
||||
type BotChatHisImpl struct {
|
||||
dataTemp.DataTemp
|
||||
}
|
||||
|
||||
func NewBotChatHisImpl(db *utils.Db) *TaskImpl {
|
||||
return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))}
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const TableNameAiBotChatHi = "ai_bot_chat_his"
|
||||
|
||||
// AiBotChatHi mapped from table <ai_bot_chat_his>
|
||||
type AiBotChatHi struct {
|
||||
HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"`
|
||||
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
|
||||
Role string `gorm:"column:role;not null;comment:system系统输出,assistant助手输出,user用户输入" json:"role"` // system系统输出,assistant助手输出,user用户输入
|
||||
Content string `gorm:"column:content;not null" json:"content"`
|
||||
Files string `gorm:"column:files;not null" json:"files"`
|
||||
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName AiBotChatHi's table name
|
||||
func (*AiBotChatHi) TableName() string {
|
||||
return TableNameAiBotChatHi
|
||||
}
|
||||
|
|
@ -1,7 +1,24 @@
|
|||
package entitys
|
||||
|
||||
type BotType int
|
||||
import (
|
||||
"ai_scheduler/internal/data/model"
|
||||
|
||||
const (
|
||||
BugAndQuesDingTalk BotType = iota + 1
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
)
|
||||
|
||||
type RequireDataDingTalkBot struct {
|
||||
Session string
|
||||
Key string
|
||||
Sys model.AiSy
|
||||
Histories []model.AiChatHi
|
||||
SessionInfo model.AiSession
|
||||
Tasks []model.AiTask
|
||||
Match *Match
|
||||
Req *chatbot.BotCallbackDataModel
|
||||
Auth string
|
||||
Ch chan Response
|
||||
KnowledgeConf KnowledgeBaseRequest
|
||||
ImgByte []api.ImageData
|
||||
ImgUrls []string
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
package entitys
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type ToolSelect struct {
|
||||
Prompt []api.Message
|
||||
Tools []RegistrationTask
|
||||
}
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
package entitys
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
)
|
||||
|
||||
type Recognize struct {
|
||||
SystemPrompt string
|
||||
UserContent *RecognizeUserContent
|
||||
ChatHis ChatHis
|
||||
Tasks []RegistrationTask
|
||||
Ch chan Response
|
||||
}
|
||||
|
||||
type RegistrationTask struct {
|
||||
}
|
||||
|
||||
type RecognizeUserContent struct {
|
||||
Text string
|
||||
File *RecognizeFile
|
||||
ActionCardUrl string
|
||||
Tag string
|
||||
}
|
||||
|
||||
type FileData []byte
|
||||
|
||||
type RecognizeFile struct {
|
||||
File []FileData // 文件数据(二进制格式)
|
||||
FileUrl string // 文件下载链接
|
||||
FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断)
|
||||
}
|
||||
|
|
@ -3,22 +3,25 @@ package entitys
|
|||
import (
|
||||
"ai_scheduler/internal/gateway"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
)
|
||||
|
||||
type ResponseType string
|
||||
|
||||
const (
|
||||
ResponseJson ResponseType = "json"
|
||||
ResponseLoading ResponseType = "loading"
|
||||
ResponseEnd ResponseType = "end"
|
||||
ResponseStream ResponseType = "stream"
|
||||
ResponseText ResponseType = "txt"
|
||||
ResponseImg ResponseType = "img"
|
||||
ResponseFile ResponseType = "file"
|
||||
ResponseErr ResponseType = "error"
|
||||
ResponseLog ResponseType = "log"
|
||||
ResponseAuth ResponseType = "auth"
|
||||
ResponseJson ResponseType = "json"
|
||||
ResponseLoading ResponseType = "loading"
|
||||
ResponseEnd ResponseType = "end"
|
||||
ResponseStream ResponseType = "stream"
|
||||
ResponseText ResponseType = "txt"
|
||||
ResponseImg ResponseType = "img"
|
||||
ResponseFile ResponseType = "file"
|
||||
ResponseErr ResponseType = "error"
|
||||
ResponseLog ResponseType = "log"
|
||||
ResponseAuth ResponseType = "auth"
|
||||
ResponseMarkdown ResponseType = "markdown"
|
||||
ResponseActionCard ResponseType = "actionCard"
|
||||
)
|
||||
|
||||
func ResLog(ch chan Response, index string, content string) {
|
||||
|
|
@ -46,6 +49,9 @@ func ResJson(ch chan Response, index string, content string) {
|
|||
}
|
||||
|
||||
func ResEnd(ch chan Response, index string, content string) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
ch <- Response{
|
||||
Index: index,
|
||||
Content: content,
|
||||
|
|
@ -54,6 +60,9 @@ func ResEnd(ch chan Response, index string, content string) {
|
|||
}
|
||||
|
||||
func ResText(ch chan Response, index string, content string) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
ch <- Response{
|
||||
Index: index,
|
||||
Content: content,
|
||||
|
|
@ -62,6 +71,9 @@ func ResText(ch chan Response, index string, content string) {
|
|||
}
|
||||
|
||||
func ResLoading(ch chan Response, index string, content string) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
ch <- Response{
|
||||
Index: index,
|
||||
Content: content,
|
||||
|
|
@ -69,6 +81,9 @@ func ResLoading(ch chan Response, index string, content string) {
|
|||
}
|
||||
}
|
||||
func ResError(ch chan Response, index string, content string) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
ch <- Response{
|
||||
Index: index,
|
||||
Content: content,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package entitys
|
|||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/data/model"
|
||||
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
|
|
@ -151,7 +150,7 @@ type RequireData struct {
|
|||
SessionInfo model.AiSession
|
||||
Tasks []model.AiTask
|
||||
Task model.AiTask
|
||||
Match *Match
|
||||
Match Match
|
||||
Req *ChatSockRequest
|
||||
Auth string
|
||||
Ch chan Response
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/services"
|
||||
"context"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||
)
|
||||
|
||||
type DingBotServiceInterface interface {
|
||||
GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string)
|
||||
OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error)
|
||||
}
|
||||
|
||||
type DingTalkBotServer struct {
|
||||
Clients []*client.StreamClient
|
||||
}
|
||||
|
||||
func NewDingTalkBotServer(
|
||||
cfg *config.Config,
|
||||
services []DingBotServiceInterface,
|
||||
) *DingTalkBotServer {
|
||||
clients := make([]*client.StreamClient, 0)
|
||||
for _, service := range services {
|
||||
serviceConf, index := service.GetServiceCfg(cfg.DingTalkBots)
|
||||
if serviceConf == nil {
|
||||
log.Info("未找到%s配置", index)
|
||||
continue
|
||||
}
|
||||
cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service)
|
||||
if cli == nil {
|
||||
log.Info("%s客户端初始失败", index)
|
||||
continue
|
||||
}
|
||||
clients = append(clients, cli)
|
||||
}
|
||||
return &DingTalkBotServer{
|
||||
Clients: clients,
|
||||
}
|
||||
}
|
||||
|
||||
func ProvideAllDingBotServices(
|
||||
dingBotSvc *services.DingBotService,
|
||||
) []DingBotServiceInterface {
|
||||
return []DingBotServiceInterface{dingBotSvc}
|
||||
}
|
||||
|
||||
func (d *DingTalkBotServer) Run(ctx context.Context) {
|
||||
for name, cli := range d.Clients {
|
||||
err := cli.Start(ctx)
|
||||
if err != nil {
|
||||
log.Info("%s启动失败", name)
|
||||
continue
|
||||
}
|
||||
log.Info("%s启动成功", name)
|
||||
}
|
||||
}
|
||||
func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) {
|
||||
cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret)))
|
||||
cli.RegisterChatBotCallbackRouter(service.OnChatBotMessageReceived)
|
||||
return
|
||||
}
|
||||
|
|
@ -1,5 +1,12 @@
|
|||
package server
|
||||
|
||||
import "github.com/google/wire"
|
||||
import (
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetServer = wire.NewSet(NewServers, NewHTTPServer)
|
||||
var ProviderSetServer = wire.NewSet(
|
||||
NewServers,
|
||||
NewHTTPServer,
|
||||
ProvideAllDingBotServices,
|
||||
NewDingTalkBotServer,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,27 @@
|
|||
package server
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type Servers struct {
|
||||
HttpServer *fiber.App
|
||||
cfg *config.Config
|
||||
HttpServer *fiber.App
|
||||
DingBotServer *DingTalkBotServer
|
||||
}
|
||||
|
||||
func NewServers(fiber *fiber.App) *Servers {
|
||||
func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer) *Servers {
|
||||
return &Servers{
|
||||
HttpServer: fiber,
|
||||
HttpServer: fiber,
|
||||
cfg: cfg,
|
||||
DingBotServer: DingBotServer,
|
||||
}
|
||||
}
|
||||
|
||||
//func DingBotServerInit(clientId string, clientSecret string, cfg *config.Config, handler *do.Handle, do *do.Do) (cli *client.StreamClient) {
|
||||
// cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret)))
|
||||
// cli.RegisterChatBotCallbackRouter(services.NewDingBotService(cfg, handler, do).OnChatBotMessageReceived)
|
||||
// return
|
||||
//}
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ func (s *CallbackService) sendStreamLog(sessionID string, content string) {
|
|||
}
|
||||
|
||||
streamLog := entitys.Response{
|
||||
Index: constants.BotToolsBugOptimizationSubmit,
|
||||
Index: string(constants.BotToolsBugOptimizationSubmit),
|
||||
Content: content,
|
||||
Type: entitys.ResponseLog,
|
||||
}
|
||||
|
|
@ -227,7 +227,7 @@ func (s *CallbackService) sendStreamTxt(sessionID string, content string) {
|
|||
}
|
||||
|
||||
streamLog := entitys.Response{
|
||||
Index: constants.BotToolsBugOptimizationSubmit,
|
||||
Index: string(constants.BotToolsBugOptimizationSubmit),
|
||||
Content: content,
|
||||
Type: entitys.ResponseText,
|
||||
}
|
||||
|
|
@ -242,7 +242,7 @@ func (s *CallbackService) sendStreamLoading(sessionID string, content string) {
|
|||
}
|
||||
|
||||
streamLog := entitys.Response{
|
||||
Index: constants.BotToolsBugOptimizationSubmit,
|
||||
Index: string(constants.BotToolsBugOptimizationSubmit),
|
||||
Content: content,
|
||||
Type: entitys.ResponseLoading,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,135 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
)
|
||||
|
||||
type DingBotService struct {
|
||||
config *config.Config
|
||||
replier *chatbot.ChatbotReplier
|
||||
env string
|
||||
DingTalkBotBiz *biz.DingTalkBotBiz
|
||||
}
|
||||
|
||||
func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService {
|
||||
return &DingBotService{config: config, replier: chatbot.NewChatbotReplier(), env: "public", DingTalkBotBiz: DingTalkBotBiz}
|
||||
}
|
||||
|
||||
func (d *DingBotService) GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) {
|
||||
return cfg[d.env], d.env
|
||||
}
|
||||
|
||||
func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) {
|
||||
requireData, err := d.DingTalkBotBiz.InitRequire(ctx, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer close(requireData.Ch)
|
||||
|
||||
if recognizeErr := d.DingTalkBotBiz.Recognize(ctx, requireData); recognizeErr != nil {
|
||||
requireData.Ch <- entitys.Response{
|
||||
Type: entitys.ResponseEnd,
|
||||
Content: fmt.Sprintf("处理消息时出错: %v", recognizeErr),
|
||||
}
|
||||
}
|
||||
//向下传递
|
||||
if err = d.handle.HandleMatch(ctx, nil, requireData); err != nil {
|
||||
requireData.Ch <- entitys.Response{
|
||||
Type: entitys.ResponseEnd,
|
||||
Content: fmt.Sprintf("匹配失败: %v", err),
|
||||
}
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp, ok := <-requireData.Ch:
|
||||
if !ok {
|
||||
return []byte("success"), nil // 通道关闭,处理完成
|
||||
}
|
||||
|
||||
if resp.Type == entitys.ResponseLog {
|
||||
return
|
||||
}
|
||||
if err := d.handleRes(ctx, data, resp); err != nil {
|
||||
return nil, fmt.Errorf("回复失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *DingBotService) handleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error {
|
||||
switch resp.Type {
|
||||
case entitys.ResponseText:
|
||||
return d.replyText(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseStream:
|
||||
return d.replySteam(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseImg:
|
||||
return d.replyImg(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseFile:
|
||||
return d.replyFile(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseMarkdown:
|
||||
return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseActionCard:
|
||||
return d.replyActionCard(ctx, data.SessionWebhook, resp.Content)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DingBotService) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingBotService) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingBotService) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingBotService) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingBotService) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingBotService) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
|
@ -2,22 +2,18 @@ package tools_bot
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/constants"
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"context"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"context"
|
||||
)
|
||||
|
||||
type BotTool struct {
|
||||
config *config.Config
|
||||
llm *utils_ollama.Client
|
||||
sessionImpl *impl.SessionImpl
|
||||
taskMap map[string]string // task_id -> session_id
|
||||
// zltxOrderAfterSaleTool tools.ZltxOrderAfterSaleTool
|
||||
taskMap map[string]string
|
||||
}
|
||||
|
||||
// NewBotTool 创建直连天下订单详情工具
|
||||
|
|
@ -27,12 +23,5 @@ func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *im
|
|||
|
||||
// Execute 执行直连天下订单详情查询
|
||||
func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) {
|
||||
switch toolName {
|
||||
case constants.BotToolsBugOptimizationSubmit:
|
||||
err = w.BugOptimizationSubmit(ctx, requireData)
|
||||
default:
|
||||
log.Errorf("未知的工具类型:%s", toolName)
|
||||
err = errors.ParamErr("未知的工具类型:%s", toolName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue