结构优化与功能增强

This commit is contained in:
renzhiyuan 2025-11-11 09:26:18 +08:00
parent e30a79b681
commit 7076d6a918
13 changed files with 577 additions and 525 deletions

1
README2.md Normal file

File diff suppressed because one or more lines are too long

View File

@ -11,6 +11,7 @@ import (
"ai_scheduler/internal/server" "ai_scheduler/internal/server"
"ai_scheduler/internal/services" "ai_scheduler/internal/services"
"ai_scheduler/internal/tools" "ai_scheduler/internal/tools"
"ai_scheduler/internal/tools_bot"
"ai_scheduler/utils" "ai_scheduler/utils"
"github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/log"
@ -27,6 +28,7 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro
biz.ProviderSetBiz, biz.ProviderSetBiz,
impl.ProviderImpl, impl.ProviderImpl,
utils.ProviderUtils, utils.ProviderUtils,
tools_bot.ProviderSetBotTools,
)) ))
} }

248
internal/biz/do/ctx.go Normal file
View File

@ -0,0 +1,248 @@
package do
import (
"ai_scheduler/internal/config"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/tmpl/dataTemp"
"context"
"fmt"
"strconv"
"strings"
"time"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"xorm.io/builder"
)
type Do struct {
Ctx *entitys.RequireData
sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
}
func NewDo(
sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
) *Do {
return &Do{
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
}
}
func (d *Do) InitCtx(req *entitys.ChatSockRequest) *Do {
d.Ctx = &entitys.RequireData{
Req: req,
}
return d
}
func (d *Do) DataAuth(c *websocket.Conn) (err error) {
d.Ctx.Session = c.Query("x-session", "")
if len(d.Ctx.Session) == 0 {
err = errors.SessionNotFound
return
}
d.Ctx.Auth = c.Query("x-authorization", "")
if len(d.Ctx.Auth) == 0 {
err = errors.AuthNotFound
return
}
d.Ctx.Key = c.Query("x-app-key", "")
if len(d.Ctx.Key) == 0 {
err = errors.KeyNotFound
return
}
d.Ctx.Sys, err = d.getSysInfo()
if err != nil {
err = errors.SysErr("获取系统信息失败:%v", err.Error())
return
}
d.Ctx.Histories, err = d.getSessionChatHis()
if err != nil {
err = errors.SysErr("获取历史记录失败:%v", err.Error())
return
}
d.Ctx.Tasks, err = d.getTasks(d.Ctx.Sys.SysID)
if err != nil {
err = errors.SysErr("获取任务列表失败:%v", err.Error())
return
}
if err = d.getImgData(); err != nil {
return
}
return
}
func (d *Do) MakeCh(c *websocket.Conn) (ctx context.Context, deferFunc func()) {
d.Ctx.Ch = make(chan entitys.Response)
ctx, cancel := context.WithCancel(context.Background())
done := d.startMessageHandler(ctx, c, d.hisImpl)
return ctx, func() {
close(d.Ctx.Ch) //关闭主通道
<-done // 等待消息处理完成
cancel()
}
}
func (d *Do) getImgData() (err error) {
if len(d.Ctx.Req.Img) == 0 {
return
}
imgs := strings.Split(d.Ctx.Req.Img, ",")
if len(imgs) == 0 {
return
}
if err = pkg.ValidateImageURL(d.Ctx.Req.Img); err != nil {
return err
}
for k, img := range imgs {
baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:"
entitys.ResLog(d.Ctx.Ch, "", "获取第"+strconv.Itoa(k+1)+"张图片")
req := l_request.Request{
Method: "GET",
Url: img,
}
res, _err := req.Send()
if _err != nil {
entitys.ResLog(d.Ctx.Ch, "", baseErr+_err.Error())
continue
}
if _, ex := res.Headers["Content-Type"]; !ex {
entitys.ResLog(d.Ctx.Ch, "", baseErr+"Content-Type不存在")
continue
}
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
entitys.ResLog(d.Ctx.Ch, "", baseErr+"expected image content")
continue
}
d.Ctx.ImgByte = append(d.Ctx.ImgByte, res.Content)
d.Ctx.ImgUrls = append(d.Ctx.ImgUrls, img)
}
return
}
func (d *Do) getRequireData() (err error) {
return
}
func (d *Do) getSysInfo() (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": d.Ctx.Key})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (d *Do) getSessionChatHis() (his []model.AiChatHi, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": d.Ctx.Session})
_, err = d.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc")
return
}
func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
return
}
// startMessageHandler 启动独立的消息处理协程
func (d *Do) startMessageHandler(
ctx context.Context,
c *websocket.Conn,
hisImpl *impl.ChatImpl,
) <-chan struct{} {
done := make(chan struct{})
var chat []string
go func() {
defer func() {
close(done)
// 保存历史记录
var his = []*model.AiChatHi{
{
SessionID: d.Ctx.Session,
Role: "user",
Content: d.Ctx.Req.Text, // 用户输入在外部处理
},
}
if len(chat) > 0 {
his = append(his, &model.AiChatHi{
SessionID: d.Ctx.Session,
Role: "assistant",
Content: strings.Join(chat, ""),
})
}
for _, hi := range his {
hisImpl.Add(hi)
}
}()
for v := range d.Ctx.Ch { // 自动检测通道关闭
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
log.Errorf("Send error: %v", err)
return
}
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
chat = append(chat, v.Content)
}
}
}()
return done
}
// 辅助函数:带超时的 WebSocket 发送
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
done := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
done <- fmt.Errorf("panic in MsgSend: %v", r)
}
close(done)
}()
// 如果 MsgSend 阻塞,这里会卡住
err := entitys.MsgSend(c, data)
done <- err
}()
select {
case err := <-done:
return err
case <-sendCtx.Done():
return sendCtx.Err()
}
}

254
internal/biz/do/handle.go Normal file
View File

@ -0,0 +1,254 @@
package do
import (
"ai_scheduler/internal/biz/llm_service"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/tools"
"ai_scheduler/internal/tools_bot"
"context"
"encoding/json"
"fmt"
"strings"
)
type Handle struct {
Ollama *llm_service.OllamaService
toolManager *tools.Manager
Bot *tools_bot.BotTool
conf *config.Config
sessionImpl *impl.SessionImpl
}
func NewHandle(
Ollama *llm_service.OllamaService,
toolManager *tools.Manager,
conf *config.Config,
sessionImpl *impl.SessionImpl,
dTalkBot *tools_bot.BotTool,
) *Handle {
return &Handle{
Ollama: Ollama,
toolManager: toolManager,
conf: conf,
sessionImpl: sessionImpl,
Bot: dTalkBot,
}
}
func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
entitys.ResLog(requireData.Ch, "", "准备意图识别")
//意图识别
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
if err != nil {
return
}
entitys.ResLog(requireData.Ch, "", recognizeMsg)
entitys.ResLog(requireData.Ch, "", "意图识别结束")
var match entitys.Match
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
err = errors.SysErr("数据结构错误:%v", err.Error())
return
}
requireData.Match = &match
return
}
func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
return
}
func (r *Handle) HandleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
if !requireData.Match.IsMatch {
if len(requireData.Match.Chat) != 0 {
entitys.ResText(requireData.Ch, "", requireData.Match.Chat)
} else {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
}
return
}
var pointTask *model.AiTask
for _, task := range requireData.Tasks {
if task.Index == requireData.Match.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.OtherTask(ctx, requireData)
}
switch pointTask.Type {
case constants.TaskTypeApi:
return r.handleApiTask(ctx, requireData, pointTask)
case constants.TaskTypeFunc:
return r.handleTask(ctx, requireData, pointTask)
case constants.TaskTypeKnowle:
return r.handleKnowle(ctx, requireData, pointTask)
case constants.TaskTypeBot:
return r.handleBot(ctx, requireData, pointTask)
default:
return r.handleOtherTask(ctx, requireData)
}
}
func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
return
}
func (r *Handle) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.Bot.Execute(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
// 知识库
func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var (
configData entitys.ConfigDataTool
sessionIdKnowledge string
query string
host string
)
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
// 通过session 找到知识库session
var has bool
if len(requireData.Session) == 0 {
return errors.SessionNotFound
}
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
if err != nil {
return
} else if !has {
return errors.SessionNotFound
}
// 找到知识库的host
{
tool, exists := r.toolManager.GetTool(configData.Tool)
if !exists {
return fmt.Errorf("tool not found: %s", configData.Tool)
}
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
} else {
host = knowledgeTool.GetConfig().BaseURL
}
}
// 知识库的session为空请求知识库获取, 并绑定
if requireData.SessionInfo.KnowlegeSessionID == "" {
// 请求知识库
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
return
}
// 绑定知识库session下次可以使用
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
return
}
}
// 用户输入解析
var ok bool
input := make(map[string]string)
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
return
}
if query, ok = input["query"]; !ok {
return fmt.Errorf("query不能为空")
}
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
Session: requireData.SessionInfo.KnowlegeSessionID,
ApiKey: requireData.Sys.KnowlegeTenantKey,
Query: query,
}
// 执行工具
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var (
request l_request.Request
requestParam map[string]interface{}
)
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
if err != nil {
return
}
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
}
var configData entitys.ConfigDataHttp
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Request, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr(422, "api地址获取失败")
return
}
res, err := request.Send()
if err != nil {
return
}
entitys.ResJson(requireData.Ch, "", pkg.JsonStringIgonErr(res.Text))
return
}

View File

@ -2,9 +2,7 @@ package handle
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/tools" "ai_scheduler/internal/tools"
"context"
) )
type Handle struct { type Handle struct {
@ -22,8 +20,3 @@ func NewHandle(
conf: conf, conf: conf,
} }
} }
func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
return
}

View File

@ -10,6 +10,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"strings" "strings"
"time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@ -73,7 +74,7 @@ func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.Requ
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)), Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)),
}, api.Message{ }, api.Message{
Role: "user", Role: "user",
Content: requireData.UserInput, Content: requireData.Req.Text,
//Images: requireData.ImgByte, //Images: requireData.ImgByte,
}) })
@ -86,7 +87,7 @@ func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.Requ
imgs.WriteString("### 用户上传图片解析内容:\n") imgs.WriteString("### 用户上传图片解析内容:\n")
prompt = append(prompt, api.Message{ prompt = append(prompt, api.Message{
Role: "user", Role: "image_desc",
Content: "" + desc.Response, Content: "" + desc.Response,
}) })
} }
@ -102,9 +103,10 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit
desc, err = r.client.Generation(ctx, &api.GenerateRequest{ desc, err = r.client.Generation(ctx, &api.GenerateRequest{
Model: r.config.Ollama.VlModel, Model: r.config.Ollama.VlModel,
Stream: new(bool), Stream: new(bool),
System: "提取出图片中的文字以及重要信息", System: "完整提取出图片中的文字以及重要信息,并对用户的需求进行预测",
Prompt: requireData.UserInput, Prompt: "完整提取出图片中的文字以及重要信息,并对用户的需求进行预测", //requireData.Req.Text,
Images: requireData.ImgByte, Images: requireData.ImgByte,
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
}) })
if err != nil { if err != nil {
return return

View File

@ -1,6 +1,7 @@
package biz package biz
import ( import (
"ai_scheduler/internal/biz/do"
"ai_scheduler/internal/biz/handle" "ai_scheduler/internal/biz/handle"
"ai_scheduler/internal/biz/llm_service" "ai_scheduler/internal/biz/llm_service"
@ -14,4 +15,6 @@ var ProviderSetBiz = wire.NewSet(
llm_service.NewLangChainGenerate, llm_service.NewLangChainGenerate,
llm_service.NewOllamaGenerate, llm_service.NewOllamaGenerate,
handle.NewHandle, handle.NewHandle,
do.NewDo,
do.NewHandle,
) )

View File

@ -1,507 +1,55 @@
package biz package biz
import ( import (
"ai_scheduler/internal/biz/handle" "ai_scheduler/internal/biz/do"
"ai_scheduler/internal/biz/llm_service"
"ai_scheduler/internal/config" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp"
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
"xorm.io/builder"
) )
// AiRouterBiz 智能路由服务 // AiRouterBiz 智能路由服务
type AiRouterBiz struct { type AiRouterBiz struct {
toolManager *tools.Manager do *do.Do
sessionImpl *impl.SessionImpl handle *do.Handle
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
rds *pkg.Rdb
langChain *llm_service.LangChainService
Ollama *llm_service.OllamaService
handle *handle.Handle
} }
// NewRouterService 创建路由服务 // NewAiRouterBiz 创建路由服务
func NewAiRouterBiz( func NewAiRouterBiz(
sessionImpl *impl.SessionImpl, do *do.Do,
sysImpl *impl.SysImpl, handle *do.Handle,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
langChain *llm_service.LangChainService,
Ollama *llm_service.OllamaService,
handle *handle.Handle,
) *AiRouterBiz { ) *AiRouterBiz {
return &AiRouterBiz{ return &AiRouterBiz{
do: do,
handle: handle, handle: handle,
sessionImpl: sessionImpl,
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
langChain: langChain,
Ollama: Ollama,
} }
} }
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
//必要数据验证和获取 //必要数据验证和获取
var requireData entitys.RequireData dos := r.do.InitCtx(req)
err = r.dataAuth(c, &requireData)
if err != nil { //初始化通道/上下文
ctx, clearFunc := dos.MakeCh(c)
defer clearFunc()
//数据验证和收集
if err = dos.DataAuth(c); err != nil {
log.Errorf("数据验证和收集失败: %s", err.Error())
return return
} }
//初始化通道/上下文 //初始化通道/上下文
requireData.Ch = make(chan entitys.Response) if err = r.handle.Recognize(ctx, dos.Ctx); err != nil {
ctx, cancel := context.WithCancel(context.Background()) log.Errorf("意图识别失败: %s", err.Error())
// 启动独立的消息处理协程
done := r.startMessageHandler(ctx, c, &requireData, req.Text)
defer func() {
close(requireData.Ch) //关闭主通道
<-done // 等待消息处理完成
cancel()
}()
//获取图片信息
err = r.getImgData(req.Img, &requireData)
if err != nil {
log.Errorf("GetImgData error: %v", err)
return return
} }
//获取文字信息
err = r.getRequireData(req.Text, &requireData)
if err != nil {
log.Errorf("SQL error: %v", err)
return
}
//意图识别
err = r.recognize(ctx, &requireData)
if err != nil {
log.Errorf("LLM error: %v", err)
return
}
//向下传递 //向下传递
if err = r.handleMatch(ctx, &requireData); err != nil { if err = r.handle.HandleMatch(ctx, dos.Ctx); err != nil {
log.Errorf("Handle error: %v", err) log.Errorf("任务处理失败: %s", err.Error())
return
}
return
}
// startMessageHandler 启动独立的消息处理协程
func (r *AiRouterBiz) startMessageHandler(
ctx context.Context,
c *websocket.Conn,
requireData *entitys.RequireData,
userInput string,
) <-chan struct{} {
done := make(chan struct{})
var chat []string
go func() {
defer func() {
close(done)
// 保存历史记录
var his = []*model.AiChatHi{
{
SessionID: requireData.Session,
Role: "user",
Content: userInput, // 用户输入在外部处理
},
}
if len(chat) > 0 {
his = append(his, &model.AiChatHi{
SessionID: requireData.Session,
Role: "assistant",
Content: strings.Join(chat, ""),
})
}
for _, hi := range his {
r.hisImpl.Add(hi)
}
}()
for v := range requireData.Ch { // 自动检测通道关闭
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
log.Errorf("Send error: %v", err)
return
}
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
chat = append(chat, v.Content)
}
}
}()
return done
}
// 辅助函数:带超时的 WebSocket 发送
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
done := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
done <- fmt.Errorf("panic in MsgSend: %v", r)
}
close(done)
}()
// 如果 MsgSend 阻塞,这里会卡住
err := entitys.MsgSend(c, data)
done <- err
}()
select {
case err := <-done:
return err
case <-sendCtx.Done():
return sendCtx.Err()
}
}
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
if len(imgUrl) == 0 {
return
}
imgs := strings.Split(imgUrl, ",")
if len(imgs) == 0 {
return
}
if err = pkg.ValidateImageURL(imgUrl); err != nil {
return err
}
for k, img := range imgs {
baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:"
entitys.ResLog(requireData.Ch, "", "获取第"+strconv.Itoa(k+1)+"张图片")
req := l_request.Request{
Method: "GET",
Url: img,
}
res, _err := req.Send()
if _err != nil {
entitys.ResLog(requireData.Ch, "", baseErr+_err.Error())
continue
}
if _, ex := res.Headers["Content-Type"]; !ex {
entitys.ResLog(requireData.Ch, "", baseErr+"Content-Type不存在")
continue
}
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
entitys.ResLog(requireData.Ch, "", baseErr+"expected image content")
continue
}
requireData.ImgByte = append(requireData.ImgByte, res.Content)
requireData.ImgUrls = append(requireData.ImgUrls, img)
}
return
}
func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
entitys.ResLog(requireData.Ch, "", "准备意图识别")
//意图识别
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
if err != nil {
return
}
entitys.ResLog(requireData.Ch, "", recognizeMsg)
entitys.ResLog(requireData.Ch, "", "意图识别结束")
var match entitys.Match
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
err = errors.SysErr("数据结构错误:%v", err.Error())
return
}
requireData.Match = &match
return
}
func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) {
requireData.Sys, err = r.getSysInfo(requireData.Key)
if err != nil {
err = errors.SysErr("获取系统信息失败:%v", err.Error())
return
}
requireData.Histories, err = r.getSessionChatHis(requireData.Session)
if err != nil {
err = errors.SysErr("获取历史记录失败:%v", err.Error())
return
}
requireData.Tasks, err = r.getTasks(requireData.Sys.SysID)
if err != nil {
err = errors.SysErr("获取任务列表失败:%v", err.Error())
return
}
requireData.UserInput = userInput
if len(requireData.UserInput) == 0 {
err = errors.SysErr("获取用户输入失败")
return
}
if len(requireData.UserInput) == 0 {
err = errors.SysErr("获取用户输入失败")
return return
} }
return return
} }
func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) {
requireData.Session = c.Query("x-session", "")
if len(requireData.Session) == 0 {
err = errors.SessionNotFound
return
}
requireData.Auth = c.Query("x-authorization", "")
if len(requireData.Auth) == 0 {
err = errors.AuthNotFound
return
}
requireData.Key = c.Query("x-app-key", "")
if len(requireData.Key) == 0 {
err = errors.KeyNotFound
return
}
return
}
func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
return
}
func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
if !requireData.Match.IsMatch {
if len(requireData.Match.Chat) != 0 {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
} else {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
}
return
}
var pointTask *model.AiTask
for _, task := range requireData.Tasks {
if task.Index == requireData.Match.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.handle.OtherTask(ctx, requireData)
}
switch pointTask.Type {
case constants.TaskTypeApi:
return r.handleApiTask(ctx, requireData, pointTask)
case constants.TaskTypeFunc:
return r.handleTask(ctx, requireData, pointTask)
case constants.TaskTypeKnowle:
return r.handleKnowle(ctx, requireData, pointTask)
case constants.TaskTypeBot:
return r.handleBot(ctx, requireData, pointTask)
default:
return r.handleOtherTask(ctx, requireData)
}
}
func (r *AiRouterBiz) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
// 知识库
func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var (
configData entitys.ConfigDataTool
sessionIdKnowledge string
query string
host string
)
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
// 通过session 找到知识库session
var has bool
if len(requireData.Session) == 0 {
return errors.SessionNotFound
}
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
if err != nil {
return
} else if !has {
return errors.SessionNotFound
}
// 找到知识库的host
{
tool, exists := r.toolManager.GetTool(configData.Tool)
if !exists {
return fmt.Errorf("tool not found: %s", configData.Tool)
}
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
} else {
host = knowledgeTool.GetConfig().BaseURL
}
}
// 知识库的session为空请求知识库获取, 并绑定
if requireData.SessionInfo.KnowlegeSessionID == "" {
// 请求知识库
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
return
}
// 绑定知识库session下次可以使用
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
return
}
}
// 用户输入解析
var ok bool
input := make(map[string]string)
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
return
}
if query, ok = input["query"]; !ok {
return fmt.Errorf("query不能为空")
}
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
Session: requireData.SessionInfo.KnowlegeSessionID,
ApiKey: requireData.Sys.KnowlegeTenantKey,
Query: query,
}
// 执行工具
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
if err != nil {
return
}
return
}
func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
var (
request l_request.Request
requestParam map[string]interface{}
)
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
if err != nil {
return
}
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
}
var configData entitys.ConfigDataHttp
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Request, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr(422, "api地址获取失败")
return
}
res, err := request.Send()
if err != nil {
return
}
entitys.ResJson(requireData.Ch, "", pkg.JsonStringIgonErr(res.Text))
return
}
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId})
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id desc")
return
}
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": appKey})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
return
}

View File

@ -0,0 +1,7 @@
package constants
type BotTools string
const (
BotToolsBugOptimizationSubmit = "bug_optimization_submit" // 系统的bug/优化建议
)

View File

@ -147,7 +147,7 @@ type RequireData struct {
SessionInfo model.AiSession SessionInfo model.AiSession
Tasks []model.AiTask Tasks []model.AiTask
Match *Match Match *Match
UserInput string Req *ChatSockRequest
Auth string Auth string
Ch chan Response Ch chan Response
KnowledgeConf KnowledgeBaseRequest KnowledgeConf KnowledgeBaseRequest

View File

@ -164,7 +164,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
}, },
{ {
Role: "user", Role: "user",
Content: requireData.UserInput, Content: requireData.Req.Text,
}, },
}, w.Name(), "") }, w.Name(), "")
if err != nil { if err != nil {

View File

@ -2,53 +2,38 @@ package tools_bot
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_ollama"
"context" "context"
"github.com/gofiber/fiber/v2/log"
) )
type BotTool struct { type BotTool struct {
config config.ToolConfig config *config.Config
llm *utils_ollama.Client llm *utils_ollama.Client
} }
// NewBotTool 创建直连天下订单详情工具 // NewBotTool 创建直连天下订单详情工具
func NewBotTool(config config.ToolConfig, llm *utils_ollama.Client) *BotTool { func NewBotTool(config *config.Config, llm *utils_ollama.Client) *BotTool {
return &BotTool{config: config, llm: llm} return &BotTool{config: config, llm: llm}
} }
// Name 返回工具名称
func (w *BotTool) Name() string {
return "DingTalkBotTool"
}
// Description 返回工具描述
func (w *BotTool) Description() string {
return "钉钉机器人调用"
}
// Definition 返回工具定义
func (w *BotTool) Definition() entitys.ToolDefinition {
return entitys.ToolDefinition{
Type: "function",
Function: entitys.FunctionDef{
Name: w.Name(),
Description: w.Description(),
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"number": map[string]interface{}{
"type": "string",
"description": "订单编号/流水号",
},
},
"required": []string{"number"},
},
},
}
}
// Execute 执行直连天下订单详情查询 // Execute 执行直连天下订单详情查询
func (w *BotTool) Execute(ctx context.Context, requireData *entitys.RequireData) (err error) { func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) {
switch toolName {
case constants.BotToolsBugOptimizationSubmit:
err = w.BugOptimizationSubmit(ctx, requireData)
default:
log.Errorf("未知的工具类型:%s", toolName)
err = errors.ParamErr("未知的工具类型:%s", toolName)
}
return
}
func (w *BotTool) BugOptimizationSubmit(ctx context.Context, requireData *entitys.RequireData) (err error) {
return return
} }

View File

@ -0,0 +1,9 @@
package tools_bot
import (
"github.com/google/wire"
)
var ProviderSetBotTools = wire.NewSet(
NewBotTool,
)