Compare commits
3 Commits
a43c19eb48
...
303dd39cb3
Author | SHA1 | Date |
---|---|---|
|
303dd39cb3 | |
|
c91e5519e9 | |
|
99c397aabf |
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
||||||
configPath := flag.String("config", "config.yaml", "Path to configuration file")
|
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
bc, err := config.LoadConfig(*configPath)
|
bc, err := config.LoadConfig(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -5,7 +5,7 @@ server:
|
||||||
|
|
||||||
ollama:
|
ollama:
|
||||||
base_url: "http://localhost:11434"
|
base_url: "http://localhost:11434"
|
||||||
model: "deepseek-r1:8b"
|
model: "qwen3:8b"
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
format: "json"
|
format: "json"
|
||||||
|
@ -13,7 +13,7 @@ ollama:
|
||||||
sys:
|
sys:
|
||||||
session_len: 3
|
session_len: 3
|
||||||
|
|
||||||
Redis:
|
redis:
|
||||||
host: 47.97.27.195:6379
|
host: 47.97.27.195:6379
|
||||||
type: node
|
type: node
|
||||||
pass: lansexiongdi@666
|
pass: lansexiongdi@666
|
||||||
|
@ -23,6 +23,6 @@ Redis:
|
||||||
maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭
|
maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭
|
||||||
tls: 30
|
tls: 30
|
||||||
db:
|
db:
|
||||||
DB:
|
db:
|
||||||
driver: mysql
|
driver: mysql
|
||||||
source: transfer:Lsxd@34234QW@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/transfer?charset=utf8mb4&parseTime=true&
|
source: root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai
|
||||||
|
|
5
go.mod
5
go.mod
|
@ -5,12 +5,12 @@ go 1.24.0
|
||||||
toolchain go1.24.7
|
toolchain go1.24.7
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
gitea.cdlsxd.cn/self-tools/l_request v1.0.8
|
||||||
github.com/emirpasic/gods v1.18.1
|
github.com/emirpasic/gods v1.18.1
|
||||||
github.com/go-kratos/kratos/v2 v2.9.1
|
github.com/go-kratos/kratos/v2 v2.9.1
|
||||||
github.com/gofiber/fiber/v2 v2.52.9
|
github.com/gofiber/fiber/v2 v2.52.9
|
||||||
github.com/gofiber/websocket/v2 v2.2.1
|
github.com/gofiber/websocket/v2 v2.2.1
|
||||||
github.com/google/wire v0.7.0
|
github.com/google/wire v0.7.0
|
||||||
github.com/ollama/ollama v0.11.10
|
|
||||||
github.com/redis/go-redis/v9 v9.14.0
|
github.com/redis/go-redis/v9 v9.14.0
|
||||||
github.com/spf13/viper v1.17.0
|
github.com/spf13/viper v1.17.0
|
||||||
github.com/tmc/langchaingo v0.1.13
|
github.com/tmc/langchaingo v0.1.13
|
||||||
|
@ -31,6 +31,7 @@ require (
|
||||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||||
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
@ -58,8 +59,8 @@ require (
|
||||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||||
go.uber.org/atomic v1.9.0 // indirect
|
go.uber.org/atomic v1.9.0 // indirect
|
||||||
go.uber.org/multierr v1.9.0 // indirect
|
go.uber.org/multierr v1.9.0 // indirect
|
||||||
golang.org/x/crypto v0.36.0 // indirect
|
|
||||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||||
|
golang.org/x/net v0.38.0 // indirect
|
||||||
golang.org/x/sys v0.31.0 // indirect
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
golang.org/x/text v0.23.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect
|
||||||
|
|
8
go.sum
8
go.sum
|
@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
|
||||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
|
gitea.cdlsxd.cn/self-tools/l_request v1.0.8 h1:FaKRql9mCVcSoaGqPeBOAruZ52slzRngQ6VRTYKNSsA=
|
||||||
|
gitea.cdlsxd.cn/self-tools/l_request v1.0.8/go.mod h1:Qf4hVXm2Eu5vOvwXk8D7U0q/aekMCkZ4Fg9wnRKlasQ=
|
||||||
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s=
|
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s=
|
||||||
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU=
|
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
|
@ -185,8 +187,6 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T
|
||||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||||
github.com/ollama/ollama v0.11.10 h1:J9zaoTPwIXOrYXCRAqI7rV4cJ+FOMuQc/vBqQ5GIdWg=
|
|
||||||
github.com/ollama/ollama v0.11.10/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
|
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
@ -265,8 +265,6 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||||
|
@ -398,8 +396,6 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
|
||||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
|
||||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package biz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/impl"
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatHistoryBiz struct {
|
||||||
|
chatRepo *impl.ChatImpl
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
|
||||||
|
s := &ChatHistoryBiz{
|
||||||
|
chatRepo: chatRepo,
|
||||||
|
}
|
||||||
|
go s.AsyncProcess(context.Background())
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error {
|
||||||
|
chat := model.AiChatHi{
|
||||||
|
SessionID: sessionID,
|
||||||
|
Role: role,
|
||||||
|
Content: content,
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.chatRepo.Create(&chat)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加会话历史
|
||||||
|
func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
||||||
|
return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步添加会话历史
|
||||||
|
func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
|
||||||
|
s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
|
||||||
|
SessionID: chat.SessionID,
|
||||||
|
Role: chat.Role.String(),
|
||||||
|
Content: chat.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步处理会话历史
|
||||||
|
func (s *ChatHistoryBiz) AsyncProcess(ctx context.Context) {
|
||||||
|
s.chatRepo.AsyncProcess(ctx)
|
||||||
|
}
|
|
@ -2,4 +2,4 @@ package biz
|
||||||
|
|
||||||
import "github.com/google/wire"
|
import "github.com/google/wire"
|
||||||
|
|
||||||
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz)
|
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz)
|
||||||
|
|
|
@ -6,14 +6,15 @@ import (
|
||||||
"ai_scheduler/internal/data/impl"
|
"ai_scheduler/internal/data/impl"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg"
|
||||||
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
"ai_scheduler/internal/tools"
|
"ai_scheduler/internal/tools"
|
||||||
"ai_scheduler/tmpl/dataTemp"
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
"github.com/tmc/langchaingo/llms/ollama"
|
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
|
||||||
|
"github.com/tmc/langchaingo/llms"
|
||||||
|
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
|
@ -21,23 +22,36 @@ import (
|
||||||
|
|
||||||
// AiRouterService 智能路由服务
|
// AiRouterService 智能路由服务
|
||||||
type AiRouterService struct {
|
type AiRouterService struct {
|
||||||
aiClient entitys.AIClient
|
//aiClient entitys.AIClient
|
||||||
toolManager *tools.Manager
|
toolManager *tools.Manager
|
||||||
sessionImpl *impl.SessionImpl
|
sessionImpl *impl.SessionImpl
|
||||||
sysImpl *impl.SysImpl
|
sysImpl *impl.SysImpl
|
||||||
taskImpl *impl.TaskImpl
|
taskImpl *impl.TaskImpl
|
||||||
|
hisImpl *impl.ChatImpl
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
|
utilAgent *utils_ollama.UtilOllama
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouterService 创建路由服务
|
// NewRouterService 创建路由服务
|
||||||
func NewAiRouterBiz(aiClient entitys.AIClient, toolManager *tools.Manager, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, conf *config.Config) entitys.RouterService {
|
func NewAiRouterBiz(
|
||||||
|
//aiClient entitys.AIClient,
|
||||||
|
toolManager *tools.Manager,
|
||||||
|
sessionImpl *impl.SessionImpl,
|
||||||
|
sysImpl *impl.SysImpl,
|
||||||
|
taskImpl *impl.TaskImpl,
|
||||||
|
hisImpl *impl.ChatImpl,
|
||||||
|
conf *config.Config,
|
||||||
|
utilAgent *utils_ollama.UtilOllama,
|
||||||
|
) entitys.RouterService {
|
||||||
return &AiRouterService{
|
return &AiRouterService{
|
||||||
aiClient: aiClient,
|
//aiClient: aiClient,
|
||||||
toolManager: toolManager,
|
toolManager: toolManager,
|
||||||
sessionImpl: sessionImpl,
|
sessionImpl: sessionImpl,
|
||||||
conf: conf,
|
conf: conf,
|
||||||
sysImpl: sysImpl,
|
sysImpl: sysImpl,
|
||||||
|
hisImpl: hisImpl,
|
||||||
taskImpl: taskImpl,
|
taskImpl: taskImpl,
|
||||||
|
utilAgent: utilAgent,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,48 +64,42 @@ func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (
|
||||||
// Route 执行智能路由
|
// Route 执行智能路由
|
||||||
func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
|
func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
|
||||||
|
|
||||||
session := c.Headers("x-session", "")
|
session := c.Headers("X-Session", "")
|
||||||
if len(session) == 0 {
|
if len(session) == 0 {
|
||||||
return errors.SessionNotFound
|
return errors.SessionNotFound
|
||||||
}
|
}
|
||||||
auth := c.Headers("x-authorization", "")
|
auth := c.Headers("X-Authorization", "")
|
||||||
|
|
||||||
if len(auth) == 0 {
|
if len(auth) == 0 {
|
||||||
return errors.AuthNotFound
|
return errors.AuthNotFound
|
||||||
}
|
}
|
||||||
key := c.Headers("x-app-key", "")
|
key := c.Headers("X-App-Key", "")
|
||||||
if len(key) == 0 {
|
if len(key) == 0 {
|
||||||
return errors.KeyNotFound
|
return errors.KeyNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
sysInfo, err := r.getSysInfo(session)
|
sysInfo, err := r.getSysInfo(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.SysNotFound
|
return errors.SysNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
history, err := r.getSessionHis(session)
|
history, err := r.getSessionChatHis(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
|
|
||||||
taskPrompt, err := r.getTasks(sysInfo.SysID)
|
task, err := r.getTasks(sysInfo.SysID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
toolDefinitions := r.registerTools(task)
|
||||||
messages = make([]entitys.Message, 0)
|
prompt := r.getPrompt(sysInfo, history, req.Text)
|
||||||
|
msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
|
||||||
|
llms.WithTools(toolDefinitions),
|
||||||
|
llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
|
||||||
|
llms.WithJSONMode(),
|
||||||
)
|
)
|
||||||
messages = append(messages, entitys.Message{}, entitys.Message{
|
c.WriteMessage(1, []byte(msg))
|
||||||
Role: "system",
|
|
||||||
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
|
||||||
}, entitys.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: r.buildIntentPrompt(history, task),
|
|
||||||
}, entitys.Message{
|
|
||||||
Role: "user",
|
|
||||||
Content: req.Text,
|
|
||||||
})
|
|
||||||
// 构建消息
|
// 构建消息
|
||||||
//messages := []entitys.Message{
|
//messages := []entitys.Message{
|
||||||
// {
|
// {
|
||||||
|
@ -174,14 +182,13 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterService) getSessionHis(sessionId string) (his []model.AiSession, err error) {
|
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||||||
|
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
cond = cond.And(builder.Eq{"session_id": sessionId})
|
cond = cond.And(builder.Eq{"session_id": sessionId})
|
||||||
cond = cond.And(builder.IsNull{"delete_at"})
|
|
||||||
cond = cond.And(builder.Eq{"status": 1})
|
|
||||||
|
|
||||||
_, err = r.sessionImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his)
|
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,22 +201,50 @@ func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterService) getTasks(sysId int32) (taskPrompt []llms.FunctionDefinition, err error) {
|
func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||||
var tasks []model.AiTask
|
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||||||
cond = cond.And(builder.IsNull{"delete_at"})
|
cond = cond.And(builder.IsNull{"delete_at"})
|
||||||
cond = cond.And(builder.Eq{"status": 1})
|
cond = cond.And(builder.Eq{"status": 1})
|
||||||
err = r.taskImpl.GetOneBySearchToStrut(&cond, &tasks)
|
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks)
|
||||||
taskPrompt = make([]llms.FunctionDefinition, len(tasks))
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||||
|
taskPrompt := make([]llms.Tool, len(tasks))
|
||||||
for k, task := range tasks {
|
for k, task := range tasks {
|
||||||
taskPrompt[k] = llms.FunctionDefinition{
|
var taskConfig entitys.TaskConfig
|
||||||
|
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
taskPrompt[k].Type = "function"
|
||||||
|
taskPrompt[k].Function = &llms.FunctionDefinition{
|
||||||
Name: task.Name,
|
Name: task.Name,
|
||||||
Description: task.Desc,
|
Description: task.Desc,
|
||||||
Parameters: task.Parameters,
|
Parameters: taskConfig.Param,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return taskPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
|
||||||
|
var (
|
||||||
|
prompt = make([]entitys.Message, 0)
|
||||||
|
)
|
||||||
|
prompt = append(prompt, entitys.Message{}, entitys.Message{
|
||||||
|
Role: "system",
|
||||||
|
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
||||||
|
}, entitys.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
||||||
|
}, entitys.Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: reqInput,
|
||||||
|
})
|
||||||
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildSystemPrompt 构建系统提示词
|
// buildSystemPrompt 构建系统提示词
|
||||||
|
@ -221,28 +256,22 @@ func (r *AiRouterService) buildSystemPrompt(prompt string) string {
|
||||||
return prompt
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterService) buildIntentPrompt(his []model.AiSession, task []model.AiTask) string {
|
func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
||||||
prompt := `##任务
|
for _, item := range his {
|
||||||
分析用户输入,判断用户的意图类型,没有使用Markdown格式的json格式回复
|
if len(chatHis.SessionId) == 0 {
|
||||||
##意图类型
|
chatHis.SessionId = item.SessionID
|
||||||
1. product_diagnosis - 商品诊断:用户想要查询、诊断或了解商品相关信息
|
}
|
||||||
2. order_diagnosis - 订单诊断:用户想要查询、诊断或了解订单相关信息
|
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
|
||||||
3. knowledge_qa - 知识问答:用户想要进行一般性问答或获取知识信息
|
Role: item.Role,
|
||||||
##判断规则
|
Content: item.Content,
|
||||||
1.当用户意图不够清晰且不匹配 knowledge_qa 以外意图时,使用knowledge_qa
|
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
|
||||||
2.当用户意图非常不清晰时使用 unknown
|
})
|
||||||
##格式要求
|
}
|
||||||
1.返回以下格式的JSON:
|
chatHis.Context = entitys.HisContext{
|
||||||
{ "intent": "product_diagnosis" | "order_diagnosis" | "knowledge_qa" | "unknown", "confidence": 0.0-1.0,"reasoning": "判断理由"}
|
UserLanguage: "zh-CN",
|
||||||
2.严格返回字符串格式,禁用markdown格式返回
|
SystemMode: "technical_support",
|
||||||
3.只返回json字符串,不包含任何其他解释性文字
|
}
|
||||||
## 用户当前的问题是:
|
return
|
||||||
{user_input}
|
|
||||||
`
|
|
||||||
|
|
||||||
prompt = strings.ReplaceAll(prompt, "{user_input}", userInput)
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractIntent 从AI响应中提取意图
|
// extractIntent 从AI响应中提取意图
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package biz
|
package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/constants"
|
||||||
"ai_scheduler/internal/data/impl"
|
"ai_scheduler/internal/data/impl"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
@ -15,6 +16,7 @@ import (
|
||||||
type SessionBiz struct {
|
type SessionBiz struct {
|
||||||
sessionRepo *impl.SessionImpl
|
sessionRepo *impl.SessionImpl
|
||||||
sysRepo *impl.SysImpl
|
sysRepo *impl.SysImpl
|
||||||
|
chatRepo *impl.ChatImpl
|
||||||
|
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
}
|
}
|
||||||
|
@ -28,14 +30,19 @@ func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个
|
// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个
|
||||||
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (sessionId string, err error) {
|
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (result *entitys.SessionInitResponse, err error) {
|
||||||
|
|
||||||
// 获取系统配置
|
// 获取系统配置
|
||||||
sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId))
|
sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return
|
||||||
} else if !has {
|
} else if !has {
|
||||||
return "", fmt.Errorf("sys not found")
|
err = fmt.Errorf("sys not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result = &entitys.SessionInitResponse{
|
||||||
|
Chat: make([]entitys.ChatHistory, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 当天的session
|
// 获取 当天的session
|
||||||
|
@ -46,20 +53,56 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
|
||||||
s.sysRepo.WithSysId(sysConfig.SysID), // 条件:系统ID
|
s.sysRepo.WithSysId(sysConfig.SysID), // 条件:系统ID
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return
|
||||||
} else if !has {
|
} else if !has {
|
||||||
// 不存在,创建一个
|
// 不存在,创建一个
|
||||||
session = model.AiSession{
|
session = model.AiSession{
|
||||||
SysID: sysConfig.SysID,
|
SysID: sysConfig.SysID,
|
||||||
SessionID: utils.UUID(),
|
SessionID: utils.UUID(),
|
||||||
|
UserID: req.UserId,
|
||||||
}
|
}
|
||||||
err = s.sessionRepo.Create(&session)
|
err = s.sessionRepo.Create(&session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := entitys.ChatHistory{
|
||||||
|
SessionID: session.SessionID,
|
||||||
|
Role: constants.RoleSystem,
|
||||||
|
Content: sysConfig.Prologue,
|
||||||
|
}
|
||||||
|
result.Chat = append(result.Chat, chat)
|
||||||
|
|
||||||
|
// 开场白写入会话历史
|
||||||
|
s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
|
||||||
|
SessionID: chat.SessionID,
|
||||||
|
Role: chat.Role.String(),
|
||||||
|
Content: chat.Content,
|
||||||
|
})
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// 存在,返回会话历史
|
||||||
|
var chatList []model.AiChatHi
|
||||||
|
chatList, err = s.chatRepo.FindAll(
|
||||||
|
s.chatRepo.WithSessionId(session.SessionID), // 条件:会话ID
|
||||||
|
s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||||||
|
s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为 entitys.ChatHistory 类型
|
||||||
|
for _, chat := range chatList {
|
||||||
|
result.Chat = append(result.Chat, entitys.ChatHistory{
|
||||||
|
SessionID: chat.SessionID,
|
||||||
|
Role: constants.Caller(chat.Role),
|
||||||
|
Content: chat.Content,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return session.SessionID, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionList 会话列表
|
// SessionList 会话列表
|
||||||
|
|
|
@ -14,8 +14,13 @@ type Config struct {
|
||||||
Sys SysConfig `mapstructure:"sys"`
|
Sys SysConfig `mapstructure:"sys"`
|
||||||
Tools ToolsConfig `mapstructure:"tools"`
|
Tools ToolsConfig `mapstructure:"tools"`
|
||||||
Logging LoggingConfig `mapstructure:"logging"`
|
Logging LoggingConfig `mapstructure:"logging"`
|
||||||
Redis *Redis `protobuf:"bytes,1,opt,name=Redis,proto3" json:"Redis,omitempty"`
|
Redis Redis `mapstructure:"redis"`
|
||||||
DB *DB `protobuf:"bytes,3,opt,name=TransDB,proto3" json:"TransDB,omitempty"`
|
DB DB `mapstructure:"db"`
|
||||||
|
// LLM *LLM `mapstructure:"llm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LLM struct {
|
||||||
|
Model string `mapstructure:"model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SysConfig 系统配置
|
// SysConfig 系统配置
|
||||||
|
@ -37,24 +42,24 @@ type OllamaConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Redis struct {
|
type Redis struct {
|
||||||
Host string `protobuf:"bytes,1,opt,name=host,proto3" json:"host,omitempty"`
|
Host string `mapstructure:"host"`
|
||||||
Type string `protobuf:"bytes,2,opt,name=type,proto3" json:"type,omitempty"`
|
Type string `mapstructure:"type"`
|
||||||
Pass string `protobuf:"bytes,3,opt,name=pass,proto3" json:"pass,omitempty"`
|
Pass string `mapstructure:"pass"`
|
||||||
Key string `protobuf:"bytes,4,opt,name=key,proto3" json:"key,omitempty"`
|
Key string `mapstructure:"key"`
|
||||||
Tls int32 `protobuf:"varint,5,opt,name=tls,proto3" json:"tls,omitempty"`
|
Tls int32 `mapstructure:"tls"`
|
||||||
Db int32 `protobuf:"varint,6,opt,name=db,proto3" json:"db,omitempty"`
|
Db int32 `mapstructure:"db"`
|
||||||
MaxIdle int32 `protobuf:"varint,7,opt,name=maxIdle,proto3" json:"maxIdle,omitempty"`
|
MaxIdle int32 `mapstructure:"maxIdle"`
|
||||||
PoolSize int32 `protobuf:"varint,8,opt,name=poolSize,proto3" json:"poolSize,omitempty"`
|
PoolSize int32 `mapstructure:"poolSize"`
|
||||||
MaxIdleTime int32 `protobuf:"varint,9,opt,name=maxIdleTime,proto3" json:"maxIdleTime,omitempty"`
|
MaxIdleTime int32 `mapstructure:"maxIdleTime"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
Driver string `protobuf:"bytes,1,opt,name=driver,proto3" json:"driver,omitempty"`
|
Driver string `mapstructure:"driver"`
|
||||||
Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"`
|
Source string `mapstructure:"source"`
|
||||||
MaxIdle int32 `protobuf:"varint,3,opt,name=maxIdle,proto3" json:"maxIdle,omitempty"`
|
MaxIdle int32 `mapstructure:"maxIdle"`
|
||||||
MaxOpen int32 `protobuf:"varint,4,opt,name=maxOpen,proto3" json:"maxOpen,omitempty"`
|
MaxOpen int32 `mapstructure:"maxOpen"`
|
||||||
MaxLifetime int32 `protobuf:"varint,5,opt,name=maxLifetime,proto3" json:"maxLifetime,omitempty"`
|
MaxLifetime int32 `mapstructure:"maxLifetime"`
|
||||||
IsDebug bool `protobuf:"varint,6,opt,name=isDebug,proto3" json:"isDebug,omitempty"`
|
IsDebug bool `mapstructure:"isDebug"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolsConfig 工具配置
|
// ToolsConfig 工具配置
|
||||||
|
|
|
@ -5,6 +5,14 @@ type Caller string
|
||||||
const (
|
const (
|
||||||
CallerZltx Caller = "zltx" // 直连天下
|
CallerZltx Caller = "zltx" // 直连天下
|
||||||
CallerHyt Caller = "hyt" // 货易通
|
CallerHyt Caller = "hyt" // 货易通
|
||||||
|
|
||||||
|
// 角色, 系统角色,用户角色
|
||||||
|
RoleSystem Caller = "system" // 系统角色
|
||||||
|
RoleUser Caller = "user" // 用户角色
|
||||||
|
RoleAssistant Caller = "assistant" // 助手角色
|
||||||
|
|
||||||
|
// 分页默认条数
|
||||||
|
ChatHistoryLimit = 10
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c Caller) String() string {
|
func (c Caller) String() string {
|
||||||
|
|
|
@ -7,3 +7,10 @@ const (
|
||||||
ConnStatusNormal
|
ConnStatusNormal
|
||||||
ConnStatusIgnore
|
ConnStatusIgnore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type TaskType int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskTypeApi ConnStatus = iota + 1
|
||||||
|
TaskTypeKnowle
|
||||||
|
)
|
||||||
|
|
|
@ -52,6 +52,7 @@ type BaseRepository[P PO] interface {
|
||||||
WithId(id interface{}) CondFunc // 查询id
|
WithId(id interface{}) CondFunc // 查询id
|
||||||
WithStatus(status int) CondFunc // 查询status
|
WithStatus(status int) CondFunc // 查询status
|
||||||
GetDb() *gorm.DB // 获取数据库连接
|
GetDb() *gorm.DB // 获取数据库连接
|
||||||
|
WithLimit(limit int) CondFunc // 限制返回条数
|
||||||
}
|
}
|
||||||
|
|
||||||
// PaginationResult 分页查询结果
|
// PaginationResult 分页查询结果
|
||||||
|
@ -207,3 +208,9 @@ func (this *BaseModel[P]) WithStatus(status int) CondFunc {
|
||||||
func (this *BaseModel[P]) GetDb() *gorm.DB {
|
func (this *BaseModel[P]) GetDb() *gorm.DB {
|
||||||
return this.Db
|
return this.Db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (this *BaseModel[P]) WithLimit(limit int) CondFunc {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Limit(limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
package impl
|
||||||
|
|
||||||
|
//import (
|
||||||
|
// "ai_scheduler/internal/data/model"
|
||||||
|
// "ai_scheduler/tmpl/dataTemp"
|
||||||
|
// "ai_scheduler/utils"
|
||||||
|
//)
|
||||||
|
|
||||||
|
//type ChatHisImpl struct {
|
||||||
|
// dataTemp.DataTemp
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func NewChatHisImpl(db *utils.Db) *ChatHisImpl {
|
||||||
|
// return &ChatHisImpl{*dataTemp.NewDataTemp(db, new(model.AiChatHi))}
|
||||||
|
//}
|
|
@ -0,0 +1,56 @@
|
||||||
|
package impl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
|
"ai_scheduler/utils"
|
||||||
|
"context"
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatImpl struct {
|
||||||
|
dataTemp.DataTemp
|
||||||
|
BaseRepository[model.AiChatHi]
|
||||||
|
chatChannel chan model.AiChatHi
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatImpl(db *utils.Db) *ChatImpl {
|
||||||
|
return &ChatImpl{
|
||||||
|
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)),
|
||||||
|
BaseRepository: NewBaseModel[model.AiChatHi](db.Client),
|
||||||
|
chatChannel: make(chan model.AiChatHi, 100),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSessionId 条件:会话ID
|
||||||
|
func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Where("session_id = ?", sessionId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步添加会话历史
|
||||||
|
func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
|
||||||
|
impl.chatChannel <- chat
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步处理会话历史
|
||||||
|
func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case chat := <-impl.chatChannel:
|
||||||
|
log.Infof("ChatHistoryAsyncProcess chat: %v", chat)
|
||||||
|
if err := impl.Create(&chat); err != nil {
|
||||||
|
log.Errorf("ChatHistoryAsyncProcess err: %v", err)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("ChatHistoryAsyncProcess ctx done")
|
||||||
|
return
|
||||||
|
// 定时打印通道大小
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,4 +4,4 @@ import (
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl)
|
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl)
|
||||||
|
|
|
@ -10,13 +10,13 @@ import (
|
||||||
|
|
||||||
type SessionImpl struct {
|
type SessionImpl struct {
|
||||||
dataTemp.DataTemp
|
dataTemp.DataTemp
|
||||||
BaseModel[model.AiSession]
|
BaseRepository[model.AiSession]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSessionImpl(db *utils.Db) *SessionImpl {
|
func NewSessionImpl(db *utils.Db) *SessionImpl {
|
||||||
return &SessionImpl{
|
return &SessionImpl{
|
||||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSession)),
|
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSession)),
|
||||||
BaseModel: BaseModel[model.AiSession]{},
|
BaseRepository: NewBaseModel[model.AiSession](db.Client),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,3 +40,9 @@ func (s *SessionImpl) WithSysId(sysId interface{}) CondFunc {
|
||||||
return db.Where("sys_id = ?", sysId)
|
return db.Where("sys_id = ?", sysId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (impl *SessionImpl) WithSessionId(sessionId interface{}) CondFunc {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Where("session_id = ?", sessionId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ const TableNameAiChatHi = "ai_chat_his"
|
||||||
type AiChatHi struct {
|
type AiChatHi struct {
|
||||||
HisID int64 `gorm:"column:his_id;primaryKey" json:"his_id"`
|
HisID int64 `gorm:"column:his_id;primaryKey" json:"his_id"`
|
||||||
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
|
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
|
||||||
Role string `gorm:"column:role;not null" json:"role"`
|
Role string `gorm:"column:role;not null;comment:system系统输出,assistant助手输出,user用户输入" json:"role"` // system系统输出,assistant助手输出,user用户输入
|
||||||
Content string `gorm:"column:content;not null" json:"content"`
|
Content string `gorm:"column:content;not null" json:"content"`
|
||||||
CreateAt *time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
CreateAt *time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
package entitys
|
||||||
|
|
||||||
|
import "ai_scheduler/internal/constants"
|
||||||
|
|
||||||
|
type ChatHistory struct {
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
Role constants.Caller `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
|
@ -5,6 +5,10 @@ type SessionInitRequest struct {
|
||||||
UserId string `json:"user_id"`
|
UserId string `json:"user_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SessionInitResponse struct {
|
||||||
|
Chat []ChatHistory `json:"chat"`
|
||||||
|
}
|
||||||
|
|
||||||
type SessionListRequest struct {
|
type SessionListRequest struct {
|
||||||
SysId string `json:"sys_id"`
|
SysId string `json:"sys_id"`
|
||||||
UserId string `json:"user_id"`
|
UserId string `json:"user_id"`
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -79,8 +80,29 @@ type Message struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Func struct {
|
type FuncApi struct {
|
||||||
Parameters struct{} `json:"parameters"`
|
Param interface{} `json:"param"`
|
||||||
|
Request l_request.Request `json:"request"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskConfig struct {
|
||||||
|
Param interface{} `json:"param"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatHis struct {
|
||||||
|
SessionId string `json:"session_id"`
|
||||||
|
Messages []HisMessage `json:"messages"`
|
||||||
|
Context HisContext `json:"context"`
|
||||||
|
}
|
||||||
|
type HisMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Timestamp string `json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HisContext struct {
|
||||||
|
UserLanguage string `json:"user_language"`
|
||||||
|
SystemMode string `json:"system_mode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouterService 路由服务接口
|
// RouterService 路由服务接口
|
||||||
|
|
|
@ -1 +1,8 @@
|
||||||
package pkg
|
package pkg
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
func JsonStringIgonErr(data interface{}) string {
|
||||||
|
dataByte, _ := json.Marshal(data)
|
||||||
|
return string(dataByte)
|
||||||
|
}
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
package pkg
|
package pkg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/pkg/ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
|
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ProviderSetClient = wire.NewSet(
|
var ProviderSetClient = wire.NewSet(
|
||||||
NewRdb,
|
NewRdb,
|
||||||
NewGormDb,
|
NewGormDb,
|
||||||
ollama.NewClient,
|
utils_ollama.NewUtilOllama,
|
||||||
)
|
)
|
||||||
|
|
|
@ -32,7 +32,7 @@ func NewRdb(c *config.Config) *Rdb {
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildRdb 构建redis client
|
// buildRdb 构建redis client
|
||||||
func buildRdb(c *config.Redis) *redis.Client {
|
func buildRdb(c config.Redis) *redis.Client {
|
||||||
|
|
||||||
rdb := redis.NewClient(&redis.Options{
|
rdb := redis.NewClient(&redis.Options{
|
||||||
Addr: c.Host,
|
Addr: c.Host,
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DBConn(c *config.DB) (*gorm.DB, func()) {
|
func DBConn(c config.DB) (*gorm.DB, func()) {
|
||||||
mysqlConn, err := sql.Open(c.Driver, c.Source)
|
mysqlConn, err := sql.Open(c.Driver, c.Source)
|
||||||
gormDB, err := gorm.Open(
|
gormDB, err := gorm.Open(
|
||||||
mysql.New(mysql.Config{Conn: mysqlConn}),
|
mysql.New(mysql.Config{Conn: mysqlConn}),
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package ollama
|
package utils_ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
|
@ -9,6 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/tmc/langchaingo/llms/ollama"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client Ollama客户端适配器
|
// Client Ollama客户端适配器
|
|
@ -0,0 +1,47 @@
|
||||||
|
package utils_ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
|
"github.com/tmc/langchaingo/llms/ollama"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UtilOllama struct {
|
||||||
|
Llm *ollama.LLM
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
||||||
|
llm, err := ollama.New(
|
||||||
|
ollama.WithModel(c.Ollama.Model),
|
||||||
|
ollama.WithHTTPClient(http.DefaultClient),
|
||||||
|
ollama.WithServerURL(getUrl(c)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UtilOllama{
|
||||||
|
Llm: llm,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//func (o *UtilOllama) a() {
|
||||||
|
// var agent agents.Agent
|
||||||
|
// agent = agents.NewOneShotAgent(llm, tools, opts...)
|
||||||
|
//
|
||||||
|
// agents.NewExecutor()
|
||||||
|
//}
|
||||||
|
|
||||||
|
func getUrl(c *config.Config) string {
|
||||||
|
baseURL := c.Ollama.BaseURL
|
||||||
|
envURL := os.Getenv("OLLAMA_BASE_URL")
|
||||||
|
if envURL != "" {
|
||||||
|
baseURL = envURL
|
||||||
|
}
|
||||||
|
return baseURL
|
||||||
|
}
|
|
@ -65,7 +65,7 @@ func routerHttp(app *fiber.App, sessionService *services.SessionService) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Post("/session/init", sessionService.SessionInit)
|
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
|
||||||
r.Post("/session/list", sessionService.SessionList)
|
r.Post("/session/list", sessionService.SessionList)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,11 +8,13 @@ import (
|
||||||
|
|
||||||
type SessionService struct {
|
type SessionService struct {
|
||||||
sessionBiz *biz.SessionBiz
|
sessionBiz *biz.SessionBiz
|
||||||
|
chatBiz *biz.ChatHistoryBiz
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSessionService(sessionBiz *biz.SessionBiz) *SessionService {
|
func NewSessionService(sessionBiz *biz.SessionBiz, chatBiz *biz.ChatHistoryBiz) *SessionService {
|
||||||
return &SessionService{
|
return &SessionService{
|
||||||
sessionBiz: sessionBiz,
|
sessionBiz: sessionBiz,
|
||||||
|
chatBiz: chatBiz,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,11 +25,9 @@ func (s *SessionService) SessionInit(c *fiber.Ctx) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionId, err := s.sessionBiz.SessionInit(c.Context(), req)
|
result, err := s.sessionBiz.SessionInit(c.Context(), req)
|
||||||
|
|
||||||
return handRes(c, err, fiber.Map{
|
return handRes(c, err, result)
|
||||||
"session_id": sessionId,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionList 获取会话列表
|
// SessionList 获取会话列表
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_task(t *testing.T) {
|
||||||
|
var c entitys.TaskConfig
|
||||||
|
config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
|
||||||
|
err := json.Unmarshal([]byte(config), &c)
|
||||||
|
t.Log(err)
|
||||||
|
}
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"ai_scheduler/utils"
|
"ai_scheduler/utils"
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
"github.com/go-kratos/kratos/v2/log"
|
"github.com/go-kratos/kratos/v2/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -103,7 +105,52 @@ func (k DataTemp) GetOneBySearchToStrut(cond *builder.Cond, result interface{})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result []interface{}) (pageBoOut *RespPageBo, err error) {
|
func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}) (pageBoOut *RespPageBo, err error) {
|
||||||
|
var (
|
||||||
|
query, _ = builder.ToBoundSQL(*cond)
|
||||||
|
model = k.Db.Model(k.Model).Where(query)
|
||||||
|
total int64
|
||||||
|
)
|
||||||
|
|
||||||
|
// 1. 计算总数
|
||||||
|
if err = model.Count(&total).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 设置分页
|
||||||
|
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
|
||||||
|
|
||||||
|
// 3. 查询数据(确保 result 是指针,如 &[]User)
|
||||||
|
if err = model.Limit(pageBoIn.GetSize()).
|
||||||
|
Offset(pageBoIn.GetOffset()).
|
||||||
|
Order("updated_at desc").
|
||||||
|
Find(result).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 可选:用 reflect 处理结果(确保类型正确)
|
||||||
|
val := reflect.ValueOf(result)
|
||||||
|
if val.Kind() != reflect.Ptr {
|
||||||
|
return nil, fmt.Errorf("result must be a pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
val = val.Elem() // 解引用
|
||||||
|
if val.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < val.Len(); i++ {
|
||||||
|
elem := val.Index(i)
|
||||||
|
if elem.Kind() == reflect.Struct {
|
||||||
|
// 示例:打印某个字段
|
||||||
|
if field := elem.FieldByName("ID"); field.IsValid() {
|
||||||
|
fmt.Println("ID:", field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pageBoOut, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k DataTemp) GetListToStruct2(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}) (pageBoOut *RespPageBo, err error) {
|
||||||
var (
|
var (
|
||||||
query, _ = builder.ToBoundSQL(*cond)
|
query, _ = builder.ToBoundSQL(*cond)
|
||||||
model = k.Db.Model(k.Model).Where(query)
|
model = k.Db.Model(k.Model).Where(query)
|
||||||
|
|
Loading…
Reference in New Issue