Merge remote-tracking branch 'origin/v3' into feature/v3-fzy
This commit is contained in:
commit
42bc6ffd87
|
|
@ -2,7 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
func main() {
|
||||
configPath := flag.String("config", "./config/config_test.yaml", "Path to configuration file")
|
||||
onBot := flag.String("bot", "", "bot start")
|
||||
flag.Parse()
|
||||
bc, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
|
|
@ -23,8 +24,8 @@ func main() {
|
|||
}
|
||||
defer func() {
|
||||
cleanup()
|
||||
|
||||
}()
|
||||
app.DingBotServer.Run(context.Background(), *onBot)
|
||||
|
||||
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,3 +86,9 @@ default_prompt:
|
|||
# 权限配置
|
||||
permissionConfig:
|
||||
permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId="
|
||||
|
||||
|
||||
#ding_talk_bots:
|
||||
# public:
|
||||
# client_id: "dingchg59zwwvmuuvldx",
|
||||
# client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz",
|
||||
|
|
@ -14,8 +14,10 @@ fi
|
|||
|
||||
CONFIG_FILE="config/config.yaml"
|
||||
BRANCH="master"
|
||||
BOT="ALL"
|
||||
if [ "$MODE" = "dev" ]; then
|
||||
CONFIG_FILE="config/config_test.yaml"
|
||||
BOT="zltx"
|
||||
BRANCH="test"
|
||||
fi
|
||||
|
||||
|
|
@ -33,6 +35,6 @@ docker run -itd \
|
|||
-e "OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}" \
|
||||
-e "MODE=${MODE}" \
|
||||
-p 8090:8090 \
|
||||
"${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}"
|
||||
"${CONTAINER_NAME}" ./server --config "./${CONFIG_FILE}" --bot "./${BOT}"
|
||||
|
||||
docker logs -f ${CONTAINER_NAME}
|
||||
4
go.mod
4
go.mod
|
|
@ -13,6 +13,7 @@ require (
|
|||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/faabiosr/cachego v0.26.0
|
||||
github.com/fastwego/dingding v1.0.0-beta.4
|
||||
github.com/gabriel-vasile/mimetype v1.4.11
|
||||
github.com/go-kratos/kratos/v2 v2.9.1
|
||||
github.com/go-playground/locales v0.14.1
|
||||
github.com/go-playground/universal-translator v0.18.1
|
||||
|
|
@ -22,6 +23,7 @@ require (
|
|||
github.com/google/uuid v1.6.0
|
||||
github.com/google/wire v0.7.0
|
||||
github.com/ollama/ollama v0.12.7
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/redis/go-redis/v9 v9.16.0
|
||||
github.com/spf13/viper v1.17.0
|
||||
github.com/tmc/langchaingo v0.1.13
|
||||
|
|
@ -56,9 +58,9 @@ require (
|
|||
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
|
|
|
|||
8
go.sum
8
go.sum
|
|
@ -174,8 +174,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
|
|||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
|
||||
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
||||
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||
|
|
@ -267,6 +267,8 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR
|
|||
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
|
||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
|
|
@ -342,6 +344,8 @@ github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1ls
|
|||
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
|
|
|||
|
|
@ -26,10 +26,18 @@ func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *C
|
|||
|
||||
// 查询会话历史
|
||||
func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) {
|
||||
chats, err := s.chatHiRepo.FindAll(
|
||||
|
||||
con := []impl.CondFunc{
|
||||
s.chatHiRepo.WithSessionId(query.SessionID),
|
||||
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
|
||||
s.chatHiRepo.OrderByDesc("his_id"),
|
||||
}
|
||||
if query.HisID > 0 {
|
||||
con = append(con, s.chatHiRepo.WithHisId(query.HisID))
|
||||
}
|
||||
|
||||
chats, err := s.chatHiRepo.FindAll(
|
||||
con...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -0,0 +1,142 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz/do"
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/mapstructure"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// AiRouterBiz 智能路由服务
|
||||
type DingTalkBotBiz struct {
|
||||
do *do.Do
|
||||
handle *do.Handle
|
||||
botConfigImpl *impl.BotConfigImpl
|
||||
replier *chatbot.ChatbotReplier
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
// NewDingTalkBotBiz
|
||||
func NewDingTalkBotBiz(
|
||||
do *do.Do,
|
||||
handle *do.Handle,
|
||||
botConfigImpl *impl.BotConfigImpl,
|
||||
) *DingTalkBotBiz {
|
||||
return &DingTalkBotBiz{
|
||||
do: do,
|
||||
handle: handle,
|
||||
botConfigImpl: botConfigImpl,
|
||||
replier: chatbot.NewChatbotReplier(),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) GetDingTalkBotCfgList() (dingBotList []entitys.DingTalkBot, err error) {
|
||||
botConfig := make([]model.AiBotConfig, 0)
|
||||
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"status": constants.Enable})
|
||||
cond = cond.And(builder.Eq{"bot_type": constants.BotTypeDingTalk})
|
||||
err = d.botConfigImpl.GetRangeToMapStruct(&cond, &botConfig)
|
||||
for _, v := range botConfig {
|
||||
var config entitys.DingTalkBot
|
||||
err = mapstructure.Decode(v, &config)
|
||||
if err != nil {
|
||||
d.log.Info("初始化“%s”失败:%s", v.BotName, err.Error())
|
||||
}
|
||||
dingBotList = append(dingBotList, config)
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallbackDataModel) (requireData *entitys.RequireDataDingTalkBot, err error) {
|
||||
requireData = &entitys.RequireDataDingTalkBot{
|
||||
Req: data,
|
||||
Ch: make(chan entitys.Response, 2),
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等")
|
||||
|
||||
requireData.Sys, err = d.do.GetSysInfoForDingTalkBot(requireData)
|
||||
requireData.Tasks, err = d.do.GetTasks(requireData.Sys.SysID)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) Recognize(ctx context.Context, bot *chatbot.BotCallbackDataModel) (match entitys.Match, err error) {
|
||||
|
||||
return d.handle.Recognize(ctx, nil, &do.WithDingTalkBot{})
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) HandleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error {
|
||||
switch resp.Type {
|
||||
case entitys.ResponseText:
|
||||
return d.replyText(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseStream:
|
||||
return d.replySteam(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseImg:
|
||||
return d.replyImg(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseFile:
|
||||
return d.replyFile(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseMarkdown:
|
||||
return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content)
|
||||
case entitys.ResponseActionCard:
|
||||
return d.replyActionCard(ctx, data.SessionWebhook, resp.Content)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
||||
func (d *DingTalkBotBiz) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||
msg := content
|
||||
if len(arg) > 0 {
|
||||
msg = fmt.Sprintf(content, arg)
|
||||
}
|
||||
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||
}
|
||||
|
|
@ -79,6 +79,20 @@ func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData *
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *Do) DataAuthForBot(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) {
|
||||
|
||||
// 2. 加载系统信息
|
||||
if err = d.loadSystemInfo(ctx, client, requireData); err != nil {
|
||||
return fmt.Errorf("获取系统信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. 加载任务列表
|
||||
if err = d.loadTaskList(ctx, client, requireData); err != nil {
|
||||
return fmt.Errorf("获取任务列表失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 提取数据验证为单独函数
|
||||
func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error {
|
||||
requireData.Session = client.GetSession()
|
||||
|
|
@ -102,7 +116,7 @@ func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.Req
|
|||
// 获取系统信息的辅助函数
|
||||
func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
||||
if sysInfo := client.GetSysInfo(); sysInfo == nil {
|
||||
sys, err := d.getSysInfo(requireData)
|
||||
sys, err := d.GetSysInfo(requireData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -117,7 +131,7 @@ func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, require
|
|||
// 获取任务列表的辅助函数
|
||||
func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
||||
if taskInfo := client.GetTasks(); len(taskInfo) == 0 {
|
||||
tasks, err := d.getTasks(requireData.Sys.SysID)
|
||||
tasks, err := d.GetTasks(requireData.Sys.SysID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -200,7 +214,16 @@ func (d *Do) getRequireData() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
|
||||
func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"app_key": requireData.Key})
|
||||
cond = cond.And(builder.IsNull{"delete_at"})
|
||||
cond = cond.And(builder.Eq{"status": 1})
|
||||
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"app_key": requireData.Key})
|
||||
cond = cond.And(builder.IsNull{"delete_at"})
|
||||
|
|
@ -219,7 +242,7 @@ func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.Ai
|
|||
return
|
||||
}
|
||||
|
||||
func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||
func (d *Do) GetTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||||
|
|
|
|||
|
|
@ -45,23 +45,24 @@ func NewHandle(
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||||
entitys.ResLog(requireData.Ch, "recognize_start", "准备意图识别")
|
||||
|
||||
func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) {
|
||||
entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别")
|
||||
prompt, err := promptProcessor.CreatePrompt(ctx, rec)
|
||||
//意图识别
|
||||
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
|
||||
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{
|
||||
Prompt: prompt,
|
||||
Tools: rec.Tasks,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize", recognizeMsg)
|
||||
entitys.ResLog(requireData.Ch, "recognize_end", "意图识别结束")
|
||||
entitys.ResLog(rec.Ch, "recognize", recognizeMsg)
|
||||
entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束")
|
||||
|
||||
var match entitys.Match
|
||||
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
|
||||
err = errors.SysErr("数据结构错误:%v", err.Error())
|
||||
return
|
||||
}
|
||||
requireData.Match = &match
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,126 @@
|
|||
package do
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz/handle"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type PromptOption interface {
|
||||
CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error)
|
||||
}
|
||||
|
||||
type WithSys struct {
|
||||
}
|
||||
|
||||
func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
|
||||
var (
|
||||
prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片
|
||||
)
|
||||
// 获取用户内容,如果出错则直接返回错误
|
||||
content, err := f.getUserContent(ctx, rec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 构建提示消息列表,包含系统提示、助手回复和用户内容
|
||||
mes = append(prompt, api.Message{
|
||||
Role: "system", // 系统角色
|
||||
Content: rec.SystemPrompt, // 系统提示内容
|
||||
}, api.Message{
|
||||
Role: "assistant", // 助手角色
|
||||
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容
|
||||
}, api.Message{
|
||||
Role: "user", // 用户角色
|
||||
Content: content.String(), // 用户输入内容
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
|
||||
var hasFile bool
|
||||
if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 {
|
||||
hasFile = true
|
||||
}
|
||||
content.WriteString(rec.UserContent.Text)
|
||||
if hasFile {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(rec.UserContent.Tag) > 0 {
|
||||
content.WriteString("\n")
|
||||
content.WriteString("### 工具必须使用:")
|
||||
content.WriteString(rec.UserContent.Tag)
|
||||
}
|
||||
|
||||
if len(rec.ChatHis.Messages) > 0 {
|
||||
content.WriteString("### 引用历史聊天记录:\n")
|
||||
content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis))
|
||||
}
|
||||
|
||||
if hasFile {
|
||||
content.WriteString("\n")
|
||||
content.WriteString("### 文件内容:\n")
|
||||
for _, file := range rec.UserContent.File {
|
||||
handle.HandleRecognizeFile(file)
|
||||
}
|
||||
|
||||
//...do something with file
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type WithDingTalkBot struct {
|
||||
}
|
||||
|
||||
func (f *WithDingTalkBot) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
|
||||
var (
|
||||
prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片
|
||||
)
|
||||
// 获取用户内容,如果出错则直接返回错误
|
||||
content, err := f.getUserContent(ctx, rec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 构建提示消息列表,包含系统提示、助手回复和用户内容
|
||||
mes = append(prompt, api.Message{
|
||||
Role: "system", // 系统角色
|
||||
Content: rec.SystemPrompt, // 系统提示内容
|
||||
}, api.Message{
|
||||
Role: "assistant", // 助手角色
|
||||
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容
|
||||
}, api.Message{
|
||||
Role: "user", // 用户角色
|
||||
Content: content.String(), // 用户输入内容
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
|
||||
var hasFile bool
|
||||
if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 {
|
||||
hasFile = true
|
||||
}
|
||||
content.WriteString(rec.UserContent.Text)
|
||||
if hasFile {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(rec.UserContent.Tag) > 0 {
|
||||
content.WriteString("\n")
|
||||
content.WriteString("### 工具必须使用:")
|
||||
content.WriteString(rec.UserContent.Tag)
|
||||
}
|
||||
|
||||
if len(rec.ChatHis.Messages) > 0 {
|
||||
content.WriteString("### 引用历史聊天记录:\n")
|
||||
content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
package handle
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/l_request"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
)
|
||||
|
||||
// HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式
|
||||
// 判断文件大小
|
||||
// 判断文件类型
|
||||
// 判断文件是否合法
|
||||
func HandleRecognizeFile(files *entitys.RecognizeFile) {
|
||||
//Todo 仲云
|
||||
return
|
||||
}
|
||||
|
||||
// 下载文件并返回二进制数据、MIME 类型
|
||||
func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err error) {
|
||||
if len(fileUrl) == 0 {
|
||||
return
|
||||
}
|
||||
req := l_request.Request{
|
||||
Method: "GET",
|
||||
Url: fileUrl,
|
||||
Headers: map[string]string{
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
|
||||
},
|
||||
}
|
||||
res, err := req.Send()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var ex bool
|
||||
if contentType, ex = res.Headers["Content-Type"]; !ex {
|
||||
err = errors.New("Content-Type不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
err = fmt.Errorf("server returned non-200 status: %d", res.StatusCode)
|
||||
}
|
||||
fileBytes = res.Content
|
||||
|
||||
return fileBytes, contentType, nil
|
||||
}
|
||||
|
||||
// detectFileType 判断文件类型
|
||||
func detectFileType(file io.ReadSeeker, filename string) constants.FileType {
|
||||
// 1. 读取文件头检测 MIME
|
||||
buffer := make([]byte, 512)
|
||||
n, _ := file.Read(buffer)
|
||||
file.Seek(0, io.SeekStart) // 重置读取位置
|
||||
|
||||
detectedMIME := mimetype.Detect(buffer[:n]).String()
|
||||
for fileType, items := range constants.FileTypeMappings {
|
||||
for _, item := range items {
|
||||
if !strings.HasPrefix(item, ".") && item == detectedMIME {
|
||||
return fileType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 备用:通过扩展名检测
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
for fileType, items := range constants.FileTypeMappings {
|
||||
for _, item := range items {
|
||||
if strings.HasPrefix(item, ".") && item == ext {
|
||||
return fileType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return constants.FileTypeUnknown
|
||||
}
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
package handle
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/tools"
|
||||
)
|
||||
|
||||
type Handle struct {
|
||||
toolManager *tools.Manager
|
||||
conf *config.Config
|
||||
}
|
||||
|
||||
func NewHandle(
|
||||
toolManager *tools.Manager,
|
||||
conf *config.Config,
|
||||
|
||||
) *Handle {
|
||||
return &Handle{
|
||||
toolManager: toolManager,
|
||||
conf: conf,
|
||||
}
|
||||
}
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
package llm_service
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -20,48 +18,3 @@ func buildSystemPrompt(prompt string) string {
|
|||
|
||||
return prompt
|
||||
}
|
||||
|
||||
func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
||||
for _, item := range his {
|
||||
if len(chatHis.SessionId) == 0 {
|
||||
chatHis.SessionId = item.SessionID
|
||||
}
|
||||
chatHis.Messages = append(chatHis.Messages, []entitys.HisMessage{
|
||||
{
|
||||
Role: constants.RoleUser,
|
||||
Content: item.Ques,
|
||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
||||
},
|
||||
{
|
||||
Role: constants.RoleAssistant,
|
||||
Content: item.Ans,
|
||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
chatHis.Context = entitys.HisContext{
|
||||
UserLanguage: "zh-CN",
|
||||
SystemMode: "technical_support",
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func BuildChatHisMessage(his []model.AiChatHi) (chatHis []entitys.HisMessage) {
|
||||
for _, item := range his {
|
||||
|
||||
chatHis = append(chatHis, []entitys.HisMessage{
|
||||
{
|
||||
Role: constants.RoleUser,
|
||||
Content: item.Ques,
|
||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
||||
},
|
||||
{
|
||||
Role: constants.RoleAssistant,
|
||||
Content: item.Ans,
|
||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,87 +1,76 @@
|
|||
package llm_service
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/pkg/utils_langchain"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
)
|
||||
|
||||
type LangChainService struct {
|
||||
client *utils_langchain.UtilLangChain
|
||||
}
|
||||
|
||||
func NewLangChainGenerate(
|
||||
client *utils_langchain.UtilLangChain,
|
||||
) *LangChainService {
|
||||
|
||||
return &LangChainService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
|
||||
prompt := r.getPrompt(sysInfo, history, userInput, tasks)
|
||||
AgentClient := r.client.Get()
|
||||
defer r.client.Put(AgentClient)
|
||||
match, err := AgentClient.Llm.GenerateContent(
|
||||
ctx, // 使用可取消的上下文
|
||||
prompt,
|
||||
llms.WithJSONMode(),
|
||||
)
|
||||
msg = match.Choices[0].Content
|
||||
return
|
||||
}
|
||||
|
||||
func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||
var (
|
||||
prompt = make([]llms.MessageContent, 0)
|
||||
)
|
||||
prompt = append(prompt, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeSystem,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeTool,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeTool,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(reqInput),
|
||||
},
|
||||
})
|
||||
return prompt
|
||||
}
|
||||
|
||||
func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||
taskPrompt := make([]llms.Tool, 0)
|
||||
for _, task := range tasks {
|
||||
var taskConfig entitys.TaskConfig
|
||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
taskPrompt = append(taskPrompt, llms.Tool{
|
||||
Type: "function",
|
||||
Function: &llms.FunctionDefinition{
|
||||
Name: task.Index,
|
||||
Description: task.Desc,
|
||||
Parameters: taskConfig.Param,
|
||||
},
|
||||
})
|
||||
|
||||
}
|
||||
return taskPrompt
|
||||
}
|
||||
//type LangChainService struct {
|
||||
// client *utils_langchain.UtilLangChain
|
||||
//}
|
||||
//
|
||||
//func NewLangChainGenerate(
|
||||
// client *utils_langchain.UtilLangChain,
|
||||
//) *LangChainService {
|
||||
//
|
||||
// return &LangChainService{
|
||||
// client: client,
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
|
||||
// prompt := r.getPrompt(sysInfo, history, userInput, tasks)
|
||||
// AgentClient := r.client.Get()
|
||||
// defer r.client.Put(AgentClient)
|
||||
// match, err := AgentClient.Llm.GenerateContent(
|
||||
// ctx, // 使用可取消的上下文
|
||||
// prompt,
|
||||
// llms.WithJSONMode(),
|
||||
// )
|
||||
// msg = match.Choices[0].Content
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||
// var (
|
||||
// prompt = make([]llms.MessageContent, 0)
|
||||
// )
|
||||
// prompt = append(prompt, llms.MessageContent{
|
||||
// Role: llms.ChatMessageTypeSystem,
|
||||
// Parts: []llms.ContentPart{
|
||||
// llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
|
||||
// },
|
||||
// }, llms.MessageContent{
|
||||
// Role: llms.ChatMessageTypeTool,
|
||||
// Parts: []llms.ContentPart{
|
||||
// llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
|
||||
// },
|
||||
// }, llms.MessageContent{
|
||||
// Role: llms.ChatMessageTypeTool,
|
||||
// Parts: []llms.ContentPart{
|
||||
// llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||
// },
|
||||
// }, llms.MessageContent{
|
||||
// Role: llms.ChatMessageTypeHuman,
|
||||
// Parts: []llms.ContentPart{
|
||||
// llms.TextPart(reqInput),
|
||||
// },
|
||||
// })
|
||||
// return prompt
|
||||
//}
|
||||
//
|
||||
//func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||
// taskPrompt := make([]llms.Tool, 0)
|
||||
// for _, task := range tasks {
|
||||
// var taskConfig entitys.TaskConfig
|
||||
// err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// taskPrompt = append(taskPrompt, llms.Tool{
|
||||
// Type: "function",
|
||||
// Function: &llms.FunctionDefinition{
|
||||
// Name: task.Index,
|
||||
// Description: task.Desc,
|
||||
// Parameters: taskConfig.Param,
|
||||
// },
|
||||
// })
|
||||
//
|
||||
// }
|
||||
// return taskPrompt
|
||||
//}
|
||||
|
|
|
|||
|
|
@ -3,19 +3,15 @@ package llm_service
|
|||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"ai_scheduler/internal/pkg/utils_vllm"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type OllamaService struct {
|
||||
|
|
@ -39,14 +35,10 @@ func NewOllamaGenerate(
|
|||
}
|
||||
}
|
||||
|
||||
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
|
||||
prompt, err := r.getPrompt(ctx, requireData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
|
||||
func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
|
||||
|
||||
match, err := r.client.ToolSelect(ctx, prompt, toolDefinitions)
|
||||
toolDefinitions := r.registerToolsOllama(req.Tools)
|
||||
match, err := r.client.ToolSelect(ctx, req.Prompt, toolDefinitions)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -70,130 +62,63 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity
|
|||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
|
||||
//func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
|
||||
// if imgByte == nil {
|
||||
// return
|
||||
// }
|
||||
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
//
|
||||
// desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
||||
// Model: r.config.Ollama.VlModel,
|
||||
// Stream: new(bool),
|
||||
// System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
// Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
// Images: requireData.ImgByte,
|
||||
// KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
||||
// //Think: &api.ThinkValue{Value: false},
|
||||
// })
|
||||
// if err != nil {
|
||||
// return
|
||||
// }
|
||||
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||
// return
|
||||
//}
|
||||
|
||||
var (
|
||||
prompt = make([]api.Message, 0)
|
||||
)
|
||||
content, err := r.getUserContent(ctx, requireData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompt = append(prompt, api.Message{
|
||||
Role: "system",
|
||||
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
|
||||
}, api.Message{
|
||||
Role: "assistant",
|
||||
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)),
|
||||
}, api.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
})
|
||||
//func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||
// if requireData.ImgByte == nil {
|
||||
// return
|
||||
// }
|
||||
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
//
|
||||
// outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
|
||||
// r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
// r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
// requireData.ImgUrls,
|
||||
// )
|
||||
// if err != nil {
|
||||
// return api.GenerateResponse{}, err
|
||||
// }
|
||||
//
|
||||
// desc = api.GenerateResponse{
|
||||
// Response: outMsg.Content,
|
||||
// }
|
||||
//
|
||||
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||
// return
|
||||
//}
|
||||
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) {
|
||||
var content strings.Builder
|
||||
content.WriteString(requireData.Req.Text)
|
||||
if len(requireData.ImgByte) > 0 {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(requireData.Req.Tags) > 0 {
|
||||
content.WriteString("\n")
|
||||
content.WriteString("### 工具必须使用:")
|
||||
content.WriteString(requireData.Req.Tags)
|
||||
}
|
||||
|
||||
if len(requireData.ImgByte) > 0 {
|
||||
// desc, err := r.RecognizeWithImg(ctx, requireData)
|
||||
desc, err := r.RecognizeWithImgVllm(ctx, requireData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
content.WriteString("### 上传图片解析内容:\n")
|
||||
content.WriteString(requireData.Req.Tags)
|
||||
content.WriteString(desc.Response)
|
||||
}
|
||||
|
||||
if requireData.Req.MarkHis > 0 {
|
||||
var his model.AiChatHi
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis})
|
||||
err := r.chatHis.GetOneBySearchToStrut(&cond, &his)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
content.WriteString("### 引用历史聊天记录:\n")
|
||||
content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his})))
|
||||
}
|
||||
return content.String(), nil
|
||||
}
|
||||
|
||||
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||
if requireData.ImgByte == nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
|
||||
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
||||
Model: r.config.Ollama.VlModel,
|
||||
Stream: new(bool),
|
||||
System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
Images: requireData.ImgByte,
|
||||
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
||||
//Think: &api.ThinkValue{Value: false},
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||
if requireData.ImgByte == nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
|
||||
outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
|
||||
r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
requireData.ImgUrls,
|
||||
)
|
||||
if err != nil {
|
||||
return api.GenerateResponse{}, err
|
||||
}
|
||||
|
||||
desc = api.GenerateResponse{
|
||||
Response: outMsg.Content,
|
||||
}
|
||||
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
||||
func (r *OllamaService) registerToolsOllama(tasks []entitys.RegistrationTask) []api.Tool {
|
||||
taskPrompt := make([]api.Tool, 0)
|
||||
for _, task := range tasks {
|
||||
var taskConfig entitys.TaskConfigDetail
|
||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
taskPrompt = append(taskPrompt, api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: task.Index,
|
||||
Name: task.Name,
|
||||
Description: task.Desc,
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: taskConfig.Param.Type,
|
||||
Required: taskConfig.Param.Required,
|
||||
Properties: taskConfig.Param.Properties,
|
||||
Type: task.TaskConfigDetail.Param.Type,
|
||||
Required: task.TaskConfigDetail.Param.Required,
|
||||
Properties: task.TaskConfigDetail.Param.Properties,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package biz
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/biz/do"
|
||||
"ai_scheduler/internal/biz/handle"
|
||||
"ai_scheduler/internal/biz/llm_service"
|
||||
|
||||
"github.com/google/wire"
|
||||
|
|
@ -12,10 +11,11 @@ var ProviderSetBiz = wire.NewSet(
|
|||
NewAiRouterBiz,
|
||||
NewSessionBiz,
|
||||
NewChatHistoryBiz,
|
||||
llm_service.NewLangChainGenerate,
|
||||
//llm_service.NewLangChainGenerate,
|
||||
llm_service.NewOllamaGenerate,
|
||||
handle.NewHandle,
|
||||
//handle.NewHandle,
|
||||
do.NewDo,
|
||||
do.NewHandle,
|
||||
NewTaskBiz,
|
||||
NewDingTalkBotBiz,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package biz
|
|||
import (
|
||||
"ai_scheduler/internal/biz/do"
|
||||
"ai_scheduler/internal/gateway"
|
||||
"context"
|
||||
|
||||
"ai_scheduler/internal/entitys"
|
||||
|
||||
|
|
@ -54,9 +55,15 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
|
|||
log.Errorf("数据验证和收集失败: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
//组装意图识别
|
||||
rec, err := r.SetRec(ctx, requireData)
|
||||
if err != nil {
|
||||
log.Errorf("组装意图识别失败: %s", err.Error())
|
||||
return
|
||||
}
|
||||
//意图识别
|
||||
if err = r.handle.Recognize(ctx, requireData); err != nil {
|
||||
requireData.Match, err = r.handle.Recognize(ctx, &rec, &do.WithSys{})
|
||||
if err != nil {
|
||||
log.Errorf("意图识别失败: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
|
@ -68,3 +75,8 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, err error) {
|
||||
//TODO 叙平
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,17 +9,17 @@ import (
|
|||
|
||||
// Config 应用配置
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Ollama OllamaConfig `mapstructure:"ollama"`
|
||||
Vllm VllmConfig `mapstructure:"vllm"`
|
||||
Sys SysConfig `mapstructure:"sys"`
|
||||
Tools ToolsConfig `mapstructure:"tools"`
|
||||
Logging LoggingConfig `mapstructure:"logging"`
|
||||
Redis Redis `mapstructure:"redis"`
|
||||
DB DB `mapstructure:"db"`
|
||||
DefaultPrompt SysPrompt `mapstructure:"default_prompt"`
|
||||
PermissionConfig PermissionConfig `mapstructure:"permissionConfig"`
|
||||
// LLM *LLM `mapstructure:"llm"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Ollama OllamaConfig `mapstructure:"ollama"`
|
||||
Vllm VllmConfig `mapstructure:"vllm"`
|
||||
Sys SysConfig `mapstructure:"sys"`
|
||||
Tools ToolsConfig `mapstructure:"tools"`
|
||||
Logging LoggingConfig `mapstructure:"logging"`
|
||||
Redis Redis `mapstructure:"redis"`
|
||||
DB DB `mapstructure:"db"`
|
||||
DefaultPrompt SysPrompt `mapstructure:"default_prompt"`
|
||||
PermissionConfig PermissionConfig `mapstructure:"permissionConfig"`
|
||||
// DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"`
|
||||
}
|
||||
|
||||
type SysPrompt struct {
|
||||
|
|
@ -52,18 +52,18 @@ type ServerConfig struct {
|
|||
|
||||
// OllamaConfig Ollama配置
|
||||
type OllamaConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Model string `mapstructure:"model"`
|
||||
GenerateModel string `mapstructure:"generate_model"`
|
||||
VlModel string `mapstructure:"vl_model"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Model string `mapstructure:"model"`
|
||||
GenerateModel string `mapstructure:"generate_model"`
|
||||
VlModel string `mapstructure:"vl_model"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
}
|
||||
|
||||
type VllmConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
VlModel string `mapstructure:"vl_model"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
Level string `mapstructure:"level"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
VlModel string `mapstructure:"vl_model"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
Level string `mapstructure:"level"`
|
||||
}
|
||||
|
||||
type Redis struct {
|
||||
|
|
|
|||
|
|
@ -3,5 +3,31 @@ package constants
|
|||
type BotTools string
|
||||
|
||||
const (
|
||||
BotToolsBugOptimizationSubmit = "bug_optimization_submit" // 系统的bug/优化建议
|
||||
BotToolsBugOptimizationSubmit BotTools = "bug_optimization_submit" // 系统的bug/优化建议
|
||||
)
|
||||
|
||||
type ChatStyle int
|
||||
|
||||
const (
|
||||
ChatStyleNormal ChatStyle = 1 //正常
|
||||
ChatStyleSerious ChatStyle = 2 //严肃
|
||||
ChatStyleGentle ChatStyle = 3 //温柔
|
||||
ChatStyleArrogance ChatStyle = 4 //傲慢
|
||||
ChatStyleCute ChatStyle = 5 //可爱
|
||||
ChatStyleAngry ChatStyle = 6 //愤怒
|
||||
)
|
||||
|
||||
var ChatStyleMap = map[ChatStyle]string{
|
||||
ChatStyleNormal: "正常",
|
||||
ChatStyleSerious: "严肃",
|
||||
ChatStyleGentle: "温柔",
|
||||
ChatStyleArrogance: "傲慢",
|
||||
ChatStyleCute: "可爱",
|
||||
ChatStyleAngry: "愤怒",
|
||||
}
|
||||
|
||||
type BotType int
|
||||
|
||||
const (
|
||||
BotTypeDingTalk BotType = 1 // 系统的bug/优化建议
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
package constants
|
||||
|
||||
type FileType string
|
||||
|
||||
const (
|
||||
FileTypeUnknown FileType = "unknown"
|
||||
FileTypeImage FileType = "image"
|
||||
//FileTypeVideo FileType = "video"
|
||||
FileTypeExcel FileType = "excel"
|
||||
FileTypeWord FileType = "word"
|
||||
FileTypeTxt FileType = "txt"
|
||||
FileTypePDF FileType = "pdf"
|
||||
)
|
||||
|
||||
var FileTypeMappings = map[FileType][]string{
|
||||
FileTypeImage: {
|
||||
"image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml",
|
||||
".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg",
|
||||
},
|
||||
FileTypeExcel: {
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".xls", ".xlsx",
|
||||
},
|
||||
FileTypeWord: {
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc", ".docx",
|
||||
},
|
||||
FileTypePDF: {
|
||||
"application/pdf",
|
||||
".pdf",
|
||||
},
|
||||
FileTypeTxt: {
|
||||
"text/plain",
|
||||
".txt",
|
||||
},
|
||||
}
|
||||
|
|
@ -22,7 +22,7 @@ BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。
|
|||
// 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题
|
||||
type PO interface {
|
||||
model.AiChatHi |
|
||||
model.AiSy | model.AiSession | model.AiTask | model.AiBot
|
||||
model.AiSy | model.AiSession | model.AiTask | model.AiBotConfig
|
||||
}
|
||||
|
||||
type BaseModel[P PO] struct {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/tmpl/dataTemp"
|
||||
"ai_scheduler/utils"
|
||||
)
|
||||
|
||||
type BotChatHisImpl struct {
|
||||
dataTemp.DataTemp
|
||||
}
|
||||
|
||||
func NewBotChatHisImpl(db *utils.Db) *BotChatHisImpl {
|
||||
return &BotChatHisImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))}
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/tmpl/dataTemp"
|
||||
"ai_scheduler/utils"
|
||||
)
|
||||
|
||||
type BotConfigImpl struct {
|
||||
dataTemp.DataTemp
|
||||
}
|
||||
|
||||
func NewBotConfigImpl(db *utils.Db) *BotConfigImpl {
|
||||
return &BotConfigImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotConfig)),
|
||||
}
|
||||
}
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/tmpl/dataTemp"
|
||||
"ai_scheduler/utils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type BotImpl struct {
|
||||
dataTemp.DataTemp
|
||||
BaseRepository[model.AiBot]
|
||||
}
|
||||
|
||||
func NewBotImpl(db *utils.Db) *BotImpl {
|
||||
return &BotImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBot)),
|
||||
BaseRepository: NewBaseModel[model.AiBot](db.Client),
|
||||
}
|
||||
}
|
||||
|
||||
// WithSysId 系统id
|
||||
func (s *BotImpl) WithSysId(sysId interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("sys_id = ?", sysId)
|
||||
}
|
||||
}
|
||||
|
|
@ -4,4 +4,10 @@ import (
|
|||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl)
|
||||
var ProviderImpl = wire.NewSet(
|
||||
NewSessionImpl,
|
||||
NewSysImpl,
|
||||
NewTaskImpl,
|
||||
NewChatHisImpl,
|
||||
NewBotConfigImpl,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const TableNameAiBotChatHi = "ai_bot_chat_his"
|
||||
|
||||
// AiBotChatHi mapped from table <ai_bot_chat_his>
|
||||
type AiBotChatHi struct {
|
||||
HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"`
|
||||
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
|
||||
Role string `gorm:"column:role;not null;comment:system系统输出,assistant助手输出,user用户输入" json:"role"` // system系统输出,assistant助手输出,user用户输入
|
||||
Content string `gorm:"column:content;not null" json:"content"`
|
||||
Files string `gorm:"column:files;not null" json:"files"`
|
||||
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName AiBotChatHi's table name
|
||||
func (*AiBotChatHi) TableName() string {
|
||||
return TableNameAiBotChatHi
|
||||
}
|
||||
|
|
@ -8,13 +8,13 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
const TableNameAiBot = "ai_bot"
|
||||
const TableNameAiBotConfig = "ai_bot_config"
|
||||
|
||||
// AiBot mapped from table <ai_bot>
|
||||
type AiBot struct {
|
||||
// AiBotConfig mapped from table <ai_bot_config>
|
||||
type AiBotConfig struct {
|
||||
BotID int32 `gorm:"column:bot_id;primaryKey;autoIncrement:true" json:"bot_id"`
|
||||
SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"`
|
||||
BotType int32 `gorm:"column:bot_type" json:"bot_type"`
|
||||
BotType int32 `gorm:"column:bot_type;not null;default:1" json:"bot_type"`
|
||||
BotName string `gorm:"column:bot_name;not null" json:"bot_name"`
|
||||
BotConfig string `gorm:"column:bot_config;not null" json:"bot_config"`
|
||||
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||
|
|
@ -23,7 +23,7 @@ type AiBot struct {
|
|||
DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"`
|
||||
}
|
||||
|
||||
// TableName AiBot's table name
|
||||
func (*AiBot) TableName() string {
|
||||
return TableNameAiBot
|
||||
// TableName AiBotConfig's table name
|
||||
func (*AiBotConfig) TableName() string {
|
||||
return TableNameAiBotConfig
|
||||
}
|
||||
|
|
@ -1,7 +1,30 @@
|
|||
package entitys
|
||||
|
||||
type BotType int
|
||||
import (
|
||||
"ai_scheduler/internal/data/model"
|
||||
|
||||
const (
|
||||
BugAndQuesDingTalk BotType = iota + 1
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
)
|
||||
|
||||
type RequireDataDingTalkBot struct {
|
||||
Session string
|
||||
Key string
|
||||
Sys model.AiSy
|
||||
Histories []model.AiChatHi
|
||||
SessionInfo model.AiSession
|
||||
Tasks []model.AiTask
|
||||
Match *Match
|
||||
Req *chatbot.BotCallbackDataModel
|
||||
Auth string
|
||||
Ch chan Response
|
||||
KnowledgeConf KnowledgeBaseRequest
|
||||
ImgByte []api.ImageData
|
||||
ImgUrls []string
|
||||
}
|
||||
|
||||
type DingTalkBot struct {
|
||||
BotIndex string
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ type ChatHisLog struct {
|
|||
}
|
||||
|
||||
type ChatHistQuery struct {
|
||||
HisID int64 `json:"his_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
package entitys
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type ToolSelect struct {
|
||||
Prompt []api.Message
|
||||
Tools []RegistrationTask
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
package entitys
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
)
|
||||
|
||||
type Recognize struct {
|
||||
SystemPrompt string
|
||||
UserContent *RecognizeUserContent
|
||||
ChatHis ChatHis
|
||||
Tasks []RegistrationTask
|
||||
Ch chan Response
|
||||
}
|
||||
|
||||
type RegistrationTask struct {
|
||||
Name string
|
||||
Desc string
|
||||
TaskConfigDetail TaskConfigDetail
|
||||
}
|
||||
|
||||
type RecognizeUserContent struct {
|
||||
Text string
|
||||
File []*RecognizeFile
|
||||
ActionCardUrl string
|
||||
Tag string
|
||||
}
|
||||
|
||||
type FileData []byte
|
||||
|
||||
type RecognizeFile struct {
|
||||
File []FileData // 文件数据(二进制格式)
|
||||
FileUrl string // 文件下载链接
|
||||
FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断)
|
||||
}
|
||||
|
|
@ -3,22 +3,25 @@ package entitys
|
|||
import (
|
||||
"ai_scheduler/internal/gateway"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
)
|
||||
|
||||
type ResponseType string
|
||||
|
||||
const (
|
||||
ResponseJson ResponseType = "json"
|
||||
ResponseLoading ResponseType = "loading"
|
||||
ResponseEnd ResponseType = "end"
|
||||
ResponseStream ResponseType = "stream"
|
||||
ResponseText ResponseType = "txt"
|
||||
ResponseImg ResponseType = "img"
|
||||
ResponseFile ResponseType = "file"
|
||||
ResponseErr ResponseType = "error"
|
||||
ResponseLog ResponseType = "log"
|
||||
ResponseAuth ResponseType = "auth"
|
||||
ResponseJson ResponseType = "json"
|
||||
ResponseLoading ResponseType = "loading"
|
||||
ResponseEnd ResponseType = "end"
|
||||
ResponseStream ResponseType = "stream"
|
||||
ResponseText ResponseType = "txt"
|
||||
ResponseImg ResponseType = "img"
|
||||
ResponseFile ResponseType = "file"
|
||||
ResponseErr ResponseType = "error"
|
||||
ResponseLog ResponseType = "log"
|
||||
ResponseAuth ResponseType = "auth"
|
||||
ResponseMarkdown ResponseType = "markdown"
|
||||
ResponseActionCard ResponseType = "actionCard"
|
||||
)
|
||||
|
||||
func ResLog(ch chan Response, index string, content string) {
|
||||
|
|
@ -46,6 +49,9 @@ func ResJson(ch chan Response, index string, content string) {
|
|||
}
|
||||
|
||||
func ResEnd(ch chan Response, index string, content string) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
ch <- Response{
|
||||
Index: index,
|
||||
Content: content,
|
||||
|
|
@ -54,6 +60,9 @@ func ResEnd(ch chan Response, index string, content string) {
|
|||
}
|
||||
|
||||
func ResText(ch chan Response, index string, content string) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
ch <- Response{
|
||||
Index: index,
|
||||
Content: content,
|
||||
|
|
@ -62,6 +71,9 @@ func ResText(ch chan Response, index string, content string) {
|
|||
}
|
||||
|
||||
func ResLoading(ch chan Response, index string, content string) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
ch <- Response{
|
||||
Index: index,
|
||||
Content: content,
|
||||
|
|
@ -69,6 +81,9 @@ func ResLoading(ch chan Response, index string, content string) {
|
|||
}
|
||||
}
|
||||
func ResError(ch chan Response, index string, content string) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
ch <- Response{
|
||||
Index: index,
|
||||
Content: content,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package entitys
|
|||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/data/model"
|
||||
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
|
|
@ -151,7 +150,7 @@ type RequireData struct {
|
|||
SessionInfo model.AiSession
|
||||
Tasks []model.AiTask
|
||||
Task model.AiTask
|
||||
Match *Match
|
||||
Match Match
|
||||
Req *ChatSockRequest
|
||||
Auth string
|
||||
Ch chan Response
|
||||
|
|
|
|||
|
|
@ -0,0 +1,75 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/services"
|
||||
"context"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||
)
|
||||
|
||||
type DingBotServiceInterface interface {
|
||||
GetServiceCfg() ([]entitys.DingTalkBot, error)
|
||||
OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error)
|
||||
}
|
||||
|
||||
type DingTalkBotServer struct {
|
||||
Clients map[string]*client.StreamClient
|
||||
}
|
||||
|
||||
// NewDingTalkBotServer 批量注册钉钉客户端cli
|
||||
// 这里支持两种方式,一种是完全独立service,一种是直接用现成的service
|
||||
// 独立的service,在本页的ProvideAllDingBotServices方法进行注册
|
||||
// 现成的service参考services->dtalk_bot.go
|
||||
// 具体使用请根据实际业务需求
|
||||
func NewDingTalkBotServer(
|
||||
services []DingBotServiceInterface,
|
||||
) *DingTalkBotServer {
|
||||
clients := make(map[string]*client.StreamClient)
|
||||
for _, service := range services {
|
||||
serviceConfigs, err := service.GetServiceCfg()
|
||||
for _, serviceConf := range serviceConfigs {
|
||||
if serviceConf.ClientId == "" || serviceConf.ClientSecret == "" {
|
||||
continue
|
||||
}
|
||||
cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service)
|
||||
if cli == nil {
|
||||
log.Info("%s客户端初始失败:%s", serviceConf.BotIndex, err.Error())
|
||||
continue
|
||||
}
|
||||
clients[serviceConf.BotIndex] = cli
|
||||
}
|
||||
}
|
||||
return &DingTalkBotServer{
|
||||
Clients: clients,
|
||||
}
|
||||
}
|
||||
|
||||
func ProvideAllDingBotServices(
|
||||
dingBotSvc *services.DingBotService,
|
||||
) []DingBotServiceInterface {
|
||||
return []DingBotServiceInterface{dingBotSvc}
|
||||
}
|
||||
|
||||
func (d *DingTalkBotServer) Run(ctx context.Context, botIndex string) {
|
||||
for name, cli := range d.Clients {
|
||||
if botIndex != "All" {
|
||||
if name != botIndex {
|
||||
continue
|
||||
}
|
||||
}
|
||||
err := cli.Start(ctx)
|
||||
if err != nil {
|
||||
log.Info("%s启动失败", name)
|
||||
continue
|
||||
}
|
||||
log.Info("%s启动成功", name)
|
||||
}
|
||||
}
|
||||
func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) {
|
||||
cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret)))
|
||||
cli.RegisterChatBotCallbackRouter(service.OnChatBotMessageReceived)
|
||||
return
|
||||
}
|
||||
|
|
@ -1,5 +1,12 @@
|
|||
package server
|
||||
|
||||
import "github.com/google/wire"
|
||||
import (
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetServer = wire.NewSet(NewServers, NewHTTPServer)
|
||||
var ProviderSetServer = wire.NewSet(
|
||||
NewServers,
|
||||
NewHTTPServer,
|
||||
ProvideAllDingBotServices,
|
||||
NewDingTalkBotServer,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,27 @@
|
|||
package server
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type Servers struct {
|
||||
HttpServer *fiber.App
|
||||
cfg *config.Config
|
||||
HttpServer *fiber.App
|
||||
DingBotServer *DingTalkBotServer
|
||||
}
|
||||
|
||||
func NewServers(fiber *fiber.App) *Servers {
|
||||
func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer) *Servers {
|
||||
return &Servers{
|
||||
HttpServer: fiber,
|
||||
HttpServer: fiber,
|
||||
cfg: cfg,
|
||||
DingBotServer: DingBotServer,
|
||||
}
|
||||
}
|
||||
|
||||
//func DingBotServerInit(clientId string, clientSecret string, cfg *config.Config, handler *do.Handle, do *do.Do) (cli *client.StreamClient) {
|
||||
// cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret)))
|
||||
// cli.RegisterChatBotCallbackRouter(services.NewDingBotService(cfg, handler, do).OnChatBotMessageReceived)
|
||||
// return
|
||||
//}
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ func (s *CallbackService) sendStreamLog(sessionID string, content string) {
|
|||
}
|
||||
|
||||
streamLog := entitys.Response{
|
||||
Index: constants.BotToolsBugOptimizationSubmit,
|
||||
Index: string(constants.BotToolsBugOptimizationSubmit),
|
||||
Content: content,
|
||||
Type: entitys.ResponseLog,
|
||||
}
|
||||
|
|
@ -227,7 +227,7 @@ func (s *CallbackService) sendStreamTxt(sessionID string, content string) {
|
|||
}
|
||||
|
||||
streamLog := entitys.Response{
|
||||
Index: constants.BotToolsBugOptimizationSubmit,
|
||||
Index: string(constants.BotToolsBugOptimizationSubmit),
|
||||
Content: content,
|
||||
Type: entitys.ResponseText,
|
||||
}
|
||||
|
|
@ -242,7 +242,7 @@ func (s *CallbackService) sendStreamLoading(sessionID string, content string) {
|
|||
}
|
||||
|
||||
streamLog := entitys.Response{
|
||||
Index: constants.BotToolsBugOptimizationSubmit,
|
||||
Index: string(constants.BotToolsBugOptimizationSubmit),
|
||||
Content: content,
|
||||
Type: entitys.ResponseLoading,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
"ai_scheduler/internal/config"
|
||||
|
||||
"ai_scheduler/internal/entitys"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
)
|
||||
|
||||
type DingBotService struct {
|
||||
config *config.Config
|
||||
|
||||
dingTalkBotBiz *biz.DingTalkBotBiz
|
||||
}
|
||||
|
||||
func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService {
|
||||
return &DingBotService{config: config, dingTalkBotBiz: DingTalkBotBiz}
|
||||
}
|
||||
|
||||
func (d *DingBotService) GetServiceCfg() ([]entitys.DingTalkBot, error) {
|
||||
return d.dingTalkBotBiz.GetDingTalkBotCfgList()
|
||||
}
|
||||
|
||||
func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) {
|
||||
requireData, err := d.dingTalkBotBiz.InitRequire(ctx, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer close(requireData.Ch)
|
||||
//if match, _err := d.dingTalkBotBiz.Recognize(ctx, data); _err != nil {
|
||||
// requireData.Ch <- entitys.Response{
|
||||
// Type: entitys.ResponseEnd,
|
||||
// Content: fmt.Sprintf("处理消息时出错: %s", _err.Error()),
|
||||
// }
|
||||
//}
|
||||
////向下传递
|
||||
//if err = d.dingTalkBotBiz.HandleMatch(ctx, nil, requireData); err != nil {
|
||||
// requireData.Ch <- entitys.Response{
|
||||
// Type: entitys.ResponseEnd,
|
||||
// Content: fmt.Sprintf("匹配失败: %v", err),
|
||||
// }
|
||||
//}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp, ok := <-requireData.Ch:
|
||||
if !ok {
|
||||
return []byte("success"), nil // 通道关闭,处理完成
|
||||
}
|
||||
|
||||
if resp.Type == entitys.ResponseLog {
|
||||
return
|
||||
}
|
||||
if err := d.dingTalkBotBiz.HandleRes(ctx, data, resp); err != nil {
|
||||
return nil, fmt.Errorf("回复失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -6,4 +6,10 @@ import (
|
|||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService)
|
||||
var ProviderSetServices = wire.NewSet(
|
||||
NewChatService,
|
||||
NewSessionService, gateway.NewGateway,
|
||||
NewTaskService,
|
||||
NewCallbackService,
|
||||
NewDingBotService,
|
||||
NewHistoryService)
|
||||
|
|
|
|||
|
|
@ -2,22 +2,18 @@ package tools_bot
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/constants"
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"context"
|
||||
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"context"
|
||||
)
|
||||
|
||||
type BotTool struct {
|
||||
config *config.Config
|
||||
llm *utils_ollama.Client
|
||||
sessionImpl *impl.SessionImpl
|
||||
taskMap map[string]string // task_id -> session_id
|
||||
// zltxOrderAfterSaleTool tools.ZltxOrderAfterSaleTool
|
||||
taskMap map[string]string
|
||||
}
|
||||
|
||||
// NewBotTool 创建直连天下订单详情工具
|
||||
|
|
@ -27,12 +23,5 @@ func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *im
|
|||
|
||||
// Execute 执行直连天下订单详情查询
|
||||
func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) {
|
||||
switch toolName {
|
||||
case constants.BotToolsBugOptimizationSubmit:
|
||||
err = w.BugOptimizationSubmit(ctx, requireData)
|
||||
default:
|
||||
log.Errorf("未知的工具类型:%s", toolName)
|
||||
err = errors.ParamErr("未知的工具类型:%s", toolName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue