结构优化与功能增强
This commit is contained in:
parent
e30a79b681
commit
7076d6a918
File diff suppressed because one or more lines are too long
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"ai_scheduler/internal/server"
|
"ai_scheduler/internal/server"
|
||||||
"ai_scheduler/internal/services"
|
"ai_scheduler/internal/services"
|
||||||
"ai_scheduler/internal/tools"
|
"ai_scheduler/internal/tools"
|
||||||
|
"ai_scheduler/internal/tools_bot"
|
||||||
"ai_scheduler/utils"
|
"ai_scheduler/utils"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"github.com/gofiber/fiber/v2/log"
|
||||||
|
|
@ -27,6 +28,7 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro
|
||||||
biz.ProviderSetBiz,
|
biz.ProviderSetBiz,
|
||||||
impl.ProviderImpl,
|
impl.ProviderImpl,
|
||||||
utils.ProviderUtils,
|
utils.ProviderUtils,
|
||||||
|
tools_bot.ProviderSetBotTools,
|
||||||
))
|
))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,248 @@
|
||||||
|
package do
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
errors "ai_scheduler/internal/data/error"
|
||||||
|
"ai_scheduler/internal/data/impl"
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg"
|
||||||
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
|
|
||||||
|
"xorm.io/builder"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Do struct {
|
||||||
|
Ctx *entitys.RequireData
|
||||||
|
sessionImpl *impl.SessionImpl
|
||||||
|
sysImpl *impl.SysImpl
|
||||||
|
taskImpl *impl.TaskImpl
|
||||||
|
hisImpl *impl.ChatImpl
|
||||||
|
conf *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDo(
|
||||||
|
sysImpl *impl.SysImpl,
|
||||||
|
taskImpl *impl.TaskImpl,
|
||||||
|
hisImpl *impl.ChatImpl,
|
||||||
|
conf *config.Config,
|
||||||
|
) *Do {
|
||||||
|
return &Do{
|
||||||
|
conf: conf,
|
||||||
|
sysImpl: sysImpl,
|
||||||
|
hisImpl: hisImpl,
|
||||||
|
taskImpl: taskImpl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Do) InitCtx(req *entitys.ChatSockRequest) *Do {
|
||||||
|
d.Ctx = &entitys.RequireData{
|
||||||
|
Req: req,
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Do) DataAuth(c *websocket.Conn) (err error) {
|
||||||
|
d.Ctx.Session = c.Query("x-session", "")
|
||||||
|
if len(d.Ctx.Session) == 0 {
|
||||||
|
err = errors.SessionNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.Ctx.Auth = c.Query("x-authorization", "")
|
||||||
|
if len(d.Ctx.Auth) == 0 {
|
||||||
|
err = errors.AuthNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.Ctx.Key = c.Query("x-app-key", "")
|
||||||
|
if len(d.Ctx.Key) == 0 {
|
||||||
|
err = errors.KeyNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Ctx.Sys, err = d.getSysInfo()
|
||||||
|
if err != nil {
|
||||||
|
err = errors.SysErr("获取系统信息失败:%v", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.Ctx.Histories, err = d.getSessionChatHis()
|
||||||
|
if err != nil {
|
||||||
|
err = errors.SysErr("获取历史记录失败:%v", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Ctx.Tasks, err = d.getTasks(d.Ctx.Sys.SysID)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.SysErr("获取任务列表失败:%v", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = d.getImgData(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Do) MakeCh(c *websocket.Conn) (ctx context.Context, deferFunc func()) {
|
||||||
|
d.Ctx.Ch = make(chan entitys.Response)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := d.startMessageHandler(ctx, c, d.hisImpl)
|
||||||
|
return ctx, func() {
|
||||||
|
close(d.Ctx.Ch) //关闭主通道
|
||||||
|
<-done // 等待消息处理完成
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Do) getImgData() (err error) {
|
||||||
|
if len(d.Ctx.Req.Img) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
imgs := strings.Split(d.Ctx.Req.Img, ",")
|
||||||
|
if len(imgs) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = pkg.ValidateImageURL(d.Ctx.Req.Img); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for k, img := range imgs {
|
||||||
|
baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:"
|
||||||
|
entitys.ResLog(d.Ctx.Ch, "", "获取第"+strconv.Itoa(k+1)+"张图片")
|
||||||
|
req := l_request.Request{
|
||||||
|
Method: "GET",
|
||||||
|
Url: img,
|
||||||
|
}
|
||||||
|
res, _err := req.Send()
|
||||||
|
if _err != nil {
|
||||||
|
entitys.ResLog(d.Ctx.Ch, "", baseErr+_err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ex := res.Headers["Content-Type"]; !ex {
|
||||||
|
entitys.ResLog(d.Ctx.Ch, "", baseErr+":Content-Type不存在")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
|
||||||
|
entitys.ResLog(d.Ctx.Ch, "", baseErr+":expected image content")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
d.Ctx.ImgByte = append(d.Ctx.ImgByte, res.Content)
|
||||||
|
d.Ctx.ImgUrls = append(d.Ctx.ImgUrls, img)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Do) getRequireData() (err error) {
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Do) getSysInfo() (sysInfo model.AiSy, err error) {
|
||||||
|
cond := builder.NewCond()
|
||||||
|
cond = cond.And(builder.Eq{"app_key": d.Ctx.Key})
|
||||||
|
cond = cond.And(builder.IsNull{"delete_at"})
|
||||||
|
cond = cond.And(builder.Eq{"status": 1})
|
||||||
|
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Do) getSessionChatHis() (his []model.AiChatHi, err error) {
|
||||||
|
|
||||||
|
cond := builder.NewCond()
|
||||||
|
cond = cond.And(builder.Eq{"session_id": d.Ctx.Session})
|
||||||
|
|
||||||
|
_, err = d.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||||
|
|
||||||
|
cond := builder.NewCond()
|
||||||
|
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||||||
|
cond = cond.And(builder.IsNull{"delete_at"})
|
||||||
|
cond = cond.And(builder.Eq{"status": 1})
|
||||||
|
_, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// startMessageHandler 启动独立的消息处理协程
|
||||||
|
func (d *Do) startMessageHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
c *websocket.Conn,
|
||||||
|
hisImpl *impl.ChatImpl,
|
||||||
|
) <-chan struct{} {
|
||||||
|
done := make(chan struct{})
|
||||||
|
var chat []string
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
close(done)
|
||||||
|
// 保存历史记录
|
||||||
|
var his = []*model.AiChatHi{
|
||||||
|
{
|
||||||
|
SessionID: d.Ctx.Session,
|
||||||
|
Role: "user",
|
||||||
|
Content: d.Ctx.Req.Text, // 用户输入在外部处理
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if len(chat) > 0 {
|
||||||
|
his = append(his, &model.AiChatHi{
|
||||||
|
SessionID: d.Ctx.Session,
|
||||||
|
Role: "assistant",
|
||||||
|
Content: strings.Join(chat, ""),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, hi := range his {
|
||||||
|
hisImpl.Add(hi)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for v := range d.Ctx.Ch { // 自动检测通道关闭
|
||||||
|
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
|
||||||
|
log.Errorf("Send error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
|
||||||
|
chat = append(chat, v.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return done
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:带超时的 WebSocket 发送
|
||||||
|
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
|
||||||
|
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
done <- fmt.Errorf("panic in MsgSend: %v", r)
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
// 如果 MsgSend 阻塞,这里会卡住
|
||||||
|
err := entitys.MsgSend(c, data)
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
return err
|
||||||
|
case <-sendCtx.Done():
|
||||||
|
return sendCtx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,254 @@
|
||||||
|
package do
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/biz/llm_service"
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
|
errors "ai_scheduler/internal/data/error"
|
||||||
|
"ai_scheduler/internal/data/impl"
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg"
|
||||||
|
"ai_scheduler/internal/pkg/l_request"
|
||||||
|
"ai_scheduler/internal/pkg/mapstructure"
|
||||||
|
"ai_scheduler/internal/tools"
|
||||||
|
"ai_scheduler/internal/tools_bot"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Handle struct {
|
||||||
|
Ollama *llm_service.OllamaService
|
||||||
|
toolManager *tools.Manager
|
||||||
|
Bot *tools_bot.BotTool
|
||||||
|
conf *config.Config
|
||||||
|
sessionImpl *impl.SessionImpl
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandle(
|
||||||
|
Ollama *llm_service.OllamaService,
|
||||||
|
toolManager *tools.Manager,
|
||||||
|
conf *config.Config,
|
||||||
|
sessionImpl *impl.SessionImpl,
|
||||||
|
dTalkBot *tools_bot.BotTool,
|
||||||
|
) *Handle {
|
||||||
|
return &Handle{
|
||||||
|
Ollama: Ollama,
|
||||||
|
toolManager: toolManager,
|
||||||
|
conf: conf,
|
||||||
|
sessionImpl: sessionImpl,
|
||||||
|
Bot: dTalkBot,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||||
|
entitys.ResLog(requireData.Ch, "", "准备意图识别")
|
||||||
|
|
||||||
|
//意图识别
|
||||||
|
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entitys.ResLog(requireData.Ch, "", recognizeMsg)
|
||||||
|
entitys.ResLog(requireData.Ch, "", "意图识别结束")
|
||||||
|
|
||||||
|
var match entitys.Match
|
||||||
|
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
|
||||||
|
err = errors.SysErr("数据结构错误:%v", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requireData.Match = &match
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||||
|
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Handle) HandleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||||
|
|
||||||
|
if !requireData.Match.IsMatch {
|
||||||
|
if len(requireData.Match.Chat) != 0 {
|
||||||
|
entitys.ResText(requireData.Ch, "", requireData.Match.Chat)
|
||||||
|
} else {
|
||||||
|
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var pointTask *model.AiTask
|
||||||
|
for _, task := range requireData.Tasks {
|
||||||
|
if task.Index == requireData.Match.Index {
|
||||||
|
pointTask = &task
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pointTask == nil || pointTask.Index == "other" {
|
||||||
|
return r.OtherTask(ctx, requireData)
|
||||||
|
}
|
||||||
|
switch pointTask.Type {
|
||||||
|
case constants.TaskTypeApi:
|
||||||
|
return r.handleApiTask(ctx, requireData, pointTask)
|
||||||
|
case constants.TaskTypeFunc:
|
||||||
|
return r.handleTask(ctx, requireData, pointTask)
|
||||||
|
case constants.TaskTypeKnowle:
|
||||||
|
return r.handleKnowle(ctx, requireData, pointTask)
|
||||||
|
case constants.TaskTypeBot:
|
||||||
|
return r.handleBot(ctx, requireData, pointTask)
|
||||||
|
default:
|
||||||
|
return r.handleOtherTask(ctx, requireData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||||
|
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Handle) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||||||
|
var configData entitys.ConfigDataTool
|
||||||
|
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = r.Bot.Execute(ctx, configData.Tool, requireData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Handle) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||||||
|
var configData entitys.ConfigDataTool
|
||||||
|
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 知识库
|
||||||
|
func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||||||
|
|
||||||
|
var (
|
||||||
|
configData entitys.ConfigDataTool
|
||||||
|
sessionIdKnowledge string
|
||||||
|
query string
|
||||||
|
host string
|
||||||
|
)
|
||||||
|
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通过session 找到知识库session
|
||||||
|
var has bool
|
||||||
|
if len(requireData.Session) == 0 {
|
||||||
|
return errors.SessionNotFound
|
||||||
|
}
|
||||||
|
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
} else if !has {
|
||||||
|
return errors.SessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// 找到知识库的host
|
||||||
|
{
|
||||||
|
tool, exists := r.toolManager.GetTool(configData.Tool)
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("tool not found: %s", configData.Tool)
|
||||||
|
}
|
||||||
|
|
||||||
|
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
|
||||||
|
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
|
||||||
|
} else {
|
||||||
|
host = knowledgeTool.GetConfig().BaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// 知识库的session为空,请求知识库获取, 并绑定
|
||||||
|
if requireData.SessionInfo.KnowlegeSessionID == "" {
|
||||||
|
// 请求知识库
|
||||||
|
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 绑定知识库session,下次可以使用
|
||||||
|
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
|
||||||
|
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户输入解析
|
||||||
|
var ok bool
|
||||||
|
input := make(map[string]string)
|
||||||
|
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if query, ok = input["query"]; !ok {
|
||||||
|
return fmt.Errorf("query不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
|
||||||
|
Session: requireData.SessionInfo.KnowlegeSessionID,
|
||||||
|
ApiKey: requireData.Sys.KnowlegeTenantKey,
|
||||||
|
Query: query,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行工具
|
||||||
|
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
||||||
|
var (
|
||||||
|
request l_request.Request
|
||||||
|
requestParam map[string]interface{}
|
||||||
|
)
|
||||||
|
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
|
||||||
|
for k, v := range requestParam {
|
||||||
|
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
var configData entitys.ConfigDataHttp
|
||||||
|
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = mapstructure.Decode(configData.Request, &request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(request.Url) == 0 {
|
||||||
|
err = errors.NewBusinessErr(422, "api地址获取失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res, err := request.Send()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entitys.ResJson(requireData.Ch, "", pkg.JsonStringIgonErr(res.Text))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -2,9 +2,7 @@ package handle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
|
||||||
"ai_scheduler/internal/tools"
|
"ai_scheduler/internal/tools"
|
||||||
"context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handle struct {
|
type Handle struct {
|
||||||
|
|
@ -22,8 +20,3 @@ func NewHandle(
|
||||||
conf: conf,
|
conf: conf,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
|
||||||
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
@ -73,7 +74,7 @@ func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.Requ
|
||||||
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)),
|
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)),
|
||||||
}, api.Message{
|
}, api.Message{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: requireData.UserInput,
|
Content: requireData.Req.Text,
|
||||||
//Images: requireData.ImgByte,
|
//Images: requireData.ImgByte,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -86,7 +87,7 @@ func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.Requ
|
||||||
imgs.WriteString("### 用户上传图片解析内容:\n")
|
imgs.WriteString("### 用户上传图片解析内容:\n")
|
||||||
|
|
||||||
prompt = append(prompt, api.Message{
|
prompt = append(prompt, api.Message{
|
||||||
Role: "user",
|
Role: "image_desc",
|
||||||
Content: "" + desc.Response,
|
Content: "" + desc.Response,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -100,11 +101,12 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit
|
||||||
entitys.ResLog(requireData.Ch, "", "图片识别中。。。")
|
entitys.ResLog(requireData.Ch, "", "图片识别中。。。")
|
||||||
|
|
||||||
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
||||||
Model: r.config.Ollama.VlModel,
|
Model: r.config.Ollama.VlModel,
|
||||||
Stream: new(bool),
|
Stream: new(bool),
|
||||||
System: "提取出图片中的文字以及重要信息",
|
System: "完整提取出图片中的文字以及重要信息,并对用户的需求进行预测",
|
||||||
Prompt: requireData.UserInput,
|
Prompt: "完整提取出图片中的文字以及重要信息,并对用户的需求进行预测", //requireData.Req.Text,
|
||||||
Images: requireData.ImgByte,
|
Images: requireData.ImgByte,
|
||||||
|
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package biz
|
package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/biz/do"
|
||||||
"ai_scheduler/internal/biz/handle"
|
"ai_scheduler/internal/biz/handle"
|
||||||
"ai_scheduler/internal/biz/llm_service"
|
"ai_scheduler/internal/biz/llm_service"
|
||||||
|
|
||||||
|
|
@ -14,4 +15,6 @@ var ProviderSetBiz = wire.NewSet(
|
||||||
llm_service.NewLangChainGenerate,
|
llm_service.NewLangChainGenerate,
|
||||||
llm_service.NewOllamaGenerate,
|
llm_service.NewOllamaGenerate,
|
||||||
handle.NewHandle,
|
handle.NewHandle,
|
||||||
|
do.NewDo,
|
||||||
|
do.NewHandle,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,507 +1,55 @@
|
||||||
package biz
|
package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/biz/handle"
|
"ai_scheduler/internal/biz/do"
|
||||||
"ai_scheduler/internal/biz/llm_service"
|
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/data/constants"
|
|
||||||
errors "ai_scheduler/internal/data/error"
|
|
||||||
"ai_scheduler/internal/data/impl"
|
|
||||||
"ai_scheduler/internal/data/model"
|
|
||||||
"ai_scheduler/internal/entitys"
|
|
||||||
"ai_scheduler/internal/pkg"
|
|
||||||
"ai_scheduler/internal/pkg/mapstructure"
|
|
||||||
"ai_scheduler/internal/tools"
|
|
||||||
"ai_scheduler/tmpl/dataTemp"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
"xorm.io/builder"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// AiRouterBiz 智能路由服务
|
// AiRouterBiz 智能路由服务
|
||||||
type AiRouterBiz struct {
|
type AiRouterBiz struct {
|
||||||
toolManager *tools.Manager
|
do *do.Do
|
||||||
sessionImpl *impl.SessionImpl
|
handle *do.Handle
|
||||||
sysImpl *impl.SysImpl
|
|
||||||
taskImpl *impl.TaskImpl
|
|
||||||
hisImpl *impl.ChatImpl
|
|
||||||
conf *config.Config
|
|
||||||
rds *pkg.Rdb
|
|
||||||
langChain *llm_service.LangChainService
|
|
||||||
Ollama *llm_service.OllamaService
|
|
||||||
handle *handle.Handle
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouterService 创建路由服务
|
// NewAiRouterBiz 创建路由服务
|
||||||
func NewAiRouterBiz(
|
func NewAiRouterBiz(
|
||||||
sessionImpl *impl.SessionImpl,
|
do *do.Do,
|
||||||
sysImpl *impl.SysImpl,
|
handle *do.Handle,
|
||||||
taskImpl *impl.TaskImpl,
|
|
||||||
hisImpl *impl.ChatImpl,
|
|
||||||
conf *config.Config,
|
|
||||||
langChain *llm_service.LangChainService,
|
|
||||||
Ollama *llm_service.OllamaService,
|
|
||||||
handle *handle.Handle,
|
|
||||||
) *AiRouterBiz {
|
) *AiRouterBiz {
|
||||||
return &AiRouterBiz{
|
return &AiRouterBiz{
|
||||||
handle: handle,
|
do: do,
|
||||||
sessionImpl: sessionImpl,
|
handle: handle,
|
||||||
conf: conf,
|
|
||||||
sysImpl: sysImpl,
|
|
||||||
hisImpl: hisImpl,
|
|
||||||
taskImpl: taskImpl,
|
|
||||||
langChain: langChain,
|
|
||||||
Ollama: Ollama,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
||||||
//必要数据验证和获取
|
//必要数据验证和获取
|
||||||
var requireData entitys.RequireData
|
dos := r.do.InitCtx(req)
|
||||||
err = r.dataAuth(c, &requireData)
|
|
||||||
if err != nil {
|
//初始化通道/上下文
|
||||||
|
ctx, clearFunc := dos.MakeCh(c)
|
||||||
|
defer clearFunc()
|
||||||
|
|
||||||
|
//数据验证和收集
|
||||||
|
if err = dos.DataAuth(c); err != nil {
|
||||||
|
log.Errorf("数据验证和收集失败: %s", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//初始化通道/上下文
|
//初始化通道/上下文
|
||||||
requireData.Ch = make(chan entitys.Response)
|
if err = r.handle.Recognize(ctx, dos.Ctx); err != nil {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
log.Errorf("意图识别失败: %s", err.Error())
|
||||||
// 启动独立的消息处理协程
|
|
||||||
done := r.startMessageHandler(ctx, c, &requireData, req.Text)
|
|
||||||
defer func() {
|
|
||||||
close(requireData.Ch) //关闭主通道
|
|
||||||
<-done // 等待消息处理完成
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
//获取图片信息
|
|
||||||
err = r.getImgData(req.Img, &requireData)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("GetImgData error: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//获取文字信息
|
|
||||||
err = r.getRequireData(req.Text, &requireData)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("SQL error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//意图识别
|
|
||||||
err = r.recognize(ctx, &requireData)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("LLM error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//向下传递
|
//向下传递
|
||||||
if err = r.handleMatch(ctx, &requireData); err != nil {
|
if err = r.handle.HandleMatch(ctx, dos.Ctx); err != nil {
|
||||||
log.Errorf("Handle error: %v", err)
|
log.Errorf("任务处理失败: %s", err.Error())
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// startMessageHandler 启动独立的消息处理协程
|
|
||||||
func (r *AiRouterBiz) startMessageHandler(
|
|
||||||
ctx context.Context,
|
|
||||||
c *websocket.Conn,
|
|
||||||
requireData *entitys.RequireData,
|
|
||||||
userInput string,
|
|
||||||
) <-chan struct{} {
|
|
||||||
done := make(chan struct{})
|
|
||||||
var chat []string
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
close(done)
|
|
||||||
// 保存历史记录
|
|
||||||
var his = []*model.AiChatHi{
|
|
||||||
{
|
|
||||||
SessionID: requireData.Session,
|
|
||||||
Role: "user",
|
|
||||||
Content: userInput, // 用户输入在外部处理
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if len(chat) > 0 {
|
|
||||||
his = append(his, &model.AiChatHi{
|
|
||||||
SessionID: requireData.Session,
|
|
||||||
Role: "assistant",
|
|
||||||
Content: strings.Join(chat, ""),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
for _, hi := range his {
|
|
||||||
r.hisImpl.Add(hi)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for v := range requireData.Ch { // 自动检测通道关闭
|
|
||||||
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
|
|
||||||
log.Errorf("Send error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
|
|
||||||
chat = append(chat, v.Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return done
|
|
||||||
}
|
|
||||||
|
|
||||||
// 辅助函数:带超时的 WebSocket 发送
|
|
||||||
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
|
|
||||||
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
done <- fmt.Errorf("panic in MsgSend: %v", r)
|
|
||||||
}
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
// 如果 MsgSend 阻塞,这里会卡住
|
|
||||||
err := entitys.MsgSend(c, data)
|
|
||||||
done <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-done:
|
|
||||||
return err
|
|
||||||
case <-sendCtx.Done():
|
|
||||||
return sendCtx.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
|
|
||||||
if len(imgUrl) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
imgs := strings.Split(imgUrl, ",")
|
|
||||||
if len(imgs) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = pkg.ValidateImageURL(imgUrl); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for k, img := range imgs {
|
|
||||||
baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:"
|
|
||||||
entitys.ResLog(requireData.Ch, "", "获取第"+strconv.Itoa(k+1)+"张图片")
|
|
||||||
req := l_request.Request{
|
|
||||||
Method: "GET",
|
|
||||||
Url: img,
|
|
||||||
}
|
|
||||||
res, _err := req.Send()
|
|
||||||
if _err != nil {
|
|
||||||
entitys.ResLog(requireData.Ch, "", baseErr+_err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, ex := res.Headers["Content-Type"]; !ex {
|
|
||||||
entitys.ResLog(requireData.Ch, "", baseErr+":Content-Type不存在")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
|
|
||||||
entitys.ResLog(requireData.Ch, "", baseErr+":expected image content")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
requireData.ImgByte = append(requireData.ImgByte, res.Content)
|
|
||||||
requireData.ImgUrls = append(requireData.ImgUrls, img)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
|
||||||
entitys.ResLog(requireData.Ch, "", "准备意图识别")
|
|
||||||
|
|
||||||
//意图识别
|
|
||||||
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
entitys.ResLog(requireData.Ch, "", recognizeMsg)
|
|
||||||
entitys.ResLog(requireData.Ch, "", "意图识别结束")
|
|
||||||
|
|
||||||
var match entitys.Match
|
|
||||||
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
|
|
||||||
err = errors.SysErr("数据结构错误:%v", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
requireData.Match = &match
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) getRequireData(userInput string, requireData *entitys.RequireData) (err error) {
|
|
||||||
requireData.Sys, err = r.getSysInfo(requireData.Key)
|
|
||||||
if err != nil {
|
|
||||||
err = errors.SysErr("获取系统信息失败:%v", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
requireData.Histories, err = r.getSessionChatHis(requireData.Session)
|
|
||||||
if err != nil {
|
|
||||||
err = errors.SysErr("获取历史记录失败:%v", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
requireData.Tasks, err = r.getTasks(requireData.Sys.SysID)
|
|
||||||
if err != nil {
|
|
||||||
err = errors.SysErr("获取任务列表失败:%v", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
requireData.UserInput = userInput
|
|
||||||
if len(requireData.UserInput) == 0 {
|
|
||||||
err = errors.SysErr("获取用户输入失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(requireData.UserInput) == 0 {
|
|
||||||
err = errors.SysErr("获取用户输入失败")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) dataAuth(c *websocket.Conn, requireData *entitys.RequireData) (err error) {
|
|
||||||
requireData.Session = c.Query("x-session", "")
|
|
||||||
if len(requireData.Session) == 0 {
|
|
||||||
err = errors.SessionNotFound
|
|
||||||
return
|
|
||||||
}
|
|
||||||
requireData.Auth = c.Query("x-authorization", "")
|
|
||||||
if len(requireData.Auth) == 0 {
|
|
||||||
err = errors.AuthNotFound
|
|
||||||
return
|
|
||||||
}
|
|
||||||
requireData.Key = c.Query("x-app-key", "")
|
|
||||||
if len(requireData.Key) == 0 {
|
|
||||||
err = errors.KeyNotFound
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
|
||||||
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) handleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
|
||||||
|
|
||||||
if !requireData.Match.IsMatch {
|
|
||||||
if len(requireData.Match.Chat) != 0 {
|
|
||||||
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
|
||||||
} else {
|
|
||||||
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var pointTask *model.AiTask
|
|
||||||
for _, task := range requireData.Tasks {
|
|
||||||
if task.Index == requireData.Match.Index {
|
|
||||||
pointTask = &task
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if pointTask == nil || pointTask.Index == "other" {
|
|
||||||
return r.handle.OtherTask(ctx, requireData)
|
|
||||||
}
|
|
||||||
switch pointTask.Type {
|
|
||||||
case constants.TaskTypeApi:
|
|
||||||
return r.handleApiTask(ctx, requireData, pointTask)
|
|
||||||
case constants.TaskTypeFunc:
|
|
||||||
return r.handleTask(ctx, requireData, pointTask)
|
|
||||||
case constants.TaskTypeKnowle:
|
|
||||||
return r.handleKnowle(ctx, requireData, pointTask)
|
|
||||||
case constants.TaskTypeBot:
|
|
||||||
return r.handleBot(ctx, requireData, pointTask)
|
|
||||||
default:
|
|
||||||
return r.handleOtherTask(ctx, requireData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) handleBot(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
|
||||||
var configData entitys.ConfigDataTool
|
|
||||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) handleTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
|
||||||
var configData entitys.ConfigDataTool
|
|
||||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 知识库
|
|
||||||
func (r *AiRouterBiz) handleKnowle(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
configData entitys.ConfigDataTool
|
|
||||||
sessionIdKnowledge string
|
|
||||||
query string
|
|
||||||
host string
|
|
||||||
)
|
|
||||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 通过session 找到知识库session
|
|
||||||
var has bool
|
|
||||||
if len(requireData.Session) == 0 {
|
|
||||||
return errors.SessionNotFound
|
|
||||||
}
|
|
||||||
requireData.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(requireData.Session))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
} else if !has {
|
|
||||||
return errors.SessionNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// 找到知识库的host
|
|
||||||
{
|
|
||||||
tool, exists := r.toolManager.GetTool(configData.Tool)
|
|
||||||
if !exists {
|
|
||||||
return fmt.Errorf("tool not found: %s", configData.Tool)
|
|
||||||
}
|
|
||||||
|
|
||||||
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
|
|
||||||
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
|
|
||||||
} else {
|
|
||||||
host = knowledgeTool.GetConfig().BaseURL
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// 知识库的session为空,请求知识库获取, 并绑定
|
|
||||||
if requireData.SessionInfo.KnowlegeSessionID == "" {
|
|
||||||
// 请求知识库
|
|
||||||
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 绑定知识库session,下次可以使用
|
|
||||||
requireData.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
|
|
||||||
if err = r.sessionImpl.Update(&requireData.SessionInfo, r.sessionImpl.WithSessionId(requireData.SessionInfo.SessionID)); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 用户输入解析
|
|
||||||
var ok bool
|
|
||||||
input := make(map[string]string)
|
|
||||||
if err = json.Unmarshal([]byte(requireData.Match.Parameters), &input); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if query, ok = input["query"]; !ok {
|
|
||||||
return fmt.Errorf("query不能为空")
|
|
||||||
}
|
|
||||||
|
|
||||||
requireData.KnowledgeConf = entitys.KnowledgeBaseRequest{
|
|
||||||
Session: requireData.SessionInfo.KnowlegeSessionID,
|
|
||||||
ApiKey: requireData.Sys.KnowlegeTenantKey,
|
|
||||||
Query: query,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 执行工具
|
|
||||||
err = r.toolManager.ExecuteTool(ctx, configData.Tool, requireData)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) handleApiTask(ctx context.Context, requireData *entitys.RequireData, task *model.AiTask) (err error) {
|
|
||||||
var (
|
|
||||||
request l_request.Request
|
|
||||||
requestParam map[string]interface{}
|
|
||||||
)
|
|
||||||
err = json.Unmarshal([]byte(requireData.Match.Parameters), &requestParam)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
|
|
||||||
for k, v := range requestParam {
|
|
||||||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
|
||||||
}
|
|
||||||
var configData entitys.ConfigDataHttp
|
|
||||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = mapstructure.Decode(configData.Request, &request)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(request.Url) == 0 {
|
|
||||||
err = errors.NewBusinessErr(422, "api地址获取失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res, err := request.Send()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
entitys.ResJson(requireData.Ch, "", pkg.JsonStringIgonErr(res.Text))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
|
||||||
|
|
||||||
cond := builder.NewCond()
|
|
||||||
cond = cond.And(builder.Eq{"session_id": sessionId})
|
|
||||||
|
|
||||||
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id desc")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
|
|
||||||
cond := builder.NewCond()
|
|
||||||
cond = cond.And(builder.Eq{"app_key": appKey})
|
|
||||||
cond = cond.And(builder.IsNull{"delete_at"})
|
|
||||||
cond = cond.And(builder.Eq{"status": 1})
|
|
||||||
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
|
||||||
|
|
||||||
cond := builder.NewCond()
|
|
||||||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
|
||||||
cond = cond.And(builder.IsNull{"delete_at"})
|
|
||||||
cond = cond.And(builder.Eq{"status": 1})
|
|
||||||
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
package constants
|
||||||
|
|
||||||
|
type BotTools string
|
||||||
|
|
||||||
|
const (
|
||||||
|
BotToolsBugOptimizationSubmit = "bug_optimization_submit" // 系统的bug/优化建议
|
||||||
|
)
|
||||||
|
|
@ -147,7 +147,7 @@ type RequireData struct {
|
||||||
SessionInfo model.AiSession
|
SessionInfo model.AiSession
|
||||||
Tasks []model.AiTask
|
Tasks []model.AiTask
|
||||||
Match *Match
|
Match *Match
|
||||||
UserInput string
|
Req *ChatSockRequest
|
||||||
Auth string
|
Auth string
|
||||||
Ch chan Response
|
Ch chan Response
|
||||||
KnowledgeConf KnowledgeBaseRequest
|
KnowledgeConf KnowledgeBaseRequest
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: requireData.UserInput,
|
Content: requireData.Req.Text,
|
||||||
},
|
},
|
||||||
}, w.Name(), "")
|
}, w.Name(), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -2,53 +2,38 @@ package tools_bot
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
|
errors "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BotTool struct {
|
type BotTool struct {
|
||||||
config config.ToolConfig
|
config *config.Config
|
||||||
llm *utils_ollama.Client
|
llm *utils_ollama.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBotTool 创建直连天下订单详情工具
|
// NewBotTool 创建直连天下订单详情工具
|
||||||
func NewBotTool(config config.ToolConfig, llm *utils_ollama.Client) *BotTool {
|
func NewBotTool(config *config.Config, llm *utils_ollama.Client) *BotTool {
|
||||||
return &BotTool{config: config, llm: llm}
|
return &BotTool{config: config, llm: llm}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name 返回工具名称
|
|
||||||
func (w *BotTool) Name() string {
|
|
||||||
return "DingTalkBotTool"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Description 返回工具描述
|
|
||||||
func (w *BotTool) Description() string {
|
|
||||||
return "钉钉机器人调用"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Definition 返回工具定义
|
|
||||||
func (w *BotTool) Definition() entitys.ToolDefinition {
|
|
||||||
return entitys.ToolDefinition{
|
|
||||||
Type: "function",
|
|
||||||
Function: entitys.FunctionDef{
|
|
||||||
Name: w.Name(),
|
|
||||||
Description: w.Description(),
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"number": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "订单编号/流水号",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": []string{"number"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute 执行直连天下订单详情查询
|
// Execute 执行直连天下订单详情查询
|
||||||
func (w *BotTool) Execute(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) {
|
||||||
|
switch toolName {
|
||||||
|
case constants.BotToolsBugOptimizationSubmit:
|
||||||
|
err = w.BugOptimizationSubmit(ctx, requireData)
|
||||||
|
default:
|
||||||
|
log.Errorf("未知的工具类型:%s", toolName)
|
||||||
|
err = errors.ParamErr("未知的工具类型:%s", toolName)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BotTool) BugOptimizationSubmit(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
package tools_bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/google/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ProviderSetBotTools = wire.NewSet(
|
||||||
|
NewBotTool,
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue