Compare commits
3 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
f4980f1f22 | |
|
|
d509a18d44 | |
|
|
4339b6eee8 |
|
|
@ -2,7 +2,6 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
|
@ -23,8 +22,8 @@ func main() {
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
cleanup()
|
cleanup()
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
//app.DingBotServer.Run(context.Background())
|
||||||
|
//app.DingBotServer.RunBots(app.DingBotServer.BotServices)
|
||||||
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,12 @@ server:
|
||||||
port: 8090
|
port: 8090
|
||||||
host: "0.0.0.0"
|
host: "0.0.0.0"
|
||||||
|
|
||||||
|
|
||||||
ollama:
|
ollama:
|
||||||
base_url: "http://127.0.0.1:11434"
|
base_url: "http://127.0.0.1:11434"
|
||||||
model: "qwen3-coder:480b-cloud"
|
model: "qwen3-coder:480b-cloud"
|
||||||
generate_model: "qwen3-coder:480b-cloud"
|
generate_model: "qwen3-coder:480b-cloud"
|
||||||
vl_model: "qwen2.5vl:7b"
|
vl_model: "gemini-3-pro-preview"
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
format: "json"
|
format: "json"
|
||||||
|
|
@ -19,6 +20,7 @@ sys:
|
||||||
channel_pool_len: 100
|
channel_pool_len: 100
|
||||||
channel_pool_size: 32
|
channel_pool_size: 32
|
||||||
llm_pool_len: 5
|
llm_pool_len: 5
|
||||||
|
heartbeat_interval: 30
|
||||||
redis:
|
redis:
|
||||||
host: 47.97.27.195:6379
|
host: 47.97.27.195:6379
|
||||||
type: node
|
type: node
|
||||||
|
|
@ -79,3 +81,9 @@ default_prompt:
|
||||||
# 权限配置
|
# 权限配置
|
||||||
permissionConfig:
|
permissionConfig:
|
||||||
permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId="
|
permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId="
|
||||||
|
|
||||||
|
|
||||||
|
ding_talk_bots:
|
||||||
|
public:
|
||||||
|
client_id: "dingchg59zwwvmuuvldx",
|
||||||
|
client_secret: "ZwetAnRiTQobNFVlNrshRagSMAJIFpBAepWkWI7on7Tt_o617KHtTjBLp8fQfplz",
|
||||||
11
go.mod
11
go.mod
|
|
@ -1,8 +1,6 @@
|
||||||
module ai_scheduler
|
module ai_scheduler
|
||||||
|
|
||||||
go 1.24.0
|
go 1.24.1
|
||||||
|
|
||||||
toolchain go1.24.7
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
gitea.cdlsxd.cn/self-tools/l_request v1.0.8
|
gitea.cdlsxd.cn/self-tools/l_request v1.0.8
|
||||||
|
|
@ -13,18 +11,18 @@ require (
|
||||||
github.com/emirpasic/gods v1.18.1
|
github.com/emirpasic/gods v1.18.1
|
||||||
github.com/faabiosr/cachego v0.26.0
|
github.com/faabiosr/cachego v0.26.0
|
||||||
github.com/fastwego/dingding v1.0.0-beta.4
|
github.com/fastwego/dingding v1.0.0-beta.4
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.11
|
||||||
github.com/go-kratos/kratos/v2 v2.9.1
|
github.com/go-kratos/kratos/v2 v2.9.1
|
||||||
github.com/gofiber/fiber/v2 v2.52.9
|
github.com/gofiber/fiber/v2 v2.52.9
|
||||||
github.com/gofiber/websocket/v2 v2.2.1
|
github.com/gofiber/websocket/v2 v2.2.1
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/google/wire v0.7.0
|
github.com/google/wire v0.7.0
|
||||||
github.com/ollama/ollama v0.12.7
|
github.com/ollama/ollama v0.12.7
|
||||||
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||||
github.com/redis/go-redis/v9 v9.16.0
|
github.com/redis/go-redis/v9 v9.16.0
|
||||||
github.com/spf13/viper v1.17.0
|
github.com/spf13/viper v1.17.0
|
||||||
github.com/tmc/langchaingo v0.1.13
|
github.com/tmc/langchaingo v0.1.13
|
||||||
google.golang.org/grpc v1.64.0
|
google.golang.org/grpc v1.64.0
|
||||||
google.golang.org/protobuf v1.34.1
|
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
|
||||||
gorm.io/driver/mysql v1.6.0
|
gorm.io/driver/mysql v1.6.0
|
||||||
gorm.io/gorm v1.31.0
|
gorm.io/gorm v1.31.0
|
||||||
xorm.io/builder v0.3.13
|
xorm.io/builder v0.3.13
|
||||||
|
|
@ -46,6 +44,7 @@ require (
|
||||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||||
|
github.com/gorilla/websocket v1.5.0 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
|
|
@ -82,5 +81,7 @@ require (
|
||||||
golang.org/x/sys v0.31.0 // indirect
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
golang.org/x/text v0.23.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect
|
||||||
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
|
||||||
6
go.sum
6
go.sum
|
|
@ -142,6 +142,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
|
||||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||||
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
||||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||||
|
|
@ -217,6 +219,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m
|
||||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
||||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||||
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||||
|
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||||
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||||
|
|
@ -273,6 +277,8 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108
|
||||||
github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0=
|
github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0=
|
||||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||||
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
||||||
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
package biz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/biz/do"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AiRouterBiz 智能路由服务
|
||||||
|
type DingTalkBotBiz struct {
|
||||||
|
do *do.Do
|
||||||
|
handle *do.Handle
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDingTalkBotBiz
|
||||||
|
func NewDingTalkBotBiz(
|
||||||
|
do *do.Do,
|
||||||
|
handle *do.Handle,
|
||||||
|
) *DingTalkBotBiz {
|
||||||
|
return &DingTalkBotBiz{
|
||||||
|
do: do,
|
||||||
|
handle: handle,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallbackDataModel) (requireData *entitys.RequireDataDingTalkBot, err error) {
|
||||||
|
requireData = &entitys.RequireDataDingTalkBot{
|
||||||
|
Req: data,
|
||||||
|
Ch: make(chan entitys.Response, 2),
|
||||||
|
}
|
||||||
|
entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等")
|
||||||
|
|
||||||
|
requireData.Sys, err = d.do.GetSysInfoForDingTalkBot(requireData)
|
||||||
|
requireData.Tasks, err = d.do.GetTasks(requireData.Sys.SysID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingTalkBotBiz) Recognize(ctx context.Context, rec *entitys.Recognize, ch chan entitys.Response) (match *entitys.Match, err error) {
|
||||||
|
return d.handle.RecognizeBot(ctx, rec, ch)
|
||||||
|
}
|
||||||
|
|
@ -81,6 +81,20 @@ func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData *
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Do) DataAuthForBot(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) {
|
||||||
|
|
||||||
|
// 2. 加载系统信息
|
||||||
|
if err = d.loadSystemInfo(ctx, client, requireData); err != nil {
|
||||||
|
return fmt.Errorf("获取系统信息失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 加载任务列表
|
||||||
|
if err = d.loadTaskList(ctx, client, requireData); err != nil {
|
||||||
|
return fmt.Errorf("获取任务列表失败: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// 提取数据验证为单独函数
|
// 提取数据验证为单独函数
|
||||||
func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error {
|
func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error {
|
||||||
requireData.Session = client.GetSession()
|
requireData.Session = client.GetSession()
|
||||||
|
|
@ -104,7 +118,7 @@ func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.Req
|
||||||
// 获取系统信息的辅助函数
|
// 获取系统信息的辅助函数
|
||||||
func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
||||||
if sysInfo := client.GetSysInfo(); sysInfo == nil {
|
if sysInfo := client.GetSysInfo(); sysInfo == nil {
|
||||||
sys, err := d.getSysInfo(requireData)
|
sys, err := d.GetSysInfo(requireData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -119,7 +133,7 @@ func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, require
|
||||||
// 获取任务列表的辅助函数
|
// 获取任务列表的辅助函数
|
||||||
func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
||||||
if taskInfo := client.GetTasks(); len(taskInfo) == 0 {
|
if taskInfo := client.GetTasks(); len(taskInfo) == 0 {
|
||||||
tasks, err := d.getTasks(requireData.Sys.SysID)
|
tasks, err := d.GetTasks(requireData.Sys.SysID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -202,7 +216,16 @@ func (d *Do) getRequireData() (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
|
func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
|
||||||
|
cond := builder.NewCond()
|
||||||
|
cond = cond.And(builder.Eq{"app_key": requireData.Key})
|
||||||
|
cond = cond.And(builder.IsNull{"delete_at"})
|
||||||
|
cond = cond.And(builder.Eq{"status": 1})
|
||||||
|
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) {
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
cond = cond.And(builder.Eq{"app_key": requireData.Key})
|
cond = cond.And(builder.Eq{"app_key": requireData.Key})
|
||||||
cond = cond.And(builder.IsNull{"delete_at"})
|
cond = cond.And(builder.IsNull{"delete_at"})
|
||||||
|
|
@ -221,7 +244,7 @@ func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.Ai
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
func (d *Do) GetTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||||
|
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gorm.io/gorm/utils"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handle struct {
|
type Handle struct {
|
||||||
|
|
@ -45,23 +46,24 @@ func NewHandle(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Handle) Recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) {
|
||||||
entitys.ResLog(requireData.Ch, "recognize_start", "准备意图识别")
|
entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别")
|
||||||
|
prompt, err := promptProcessor.CreatePrompt(ctx, rec)
|
||||||
//意图识别
|
//意图识别
|
||||||
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, requireData)
|
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{
|
||||||
|
Prompt: prompt,
|
||||||
|
Tools: rec.Tasks,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
entitys.ResLog(requireData.Ch, "recognize", recognizeMsg)
|
entitys.ResLog(rec.Ch, "recognize", recognizeMsg)
|
||||||
entitys.ResLog(requireData.Ch, "recognize_end", "意图识别结束")
|
entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束")
|
||||||
|
|
||||||
var match entitys.Match
|
|
||||||
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
|
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
|
||||||
err = errors.SysErr("数据结构错误:%v", err.Error())
|
err = errors.SysErr("数据结构错误:%v", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
requireData.Match = &match
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
package do
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg"
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PromptOption interface {
|
||||||
|
CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WithSys struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
|
||||||
|
var (
|
||||||
|
prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片
|
||||||
|
)
|
||||||
|
// 获取用户内容,如果出错则直接返回错误
|
||||||
|
content, err := f.getUserContent(ctx, rec)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 构建提示消息列表,包含系统提示、助手回复和用户内容
|
||||||
|
mes = append(prompt, api.Message{
|
||||||
|
Role: "system", // 系统角色
|
||||||
|
Content: rec.SystemPrompt, // 系统提示内容
|
||||||
|
}, api.Message{
|
||||||
|
Role: "assistant", // 助手角色
|
||||||
|
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容
|
||||||
|
}, api.Message{
|
||||||
|
Role: "user", // 用户角色
|
||||||
|
Content: content.String(), // 用户输入内容
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
|
||||||
|
var hasFile bool
|
||||||
|
if len(rec.UserContent.FileUrl) > 0 || rec.UserContent.File != nil {
|
||||||
|
hasFile = true
|
||||||
|
}
|
||||||
|
content.WriteString(rec.UserContent.Text)
|
||||||
|
if hasFile {
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rec.UserContent.Tag) > 0 {
|
||||||
|
content.WriteString("\n")
|
||||||
|
content.WriteString("### 工具必须使用:")
|
||||||
|
content.WriteString(rec.UserContent.Tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rec.ChatHis.Messages) > 0 {
|
||||||
|
content.WriteString("### 引用历史聊天记录:\n")
|
||||||
|
content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis))
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasFile {
|
||||||
|
content.WriteString("\n")
|
||||||
|
content.WriteString("### 文件内容:\n")
|
||||||
|
hand.WriteString(rec.UserContent.FileUrl, rec.UserContent.FileUrl)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,84 @@
|
||||||
|
package handle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg/l_request"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gabriel-vasile/mimetype"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HandleRecognizeFile 这里的目的是无论将什么类型的file都转为二进制格式
|
||||||
|
// 判断文件大小
|
||||||
|
// 判断文件类型
|
||||||
|
// 判断文件是否合法
|
||||||
|
func HandleRecognizeFile(file *entitys.RecognizeFile) {
|
||||||
|
//Todo 仲云
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 下载文件并返回二进制数据、MIME 类型
|
||||||
|
func downloadFile(fileUrl string) (fileBytes []byte, contentType string, err error) {
|
||||||
|
if len(fileUrl) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req := l_request.Request{
|
||||||
|
Method: "GET",
|
||||||
|
Url: fileUrl,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||||
|
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
res, err := req.Send()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var ex bool
|
||||||
|
if contentType, ex = res.Headers["Content-Type"]; !ex {
|
||||||
|
err = errors.New("Content-Type不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
err = fmt.Errorf("server returned non-200 status: %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
fileBytes = res.Content
|
||||||
|
|
||||||
|
return fileBytes, contentType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectFileType 判断文件类型
|
||||||
|
func detectFileType(file io.ReadSeeker, filename string) constants.FileType {
|
||||||
|
// 1. 读取文件头检测 MIME
|
||||||
|
buffer := make([]byte, 512)
|
||||||
|
n, _ := file.Read(buffer)
|
||||||
|
file.Seek(0, io.SeekStart) // 重置读取位置
|
||||||
|
|
||||||
|
detectedMIME := mimetype.Detect(buffer[:n]).String()
|
||||||
|
for fileType, items := range constants.FileTypeMappings {
|
||||||
|
for _, item := range items {
|
||||||
|
if !strings.HasPrefix(item, ".") && item == detectedMIME {
|
||||||
|
return fileType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 备用:通过扩展名检测
|
||||||
|
ext := strings.ToLower(filepath.Ext(filename))
|
||||||
|
for fileType, items := range constants.FileTypeMappings {
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.HasPrefix(item, ".") && item == ext {
|
||||||
|
return fileType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return constants.FileTypeUnknown
|
||||||
|
}
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
package handle
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai_scheduler/internal/config"
|
|
||||||
"ai_scheduler/internal/tools"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Handle struct {
|
|
||||||
toolManager *tools.Manager
|
|
||||||
conf *config.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHandle(
|
|
||||||
toolManager *tools.Manager,
|
|
||||||
conf *config.Config,
|
|
||||||
|
|
||||||
) *Handle {
|
|
||||||
return &Handle{
|
|
||||||
toolManager: toolManager,
|
|
||||||
conf: conf,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
package llm_service
|
package llm_service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/data/constants"
|
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"ai_scheduler/internal/entitys"
|
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -20,48 +18,3 @@ func buildSystemPrompt(prompt string) string {
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
|
||||||
for _, item := range his {
|
|
||||||
if len(chatHis.SessionId) == 0 {
|
|
||||||
chatHis.SessionId = item.SessionID
|
|
||||||
}
|
|
||||||
chatHis.Messages = append(chatHis.Messages, []entitys.HisMessage{
|
|
||||||
{
|
|
||||||
Role: constants.RoleUser,
|
|
||||||
Content: item.Ques,
|
|
||||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: constants.RoleAssistant,
|
|
||||||
Content: item.Ans,
|
|
||||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
|
||||||
},
|
|
||||||
}...)
|
|
||||||
}
|
|
||||||
chatHis.Context = entitys.HisContext{
|
|
||||||
UserLanguage: "zh-CN",
|
|
||||||
SystemMode: "technical_support",
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func BuildChatHisMessage(his []model.AiChatHi) (chatHis []entitys.HisMessage) {
|
|
||||||
for _, item := range his {
|
|
||||||
|
|
||||||
chatHis = append(chatHis, []entitys.HisMessage{
|
|
||||||
{
|
|
||||||
Role: constants.RoleUser,
|
|
||||||
Content: item.Ques,
|
|
||||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: constants.RoleAssistant,
|
|
||||||
Content: item.Ans,
|
|
||||||
Timestamp: item.CreateAt.Format(time.DateTime),
|
|
||||||
},
|
|
||||||
}...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,87 +1,76 @@
|
||||||
package llm_service
|
package llm_service
|
||||||
|
|
||||||
import (
|
//type LangChainService struct {
|
||||||
"ai_scheduler/internal/data/model"
|
// client *utils_langchain.UtilLangChain
|
||||||
"ai_scheduler/internal/entitys"
|
//}
|
||||||
"ai_scheduler/internal/pkg"
|
//
|
||||||
"ai_scheduler/internal/pkg/utils_langchain"
|
//func NewLangChainGenerate(
|
||||||
"context"
|
// client *utils_langchain.UtilLangChain,
|
||||||
"encoding/json"
|
//) *LangChainService {
|
||||||
|
//
|
||||||
"github.com/tmc/langchaingo/llms"
|
// return &LangChainService{
|
||||||
)
|
// client: client,
|
||||||
|
// }
|
||||||
type LangChainService struct {
|
//}
|
||||||
client *utils_langchain.UtilLangChain
|
//
|
||||||
}
|
//func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
|
||||||
|
// prompt := r.getPrompt(sysInfo, history, userInput, tasks)
|
||||||
func NewLangChainGenerate(
|
// AgentClient := r.client.Get()
|
||||||
client *utils_langchain.UtilLangChain,
|
// defer r.client.Put(AgentClient)
|
||||||
) *LangChainService {
|
// match, err := AgentClient.Llm.GenerateContent(
|
||||||
|
// ctx, // 使用可取消的上下文
|
||||||
return &LangChainService{
|
// prompt,
|
||||||
client: client,
|
// llms.WithJSONMode(),
|
||||||
}
|
// )
|
||||||
}
|
// msg = match.Choices[0].Content
|
||||||
|
// return
|
||||||
func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
|
//}
|
||||||
prompt := r.getPrompt(sysInfo, history, userInput, tasks)
|
//
|
||||||
AgentClient := r.client.Get()
|
//func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||||
defer r.client.Put(AgentClient)
|
// var (
|
||||||
match, err := AgentClient.Llm.GenerateContent(
|
// prompt = make([]llms.MessageContent, 0)
|
||||||
ctx, // 使用可取消的上下文
|
// )
|
||||||
prompt,
|
// prompt = append(prompt, llms.MessageContent{
|
||||||
llms.WithJSONMode(),
|
// Role: llms.ChatMessageTypeSystem,
|
||||||
)
|
// Parts: []llms.ContentPart{
|
||||||
msg = match.Choices[0].Content
|
// llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
|
||||||
return
|
// },
|
||||||
}
|
// }, llms.MessageContent{
|
||||||
|
// Role: llms.ChatMessageTypeTool,
|
||||||
func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
// Parts: []llms.ContentPart{
|
||||||
var (
|
// llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
|
||||||
prompt = make([]llms.MessageContent, 0)
|
// },
|
||||||
)
|
// }, llms.MessageContent{
|
||||||
prompt = append(prompt, llms.MessageContent{
|
// Role: llms.ChatMessageTypeTool,
|
||||||
Role: llms.ChatMessageTypeSystem,
|
// Parts: []llms.ContentPart{
|
||||||
Parts: []llms.ContentPart{
|
// llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||||
llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
|
// },
|
||||||
},
|
// }, llms.MessageContent{
|
||||||
}, llms.MessageContent{
|
// Role: llms.ChatMessageTypeHuman,
|
||||||
Role: llms.ChatMessageTypeTool,
|
// Parts: []llms.ContentPart{
|
||||||
Parts: []llms.ContentPart{
|
// llms.TextPart(reqInput),
|
||||||
llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
|
// },
|
||||||
},
|
// })
|
||||||
}, llms.MessageContent{
|
// return prompt
|
||||||
Role: llms.ChatMessageTypeTool,
|
//}
|
||||||
Parts: []llms.ContentPart{
|
//
|
||||||
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
//func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||||
},
|
// taskPrompt := make([]llms.Tool, 0)
|
||||||
}, llms.MessageContent{
|
// for _, task := range tasks {
|
||||||
Role: llms.ChatMessageTypeHuman,
|
// var taskConfig entitys.TaskConfig
|
||||||
Parts: []llms.ContentPart{
|
// err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||||
llms.TextPart(reqInput),
|
// if err != nil {
|
||||||
},
|
// continue
|
||||||
})
|
// }
|
||||||
return prompt
|
// taskPrompt = append(taskPrompt, llms.Tool{
|
||||||
}
|
// Type: "function",
|
||||||
|
// Function: &llms.FunctionDefinition{
|
||||||
func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
|
// Name: task.Index,
|
||||||
taskPrompt := make([]llms.Tool, 0)
|
// Description: task.Desc,
|
||||||
for _, task := range tasks {
|
// Parameters: taskConfig.Param,
|
||||||
var taskConfig entitys.TaskConfig
|
// },
|
||||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
// })
|
||||||
if err != nil {
|
//
|
||||||
continue
|
// }
|
||||||
}
|
// return taskPrompt
|
||||||
taskPrompt = append(taskPrompt, llms.Tool{
|
//}
|
||||||
Type: "function",
|
|
||||||
Function: &llms.FunctionDefinition{
|
|
||||||
Name: task.Index,
|
|
||||||
Description: task.Desc,
|
|
||||||
Parameters: taskConfig.Param,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
return taskPrompt
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -35,14 +35,14 @@ func NewOllamaGenerate(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
|
func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
|
||||||
prompt, err := r.getPrompt(ctx, requireData)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
|
toolDefinitions := r.registerToolsOllama(req.Tools)
|
||||||
|
|
||||||
match, err := r.client.ToolSelect(ctx, prompt, toolDefinitions)
|
match, err := r.client.ToolSelect(ctx, rec.Prompt, toolDefinitions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -64,86 +64,28 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity
|
||||||
|
|
||||||
msg = match.Message.Content
|
msg = match.Message.Content
|
||||||
return
|
return
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
|
func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
|
||||||
|
if imgByte == nil {
|
||||||
var (
|
|
||||||
prompt = make([]api.Message, 0)
|
|
||||||
)
|
|
||||||
content, err := r.getUserContent(ctx, requireData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
prompt = append(prompt, api.Message{
|
|
||||||
Role: "system",
|
|
||||||
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
|
|
||||||
}, api.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)),
|
|
||||||
}, api.Message{
|
|
||||||
Role: "user",
|
|
||||||
Content: content,
|
|
||||||
})
|
|
||||||
|
|
||||||
return prompt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) {
|
|
||||||
var content strings.Builder
|
|
||||||
content.WriteString(requireData.Req.Text)
|
|
||||||
if len(requireData.ImgByte) > 0 {
|
|
||||||
content.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(requireData.Req.Tags) > 0 {
|
|
||||||
content.WriteString("\n")
|
|
||||||
content.WriteString("### 工具必须使用:")
|
|
||||||
content.WriteString(requireData.Req.Tags)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(requireData.ImgByte) > 0 {
|
|
||||||
desc, err := r.RecognizeWithImg(ctx, requireData)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
content.WriteString("### 上传图片解析内容:\n")
|
|
||||||
content.WriteString(requireData.Req.Tags)
|
|
||||||
content.WriteString(desc.Response)
|
|
||||||
}
|
|
||||||
|
|
||||||
if requireData.Req.MarkHis > 0 {
|
|
||||||
var his model.AiChatHi
|
|
||||||
cond := builder.NewCond()
|
|
||||||
cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis})
|
|
||||||
err := r.chatHis.GetOneBySearchToStrut(&cond, &his)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
content.WriteString("### 引用历史聊天记录:\n")
|
|
||||||
content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his})))
|
|
||||||
}
|
|
||||||
return content.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
|
||||||
if requireData.ImgByte == nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
entitys.ResLog(ch, "recognize_img_start", "图片识别中...")
|
||||||
|
|
||||||
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
||||||
Model: r.config.Ollama.VlModel,
|
Model: r.config.Ollama.VlModel,
|
||||||
Stream: new(bool),
|
Stream: new(bool),
|
||||||
System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||||
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||||
Images: requireData.ImgByte,
|
Images: imgByte,
|
||||||
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
||||||
|
//Think: &api.ThinkValue{Value: false},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
entitys.ResLog(ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/biz/do"
|
"ai_scheduler/internal/biz/do"
|
||||||
"ai_scheduler/internal/biz/handle"
|
|
||||||
"ai_scheduler/internal/biz/llm_service"
|
"ai_scheduler/internal/biz/llm_service"
|
||||||
|
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
|
|
@ -12,10 +11,11 @@ var ProviderSetBiz = wire.NewSet(
|
||||||
NewAiRouterBiz,
|
NewAiRouterBiz,
|
||||||
NewSessionBiz,
|
NewSessionBiz,
|
||||||
NewChatHistoryBiz,
|
NewChatHistoryBiz,
|
||||||
llm_service.NewLangChainGenerate,
|
//llm_service.NewLangChainGenerate,
|
||||||
llm_service.NewOllamaGenerate,
|
llm_service.NewOllamaGenerate,
|
||||||
handle.NewHandle,
|
//handle.NewHandle,
|
||||||
do.NewDo,
|
do.NewDo,
|
||||||
do.NewHandle,
|
do.NewHandle,
|
||||||
NewTaskBiz,
|
NewTaskBiz,
|
||||||
|
NewDingTalkBotBiz,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package biz
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/biz/do"
|
"ai_scheduler/internal/biz/do"
|
||||||
"ai_scheduler/internal/gateway"
|
"ai_scheduler/internal/gateway"
|
||||||
|
"context"
|
||||||
|
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
|
||||||
|
|
@ -56,9 +57,15 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
|
||||||
log.Errorf("数据验证和收集失败: %s", err.Error())
|
log.Errorf("数据验证和收集失败: %s", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
//组装意图识别
|
||||||
|
rec, err := r.SetRec(ctx, requireData)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("组装意图识别失败: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
//意图识别
|
//意图识别
|
||||||
if err = r.handle.Recognize(ctx, requireData); err != nil {
|
requireData.Match, err = r.handle.Recognize(ctx, &rec, &do.WithSys{})
|
||||||
|
if err != nil {
|
||||||
log.Errorf("意图识别失败: %s", err.Error())
|
log.Errorf("意图识别失败: %s", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -70,3 +77,8 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) SetRec(ctx context.Context, requireData *entitys.RequireData) (match entitys.Recognize, err error) {
|
||||||
|
//TODO 叙平
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,16 +9,21 @@ import (
|
||||||
|
|
||||||
// Config 应用配置
|
// Config 应用配置
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
Ollama OllamaConfig `mapstructure:"ollama"`
|
Ollama OllamaConfig `mapstructure:"ollama"`
|
||||||
Sys SysConfig `mapstructure:"sys"`
|
Sys SysConfig `mapstructure:"sys"`
|
||||||
Tools ToolsConfig `mapstructure:"tools"`
|
Tools ToolsConfig `mapstructure:"tools"`
|
||||||
Logging LoggingConfig `mapstructure:"logging"`
|
Logging LoggingConfig `mapstructure:"logging"`
|
||||||
Redis Redis `mapstructure:"redis"`
|
Redis Redis `mapstructure:"redis"`
|
||||||
DB DB `mapstructure:"db"`
|
DB DB `mapstructure:"db"`
|
||||||
DefaultPrompt SysPrompt `mapstructure:"default_prompt"`
|
DefaultPrompt SysPrompt `mapstructure:"default_prompt"`
|
||||||
PermissionConfig PermissionConfig `mapstructure:"permissionConfig"`
|
PermissionConfig PermissionConfig `mapstructure:"permissionConfig"`
|
||||||
// LLM *LLM `mapstructure:"llm"`
|
DingTalkBots map[string]*DingTalkBot `mapstructure:"ding_talk_bots"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DingTalkBot struct {
|
||||||
|
ClientId string `mapstructure:"client_id"`
|
||||||
|
ClientSecret string `mapstructure:"client_secret"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SysPrompt struct {
|
type SysPrompt struct {
|
||||||
|
|
@ -36,10 +41,11 @@ type LLM struct {
|
||||||
|
|
||||||
// SysConfig 系统配置
|
// SysConfig 系统配置
|
||||||
type SysConfig struct {
|
type SysConfig struct {
|
||||||
SessionLen int `mapstructure:"session_len"`
|
SessionLen int `mapstructure:"session_len"`
|
||||||
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
||||||
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
||||||
LlmPoolLen int `mapstructure:"llm_pool_len"`
|
LlmPoolLen int `mapstructure:"llm_pool_len"`
|
||||||
|
HeartbeatInterval int `mapstructure:"heartbeat_interval"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig 服务器配置
|
// ServerConfig 服务器配置
|
||||||
|
|
|
||||||
|
|
@ -3,5 +3,25 @@ package constants
|
||||||
type BotTools string
|
type BotTools string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
BotToolsBugOptimizationSubmit = "bug_optimization_submit" // 系统的bug/优化建议
|
BotToolsBugOptimizationSubmit BotTools = "bug_optimization_submit" // 系统的bug/优化建议
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ChatStyle int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChatStyleNormal ChatStyle = 1 //正常
|
||||||
|
ChatStyleSerious ChatStyle = 2 //严肃
|
||||||
|
ChatStyleGentle ChatStyle = 3 //温柔
|
||||||
|
ChatStyleArrogance ChatStyle = 4 //傲慢
|
||||||
|
ChatStyleCute ChatStyle = 5 //可爱
|
||||||
|
ChatStyleAngry ChatStyle = 6 //愤怒
|
||||||
|
)
|
||||||
|
|
||||||
|
var ChatStyleMap = map[ChatStyle]string{
|
||||||
|
ChatStyleNormal: "正常",
|
||||||
|
ChatStyleSerious: "严肃",
|
||||||
|
ChatStyleGentle: "温柔",
|
||||||
|
ChatStyleArrogance: "傲慢",
|
||||||
|
ChatStyleCute: "可爱",
|
||||||
|
ChatStyleAngry: "愤怒",
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
package constants
|
||||||
|
|
||||||
|
type FileType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
FileTypeUnknown FileType = "unknown"
|
||||||
|
FileTypeImage FileType = "image"
|
||||||
|
//FileTypeVideo FileType = "video"
|
||||||
|
FileTypeExcel FileType = "excel"
|
||||||
|
FileTypeWord FileType = "word"
|
||||||
|
FileTypeTxt FileType = "txt"
|
||||||
|
FileTypePDF FileType = "pdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
var FileTypeMappings = map[FileType][]string{
|
||||||
|
FileTypeImage: {
|
||||||
|
"image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml",
|
||||||
|
".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg",
|
||||||
|
},
|
||||||
|
FileTypeExcel: {
|
||||||
|
"application/vnd.ms-excel",
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
|
".xls", ".xlsx",
|
||||||
|
},
|
||||||
|
FileTypeWord: {
|
||||||
|
"application/msword",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
".doc", ".docx",
|
||||||
|
},
|
||||||
|
FileTypePDF: {
|
||||||
|
"application/pdf",
|
||||||
|
".pdf",
|
||||||
|
},
|
||||||
|
FileTypeTxt: {
|
||||||
|
"text/plain",
|
||||||
|
".txt",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
package impl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
|
"ai_scheduler/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BotChatHisImpl struct {
|
||||||
|
dataTemp.DataTemp
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBotChatHisImpl(db *utils.Db) *TaskImpl {
|
||||||
|
return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiBotChatHi))}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||||
|
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||||
|
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||||
|
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const TableNameAiBotChatHi = "ai_bot_chat_his"
|
||||||
|
|
||||||
|
// AiBotChatHi mapped from table <ai_bot_chat_his>
|
||||||
|
type AiBotChatHi struct {
|
||||||
|
HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"`
|
||||||
|
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
|
||||||
|
Role string `gorm:"column:role;not null;comment:system系统输出,assistant助手输出,user用户输入" json:"role"` // system系统输出,assistant助手输出,user用户输入
|
||||||
|
Content string `gorm:"column:content;not null" json:"content"`
|
||||||
|
Files string `gorm:"column:files;not null" json:"files"`
|
||||||
|
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||||
|
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName AiBotChatHi's table name
|
||||||
|
func (*AiBotChatHi) TableName() string {
|
||||||
|
return TableNameAiBotChatHi
|
||||||
|
}
|
||||||
|
|
@ -1,7 +1,24 @@
|
||||||
package entitys
|
package entitys
|
||||||
|
|
||||||
type BotType int
|
import (
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
|
||||||
const (
|
"github.com/ollama/ollama/api"
|
||||||
BugAndQuesDingTalk BotType = iota + 1
|
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RequireDataDingTalkBot struct {
|
||||||
|
Session string
|
||||||
|
Key string
|
||||||
|
Sys model.AiSy
|
||||||
|
Histories []model.AiChatHi
|
||||||
|
SessionInfo model.AiSession
|
||||||
|
Tasks []model.AiTask
|
||||||
|
Match *Match
|
||||||
|
Req *chatbot.BotCallbackDataModel
|
||||||
|
Auth string
|
||||||
|
Ch chan Response
|
||||||
|
KnowledgeConf KnowledgeBaseRequest
|
||||||
|
ImgByte []api.ImageData
|
||||||
|
ImgUrls []string
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
package entitys
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ToolSelect struct {
|
||||||
|
Prompt []api.Message
|
||||||
|
Tools []RegistrationTask
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
package entitys
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Recognize struct {
|
||||||
|
SystemPrompt string
|
||||||
|
UserContent *RecognizeUserContent
|
||||||
|
ChatHis ChatHis
|
||||||
|
Tasks []RegistrationTask
|
||||||
|
Ch chan Response
|
||||||
|
}
|
||||||
|
|
||||||
|
type RegistrationTask struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
type RecognizeUserContent struct {
|
||||||
|
Text string
|
||||||
|
File *RecognizeFile
|
||||||
|
ActionCardUrl string
|
||||||
|
Tag string
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileData []byte
|
||||||
|
|
||||||
|
type RecognizeFile struct {
|
||||||
|
File []FileData // 文件数据(二进制格式)
|
||||||
|
FileUrl string // 文件下载链接
|
||||||
|
FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断)
|
||||||
|
}
|
||||||
|
|
@ -2,22 +2,25 @@ package entitys
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResponseType string
|
type ResponseType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ResponseJson ResponseType = "json"
|
ResponseJson ResponseType = "json"
|
||||||
ResponseLoading ResponseType = "loading"
|
ResponseLoading ResponseType = "loading"
|
||||||
ResponseEnd ResponseType = "end"
|
ResponseEnd ResponseType = "end"
|
||||||
ResponseStream ResponseType = "stream"
|
ResponseStream ResponseType = "stream"
|
||||||
ResponseText ResponseType = "txt"
|
ResponseText ResponseType = "txt"
|
||||||
ResponseImg ResponseType = "img"
|
ResponseImg ResponseType = "img"
|
||||||
ResponseFile ResponseType = "file"
|
ResponseFile ResponseType = "file"
|
||||||
ResponseErr ResponseType = "error"
|
ResponseErr ResponseType = "error"
|
||||||
ResponseLog ResponseType = "log"
|
ResponseLog ResponseType = "log"
|
||||||
ResponseAuth ResponseType = "auth"
|
ResponseAuth ResponseType = "auth"
|
||||||
|
ResponseMarkdown ResponseType = "markdown"
|
||||||
|
ResponseActionCard ResponseType = "actionCard"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ResLog(ch chan Response, index string, content string) {
|
func ResLog(ch chan Response, index string, content string) {
|
||||||
|
|
@ -45,6 +48,9 @@ func ResJson(ch chan Response, index string, content string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResEnd(ch chan Response, index string, content string) {
|
func ResEnd(ch chan Response, index string, content string) {
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
ch <- Response{
|
ch <- Response{
|
||||||
Index: index,
|
Index: index,
|
||||||
Content: content,
|
Content: content,
|
||||||
|
|
@ -53,6 +59,9 @@ func ResEnd(ch chan Response, index string, content string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResText(ch chan Response, index string, content string) {
|
func ResText(ch chan Response, index string, content string) {
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
ch <- Response{
|
ch <- Response{
|
||||||
Index: index,
|
Index: index,
|
||||||
Content: content,
|
Content: content,
|
||||||
|
|
@ -61,6 +70,9 @@ func ResText(ch chan Response, index string, content string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResLoading(ch chan Response, index string, content string) {
|
func ResLoading(ch chan Response, index string, content string) {
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
ch <- Response{
|
ch <- Response{
|
||||||
Index: index,
|
Index: index,
|
||||||
Content: content,
|
Content: content,
|
||||||
|
|
@ -68,6 +80,9 @@ func ResLoading(ch chan Response, index string, content string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func ResError(ch chan Response, index string, content string) {
|
func ResError(ch chan Response, index string, content string) {
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
ch <- Response{
|
ch <- Response{
|
||||||
Index: index,
|
Index: index,
|
||||||
Content: content,
|
Content: content,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ package entitys
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/data/constants"
|
"ai_scheduler/internal/data/constants"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
|
@ -150,7 +149,7 @@ type RequireData struct {
|
||||||
Histories []model.AiChatHi
|
Histories []model.AiChatHi
|
||||||
SessionInfo model.AiSession
|
SessionInfo model.AiSession
|
||||||
Tasks []model.AiTask
|
Tasks []model.AiTask
|
||||||
Match *Match
|
Match Match
|
||||||
Req *ChatSockRequest
|
Req *ChatSockRequest
|
||||||
Auth string
|
Auth string
|
||||||
Ch chan Response
|
Ch chan Response
|
||||||
|
|
|
||||||
|
|
@ -3,33 +3,44 @@ package gateway
|
||||||
import (
|
import (
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"encoding/hex"
|
"ai_scheduler/internal/pkg"
|
||||||
"fmt"
|
"context"
|
||||||
"github.com/gofiber/websocket/v2"
|
"encoding/binary"
|
||||||
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrConnClosed = errors.SysErr("连接不存在或已关闭")
|
ErrConnClosed = errors.SysErr("连接不存在或已关闭")
|
||||||
|
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
idBuf = make([]byte, 20)
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
id string // 客户端唯一ID
|
id string // 客户端唯一ID
|
||||||
conn *websocket.Conn // WebSocket 连接
|
conn *websocket.Conn // WebSocket 连接
|
||||||
session string // 会话ID
|
session string // 会话ID
|
||||||
key string // 应用密钥
|
key string // 应用密钥
|
||||||
auth string // 用户凭证token
|
auth string // 用户凭证token
|
||||||
codes []string // 用户权限code
|
codes []string // 用户权限code
|
||||||
sysInfo *model.AiSy // 系统信息
|
sysInfo *model.AiSy // 系统信息
|
||||||
tasks []model.AiTask // 任务列表
|
tasks []model.AiTask // 任务列表
|
||||||
sysCode string // 系统编码
|
sysCode string // 系统编码
|
||||||
|
Ctx context.Context
|
||||||
|
Cancel context.CancelFunc
|
||||||
|
LastActive time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(conn *websocket.Conn) *Client {
|
func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client {
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
id: generateClientID(),
|
id: generateClientID(),
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
Ctx: ctx,
|
||||||
|
Cancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -103,12 +114,16 @@ func (c *Client) SendFunc(msg []byte) error {
|
||||||
|
|
||||||
// 生成唯一的客户端ID
|
// 生成唯一的客户端ID
|
||||||
func generateClientID() string {
|
func generateClientID() string {
|
||||||
// 使用时间戳+随机数确保唯一性
|
// 1. 时间戳
|
||||||
timestamp := time.Now().UnixNano()
|
timestamp := time.Now().UnixNano()
|
||||||
randomBytes := make([]byte, 4)
|
binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
|
||||||
rand.Read(randomBytes)
|
|
||||||
randomStr := hex.EncodeToString(randomBytes)
|
// 2. 随机数(4字节)
|
||||||
return fmt.Sprintf("%d%s", timestamp, randomStr)
|
binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
|
||||||
|
|
||||||
|
// 3. 十六进制编码
|
||||||
|
n := pkg.HexEncode(idBuf[:12], idBuf[12:])
|
||||||
|
return string(idBuf[12 : 12+n])
|
||||||
}
|
}
|
||||||
|
|
||||||
// 连接数据验证和收集
|
// 连接数据验证和收集
|
||||||
|
|
@ -136,3 +151,22 @@ func (c *Client) DataAuth() (err error) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
|
||||||
|
ticker := time.NewTicker(timeoutSecond * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
//*2是防止丢包,连续丢包两次,再加5s网络延迟容错
|
||||||
|
if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错
|
||||||
|
log.Println("Heartbeat timeout", "id", c.id)
|
||||||
|
c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
|
||||||
|
c.conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-c.Ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Gateway struct {
|
type Gateway struct {
|
||||||
|
|
@ -20,14 +22,27 @@ func NewGateway() *Gateway {
|
||||||
|
|
||||||
func (g *Gateway) AddClient(c *Client) {
|
func (g *Gateway) AddClient(c *Client) {
|
||||||
g.mu.Lock()
|
g.mu.Lock()
|
||||||
defer g.mu.Unlock()
|
defer func() {
|
||||||
|
g.mu.Unlock()
|
||||||
|
//心跳开始计时
|
||||||
|
c.LastActive = time.Now()
|
||||||
|
log.Println("client connected:", c.GetID())
|
||||||
|
log.Println("客户端已连接")
|
||||||
|
}()
|
||||||
g.clients[c.GetID()] = c
|
g.clients[c.GetID()] = c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) RemoveClient(clientID string) {
|
func (g *Gateway) Cleanup(clientID string) {
|
||||||
g.mu.Lock()
|
g.mu.Lock()
|
||||||
defer g.mu.Unlock()
|
defer func() {
|
||||||
delete(g.clients, clientID)
|
if c, ex := g.clients[clientID]; ex {
|
||||||
|
delete(g.clients, clientID)
|
||||||
|
_ = c.conn.Close()
|
||||||
|
c.Cancel()
|
||||||
|
}
|
||||||
|
g.mu.Unlock()
|
||||||
|
log.Println("client disconnected:", clientID)
|
||||||
|
}()
|
||||||
for uid, list := range g.uidMap {
|
for uid, list := range g.uidMap {
|
||||||
newList := []string{}
|
newList := []string{}
|
||||||
for _, cid := range list {
|
for _, cid := range list {
|
||||||
|
|
@ -37,6 +52,7 @@ func (g *Gateway) RemoveClient(clientID string) {
|
||||||
}
|
}
|
||||||
g.uidMap[uid] = newList
|
g.uidMap[uid] = newList
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) SendToAll(msg []byte) {
|
func (g *Gateway) SendToAll(msg []byte) {
|
||||||
|
|
@ -63,6 +79,7 @@ func (g *Gateway) BindUid(clientID, uid string) error {
|
||||||
return errors.New("client not found")
|
return errors.New("client not found")
|
||||||
}
|
}
|
||||||
g.uidMap[uid] = append(g.uidMap[uid], clientID)
|
g.uidMap[uid] = append(g.uidMap[uid], clientID)
|
||||||
|
log.Printf("bind %s -> uid:%s\n", clientID, uid)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,3 +55,13 @@ func ValidateImageURL(rawURL string) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst,返回写入长度
|
||||||
|
func HexEncode(src, dst []byte) int {
|
||||||
|
const hextable = "0123456789abcdef"
|
||||||
|
for i := 0; i < len(src); i++ {
|
||||||
|
dst[i*2] = hextable[src[i]>>4]
|
||||||
|
dst[i*2+1] = hextable[src[i]&0xf]
|
||||||
|
}
|
||||||
|
return len(src) * 2
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/services"
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/log"
|
||||||
|
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||||
|
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DingBotServiceInterface interface {
|
||||||
|
GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string)
|
||||||
|
OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DingTalkBotServer struct {
|
||||||
|
Clients []*client.StreamClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDingTalkBotServer(
|
||||||
|
cfg *config.Config,
|
||||||
|
services []DingBotServiceInterface,
|
||||||
|
) *DingTalkBotServer {
|
||||||
|
clients := make([]*client.StreamClient, 0)
|
||||||
|
for _, service := range services {
|
||||||
|
serviceConf, index := service.GetServiceCfg(cfg.DingTalkBots)
|
||||||
|
if serviceConf == nil {
|
||||||
|
log.Info("未找到%s配置", index)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cli := DingBotServerInit(serviceConf.ClientId, serviceConf.ClientSecret, service)
|
||||||
|
if cli == nil {
|
||||||
|
log.Info("%s客户端初始失败", index)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
clients = append(clients, cli)
|
||||||
|
}
|
||||||
|
return &DingTalkBotServer{
|
||||||
|
Clients: clients,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ProvideAllDingBotServices(
|
||||||
|
dingBotSvc *services.DingBotService,
|
||||||
|
) []DingBotServiceInterface {
|
||||||
|
return []DingBotServiceInterface{dingBotSvc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingTalkBotServer) Run(ctx context.Context) {
|
||||||
|
for name, cli := range d.Clients {
|
||||||
|
err := cli.Start(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Info("%s启动失败", name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Info("%s启动成功", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) {
|
||||||
|
cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret)))
|
||||||
|
cli.RegisterChatBotCallbackRouter(service.OnChatBotMessageReceived)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,12 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import "github.com/google/wire"
|
import (
|
||||||
|
"github.com/google/wire"
|
||||||
|
)
|
||||||
|
|
||||||
var ProviderSetServer = wire.NewSet(NewServers, NewHTTPServer)
|
var ProviderSetServer = wire.NewSet(
|
||||||
|
NewServers,
|
||||||
|
NewHTTPServer,
|
||||||
|
ProvideAllDingBotServices,
|
||||||
|
NewDingTalkBotServer,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -82,10 +82,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
||||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||||
ws := app.Group("ws/v1/")
|
ws := app.Group("ws/v1/")
|
||||||
// WebSocket 路由配置
|
// WebSocket 路由配置
|
||||||
ws.Get("/chat", websocket.New(func(c *websocket.Conn) {
|
ws.Get("/chat", websocket.New(chatService.Chat, websocket.Config{
|
||||||
// 可以在这里添加握手前的中间件逻辑(如头校验)
|
|
||||||
chatService.Chat(c) // 调用实际的 Chat 处理函数
|
|
||||||
}, websocket.Config{
|
|
||||||
// 可选配置:跨域检查、最大负载大小等
|
// 可选配置:跨域检查、最大负载大小等
|
||||||
HandshakeTimeout: 10 * time.Second,
|
HandshakeTimeout: 10 * time.Second,
|
||||||
//Subprotocols: []string{"json", "msgpack"},
|
//Subprotocols: []string{"json", "msgpack"},
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,27 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import "github.com/gofiber/fiber/v2"
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
type Servers struct {
|
type Servers struct {
|
||||||
HttpServer *fiber.App
|
cfg *config.Config
|
||||||
|
HttpServer *fiber.App
|
||||||
|
DingBotServer *DingTalkBotServer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServers(fiber *fiber.App) *Servers {
|
func NewServers(cfg *config.Config, fiber *fiber.App, DingBotServer *DingTalkBotServer) *Servers {
|
||||||
return &Servers{
|
return &Servers{
|
||||||
HttpServer: fiber,
|
HttpServer: fiber,
|
||||||
|
cfg: cfg,
|
||||||
|
DingBotServer: DingBotServer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//func DingBotServerInit(clientId string, clientSecret string, cfg *config.Config, handler *do.Handle, do *do.Do) (cli *client.StreamClient) {
|
||||||
|
// cli = client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret)))
|
||||||
|
// cli.RegisterChatBotCallbackRouter(services.NewDingBotService(cfg, handler, do).OnChatBotMessageReceived)
|
||||||
|
// return
|
||||||
|
//}
|
||||||
|
|
|
||||||
|
|
@ -212,7 +212,7 @@ func (s *CallbackService) sendStreamLog(sessionID string, content string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
streamLog := entitys.Response{
|
streamLog := entitys.Response{
|
||||||
Index: constants.BotToolsBugOptimizationSubmit,
|
Index: string(constants.BotToolsBugOptimizationSubmit),
|
||||||
Content: content,
|
Content: content,
|
||||||
Type: entitys.ResponseLog,
|
Type: entitys.ResponseLog,
|
||||||
}
|
}
|
||||||
|
|
@ -227,7 +227,7 @@ func (s *CallbackService) sendStreamTxt(sessionID string, content string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
streamLog := entitys.Response{
|
streamLog := entitys.Response{
|
||||||
Index: constants.BotToolsBugOptimizationSubmit,
|
Index: string(constants.BotToolsBugOptimizationSubmit),
|
||||||
Content: content,
|
Content: content,
|
||||||
Type: entitys.ResponseText,
|
Type: entitys.ResponseText,
|
||||||
}
|
}
|
||||||
|
|
@ -242,7 +242,7 @@ func (s *CallbackService) sendStreamLoading(sessionID string, content string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
streamLog := entitys.Response{
|
streamLog := entitys.Response{
|
||||||
Index: constants.BotToolsBugOptimizationSubmit,
|
Index: string(constants.BotToolsBugOptimizationSubmit),
|
||||||
Content: content,
|
Content: content,
|
||||||
Type: entitys.ResponseLoading,
|
Type: entitys.ResponseLoading,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,11 @@ import (
|
||||||
"ai_scheduler/internal/data/constants"
|
"ai_scheduler/internal/data/constants"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/gateway"
|
"ai_scheduler/internal/gateway"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
|
|
@ -38,20 +40,6 @@ func NewChatService(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolCallResponse 工具调用响应
|
|
||||||
type ToolCallResponse struct {
|
|
||||||
ID string `json:"id" example:"call_1"`
|
|
||||||
Type string `json:"type" example:"function"`
|
|
||||||
Function FunctionCallResponse `json:"function"`
|
|
||||||
Result interface{} `json:"result,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FunctionCallResponse 函数调用响应
|
|
||||||
type FunctionCallResponse struct {
|
|
||||||
Name string `json:"name" example:"get_weather"`
|
|
||||||
Arguments interface{} `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
||||||
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -63,15 +51,18 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
||||||
// Chat 处理WebSocket聊天连接
|
// Chat 处理WebSocket聊天连接
|
||||||
// 这是WebSocket处理的主入口函数
|
// 这是WebSocket处理的主入口函数
|
||||||
func (h *ChatService) Chat(c *websocket.Conn) {
|
func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
// 创建新的客户端实例
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
h.mu.Lock()
|
|
||||||
client := gateway.NewClient(c)
|
|
||||||
h.mu.Unlock()
|
|
||||||
|
|
||||||
|
// 创建新的客户端实例
|
||||||
|
client := gateway.NewClient(c, ctx, cancel)
|
||||||
|
// 心跳检测
|
||||||
|
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
|
||||||
// 将客户端添加到网关管理
|
// 将客户端添加到网关管理
|
||||||
h.Gw.AddClient(client)
|
h.Gw.AddClient(client)
|
||||||
log.Println("client connected:", client.GetID())
|
// 确保在函数返回时移除客户端并关闭连接
|
||||||
log.Println("客户端已连接")
|
defer func() {
|
||||||
|
h.Gw.Cleanup(client.GetID())
|
||||||
|
}()
|
||||||
|
|
||||||
// 绑定会话ID
|
// 绑定会话ID
|
||||||
uid := c.Query("x-session")
|
uid := c.Query("x-session")
|
||||||
|
|
@ -79,7 +70,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
|
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
|
||||||
log.Println("绑定UID错误:", err)
|
log.Println("绑定UID错误:", err)
|
||||||
}
|
}
|
||||||
log.Printf("bind %s -> uid:%s\n", client.GetID(), uid)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证并收集连接数据,后续对话中会使用
|
// 验证并收集连接数据,后续对话中会使用
|
||||||
|
|
@ -89,13 +79,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 确保在函数返回时移除客户端并关闭连接
|
|
||||||
defer func() {
|
|
||||||
h.Gw.RemoveClient(client.GetID())
|
|
||||||
_ = c.Close()
|
|
||||||
log.Println("client disconnected:", client.GetID())
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 循环读取客户端消息
|
// 循环读取客户端消息
|
||||||
for {
|
for {
|
||||||
// 读取消息
|
// 读取消息
|
||||||
|
|
@ -104,7 +87,14 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
log.Println("读取错误:", err)
|
log.Println("读取错误:", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
//if string(message) == `{"type":"ping"}` {
|
||||||
|
// client.LastActive = time.Now()
|
||||||
|
// if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil {
|
||||||
|
// log.Println("Heartbeat response failed", "id", client.GetID(), "err", err)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// continue
|
||||||
|
//}
|
||||||
// 处理消息
|
// 处理消息
|
||||||
msg, chatType := h.handleMessageToString(c, messageType, message)
|
msg, chatType := h.handleMessageToString(c, messageType, message)
|
||||||
if chatType == constants.ConnStatusClosed {
|
if chatType == constants.ConnStatusClosed {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,135 @@
|
||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/biz"
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DingBotService struct {
|
||||||
|
config *config.Config
|
||||||
|
replier *chatbot.ChatbotReplier
|
||||||
|
env string
|
||||||
|
DingTalkBotBiz *biz.DingTalkBotBiz
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDingBotService(config *config.Config, DingTalkBotBiz *biz.DingTalkBotBiz) *DingBotService {
|
||||||
|
return &DingBotService{config: config, replier: chatbot.NewChatbotReplier(), env: "public", DingTalkBotBiz: DingTalkBotBiz}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingBotService) GetServiceCfg(cfg map[string]*config.DingTalkBot) (*config.DingTalkBot, string) {
|
||||||
|
return cfg[d.env], d.env
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) (content []byte, err error) {
|
||||||
|
requireData, err := d.DingTalkBotBiz.InitRequire(ctx, data)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer close(requireData.Ch)
|
||||||
|
|
||||||
|
if recognizeErr := d.DingTalkBotBiz.Recognize(ctx, requireData); recognizeErr != nil {
|
||||||
|
requireData.Ch <- entitys.Response{
|
||||||
|
Type: entitys.ResponseEnd,
|
||||||
|
Content: fmt.Sprintf("处理消息时出错: %v", recognizeErr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//向下传递
|
||||||
|
if err = d.handle.HandleMatch(ctx, nil, requireData); err != nil {
|
||||||
|
requireData.Ch <- entitys.Response{
|
||||||
|
Type: entitys.ResponseEnd,
|
||||||
|
Content: fmt.Sprintf("匹配失败: %v", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case resp, ok := <-requireData.Ch:
|
||||||
|
if !ok {
|
||||||
|
return []byte("success"), nil // 通道关闭,处理完成
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Type == entitys.ResponseLog {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := d.handleRes(ctx, data, resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("回复失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingBotService) handleRes(ctx context.Context, data *chatbot.BotCallbackDataModel, resp entitys.Response) error {
|
||||||
|
switch resp.Type {
|
||||||
|
case entitys.ResponseText:
|
||||||
|
return d.replyText(ctx, data.SessionWebhook, resp.Content)
|
||||||
|
case entitys.ResponseStream:
|
||||||
|
return d.replySteam(ctx, data.SessionWebhook, resp.Content)
|
||||||
|
case entitys.ResponseImg:
|
||||||
|
return d.replyImg(ctx, data.SessionWebhook, resp.Content)
|
||||||
|
case entitys.ResponseFile:
|
||||||
|
return d.replyFile(ctx, data.SessionWebhook, resp.Content)
|
||||||
|
case entitys.ResponseMarkdown:
|
||||||
|
return d.replyMarkdown(ctx, data.SessionWebhook, resp.Content)
|
||||||
|
case entitys.ResponseActionCard:
|
||||||
|
return d.replyActionCard(ctx, data.SessionWebhook, resp.Content)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingBotService) replyText(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||||
|
msg := content
|
||||||
|
if len(arg) > 0 {
|
||||||
|
msg = fmt.Sprintf(content, arg)
|
||||||
|
}
|
||||||
|
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingBotService) replySteam(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||||
|
msg := content
|
||||||
|
if len(arg) > 0 {
|
||||||
|
msg = fmt.Sprintf(content, arg)
|
||||||
|
}
|
||||||
|
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingBotService) replyImg(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||||
|
msg := content
|
||||||
|
if len(arg) > 0 {
|
||||||
|
msg = fmt.Sprintf(content, arg)
|
||||||
|
}
|
||||||
|
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingBotService) replyFile(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||||
|
msg := content
|
||||||
|
if len(arg) > 0 {
|
||||||
|
msg = fmt.Sprintf(content, arg)
|
||||||
|
}
|
||||||
|
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingBotService) replyMarkdown(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||||
|
msg := content
|
||||||
|
if len(arg) > 0 {
|
||||||
|
msg = fmt.Sprintf(content, arg)
|
||||||
|
}
|
||||||
|
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DingBotService) replyActionCard(ctx context.Context, SessionWebhook string, content string, arg ...string) error {
|
||||||
|
msg := content
|
||||||
|
if len(arg) > 0 {
|
||||||
|
msg = fmt.Sprintf(content, arg)
|
||||||
|
}
|
||||||
|
return d.replier.SimpleReplyText(ctx, SessionWebhook, []byte(msg))
|
||||||
|
}
|
||||||
|
|
@ -6,4 +6,11 @@ import (
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService)
|
var ProviderSetServices = wire.NewSet(
|
||||||
|
NewChatService,
|
||||||
|
NewSessionService,
|
||||||
|
gateway.NewGateway,
|
||||||
|
NewTaskService,
|
||||||
|
NewCallbackService,
|
||||||
|
NewDingBotService,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,22 +2,18 @@ package tools_bot
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/data/constants"
|
|
||||||
errors "ai_scheduler/internal/data/error"
|
|
||||||
"ai_scheduler/internal/data/impl"
|
"ai_scheduler/internal/data/impl"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BotTool struct {
|
type BotTool struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
llm *utils_ollama.Client
|
llm *utils_ollama.Client
|
||||||
sessionImpl *impl.SessionImpl
|
sessionImpl *impl.SessionImpl
|
||||||
taskMap map[string]string // task_id -> session_id
|
taskMap map[string]string
|
||||||
// zltxOrderAfterSaleTool tools.ZltxOrderAfterSaleTool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBotTool 创建直连天下订单详情工具
|
// NewBotTool 创建直连天下订单详情工具
|
||||||
|
|
@ -27,12 +23,5 @@ func NewBotTool(config *config.Config, llm *utils_ollama.Client, sessionImpl *im
|
||||||
|
|
||||||
// Execute 执行直连天下订单详情查询
|
// Execute 执行直连天下订单详情查询
|
||||||
func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) {
|
func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *entitys.RequireData) (err error) {
|
||||||
switch toolName {
|
|
||||||
case constants.BotToolsBugOptimizationSubmit:
|
|
||||||
err = w.BugOptimizationSubmit(ctx, requireData)
|
|
||||||
default:
|
|
||||||
log.Errorf("未知的工具类型:%s", toolName)
|
|
||||||
err = errors.ParamErr("未知的工具类型:%s", toolName)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue