Merge branch 'v3' into feature/fzy/refine

This commit is contained in:
fuzhongyun 2025-12-10 16:20:51 +08:00
commit b243d03226
42 changed files with 1094 additions and 386 deletions

View File

@ -2,7 +2,7 @@ package main
import (
"ai_scheduler/internal/config"
"context"
"flag"
"fmt"
@ -11,6 +11,7 @@ import (
func main() {
configPath := flag.String("config", "./config/config_test.yaml", "Path to configuration file")
onBot := flag.String("bot", "", "bot start")
flag.Parse()
bc, err := config.LoadConfig(*configPath)
if err != nil {
@ -23,8 +24,8 @@ func main() {
}
defer func() {
cleanup()
}()
app.DingBotServer.Run(context.Background(), *onBot)
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
}

View File

@ -136,3 +136,8 @@ llm:
max_tokens: 4096
stream: true
#ding_talk_bots:
# public:
# client_id: "dingchg59zwwvmuuvldx",
# client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz",

View File

@ -14,8 +14,10 @@ fi
CONFIG_FILE="config/config.yaml"
BRANCH="master"
BOT="ALL"
if [ "$MODE" = "dev" ]; then
CONFIG_FILE="config/config_test.yaml"
BOT="zltx"
BRANCH="test"
fi
@ -33,6 +35,6 @@ docker run -itd \
-e "OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}" \
-e "MODE=${MODE}" \
-p 8090:8090 \
"${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}"
"${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}" --bot "./${BOT}"
docker logs -f ${CONTAINER_NAME}

6
go.mod
View File

@ -1,6 +1,6 @@
module ai_scheduler
go 1.24.0
go 1.24.7
require (
gitea.cdlsxd.cn/self-tools/l_request v1.0.8
@ -14,6 +14,7 @@ require (
github.com/emirpasic/gods v1.18.1
github.com/faabiosr/cachego v0.26.0
github.com/fastwego/dingding v1.0.0-beta.4
github.com/gabriel-vasile/mimetype v1.4.11
github.com/go-kratos/kratos/v2 v2.9.1
github.com/go-playground/locales v0.14.1
github.com/go-playground/universal-translator v0.18.1
@ -23,6 +24,7 @@ require (
github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/ollama/ollama v0.12.7
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/redis/go-redis/v9 v9.16.0
github.com/spf13/viper v1.17.0
github.com/tmc/langchaingo v0.1.13
@ -58,9 +60,9 @@ require (
github.com/evanphx/json-patch v0.5.2 // indirect
github.com/fasthttp/websocket v1.5.3 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goph/emperror v0.17.2 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect

8
go.sum
View File

@ -178,8 +178,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
@ -271,6 +271,8 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
@ -346,6 +348,8 @@ github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1ls
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@ -26,10 +26,18 @@ func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *C
// 查询会话历史
func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) {
chats, err := s.chatHiRepo.FindAll(
con := []impl.CondFunc{
s.chatHiRepo.WithSessionId(query.SessionID),
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
s.chatHiRepo.OrderByDesc("his_id"),
}
if query.HisID > 0 {
con = append(con, s.chatHiRepo.WithHisId(query.HisID))
}
chats, err := s.chatHiRepo.FindAll(
con...,
)
if err != nil {
return nil, err

View File

@ -0,0 +1,142 @@
package biz
import (
"ai_scheduler/internal/biz/do"
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/mapstructure"
"context"
"fmt"
"github.com/gofiber/fiber/v2/log"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"xorm.io/builder"
)
// AiRouterBiz 智能路由服务
type DingTalkBotBiz struct {
do *do.Do
handle *do.Handle
botConfigImpl *impl.BotConfigImpl
replier *chatbot.ChatbotReplier
log log.Logger
}
// NewDingTalkBotBiz
func NewDingTalkBotBiz(
do *do.Do,
handle *do.Handle,
botConfigImpl *impl.BotConfigImpl,
) *DingTalkBotBiz {
return &DingTalkBotBiz{
do: do,
handle: handle,
botConfigImpl: botConfigImpl,
replier: chatbot.NewChatbotReplier(),
}
}
func (d *DingTalkBotBiz) GetDingTalkBotCfgList() (dingBotList []entitys.DingTalkBot, err error) {
botConfig := make([]model.AiBotConfig, 0)
cond := builder.NewCond()
cond = cond.And(builder.Eq{"status": constants.Enable})
cond = cond.And(builder.Eq{"bot_type": constants.BotTypeDingTalk})
err = d.botConfigImpl.GetRangeToMapStruct(&cond, &botConfig)
for _, v := range botConfig {
var config entitys.DingTalkBot
err = mapstructure.Decode(v, &config)
if err != nil {
d.log.Info("初始化“%s”失败:%s", v.BotName, err.Error())
}
dingBotList = append(dingBotList, config)
}
return
}
func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallbackDataModel) (requireData *entitys.RequireDataDingTalkBot, err error) {
requireData = &entitys.RequireDataDingTalkBot{
Req: data,
Ch: make(chan entitys.Response, 2),
}
entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等")
requireData.Sys, err = d.do.GetSysInfoForDingTalkBot(requireData)
requireData.Tasks, err = d.do.GetTasks(requireData.Sys.SysID)
return
}
func (d *DingTalkBotBiz) Recognize(ctx context.Context, bot *chatbot.BotCallbackDataModel) (match entitys.Match, err error) {
return d.handle.Recognize(ctx, nil, &do.WithDingTalkBot{})
}
func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error {
switch resp.Type {
case entitys.ResponseText:
return d.replyText(ctx, data.SessionWebhook, resp.Content)
case entitys.ResponseStream:
return d.replySteam(ctx, data.SessionWebhook, resp.Content)
case entitys.ResponseImg:
return d.replyImg(ctx, data.SessionWebhook, resp.Content)
case entitys.ResponseFile:
return d.replyFile(ctx, data.SessionWebhook, resp.Content)
case entitys.ResponseMarkdown:
return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content)
case entitys.ResponseActionCard:
return d.replyActionCard(ctx, data.SessionWebhook, resp.Content)
default:
return nil
}
}
func (d *DingTalkBotBiz) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
msg := content
if len(arg) > 0 {
msg = fmt.Sprintf(content, arg)
}
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
}
func (d *DingTalkBotBiz) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
msg := content
if len(arg) > 0 {
msg = fmt.Sprintf(content, arg)
}
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
}
func (d *DingTalkBotBiz) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
msg := content
if len(arg) > 0 {
msg = fmt.Sprintf(content, arg)
}
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
}
func (d *DingTalkBotBiz) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
msg := content
if len(arg) > 0 {
msg = fmt.Sprintf(content, arg)
}
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
}
func (d *DingTalkBotBiz) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
msg := content
if len(arg) > 0 {
msg = fmt.Sprintf(content, arg)
}
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
}
func (d *DingTalkBotBiz) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
msg := content
if len(arg) > 0 {
msg = fmt.Sprintf(content, arg)
}
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
}

View File

@ -79,6 +79,20 @@ func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData *
return nil
}
func (d *Do) DataAuthForBot(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) {
// 2. 加载系统信息
if err = d.loadSystemInfo(ctx, client, requireData); err != nil {
return fmt.Errorf("获取系统信息失败: %w", err)
}
// 3. 加载任务列表
if err = d.loadTaskList(ctx, client, requireData); err != nil {
return fmt.Errorf("获取任务列表失败: %w", err)
}
return nil
}
// 提取数据验证为单独函数
func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error {
requireData.Session = client.GetSession()
@ -102,7 +116,7 @@ func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.Req
// 获取系统信息的辅助函数
func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
if sysInfo := client.GetSysInfo(); sysInfo == nil {
sys, err := d.getSysInfo(requireData)
sys, err := d.GetSysInfo(requireData)
if err != nil {
return err
}
@ -117,7 +131,7 @@ func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, require
// 获取任务列表的辅助函数
func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
if taskInfo := client.GetTasks(); len(taskInfo) == 0 {
tasks, err := d.getTasks(requireData.Sys.SysID)
tasks, err := d.GetTasks(requireData.Sys.SysID)
if err != nil {
return err
}
@ -200,7 +214,16 @@ func (d *Do) getRequireData() (err error) {
return
}
func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": requireData.Key})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": requireData.Key})
cond = cond.And(builder.IsNull{"delete_at"})
@ -219,7 +242,7 @@ func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.Ai
return
}
func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
func (d *Do) GetTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})

View File

@ -9,7 +9,6 @@ import (
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/tools"
@ -46,23 +45,24 @@ func NewHandle(
}
}
func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
entitys.ResLog(requireData.Ch, "recognize_start", "准备意图识别")
func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) {
entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别")
prompt, err := promptProcessor.CreatePrompt(ctx, rec)
//意图识别
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{
Prompt: prompt,
Tools: rec.Tasks,
})
if err != nil {
return
}
entitys.ResLog(requireData.Ch, "recognize", recognizeMsg)
entitys.ResLog(requireData.Ch, "recognize_end", "意图识别结束")
entitys.ResLog(rec.Ch, "recognize", recognizeMsg)
entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束")
var match entitys.Match
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
err = errors.SysErr("数据结构错误:%v", err.Error())
return
}
requireData.Match = &match
return
}
@ -269,7 +269,7 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require
if err != nil {
return
}
entitys.ResJson(requireData.Ch, "", pkg.JsonStringIgonErr(res.Text))
entitys.ResJson(requireData.Ch, "", res.Text)
return
}

126
internal/biz/do/prompt.go Normal file
View File

@ -0,0 +1,126 @@
package do
import (
"ai_scheduler/internal/biz/handle"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"context"
"strings"
"github.com/ollama/ollama/api"
)
type PromptOption interface {
CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error)
}
type WithSys struct {
}
func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
var (
prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片
)
// 获取用户内容,如果出错则直接返回错误
content, err := f.getUserContent(ctx, rec)
if err != nil {
return nil, err
}
// 构建提示消息列表,包含系统提示、助手回复和用户内容
mes = append(prompt, api.Message{
Role: "system", // 系统角色
Content: rec.SystemPrompt, // 系统提示内容
}, api.Message{
Role: "assistant", // 助手角色
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容
}, api.Message{
Role: "user", // 用户角色
Content: content.String(), // 用户输入内容
})
return
}
func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
var hasFile bool
if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 {
hasFile = true
}
content.WriteString(rec.UserContent.Text)
if hasFile {
content.WriteString("\n")
}
if len(rec.UserContent.Tag) > 0 {
content.WriteString("\n")
content.WriteString("### 工具必须使用:")
content.WriteString(rec.UserContent.Tag)
}
if len(rec.ChatHis.Messages) > 0 {
content.WriteString("### 引用历史聊天记录:\n")
content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis))
}
if hasFile {
content.WriteString("\n")
content.WriteString("### 文件内容:\n")
for _, file := range rec.UserContent.File {
handle.HandleRecognizeFile(file)
}
//...do something with file
}
return
}
type WithDingTalkBot struct {
}
func (f *WithDingTalkBot) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
var (
prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片
)
// 获取用户内容,如果出错则直接返回错误
content, err := f.getUserContent(ctx, rec)
if err != nil {
return nil, err
}
// 构建提示消息列表,包含系统提示、助手回复和用户内容
mes = append(prompt, api.Message{
Role: "system", // 系统角色
Content: rec.SystemPrompt, // 系统提示内容
}, api.Message{
Role: "assistant", // 助手角色
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容
}, api.Message{
Role: "user", // 用户角色
Content: content.String(), // 用户输入内容
})
return
}
func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
var hasFile bool
if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 {
hasFile = true
}
content.WriteString(rec.UserContent.Text)
if hasFile {
content.WriteString("\n")
}
if len(rec.UserContent.Tag) > 0 {
content.WriteString("\n")
content.WriteString("### 工具必须使用:")
content.WriteString(rec.UserContent.Tag)
}
if len(rec.ChatHis.Messages) > 0 {
content.WriteString("### 引用历史聊天记录:\n")
content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis))
}
return
}

View File

@ -0,0 +1,84 @@
package handle
import (
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/l_request"
"errors"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"github.com/gabriel-vasile/mimetype"
)
// HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式
// 判断文件大小
// 判断文件类型
// 判断文件是否合法
func HandleRecognizeFile(files *entitys.RecognizeFile) {
//Todo 仲云
return
}
// 下载文件并返回二进制数据、MIME 类型
func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err error) {
if len(fileUrl) == 0 {
return
}
req := l_request.Request{
Method: "GET",
Url: fileUrl,
Headers: map[string]string{
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
},
}
res, err := req.Send()
if err != nil {
return
}
var ex bool
if contentType, ex = res.Headers["Content-Type"]; !ex {
err = errors.New("Content-Type不存在")
return
}
if res.StatusCode != http.StatusOK {
err = fmt.Errorf("server returned non-200 status: %d", res.StatusCode)
}
fileBytes = res.Content
return fileBytes, contentType, nil
}
// detectFileType 判断文件类型
func detectFileType(file io.ReadSeeker, filename string) constants.FileType {
// 1. 读取文件头检测 MIME
buffer := make([]byte, 512)
n, _ := file.Read(buffer)
file.Seek(0, io.SeekStart) // 重置读取位置
detectedMIME := mimetype.Detect(buffer[:n]).String()
for fileType, items := range constants.FileTypeMappings {
for _, item := range items {
if !strings.HasPrefix(item, ".") && item == detectedMIME {
return fileType
}
}
}
// 2. 备用:通过扩展名检测
ext := strings.ToLower(filepath.Ext(filename))
for fileType, items := range constants.FileTypeMappings {
for _, item := range items {
if strings.HasPrefix(item, ".") && item == ext {
return fileType
}
}
}
return constants.FileTypeUnknown
}

View File

@ -1,22 +0,0 @@
package handle
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/tools"
)
type Handle struct {
toolManager *tools.Manager
conf *config.Config
}
func NewHandle(
toolManager *tools.Manager,
conf *config.Config,
) *Handle {
return &Handle{
toolManager: toolManager,
conf: conf,
}
}

View File

@ -1,9 +1,7 @@
package llm_service
import (
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"context"
"time"
)
@ -20,48 +18,3 @@ func buildSystemPrompt(prompt string) string {
return prompt
}
func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, []entitys.HisMessage{
{
Role: constants.RoleUser,
Content: item.Ques,
Timestamp: item.CreateAt.Format(time.DateTime),
},
{
Role: constants.RoleAssistant,
Content: item.Ans,
Timestamp: item.CreateAt.Format(time.DateTime),
},
}...)
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}
func BuildChatHisMessage(his []model.AiChatHi) (chatHis []entitys.HisMessage) {
for _, item := range his {
chatHis = append(chatHis, []entitys.HisMessage{
{
Role: constants.RoleUser,
Content: item.Ques,
Timestamp: item.CreateAt.Format(time.DateTime),
},
{
Role: constants.RoleAssistant,
Content: item.Ans,
Timestamp: item.CreateAt.Format(time.DateTime),
},
}...)
}
return
}

View File

@ -1,87 +1,76 @@
package llm_service
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_langchain"
"context"
"encoding/json"
"github.com/tmc/langchaingo/llms"
)
type LangChainService struct {
client *utils_langchain.UtilLangChain
}
func NewLangChainGenerate(
client *utils_langchain.UtilLangChain,
) *LangChainService {
return &LangChainService{
client: client,
}
}
func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
prompt := r.getPrompt(sysInfo, history, userInput, tasks)
AgentClient := r.client.Get()
defer r.client.Put(AgentClient)
match, err := AgentClient.Llm.GenerateContent(
ctx, // 使用可取消的上下文
prompt,
llms.WithJSONMode(),
)
msg = match.Choices[0].Content
return
}
func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
var (
prompt = make([]llms.MessageContent, 0)
)
prompt = append(prompt, llms.MessageContent{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart(reqInput),
},
})
return prompt
}
func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfig
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, llms.Tool{
Type: "function",
Function: &llms.FunctionDefinition{
Name: task.Index,
Description: task.Desc,
Parameters: taskConfig.Param,
},
})
}
return taskPrompt
}
//type LangChainService struct {
// client *utils_langchain.UtilLangChain
//}
//
//func NewLangChainGenerate(
// client *utils_langchain.UtilLangChain,
//) *LangChainService {
//
// return &LangChainService{
// client: client,
// }
//}
//
//func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
// prompt := r.getPrompt(sysInfo, history, userInput, tasks)
// AgentClient := r.client.Get()
// defer r.client.Put(AgentClient)
// match, err := AgentClient.Llm.GenerateContent(
// ctx, // 使用可取消的上下文
// prompt,
// llms.WithJSONMode(),
// )
// msg = match.Choices[0].Content
// return
//}
//
//func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
// var (
// prompt = make([]llms.MessageContent, 0)
// )
// prompt = append(prompt, llms.MessageContent{
// Role: llms.ChatMessageTypeSystem,
// Parts: []llms.ContentPart{
// llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
// },
// }, llms.MessageContent{
// Role: llms.ChatMessageTypeTool,
// Parts: []llms.ContentPart{
// llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
// },
// }, llms.MessageContent{
// Role: llms.ChatMessageTypeTool,
// Parts: []llms.ContentPart{
// llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
// },
// }, llms.MessageContent{
// Role: llms.ChatMessageTypeHuman,
// Parts: []llms.ContentPart{
// llms.TextPart(reqInput),
// },
// })
// return prompt
//}
//
//func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
// taskPrompt := make([]llms.Tool, 0)
// for _, task := range tasks {
// var taskConfig entitys.TaskConfig
// err := json.Unmarshal([]byte(task.Config), &taskConfig)
// if err != nil {
// continue
// }
// taskPrompt = append(taskPrompt, llms.Tool{
// Type: "function",
// Function: &llms.FunctionDefinition{
// Name: task.Index,
// Description: task.Desc,
// Parameters: taskConfig.Param,
// },
// })
//
// }
// return taskPrompt
//}

View File

@ -3,19 +3,15 @@ package llm_service
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/pkg/utils_vllm"
"context"
"encoding/json"
"errors"
"strings"
"time"
"github.com/ollama/ollama/api"
"xorm.io/builder"
)
type OllamaService struct {
@ -39,14 +35,10 @@ func NewOllamaGenerate(
}
}
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
prompt, err := r.getPrompt(ctx, requireData)
if err != nil {
return
}
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
match, err := r.client.ToolSelect(ctx, prompt, toolDefinitions)
toolDefinitions := r.registerToolsOllama(req.Tools)
match, err := r.client.ToolSelect(ctx, req.Prompt, toolDefinitions)
if err != nil {
return
}
@ -70,130 +62,63 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity
return
}
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
//func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
// if imgByte == nil {
// return
// }
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
//
// desc, err = r.client.Generation(ctx, &api.GenerateRequest{
// Model: r.config.Ollama.VlModel,
// Stream: new(bool),
// System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
// Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
// Images: requireData.ImgByte,
// KeepAlive: &api.Duration{Duration: 3600 * time.Second},
// //Think: &api.ThinkValue{Value: false},
// })
// if err != nil {
// return
// }
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
// return
//}
var (
prompt = make([]api.Message, 0)
)
content, err := r.getUserContent(ctx, requireData)
if err != nil {
return nil, err
}
prompt = append(prompt, api.Message{
Role: "system",
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
}, api.Message{
Role: "assistant",
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)),
}, api.Message{
Role: "user",
Content: content,
})
//func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
// if requireData.ImgByte == nil {
// return
// }
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
//
// outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
// r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
// r.config.DefaultPrompt.ImgRecognize.UserPrompt,
// requireData.ImgUrls,
// )
// if err != nil {
// return api.GenerateResponse{}, err
// }
//
// desc = api.GenerateResponse{
// Response: outMsg.Content,
// }
//
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
// return
//}
return prompt, nil
}
func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) {
var content strings.Builder
content.WriteString(requireData.Req.Text)
if len(requireData.ImgByte) > 0 {
content.WriteString("\n")
}
if len(requireData.Req.Tags) > 0 {
content.WriteString("\n")
content.WriteString("### 工具必须使用:")
content.WriteString(requireData.Req.Tags)
}
if len(requireData.ImgByte) > 0 {
// desc, err := r.RecognizeWithImg(ctx, requireData)
desc, err := r.RecognizeWithImgVllm(ctx, requireData)
if err != nil {
return "", err
}
content.WriteString("### 上传图片解析内容:\n")
content.WriteString(requireData.Req.Tags)
content.WriteString(desc.Response)
}
if requireData.Req.MarkHis > 0 {
var his model.AiChatHi
cond := builder.NewCond()
cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis})
err := r.chatHis.GetOneBySearchToStrut(&cond, &his)
if err != nil {
return "", err
}
content.WriteString("### 引用历史聊天记录:\n")
content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his})))
}
return content.String(), nil
}
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
if requireData.ImgByte == nil {
return
}
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
Model: r.config.Ollama.VlModel,
Stream: new(bool),
System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
Images: requireData.ImgByte,
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
//Think: &api.ThinkValue{Value: false},
})
if err != nil {
return
}
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
return
}
func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
if requireData.ImgByte == nil {
return
}
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
r.config.DefaultPrompt.ImgRecognize.UserPrompt,
requireData.ImgUrls,
)
if err != nil {
return api.GenerateResponse{}, err
}
desc = api.GenerateResponse{
Response: outMsg.Content,
}
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
return
}
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
func (r *OllamaService) registerToolsOllama(tasks []entitys.RegistrationTask) []api.Tool {
taskPrompt := make([]api.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfigDetail
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: task.Index,
Name: task.Name,
Description: task.Desc,
Parameters: api.ToolFunctionParameters{
Type: taskConfig.Param.Type,
Required: taskConfig.Param.Required,
Properties: taskConfig.Param.Properties,
Type: task.TaskConfigDetail.Param.Type,
Required: task.TaskConfigDetail.Param.Required,
Properties: task.TaskConfigDetail.Param.Properties,
},
},
})

View File

@ -2,7 +2,6 @@ package biz
import (
"ai_scheduler/internal/biz/do"
"ai_scheduler/internal/biz/handle"
"ai_scheduler/internal/biz/llm_service"
"github.com/google/wire"
@ -12,10 +11,11 @@ var ProviderSetBiz = wire.NewSet(
NewAiRouterBiz,
NewSessionBiz,
NewChatHistoryBiz,
llm_service.NewLangChainGenerate,
//llm_service.NewLangChainGenerate,
llm_service.NewOllamaGenerate,
handle.NewHandle,
//handle.NewHandle,
do.NewDo,
do.NewHandle,
NewTaskBiz,
NewDingTalkBotBiz,
)

View File

@ -2,7 +2,12 @@ package biz
import (
"ai_scheduler/internal/biz/do"
"ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/gateway"
"context"
"encoding/json"
"time"
"ai_scheduler/internal/entitys"
@ -54,9 +59,15 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
log.Errorf("数据验证和收集失败: %s", err.Error())
return
}
//组装意图识别
rec, sys, err := r.SetRec(ctx, requireData)
if err != nil {
log.Errorf("组装意图识别失败: %s", err.Error())
return
}
//意图识别
if err = r.handle.Recognize(ctx, requireData); err != nil {
requireData.Match, err = r.handle.Recognize(ctx, &rec, sys)
if err != nil {
log.Errorf("意图识别失败: %s", err.Error())
return
}
@ -68,3 +79,107 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
}
return
}
func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, sys do.PromptOption, err error) {
// 参数空值检查
if requireData == nil || requireData.Req == nil {
return match, sys, errors.NewBusinessErr(500, "请求参数为空")
}
// 对应不同的appKey, 配置不同的系统提示词
switch requireData.Sys.AppKey {
default:
sys = &do.WithSys{}
}
// 1. 系统提示词
match.SystemPrompt = requireData.Sys.SysPrompt
// 2. 用户输入和文件处理
match.UserContent, err = r.buildUserContent(requireData)
if err != nil {
log.Errorf("构建用户内容失败: %s", err.Error())
return
}
// 3. 聊天记录 - 只有在有历史记录时才构建
if len(requireData.Histories) > 0 {
match.ChatHis = r.buildChatHistory(requireData)
}
// 4. 任务列表 - 预分配切片容量
if len(requireData.Tasks) > 0 {
match.Tasks = make([]entitys.RegistrationTask, 0, len(requireData.Tasks))
for _, task := range requireData.Tasks {
taskConfig := entitys.TaskConfigDetail{}
if err = json.Unmarshal([]byte(task.Config), &taskConfig); err != nil {
log.Errorf("解析任务配置失败: %s, 任务ID: %s", err.Error(), task.Index)
continue // 解析失败时跳过该任务,而不是直接返回错误
}
match.Tasks = append(match.Tasks, entitys.RegistrationTask{
Name: task.Index,
Desc: task.Desc,
TaskConfigDetail: taskConfig, // 直接使用解析后的配置,避免重复构建
})
}
}
match.Ch = requireData.Ch
return
}
// buildUserContent 构建用户内容
func (r *AiRouterBiz) buildUserContent(requireData *entitys.RequireData) (*entitys.RecognizeUserContent, error) {
// 预分配文件切片容量最多2个文件File和Img
files := make([]*entitys.RecognizeFile, 0, 2)
// 处理文件和图片
fileUrls := []string{requireData.Req.File, requireData.Req.Img}
for _, url := range fileUrls {
if url != "" {
files = append(files, &entitys.RecognizeFile{FileUrl: url})
}
}
// 构建并返回用户内容
return &entitys.RecognizeUserContent{
Text: requireData.Req.Text,
File: files,
ActionCardUrl: "", // TODO: 后续实现操作卡片功能
Tag: requireData.Req.Tags,
}, nil
}
// buildChatHistory 构建聊天历史
func (r *AiRouterBiz) buildChatHistory(requireData *entitys.RequireData) entitys.ChatHis {
// 预分配消息切片容量每个历史记录生成2条消息
messages := make([]entitys.HisMessage, 0, len(requireData.Histories)*2)
// 构建聊天记录
for _, h := range requireData.Histories {
// 用户消息
messages = append(messages, entitys.HisMessage{
Role: constants.RoleUser, // 用户角色
Content: h.Ans, // 用户输入内容
Timestamp: h.CreateAt.Format(time.DateTime),
})
// 助手消息
messages = append(messages, entitys.HisMessage{
Role: constants.RoleAssistant, // 助手角色
Content: h.Ques, // 助手回复内容
Timestamp: h.CreateAt.Format(time.DateTime),
})
}
// 构建聊天历史上下文
return entitys.ChatHis{
SessionId: requireData.Session,
Messages: messages,
Context: entitys.HisContext{
UserLanguage: constants.LangZhCN, // 默认中文
SystemMode: constants.SystemModeTechnicalSupport, // 默认技术支持模式
},
}
}

View File

@ -20,6 +20,7 @@ type Config struct {
DefaultPrompt SysPrompt `mapstructure:"default_prompt"`
PermissionConfig PermissionConfig `mapstructure:"permissionConfig"`
LLM LLM `mapstructure:"llm"`
// DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"`
}
type SysPrompt struct {

View File

@ -3,5 +3,31 @@ package constants
type BotTools string
const (
BotToolsBugOptimizationSubmit = "bug_optimization_submit" // 系统的bug/优化建议
BotToolsBugOptimizationSubmit BotTools = "bug_optimization_submit" // 系统的bug/优化建议
)
type ChatStyle int
const (
ChatStyleNormal ChatStyle = 1 //正常
ChatStyleSerious ChatStyle = 2 //严肃
ChatStyleGentle ChatStyle = 3 //温柔
ChatStyleArrogance ChatStyle = 4 //傲慢
ChatStyleCute ChatStyle = 5 //可爱
ChatStyleAngry ChatStyle = 6 //愤怒
)
var ChatStyleMap = map[ChatStyle]string{
ChatStyleNormal: "正常",
ChatStyleSerious: "严肃",
ChatStyleGentle: "温柔",
ChatStyleArrogance: "傲慢",
ChatStyleCute: "可爱",
ChatStyleAngry: "愤怒",
}
type BotType int
const (
BotTypeDingTalk BotType = 1 // 系统的bug/优化建议
)

View File

@ -13,6 +13,14 @@ const (
// 分页默认条数
ChatHistoryLimit = 10
// 语言
LangZhCN = "zh-CN" // 中文
// 系统模式
SystemModeDefault = "default" // 默认模式
// 系统模式 "technical_support", // 技术支持模式
SystemModeTechnicalSupport = "technical_support" // 技术支持模式
)
func (c Caller) String() string {

View File

@ -0,0 +1,38 @@
package constants
type FileType string
const (
FileTypeUnknown FileType = "unknown"
FileTypeImage FileType = "image"
//FileTypeVideo FileType = "video"
FileTypeExcel FileType = "excel"
FileTypeWord FileType = "word"
FileTypeTxt FileType = "txt"
FileTypePDF FileType = "pdf"
)
var FileTypeMappings = map[FileType][]string{
FileTypeImage: {
"image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml",
".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg",
},
FileTypeExcel: {
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
".xls", ".xlsx",
},
FileTypeWord: {
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".doc", ".docx",
},
FileTypePDF: {
"application/pdf",
".pdf",
},
FileTypeTxt: {
"text/plain",
".txt",
},
}

View File

@ -22,7 +22,7 @@ BaseModel 是一个泛型结构体用于封装GORM数据库通用操作。
// 定义受支持的PO类型集合可根据需要扩展, 只有包含表结构才能使用BaseModel避免使用出现问题
type PO interface {
model.AiChatHi |
model.AiSy | model.AiSession | model.AiTask | model.AiBot
model.AiSy | model.AiSession | model.AiTask | model.AiBotConfig
}
type BaseModel[P PO] struct {

View File

@ -0,0 +1,15 @@
package impl
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils"
)
type BotChatHisImpl struct {
dataTemp.DataTemp
}
func NewBotChatHisImpl(db *utils.Db) *BotChatHisImpl {
return &BotChatHisImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))}
}

View File

@ -0,0 +1,17 @@
package impl
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils"
)
type BotConfigImpl struct {
dataTemp.DataTemp
}
func NewBotConfigImpl(db *utils.Db) *BotConfigImpl {
return &BotConfigImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotConfig)),
}
}

View File

@ -1,28 +0,0 @@
package impl
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils"
"gorm.io/gorm"
)
type BotImpl struct {
dataTemp.DataTemp
BaseRepository[model.AiBot]
}
func NewBotImpl(db *utils.Db) *BotImpl {
return &BotImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBot)),
BaseRepository: NewBaseModel[model.AiBot](db.Client),
}
}
// WithSysId 系统id
func (s *BotImpl) WithSysId(sysId interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("sys_id = ?", sysId)
}
}

View File

@ -4,4 +4,10 @@ import (
"github.com/google/wire"
)
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl)
var ProviderImpl = wire.NewSet(
NewSessionImpl,
NewSysImpl,
NewTaskImpl,
NewChatHisImpl,
NewBotConfigImpl,
)

View File

@ -0,0 +1,27 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
import (
"time"
)
const TableNameAiBotChatHi = "ai_bot_chat_his"
// AiBotChatHi mapped from table <ai_bot_chat_his>
type AiBotChatHi struct {
HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"`
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
Role string `gorm:"column:role;not null;comment:system系统输出assistant助手输出,user用户输入" json:"role"` // system系统输出assistant助手输出,user用户输入
Content string `gorm:"column:content;not null" json:"content"`
Files string `gorm:"column:files;not null" json:"files"`
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
}
// TableName AiBotChatHi's table name
func (*AiBotChatHi) TableName() string {
return TableNameAiBotChatHi
}

View File

@ -8,13 +8,13 @@ import (
"time"
)
const TableNameAiBot = "ai_bot"
const TableNameAiBotConfig = "ai_bot_config"
// AiBot mapped from table <ai_bot>
type AiBot struct {
// AiBotConfig mapped from table <ai_bot_config>
type AiBotConfig struct {
BotID int32 `gorm:"column:bot_id;primaryKey;autoIncrement:true" json:"bot_id"`
SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"`
BotType int32 `gorm:"column:bot_type" json:"bot_type"`
BotType int32 `gorm:"column:bot_type;not null;default:1" json:"bot_type"`
BotName string `gorm:"column:bot_name;not null" json:"bot_name"`
BotConfig string `gorm:"column:bot_config;not null" json:"bot_config"`
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
@ -23,7 +23,7 @@ type AiBot struct {
DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"`
}
// TableName AiBot's table name
func (*AiBot) TableName() string {
return TableNameAiBot
// TableName AiBotConfig's table name
func (*AiBotConfig) TableName() string {
return TableNameAiBotConfig
}

View File

@ -1,7 +1,30 @@
package entitys
type BotType int
import (
"ai_scheduler/internal/data/model"
const (
BugAndQuesDingTalk BotType = iota + 1
"github.com/ollama/ollama/api"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
)
type RequireDataDingTalkBot struct {
Session string
Key string
Sys model.AiSy
Histories []model.AiChatHi
SessionInfo model.AiSession
Tasks []model.AiTask
Match *Match
Req *chatbot.BotCallbackDataModel
Auth string
Ch chan Response
KnowledgeConf KnowledgeBaseRequest
ImgByte []api.ImageData
ImgUrls []string
}
type DingTalkBot struct {
BotIndex string
ClientId string
ClientSecret string
}

View File

@ -19,6 +19,7 @@ type ChatHisLog struct {
}
type ChatHistQuery struct {
HisID int64 `json:"his_id"`
SessionID string `json:"session_id"`
Page int `json:"page"`
PageSize int `json:"page_size"`

View File

@ -0,0 +1,10 @@
package entitys
import (
"github.com/ollama/ollama/api"
)
type ToolSelect struct {
Prompt []api.Message
Tools []RegistrationTask
}

View File

@ -0,0 +1,34 @@
package entitys
import (
"ai_scheduler/internal/data/constants"
)
type Recognize struct {
SystemPrompt string // 系统提示内容
UserContent *RecognizeUserContent // 用户输入内容
ChatHis ChatHis // 会话历史记录
Tasks []RegistrationTask
Ch chan Response
}
type RegistrationTask struct {
Name string
Desc string
TaskConfigDetail TaskConfigDetail
}
type RecognizeUserContent struct {
Text string // 用户输入的文本内容
File []*RecognizeFile // 文件内容
ActionCardUrl string // 操作卡片链接
Tag string // 工具标签
}
type FileData []byte
type RecognizeFile struct {
File []FileData // 文件数据(二进制格式)
FileUrl string // 文件下载链接
FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断)
}

View File

@ -3,22 +3,25 @@ package entitys
import (
"ai_scheduler/internal/gateway"
"encoding/json"
"github.com/gofiber/websocket/v2"
)
type ResponseType string
const (
ResponseJson ResponseType = "json"
ResponseLoading ResponseType = "loading"
ResponseEnd ResponseType = "end"
ResponseStream ResponseType = "stream"
ResponseText ResponseType = "txt"
ResponseImg ResponseType = "img"
ResponseFile ResponseType = "file"
ResponseErr ResponseType = "error"
ResponseLog ResponseType = "log"
ResponseAuth ResponseType = "auth"
ResponseJson ResponseType = "json"
ResponseLoading ResponseType = "loading"
ResponseEnd ResponseType = "end"
ResponseStream ResponseType = "stream"
ResponseText ResponseType = "txt"
ResponseImg ResponseType = "img"
ResponseFile ResponseType = "file"
ResponseErr ResponseType = "error"
ResponseLog ResponseType = "log"
ResponseAuth ResponseType = "auth"
ResponseMarkdown ResponseType = "markdown"
ResponseActionCard ResponseType = "actionCard"
)
func ResLog(ch chan Response, index string, content string) {
@ -46,6 +49,9 @@ func ResJson(ch chan Response, index string, content string) {
}
func ResEnd(ch chan Response, index string, content string) {
if ch == nil {
return
}
ch <- Response{
Index: index,
Content: content,
@ -54,6 +60,9 @@ func ResEnd(ch chan Response, index string, content string) {
}
func ResText(ch chan Response, index string, content string) {
if ch == nil {
return
}
ch <- Response{
Index: index,
Content: content,
@ -62,6 +71,9 @@ func ResText(ch chan Response, index string, content string) {
}
func ResLoading(ch chan Response, index string, content string) {
if ch == nil {
return
}
ch <- Response{
Index: index,
Content: content,
@ -69,6 +81,9 @@ func ResLoading(ch chan Response, index string, content string) {
}
}
func ResError(ch chan Response, index string, content string) {
if ch == nil {
return
}
ch <- Response{
Index: index,
Content: content,

View File

@ -3,7 +3,6 @@ package entitys
import (
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/model"
"context"
"encoding/json"
@ -139,8 +138,8 @@ type HisMessage struct {
}
type HisContext struct {
UserLanguage string `json:"user_language"`
SystemMode string `json:"system_mode"`
UserLanguage string `json:"user_language"` // 用户语言
SystemMode string `json:"system_mode"` // 系统模式,
}
type RequireData struct {
@ -151,7 +150,7 @@ type RequireData struct {
SessionInfo model.AiSession
Tasks []model.AiTask
Task model.AiTask
Match *Match
Match Match
Req *ChatSockRequest
Auth string
Ch chan Response

View File

@ -0,0 +1,75 @@
package server
import (
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/services"
"context"
"github.com/go-kratos/kratos/v2/log"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
)
type DingBotServiceInterface interface {
GetServiceCfg() ([]entitys.DingTalkBot, error)
OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error)
}
type DingTalkBotServer struct {
Clients map[string]*client.StreamClient
}
// NewDingTalkBotServer 批量注册钉钉客户端cli
// 这里支持两种方式一种是完全独立service,一种是直接用现成的service
// 独立的service,在本页的ProvideAllDingBotServices方法进行注册
// 现成的service参考services->dtalk_bot.go
// 具体使用请根据实际业务需求
func NewDingTalkBotServer(
services []DingBotServiceInterface,
) *DingTalkBotServer {
clients := make(map[string]*client.StreamClient)
for _, service := range services {
serviceConfigs, err := service.GetServiceCfg()
for _, serviceConf := range serviceConfigs {
if serviceConf.ClientId == "" || serviceConf.ClientSecret == "" {
continue
}
cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service)
if cli == nil {
log.Info("%s客户端初始失败:%s", serviceConf.BotIndex, err.Error())
continue
}
clients[serviceConf.BotIndex] = cli
}
}
return &DingTalkBotServer{
Clients: clients,
}
}
func ProvideAllDingBotServices(
dingBotSvc *services.DingBotService,
) []DingBotServiceInterface {
return []DingBotServiceInterface{dingBotSvc}
}
func (d *DingTalkBotServer) Run(ctx context.Context, botIndex string) {
for name, cli := range d.Clients {
if botIndex != "All" {
if name != botIndex {
continue
}
}
err := cli.Start(ctx)
if err != nil {
log.Info("%s启动失败", name)
continue
}
log.Info("%s启动成功", name)
}
}
func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) {
cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret)))
cli.RegisterChatBotCallbackRouter(service.OnChatBotMessageReceived)
return
}

View File

@ -1,5 +1,12 @@
package server
import "github.com/google/wire"
import (
"github.com/google/wire"
)
var ProviderSetServer = wire.NewSet(NewServers, NewHTTPServer)
var ProviderSetServer = wire.NewSet(
NewServers,
NewHTTPServer,
ProvideAllDingBotServices,
NewDingTalkBotServer,
)

View File

@ -1,13 +1,27 @@
package server
import "github.com/gofiber/fiber/v2"
import (
"ai_scheduler/internal/config"
"github.com/gofiber/fiber/v2"
)
type Servers struct {
HttpServer *fiber.App
cfg *config.Config
HttpServer *fiber.App
DingBotServer *DingTalkBotServer
}
func NewServers(fiber *fiber.App) *Servers {
func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer) *Servers {
return &Servers{
HttpServer: fiber,
HttpServer: fiber,
cfg: cfg,
DingBotServer: DingBotServer,
}
}
//func DingBotServerInit(clientId string, clientSecret string, cfg *config.Config, handler *do.Handle, do *do.Do) (cli *client.StreamClient) {
// cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret)))
// cli.RegisterChatBotCallbackRouter(services.NewDingBotService(cfg, handler, do).OnChatBotMessageReceived)
// return
//}

View File

@ -212,7 +212,7 @@ func (s *CallbackService) sendStreamLog(sessionID string, content string) {
}
streamLog := entitys.Response{
Index: constants.BotToolsBugOptimizationSubmit,
Index: string(constants.BotToolsBugOptimizationSubmit),
Content: content,
Type: entitys.ResponseLog,
}
@ -227,7 +227,7 @@ func (s *CallbackService) sendStreamTxt(sessionID string, content string) {
}
streamLog := entitys.Response{
Index: constants.BotToolsBugOptimizationSubmit,
Index: string(constants.BotToolsBugOptimizationSubmit),
Content: content,
Type: entitys.ResponseText,
}
@ -242,7 +242,7 @@ func (s *CallbackService) sendStreamLoading(sessionID string, content string) {
}
streamLog := entitys.Response{
Index: constants.BotToolsBugOptimizationSubmit,
Index: string(constants.BotToolsBugOptimizationSubmit),
Content: content,
Type: entitys.ResponseLoading,
}

View File

@ -0,0 +1,67 @@
package services
import (
"ai_scheduler/internal/biz"
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"context"
"fmt"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
)
type DingBotService struct {
config *config.Config
dingTalkBotBiz *biz.DingTalkBotBiz
}
func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService {
return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz}
}
func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) {
return d.dingTalkBotBiz.GetDingTalkBotCfgList()
}
func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) {
requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data)
if err != nil {
return
}
go func() {
defer close(requireData.Ch)
//if match, _err := d.dingTalkBotBiz.Recognize(ctx, data); _err != nil {
// requireData.Ch <- entitys.Response{
// Type: entitys.ResponseEnd,
// Content: fmt.Sprintf("处理消息时出错: %s", _err.Error()),
// }
//}
////向下传递
//if err = d.dingTalkBotBiz.HandleMatch(ctx, nil, requireData); err != nil {
// requireData.Ch <- entitys.Response{
// Type: entitys.ResponseEnd,
// Content: fmt.Sprintf("匹配失败: %v", err),
// }
//}
}()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case resp, ok := <-requireData.Ch:
if !ok {
return []byte("success"), nil // 通道关闭,处理完成
}
if resp.Type == entitys.ResponseLog {
return
}
if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil {
return nil, fmt.Errorf("回复失败: %w", err)
}
}
}
return
}

View File

@ -6,4 +6,10 @@ import (
"github.com/google/wire"
)
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService)
var ProviderSetServices = wire.NewSet(
NewChatService,
NewSessionService, gateway.NewGateway,
NewTaskService,
NewCallbackService,
NewDingBotService,
NewHistoryService)

View File

@ -59,6 +59,7 @@ func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
// Execute 执行知识库查询
func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
entitys.ResLoading(requireData.Ch, k.Name(), "正在为您搜索相关信息")
return k.chat(requireData)

View File

@ -2,22 +2,18 @@ package tools_bot
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"github.com/gofiber/fiber/v2/log"
"context"
)
type BotTool struct {
config *config.Config
llm *utils_ollama.Client
sessionImpl *impl.SessionImpl
taskMap map[string]string // task_id -> session_id
// zltxOrderAfterSaleTool tools.ZltxOrderAfterSaleTool
taskMap map[string]string
}
// NewBotTool 创建直连天下订单详情工具
@ -27,12 +23,5 @@ func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *im
// Execute 执行直连天下订单详情查询
func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) {
switch toolName {
case constants.BotToolsBugOptimizationSubmit:
err = w.BugOptimizationSubmit(ctx, requireData)
default:
log.Errorf("未知的工具类型:%s", toolName)
err = errors.ParamErr("未知的工具类型:%s", toolName)
}
return
}