Compare commits
34 Commits
ecaf2635b5
...
aa86882f7d
| Author | SHA1 | Date |
|---|---|---|
|
|
aa86882f7d | |
|
|
e82ecaf9f7 | |
|
|
3818b3f6cf | |
|
|
a556a18d1a | |
|
|
f11be8c1c0 | |
|
|
ad58547eed | |
|
|
4b5f7355b6 | |
|
|
50a5193922 | |
|
|
c07b57a705 | |
|
|
b996f62098 | |
|
|
dc85df0add | |
|
|
8811ea1cff | |
|
|
c130900ac6 | |
|
|
3b91505a12 | |
|
|
2da5e15a99 | |
|
|
ff81cddc46 | |
|
|
3f37427181 | |
|
|
1412554661 | |
|
|
570e13e527 | |
|
|
32866d59a1 | |
|
|
b81c9ef137 | |
|
|
0ce09611de | |
|
|
03f6130a6e | |
|
|
88338b2889 | |
|
|
addcf2d24d | |
|
|
8a2411016c | |
|
|
068563b914 | |
|
|
303dd39cb3 | |
|
|
c91e5519e9 | |
|
|
99c397aabf | |
|
|
a43c19eb48 | |
|
|
11241af84c | |
|
|
904b67a608 | |
|
|
949f80d417 |
|
|
@ -10,18 +10,21 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
|
||||||
configPath := flag.String("config", "config.yaml", "Path to configuration file")
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
bc, err := config.LoadConfig(*configPath)
|
bc, err := config.LoadConfig(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("加载配置失败: %v", err)
|
log.Fatalf("加载配置失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app, cleanup, err := InitializeApp(bc, log.DefaultLogger())
|
app, cleanup, err := InitializeApp(bc, log.DefaultLogger())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("项目初始化失败: %v", err)
|
log.Fatalf("项目初始化失败: %v", err)
|
||||||
}
|
}
|
||||||
defer cleanup()
|
defer func() {
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,17 @@ server:
|
||||||
|
|
||||||
ollama:
|
ollama:
|
||||||
base_url: "http://localhost:11434"
|
base_url: "http://localhost:11434"
|
||||||
model: "deepseek-r1:8b"
|
model: "qwen3:8b"
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
format: "json"
|
format: "json"
|
||||||
|
|
||||||
sys:
|
sys:
|
||||||
session_len: 3
|
session_len: 3
|
||||||
|
channel_pool_len: 100
|
||||||
|
channel_pool_size: 32
|
||||||
|
|
||||||
Redis:
|
redis:
|
||||||
host: 47.97.27.195:6379
|
host: 47.97.27.195:6379
|
||||||
type: node
|
type: node
|
||||||
pass: lansexiongdi@666
|
pass: lansexiongdi@666
|
||||||
|
|
@ -23,6 +25,28 @@ Redis:
|
||||||
maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭
|
maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭
|
||||||
tls: 30
|
tls: 30
|
||||||
db:
|
db:
|
||||||
DB:
|
db:
|
||||||
driver: mysql
|
driver: mysql
|
||||||
source: transfer:Lsxd@34234QW@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/transfer?charset=utf8mb4&parseTime=true&
|
source: root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai
|
||||||
|
|
||||||
|
tools:
|
||||||
|
zltxOrderDetail:
|
||||||
|
enabled: true
|
||||||
|
base_url: "https://gateway.dev.cdlsxd.cn/"
|
||||||
|
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU4MDkxOTU4LCJuYmYiOjE3NTgwOTAxNTgsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.Bjsx9f8yfcrV9EWxb0n6POwnXVOq9XPRD78JFZnnf1_VAVMN78W4W570SZL27PWuDnkD7E4oUg6RzeZwZgl7BZrNpNr-a-QpNC5qCptqrqXeNfVStmX7pxWA8GqnzI8ybkZgbhQ58Gje7DzdJtBq_8zte_LDaYhTYXdIc5EAG0AbCzAk22nPTl47nkMeHtmisXQVLEsdibl1hW3ViFJlXwfXvUrOENItmL1_mRYkggUB0MaTu2nHJOYM6PaOVGLHx-74eepnmK2rm6konFEb6ed-Ukc6gVR-nM9yWZaYLYNGNKJLwZoCX3tRuerq74n4kzQgWmUEJeaVI1yIGSw1zw"
|
||||||
|
zltxOrderDirectLog:
|
||||||
|
enabled: true
|
||||||
|
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/log/"
|
||||||
|
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ"
|
||||||
|
zltxProduct:
|
||||||
|
enabled: true
|
||||||
|
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/oursProduct"
|
||||||
|
add_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/platformProduct/"
|
||||||
|
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ"
|
||||||
|
zltxOrderStatistics:
|
||||||
|
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/search/"
|
||||||
|
enabled: true
|
||||||
|
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ"
|
||||||
|
knowledge:
|
||||||
|
base_url: "http://117.175.169.61:10000"
|
||||||
|
enabled: true
|
||||||
|
|
|
||||||
11
go.mod
11
go.mod
|
|
@ -5,15 +5,17 @@ go 1.24.0
|
||||||
toolchain go1.24.7
|
toolchain go1.24.7
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
gitea.cdlsxd.cn/self-tools/l_request v1.0.8
|
||||||
github.com/emirpasic/gods v1.18.1
|
github.com/emirpasic/gods v1.18.1
|
||||||
github.com/go-kratos/kratos/v2 v2.9.1
|
github.com/go-kratos/kratos/v2 v2.9.1
|
||||||
github.com/gofiber/fiber/v2 v2.52.9
|
github.com/gofiber/fiber/v2 v2.52.9
|
||||||
github.com/gofiber/websocket/v2 v2.2.1
|
github.com/gofiber/websocket/v2 v2.2.1
|
||||||
github.com/google/wire v0.7.0
|
github.com/google/wire v0.7.0
|
||||||
github.com/ollama/ollama v0.11.10
|
github.com/ollama/ollama v0.12.0
|
||||||
github.com/redis/go-redis/v9 v9.14.0
|
github.com/redis/go-redis/v9 v9.14.0
|
||||||
github.com/spf13/viper v1.17.0
|
github.com/spf13/viper v1.17.0
|
||||||
google.golang.org/grpc v1.61.1
|
github.com/tmc/langchaingo v0.1.13
|
||||||
|
google.golang.org/grpc v1.64.0
|
||||||
google.golang.org/protobuf v1.34.1
|
google.golang.org/protobuf v1.34.1
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gorm.io/driver/mysql v1.6.0
|
gorm.io/driver/mysql v1.6.0
|
||||||
|
|
@ -26,10 +28,10 @@ require (
|
||||||
github.com/andybalholm/brotli v1.1.0 // indirect
|
github.com/andybalholm/brotli v1.1.0 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||||
github.com/golang/protobuf v1.5.4 // indirect
|
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
|
@ -41,6 +43,7 @@ require (
|
||||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.3.0 // indirect
|
github.com/sagikazarmark/locafero v0.3.0 // indirect
|
||||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||||
|
|
@ -60,6 +63,6 @@ require (
|
||||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||||
golang.org/x/sys v0.31.0 // indirect
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
golang.org/x/text v0.23.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240102182953-50ed04b92917 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
|
||||||
26
go.sum
26
go.sum
|
|
@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
|
||||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
|
gitea.cdlsxd.cn/self-tools/l_request v1.0.8 h1:FaKRql9mCVcSoaGqPeBOAruZ52slzRngQ6VRTYKNSsA=
|
||||||
|
gitea.cdlsxd.cn/self-tools/l_request v1.0.8/go.mod h1:Qf4hVXm2Eu5vOvwXk8D7U0q/aekMCkZ4Fg9wnRKlasQ=
|
||||||
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s=
|
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s=
|
||||||
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU=
|
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
|
|
@ -64,6 +66,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
|
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
|
||||||
|
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
|
|
@ -114,8 +118,6 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq
|
||||||
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
||||||
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||||
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
|
||||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
|
||||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||||
|
|
@ -185,12 +187,14 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T
|
||||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||||
github.com/ollama/ollama v0.11.10 h1:J9zaoTPwIXOrYXCRAqI7rV4cJ+FOMuQc/vBqQ5GIdWg=
|
github.com/ollama/ollama v0.12.0 h1:BRry7G2Skz7Mu+E6rz40tzBXNbLTEhheGT8umc1zvxo=
|
||||||
github.com/ollama/ollama v0.11.10/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
|
github.com/ollama/ollama v0.12.0/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
|
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
|
@ -234,6 +238,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
|
github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA=
|
||||||
|
github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg=
|
||||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
|
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
|
||||||
|
|
@ -522,8 +528,8 @@ google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6D
|
||||||
google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||||
google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||||
google.golang.org/genproto v0.0.0-20210226172003-ab064af71705/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
google.golang.org/genproto v0.0.0-20210226172003-ab064af71705/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240102182953-50ed04b92917 h1:6G8oQ016D88m1xAKljMlBOOGWDZkes4kMhgGFlf8WcQ=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 h1:1GBuWVLM/KMVUv1t1En5Gs+gFZCNd360GGb4sSxtrhU=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240102182953-50ed04b92917/go.mod h1:xtjpI3tXFPP051KaWnhvxkiubL/6dJ18vLVf7q2pTOU=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0=
|
||||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||||
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||||
|
|
@ -540,8 +546,8 @@ google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM
|
||||||
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
||||||
google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8=
|
google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8=
|
||||||
google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
|
google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
|
||||||
google.golang.org/grpc v1.61.1 h1:kLAiWrZs7YeDM6MumDe7m3y4aM6wacLzM1Y/wiLP9XY=
|
google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY=
|
||||||
google.golang.org/grpc v1.61.1/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs=
|
google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg=
|
||||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||||
|
|
@ -562,6 +568,8 @@ gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||||
|
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|
@ -579,5 +587,7 @@ honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9
|
||||||
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
|
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
|
||||||
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
|
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
|
||||||
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
|
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
|
||||||
|
sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo=
|
||||||
|
sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8=
|
||||||
xorm.io/builder v0.3.13 h1:a3jmiVVL19psGeXx8GIurTp7p0IIgqeDmwhcR6BAOAo=
|
xorm.io/builder v0.3.13 h1:a3jmiVVL19psGeXx8GIurTp7p0IIgqeDmwhcR6BAOAo=
|
||||||
xorm.io/builder v0.3.13/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE=
|
xorm.io/builder v0.3.13/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE=
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
package biz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/impl"
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatHistoryBiz struct {
|
||||||
|
chatRepo *impl.ChatImpl
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
|
||||||
|
s := &ChatHistoryBiz{
|
||||||
|
chatRepo: chatRepo,
|
||||||
|
}
|
||||||
|
go s.AsyncProcess(context.Background())
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error {
|
||||||
|
chat := model.AiChatHi{
|
||||||
|
SessionID: sessionID,
|
||||||
|
Role: role,
|
||||||
|
Content: content,
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.chatRepo.Create(&chat)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加会话历史
|
||||||
|
func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
||||||
|
return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步添加会话历史
|
||||||
|
func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
|
||||||
|
s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
|
||||||
|
SessionID: chat.SessionID,
|
||||||
|
Role: chat.Role.String(),
|
||||||
|
Content: chat.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步处理会话历史
|
||||||
|
func (s *ChatHistoryBiz) AsyncProcess(ctx context.Context) {
|
||||||
|
s.chatRepo.AsyncProcess(ctx)
|
||||||
|
}
|
||||||
|
|
@ -2,4 +2,4 @@ package biz
|
||||||
|
|
||||||
import "github.com/google/wire"
|
import "github.com/google/wire"
|
||||||
|
|
||||||
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz)
|
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz)
|
||||||
|
|
|
||||||
|
|
@ -2,179 +2,438 @@ package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/data/impl"
|
"ai_scheduler/internal/data/impl"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/tools"
|
"ai_scheduler/internal/pkg"
|
||||||
|
"ai_scheduler/internal/pkg/mapstructure"
|
||||||
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
|
tools "ai_scheduler/internal/tools"
|
||||||
"ai_scheduler/tmpl/dataTemp"
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
|
"github.com/tmc/langchaingo/llms"
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AiRouterService 智能路由服务
|
// AiRouterBiz 智能路由服务
|
||||||
type AiRouterService struct {
|
type AiRouterBiz struct {
|
||||||
aiClient entitys.AIClient
|
//aiClient entitys.AIClient
|
||||||
toolManager *tools.Manager
|
toolManager *tools.Manager
|
||||||
sessionImpl *impl.SessionImpl
|
sessionImpl *impl.SessionImpl
|
||||||
|
sysImpl *impl.SysImpl
|
||||||
|
taskImpl *impl.TaskImpl
|
||||||
|
hisImpl *impl.ChatImpl
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
|
utilAgent *utils_ollama.UtilOllama
|
||||||
|
channelPool *pkg.SafeChannelPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouterService 创建路由服务
|
// NewRouterService 创建路由服务
|
||||||
func NewAiRouterBiz(aiClient entitys.AIClient, toolManager *tools.Manager, sessionImpl *impl.SessionImpl, conf *config.Config) entitys.RouterService {
|
func NewAiRouterBiz(
|
||||||
return &AiRouterService{
|
//aiClient entitys.AIClient,
|
||||||
aiClient: aiClient,
|
toolManager *tools.Manager,
|
||||||
|
sessionImpl *impl.SessionImpl,
|
||||||
|
sysImpl *impl.SysImpl,
|
||||||
|
taskImpl *impl.TaskImpl,
|
||||||
|
hisImpl *impl.ChatImpl,
|
||||||
|
conf *config.Config,
|
||||||
|
utilAgent *utils_ollama.UtilOllama,
|
||||||
|
channelPool *pkg.SafeChannelPool,
|
||||||
|
) *AiRouterBiz {
|
||||||
|
return &AiRouterBiz{
|
||||||
|
//aiClient: aiClient,
|
||||||
toolManager: toolManager,
|
toolManager: toolManager,
|
||||||
sessionImpl: sessionImpl,
|
sessionImpl: sessionImpl,
|
||||||
conf: conf,
|
conf: conf,
|
||||||
|
sysImpl: sysImpl,
|
||||||
|
hisImpl: hisImpl,
|
||||||
|
taskImpl: taskImpl,
|
||||||
|
utilAgent: utilAgent,
|
||||||
|
channelPool: channelPool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route 执行智能路由
|
// Route 执行智能路由
|
||||||
func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
|
func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route 执行智能路由
|
// Route 执行智能路由
|
||||||
func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
|
func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) {
|
||||||
session := c.Headers("x-session", "")
|
ch, err := r.channelPool.Get()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
entitys.MsgSend(c, entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: err.Error(),
|
||||||
|
Type: entitys.ResponseErr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
entitys.MsgSend(c, entitys.ResponseData{
|
||||||
|
Done: true,
|
||||||
|
Content: "",
|
||||||
|
Type: entitys.ResponseEnd,
|
||||||
|
})
|
||||||
|
err = r.channelPool.Put(ch)
|
||||||
|
if err != nil {
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
session := c.Headers("X-Session", "")
|
||||||
if len(session) == 0 {
|
if len(session) == 0 {
|
||||||
return errors.SessionNotFound
|
return errors.SessionNotFound
|
||||||
}
|
}
|
||||||
auth := c.Headers("x-authorization", "")
|
auth := c.Headers("X-Authorization", "")
|
||||||
|
|
||||||
if len(auth) == 0 {
|
if len(auth) == 0 {
|
||||||
return errors.AuthNotFound
|
return errors.AuthNotFound
|
||||||
}
|
}
|
||||||
key := c.Headers("x-app-key", "")
|
key := c.Headers("X-App-Key", "")
|
||||||
if len(key) == 0 {
|
if len(key) == 0 {
|
||||||
return errors.KeyNotFound
|
return errors.KeyNotFound
|
||||||
}
|
}
|
||||||
var sysInfo model.AiSy
|
|
||||||
cond := builder.NewCond()
|
sysInfo, err := r.getSysInfo(key)
|
||||||
cond = cond.And(builder.Eq{"app_key": key})
|
|
||||||
err := r.sessionImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.SysNotFound
|
return errors.SysNotFound
|
||||||
}
|
}
|
||||||
cond = builder.NewCond()
|
|
||||||
cond = cond.And(builder.Eq{"session_id": session})
|
history, err := r.getSessionChatHis(session)
|
||||||
history, _, err := r.sessionImpl.GetList(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
fmt.Printf("history:%v\n", history)
|
|
||||||
var (
|
|
||||||
messages = make([]entitys.Message, 0)
|
|
||||||
onece sync.Once
|
|
||||||
)
|
|
||||||
onece.Do(func() {
|
|
||||||
|
|
||||||
messages = append(messages, entitys.Message{
|
task, err := r.getTasks(sysInfo.SysID)
|
||||||
|
if err != nil {
|
||||||
|
return errors.SystemError
|
||||||
|
}
|
||||||
|
//意图预测
|
||||||
|
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
||||||
|
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
||||||
|
llms.WithJSONMode(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return errors.SystemError
|
||||||
|
}
|
||||||
|
log.Info(match.Choices[0].Content)
|
||||||
|
ch <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: match.Choices[0].Content,
|
||||||
|
Type: entitys.ResponseLog,
|
||||||
|
}
|
||||||
|
var matchJson entitys.Match
|
||||||
|
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
|
||||||
|
if err != nil {
|
||||||
|
return errors.SystemError
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("recovered from panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(ch)
|
||||||
|
err = r.handleMatch(c, ch, &matchJson, task, sysInfo)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for v := range ch {
|
||||||
|
if err := entitys.MsgSend(c, v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match) (err error) {
|
||||||
|
ch <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: matchJson.Reasoning,
|
||||||
|
Type: entitys.ResponseText,
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) handleMatch(c *websocket.Conn, ch chan entitys.ResponseData, matchJson *entitys.Match, tasks []model.AiTask, sysInfo model.AiSy) (err error) {
|
||||||
|
|
||||||
|
if !matchJson.IsMatch {
|
||||||
|
ch <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: matchJson.Reasoning,
|
||||||
|
Type: entitys.ResponseText,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var pointTask *model.AiTask
|
||||||
|
for _, task := range tasks {
|
||||||
|
if task.Index == matchJson.Index {
|
||||||
|
pointTask = &task
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pointTask == nil || pointTask.Index == "other" {
|
||||||
|
return r.handleOtherTask(c, ch, matchJson)
|
||||||
|
}
|
||||||
|
switch pointTask.Type {
|
||||||
|
case constants.TaskTypeApi:
|
||||||
|
return r.handleApiTask(ch, c, matchJson, pointTask)
|
||||||
|
case constants.TaskTypeFunc:
|
||||||
|
return r.handleTask(ch, c, matchJson, pointTask)
|
||||||
|
case constants.TaskTypeKnowle:
|
||||||
|
return r.handleKnowle(ch, c, matchJson, pointTask, sysInfo)
|
||||||
|
default:
|
||||||
|
return r.handleOtherTask(c, ch, matchJson)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) handleTask(channel chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||||||
|
|
||||||
|
var configData entitys.ConfigDataTool
|
||||||
|
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, []byte(matchJson.Parameters))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 知识库
|
||||||
|
func (r *AiRouterBiz) handleKnowle(channel chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask, sysInfo model.AiSy) (err error) {
|
||||||
|
|
||||||
|
var (
|
||||||
|
configData entitys.ConfigDataTool
|
||||||
|
sessionIdKnowledge string
|
||||||
|
query string
|
||||||
|
host string
|
||||||
|
)
|
||||||
|
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通过session 找到知识库session
|
||||||
|
session := c.Headers("X-Session", "")
|
||||||
|
if len(session) == 0 {
|
||||||
|
return errors.SessionNotFound
|
||||||
|
}
|
||||||
|
sessionInfo, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(session))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
} else if !has {
|
||||||
|
return errors.SessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// 找到知识库的host
|
||||||
|
{
|
||||||
|
tool, exists := r.toolManager.GetTool(configData.Tool)
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("tool not found: %s", configData.Tool)
|
||||||
|
}
|
||||||
|
|
||||||
|
if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok {
|
||||||
|
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
|
||||||
|
} else {
|
||||||
|
host = knowledgeTool.GetConfig().BaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// 知识库的session为空,请求知识库获取, 并绑定
|
||||||
|
if sessionInfo.KnowlegeSessionID == "" {
|
||||||
|
// 请求知识库
|
||||||
|
if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, sysInfo.KnowlegeBaseID, sysInfo.KnowlegeTenantKey); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 绑定知识库session,下次可以使用
|
||||||
|
sessionInfo.KnowlegeSessionID = sessionIdKnowledge
|
||||||
|
if err = r.sessionImpl.Update(&sessionInfo, r.sessionImpl.WithSessionId(sessionInfo.SessionID)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户输入解析
|
||||||
|
var ok bool
|
||||||
|
input := make(map[string]string)
|
||||||
|
if err = json.Unmarshal([]byte(matchJson.Parameters), &input); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if query, ok = input["query"]; !ok {
|
||||||
|
return fmt.Errorf("query不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
knowledgeConfig := tools.KnowledgeBaseRequest{
|
||||||
|
Session: sessionInfo.KnowlegeSessionID,
|
||||||
|
ApiKey: sysInfo.KnowlegeTenantKey,
|
||||||
|
Query: query,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(knowledgeConfig)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行工具
|
||||||
|
err = r.toolManager.ExecuteTool(channel, c, configData.Tool, b)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) handleApiTask(channels chan entitys.ResponseData, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
|
||||||
|
var (
|
||||||
|
request l_request.Request
|
||||||
|
auth = c.Headers("X-Authorization", "")
|
||||||
|
requestParam map[string]interface{}
|
||||||
|
)
|
||||||
|
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
|
||||||
|
for k, v := range requestParam {
|
||||||
|
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
var configData entitys.ConfigDataHttp
|
||||||
|
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = mapstructure.Decode(configData.Request, &request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(request.Url) == 0 {
|
||||||
|
err = errors.NewBusinessErr("00022", "api地址获取失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res, err := request.Send()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.WriteMessage(1, res.Content)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||||||
|
|
||||||
|
cond := builder.NewCond()
|
||||||
|
cond = cond.And(builder.Eq{"session_id": sessionId})
|
||||||
|
|
||||||
|
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id asc")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
|
||||||
|
cond := builder.NewCond()
|
||||||
|
cond = cond.And(builder.Eq{"app_key": appKey})
|
||||||
|
cond = cond.And(builder.IsNull{"delete_at"})
|
||||||
|
cond = cond.And(builder.Eq{"status": 1})
|
||||||
|
err = r.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||||
|
|
||||||
|
cond := builder.NewCond()
|
||||||
|
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||||||
|
cond = cond.And(builder.IsNull{"delete_at"})
|
||||||
|
cond = cond.And(builder.Eq{"status": 1})
|
||||||
|
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||||
|
taskPrompt := make([]llms.Tool, 0)
|
||||||
|
for _, task := range tasks {
|
||||||
|
var taskConfig entitys.TaskConfig
|
||||||
|
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
taskPrompt = append(taskPrompt, llms.Tool{
|
||||||
|
Type: "function",
|
||||||
|
Function: &llms.FunctionDefinition{
|
||||||
|
Name: task.Index,
|
||||||
|
Description: task.Desc,
|
||||||
|
Parameters: taskConfig.Param,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
return taskPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
|
||||||
|
var (
|
||||||
|
prompt = make([]entitys.Message, 0)
|
||||||
|
)
|
||||||
|
prompt = append(prompt, entitys.Message{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
||||||
})
|
}, entitys.Message{
|
||||||
})
|
|
||||||
messages = append(messages, entitys.Message{}, entitys.Message{
|
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: r.buildIntentPrompt(req.Text),
|
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
||||||
}, entitys.Message{
|
}, entitys.Message{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: req.Text,
|
Content: reqInput,
|
||||||
})
|
})
|
||||||
// 构建消息
|
return prompt
|
||||||
//messages := []entitys.Message{
|
}
|
||||||
// {
|
|
||||||
// Role: "user",
|
|
||||||
// Content: req.UserInput,
|
|
||||||
// },
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 第1次调用AI,获取用户意图
|
|
||||||
//intentResponse, err := r.aiClient.Chat(ctx, messages, nil)
|
|
||||||
//if err != nil {
|
|
||||||
// return nil, fmt.Errorf("AI响应失败: %w", err)
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 从AI响应中提取意图
|
|
||||||
//intent := r.extractIntent(intentResponse)
|
|
||||||
//if intent == "" {
|
|
||||||
// return nil, fmt.Errorf("未识别到用户意图")
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//switch intent {
|
|
||||||
//case "order_diagnosis":
|
|
||||||
// // 订单诊断意图
|
|
||||||
// return r.handleOrderDiagnosis(ctx, req, messages)
|
|
||||||
//case "knowledge_qa":
|
|
||||||
// // 知识问答意图
|
|
||||||
// return r.handleKnowledgeQA(ctx, req, messages)
|
|
||||||
//default:
|
|
||||||
// // 未知意图
|
|
||||||
// return nil, fmt.Errorf("意图识别失败,请明确您的需求呢,我可以为您")
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 获取工具定义
|
|
||||||
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
|
|
||||||
//
|
|
||||||
//// 第2次调用AI,获取是否需要使用工具
|
|
||||||
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
|
|
||||||
//if err != nil {
|
|
||||||
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 如果没有工具调用,直接返回
|
|
||||||
//if len(response.ToolCalls) == 0 {
|
|
||||||
// return response, nil
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 执行工具调用
|
|
||||||
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
|
|
||||||
//if err != nil {
|
|
||||||
// return nil, fmt.Errorf("failed to execute tools: %w", err)
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 构建包含工具结果的消息
|
|
||||||
//messages = append(messages, entitys.Message{
|
|
||||||
// Role: "assistant",
|
|
||||||
// Content: response.Message,
|
|
||||||
//})
|
|
||||||
//
|
|
||||||
//// 添加工具调用结果
|
|
||||||
//for _, toolResult := range toolResults {
|
|
||||||
// toolResultStr, _ := json.Marshal(toolResult.Result)
|
|
||||||
// messages = append(messages, entitys.Message{
|
|
||||||
// Role: "tool",
|
|
||||||
// Content: fmt.Sprintf("Tool %s result: %s", toolResult.Function.Name, string(toolResultStr)),
|
|
||||||
// })
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 第二次调用AI,生成最终回复
|
|
||||||
//finalResponse, err := r.aiClient.Chat(ctx, messages, nil)
|
|
||||||
//if err != nil {
|
|
||||||
// return nil, fmt.Errorf("failed to generate final response: %w", err)
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 合并工具调用信息到最终响应
|
|
||||||
//finalResponse.ToolCalls = toolResults
|
|
||||||
//
|
|
||||||
//log.Printf("Router processed request: %s, used %d tools", req.UserInput, len(toolResults))
|
|
||||||
|
|
||||||
//return finalResponse, nil
|
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||||
return nil
|
var (
|
||||||
|
prompt = make([]llms.MessageContent, 0)
|
||||||
|
)
|
||||||
|
prompt = append(prompt, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeSystem,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
|
||||||
|
},
|
||||||
|
}, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeTool,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
||||||
|
},
|
||||||
|
}, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeTool,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||||
|
},
|
||||||
|
}, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeHuman,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(reqInput),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildSystemPrompt 构建系统提示词
|
// buildSystemPrompt 构建系统提示词
|
||||||
func (r *AiRouterService) buildSystemPrompt(prompt string) string {
|
func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
|
||||||
if len(prompt) == 0 {
|
if len(prompt) == 0 {
|
||||||
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
|
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
|
||||||
}
|
}
|
||||||
|
|
@ -182,86 +441,26 @@ func (r *AiRouterService) buildSystemPrompt(prompt string) string {
|
||||||
return prompt
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildIntentPrompt 构建意图识别提示词
|
func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
|
||||||
func (r *AiRouterService) buildIntentPrompt(userInput string) string {
|
for _, item := range his {
|
||||||
prompt := `##任务
|
if len(chatHis.SessionId) == 0 {
|
||||||
分析用户输入,判断用户的意图类型,没有使用Markdown格式的json格式回复
|
chatHis.SessionId = item.SessionID
|
||||||
##意图类型
|
|
||||||
1. product_diagnosis - 商品诊断:用户想要查询、诊断或了解商品相关信息
|
|
||||||
2. order_diagnosis - 订单诊断:用户想要查询、诊断或了解订单相关信息
|
|
||||||
3. knowledge_qa - 知识问答:用户想要进行一般性问答或获取知识信息
|
|
||||||
##判断规则
|
|
||||||
1.当用户意图不够清晰且不匹配 knowledge_qa 以外意图时,使用knowledge_qa
|
|
||||||
2.当用户意图非常不清晰时使用 unknown
|
|
||||||
##格式要求
|
|
||||||
1.返回以下格式的JSON:
|
|
||||||
{ "intent": "product_diagnosis" | "order_diagnosis" | "knowledge_qa" | "unknown", "confidence": 0.0-1.0,"reasoning": "判断理由"}
|
|
||||||
2.严格返回字符串格式,禁用markdown格式返回
|
|
||||||
3.只返回json字符串,不包含任何其他解释性文字
|
|
||||||
## 用户当前的问题是:
|
|
||||||
{user_input}
|
|
||||||
`
|
|
||||||
|
|
||||||
prompt = strings.ReplaceAll(prompt, "{user_input}", userInput)
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractIntent 从AI响应中提取意图
|
|
||||||
func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string {
|
|
||||||
if response == nil || response.Message == "" {
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
|
||||||
// 尝试解析JSON
|
Role: item.Role,
|
||||||
var intent struct {
|
Content: item.Content,
|
||||||
Intent string `json:"intent"`
|
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
|
||||||
Confidence string `json:"confidence"`
|
})
|
||||||
Reasoning string `json:"reasoning"`
|
|
||||||
}
|
}
|
||||||
err := json.Unmarshal([]byte(response.Message), &intent)
|
chatHis.Context = entitys.HisContext{
|
||||||
if err != nil {
|
UserLanguage: "zh-CN",
|
||||||
log.Printf("Failed to parse intent JSON: %v", err)
|
SystemMode: "technical_support",
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
return intent.Intent
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleOrderDiagnosis 处理订单诊断意图
|
|
||||||
func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
|
||||||
// 调用订单详情工具
|
|
||||||
//orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail")
|
|
||||||
//if orderDetailTool == nil || !ok {
|
|
||||||
// return nil, fmt.Errorf("order detail tool not found")
|
|
||||||
//}
|
|
||||||
//orderDetailTool.Execute(ctx, json.RawMessage{})
|
|
||||||
//
|
|
||||||
//// 获取相关工具定义
|
|
||||||
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
|
|
||||||
//
|
|
||||||
//// 调用AI,获取是否需要使用工具
|
|
||||||
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
|
|
||||||
//if err != nil {
|
|
||||||
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 如果没有工具调用,直接返回
|
|
||||||
//if len(response.ToolCalls) == 0 {
|
|
||||||
// return response, nil
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 执行工具调用
|
|
||||||
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
|
|
||||||
//if err != nil {
|
|
||||||
// return nil, fmt.Errorf("failed to execute tools: %w", err)
|
|
||||||
//}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleKnowledgeQA 处理知识问答意图
|
// handleKnowledgeQA 处理知识问答意图
|
||||||
func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
package biz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/data/impl"
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg"
|
||||||
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
|
"ai_scheduler/internal/tools"
|
||||||
|
"ai_scheduler/utils"
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_task(t *testing.T) {
|
||||||
|
var c entitys.TaskConfig
|
||||||
|
config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
|
||||||
|
err := json.Unmarshal([]byte(config), &c)
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type configData struct {
|
||||||
|
Param map[string]interface{} `json:"param"`
|
||||||
|
Do map[string]interface{} `json:"do"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Order(t *testing.T) {
|
||||||
|
routerBiz := in()
|
||||||
|
ch := make(chan entitys.ResponseData, 5)
|
||||||
|
defer close(ch)
|
||||||
|
err := routerBiz.handleTask(ch, nil, &entitys.Match{Index: "order_diagnosis", Parameters: `{"order_number":"822895927188791297"}`}, &model.AiTask{Config: `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`})
|
||||||
|
select {
|
||||||
|
case v := <-ch: // 尝试接收
|
||||||
|
fmt.Println("接收到值:", v)
|
||||||
|
default:
|
||||||
|
fmt.Println("无数据可接收")
|
||||||
|
}
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_OrderLog(t *testing.T) {
|
||||||
|
routerBiz := in()
|
||||||
|
ch := make(chan entitys.ResponseData, 5)
|
||||||
|
defer close(ch)
|
||||||
|
err := routerBiz.handleTask(ch, nil, &entitys.Match{Index: "order_diagnosis", Parameters: `{"order_number":"822979421673758721","serial_number":"822979421979938817"}`}, &model.AiTask{Config: `{"tool": "zltxOrderDirectLog", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`})
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_ProductLog(t *testing.T) {
|
||||||
|
routerBiz := in()
|
||||||
|
ch := make(chan entitys.ResponseData, 5)
|
||||||
|
defer close(ch)
|
||||||
|
err := routerBiz.handleTask(ch, nil, &entitys.Match{Index: "order_diagnosis", Parameters: `{"id":"101","serial_number":"822979421979938817"}`}, &model.AiTask{Config: `{"tool": "zltxProduct", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`})
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func in() *AiRouterBiz {
|
||||||
|
|
||||||
|
modDir, err := getModuleDir()
|
||||||
|
if err != nil {
|
||||||
|
panic("1")
|
||||||
|
}
|
||||||
|
configPath := flag.String("config", fmt.Sprintf("%s/config/config.yaml", modDir), "Path to configuration file")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
configConfig, err := config.LoadConfig(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
panic("加载配置失败")
|
||||||
|
}
|
||||||
|
client, _, err := utils_ollama.NewClient(configConfig)
|
||||||
|
allLogger := log.DefaultLogger()
|
||||||
|
utilOllama := utils_ollama.NewUtilOllama(configConfig, allLogger)
|
||||||
|
manager := tools.NewManager(configConfig, client)
|
||||||
|
|
||||||
|
db, _ := utils.NewGormDb(configConfig)
|
||||||
|
sessionImpl := impl.NewSessionImpl(db)
|
||||||
|
sysImpl := impl.NewSysImpl(db)
|
||||||
|
taskImpl := impl.NewTaskImpl(db)
|
||||||
|
chatImpl := impl.NewChatImpl(db)
|
||||||
|
safeChannelPool, _ := pkg.NewSafeChannelPool(configConfig)
|
||||||
|
routerBiz := NewAiRouterBiz(manager, sessionImpl, sysImpl, taskImpl, chatImpl, configConfig, utilOllama, safeChannelPool)
|
||||||
|
|
||||||
|
return routerBiz
|
||||||
|
}
|
||||||
|
|
||||||
|
func getModuleDir() (string, error) {
|
||||||
|
dir, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
modPath := filepath.Join(dir, "go.mod")
|
||||||
|
if _, err := os.Stat(modPath); err == nil {
|
||||||
|
return dir, nil // 找到 go.mod
|
||||||
|
}
|
||||||
|
|
||||||
|
// 向上查找父目录
|
||||||
|
parent := filepath.Dir(dir)
|
||||||
|
if parent == dir {
|
||||||
|
break // 到达根目录,未找到
|
||||||
|
}
|
||||||
|
dir = parent
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("go.mod not found in current directory or parents")
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package biz
|
package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
"ai_scheduler/internal/data/impl"
|
"ai_scheduler/internal/data/impl"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
|
@ -15,6 +16,7 @@ import (
|
||||||
type SessionBiz struct {
|
type SessionBiz struct {
|
||||||
sessionRepo *impl.SessionImpl
|
sessionRepo *impl.SessionImpl
|
||||||
sysRepo *impl.SysImpl
|
sysRepo *impl.SysImpl
|
||||||
|
chatRepo *impl.ChatImpl
|
||||||
|
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
}
|
}
|
||||||
|
|
@ -28,14 +30,19 @@ func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个
|
// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个
|
||||||
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (sessionId string, err error) {
|
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (result *entitys.SessionInitResponse, err error) {
|
||||||
|
|
||||||
// 获取系统配置
|
// 获取系统配置
|
||||||
sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId))
|
sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return
|
||||||
} else if !has {
|
} else if !has {
|
||||||
return "", fmt.Errorf("sys not found")
|
err = fmt.Errorf("sys not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result = &entitys.SessionInitResponse{
|
||||||
|
Chat: make([]entitys.ChatHistory, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 当天的session
|
// 获取 当天的session
|
||||||
|
|
@ -46,22 +53,56 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
|
||||||
s.sysRepo.WithSysId(sysConfig.SysID), // 条件:系统ID
|
s.sysRepo.WithSysId(sysConfig.SysID), // 条件:系统ID
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return
|
||||||
} else if !has {
|
} else if !has {
|
||||||
// 不存在,创建一个
|
// 不存在,创建一个
|
||||||
session = model.AiSession{
|
session = model.AiSession{
|
||||||
SysID: sysConfig.SysID,
|
SysID: sysConfig.SysID,
|
||||||
SessionID: utils.UUID(),
|
SessionID: utils.UUID(),
|
||||||
CreateAt: time.Now(),
|
UserID: req.UserId,
|
||||||
UpdateAt: time.Now(),
|
|
||||||
}
|
}
|
||||||
err = s.sessionRepo.Create(&session)
|
err = s.sessionRepo.Create(&session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := entitys.ChatHistory{
|
||||||
|
SessionID: session.SessionID,
|
||||||
|
Role: constants.RoleSystem,
|
||||||
|
Content: sysConfig.Prologue,
|
||||||
|
}
|
||||||
|
result.Chat = append(result.Chat, chat)
|
||||||
|
|
||||||
|
// 开场白写入会话历史
|
||||||
|
s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
|
||||||
|
SessionID: chat.SessionID,
|
||||||
|
Role: chat.Role.String(),
|
||||||
|
Content: chat.Content,
|
||||||
|
})
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// 存在,返回会话历史
|
||||||
|
var chatList []model.AiChatHi
|
||||||
|
chatList, err = s.chatRepo.FindAll(
|
||||||
|
s.chatRepo.WithSessionId(session.SessionID), // 条件:会话ID
|
||||||
|
s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||||||
|
s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为 entitys.ChatHistory 类型
|
||||||
|
for _, chat := range chatList {
|
||||||
|
result.Chat = append(result.Chat, entitys.ChatHistory{
|
||||||
|
SessionID: chat.SessionID,
|
||||||
|
Role: constants.Caller(chat.Role),
|
||||||
|
Content: chat.Content,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return session.SessionID, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionList 会话列表
|
// SessionList 会话列表
|
||||||
|
|
|
||||||
|
|
@ -14,13 +14,20 @@ type Config struct {
|
||||||
Sys SysConfig `mapstructure:"sys"`
|
Sys SysConfig `mapstructure:"sys"`
|
||||||
Tools ToolsConfig `mapstructure:"tools"`
|
Tools ToolsConfig `mapstructure:"tools"`
|
||||||
Logging LoggingConfig `mapstructure:"logging"`
|
Logging LoggingConfig `mapstructure:"logging"`
|
||||||
Redis *Redis `protobuf:"bytes,1,opt,name=Redis,proto3" json:"Redis,omitempty"`
|
Redis Redis `mapstructure:"redis"`
|
||||||
DB *DB `protobuf:"bytes,3,opt,name=TransDB,proto3" json:"TransDB,omitempty"`
|
DB DB `mapstructure:"db"`
|
||||||
|
// LLM *LLM `mapstructure:"llm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LLM struct {
|
||||||
|
Model string `mapstructure:"model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SysConfig 系统配置
|
// SysConfig 系统配置
|
||||||
type SysConfig struct {
|
type SysConfig struct {
|
||||||
SessionLen int `mapstructure:"session_len"`
|
SessionLen int `mapstructure:"session_len"`
|
||||||
|
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
||||||
|
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig 服务器配置
|
// ServerConfig 服务器配置
|
||||||
|
|
@ -37,24 +44,24 @@ type OllamaConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Redis struct {
|
type Redis struct {
|
||||||
Host string `protobuf:"bytes,1,opt,name=host,proto3" json:"host,omitempty"`
|
Host string `mapstructure:"host"`
|
||||||
Type string `protobuf:"bytes,2,opt,name=type,proto3" json:"type,omitempty"`
|
Type string `mapstructure:"type"`
|
||||||
Pass string `protobuf:"bytes,3,opt,name=pass,proto3" json:"pass,omitempty"`
|
Pass string `mapstructure:"pass"`
|
||||||
Key string `protobuf:"bytes,4,opt,name=key,proto3" json:"key,omitempty"`
|
Key string `mapstructure:"key"`
|
||||||
Tls int32 `protobuf:"varint,5,opt,name=tls,proto3" json:"tls,omitempty"`
|
Tls int32 `mapstructure:"tls"`
|
||||||
Db int32 `protobuf:"varint,6,opt,name=db,proto3" json:"db,omitempty"`
|
Db int32 `mapstructure:"db"`
|
||||||
MaxIdle int32 `protobuf:"varint,7,opt,name=maxIdle,proto3" json:"maxIdle,omitempty"`
|
MaxIdle int32 `mapstructure:"maxIdle"`
|
||||||
PoolSize int32 `protobuf:"varint,8,opt,name=poolSize,proto3" json:"poolSize,omitempty"`
|
PoolSize int32 `mapstructure:"poolSize"`
|
||||||
MaxIdleTime int32 `protobuf:"varint,9,opt,name=maxIdleTime,proto3" json:"maxIdleTime,omitempty"`
|
MaxIdleTime int32 `mapstructure:"maxIdleTime"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
Driver string `protobuf:"bytes,1,opt,name=driver,proto3" json:"driver,omitempty"`
|
Driver string `mapstructure:"driver"`
|
||||||
Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"`
|
Source string `mapstructure:"source"`
|
||||||
MaxIdle int32 `protobuf:"varint,3,opt,name=maxIdle,proto3" json:"maxIdle,omitempty"`
|
MaxIdle int32 `mapstructure:"maxIdle"`
|
||||||
MaxOpen int32 `protobuf:"varint,4,opt,name=maxOpen,proto3" json:"maxOpen,omitempty"`
|
MaxOpen int32 `mapstructure:"maxOpen"`
|
||||||
MaxLifetime int32 `protobuf:"varint,5,opt,name=maxLifetime,proto3" json:"maxLifetime,omitempty"`
|
MaxLifetime int32 `mapstructure:"maxLifetime"`
|
||||||
IsDebug bool `protobuf:"varint,6,opt,name=isDebug,proto3" json:"isDebug,omitempty"`
|
IsDebug bool `mapstructure:"isDebug"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolsConfig 工具配置
|
// ToolsConfig 工具配置
|
||||||
|
|
@ -62,8 +69,12 @@ type ToolsConfig struct {
|
||||||
Weather ToolConfig `mapstructure:"weather"`
|
Weather ToolConfig `mapstructure:"weather"`
|
||||||
Calculator ToolConfig `mapstructure:"calculator"`
|
Calculator ToolConfig `mapstructure:"calculator"`
|
||||||
ZltxOrderDetail ToolConfig `mapstructure:"zltxOrderDetail"`
|
ZltxOrderDetail ToolConfig `mapstructure:"zltxOrderDetail"`
|
||||||
ZltxOrderLog ToolConfig `mapstructure:"zltxOrderLog"`
|
ZltxOrderDirectLog ToolConfig `mapstructure:"zltxOrderDirectLog"`
|
||||||
Knowledge ToolConfig `mapstructure:"knowledge"`
|
Knowledge ToolConfig `mapstructure:"knowledge"`
|
||||||
|
//通过ID获取我们的商品信息
|
||||||
|
ZltxProduct ToolConfig `mapstructure:"zltxProduct"`
|
||||||
|
//通过账号获取订单统计信息
|
||||||
|
ZltxOrderStatistics ToolConfig `mapstructure:"zltxOrderStatistics"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolConfig 单个工具配置
|
// ToolConfig 单个工具配置
|
||||||
|
|
@ -71,7 +82,8 @@ type ToolConfig struct {
|
||||||
Enabled bool `mapstructure:"enabled"`
|
Enabled bool `mapstructure:"enabled"`
|
||||||
BaseURL string `mapstructure:"base_url"`
|
BaseURL string `mapstructure:"base_url"`
|
||||||
APIKey string `mapstructure:"api_key"`
|
APIKey string `mapstructure:"api_key"`
|
||||||
BizSystem string `mapstructure:"biz_system"`
|
//附加地址
|
||||||
|
AddURL string `mapstructure:"add_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoggingConfig 日志配置
|
// LoggingConfig 日志配置
|
||||||
|
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
||||||
package constants
|
|
||||||
|
|
||||||
type Caller string
|
|
||||||
|
|
||||||
const (
|
|
||||||
CallerZltx Caller = "zltx" // 直连天下
|
|
||||||
CallerHyt Caller = "hyt" // 货易通
|
|
||||||
)
|
|
||||||
|
|
||||||
func (c Caller) String() string {
|
|
||||||
return string(c)
|
|
||||||
}
|
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
package constant
|
|
||||||
|
|
||||||
type ConnStatus int8
|
|
||||||
|
|
||||||
const (
|
|
||||||
ConnStatusClosed ConnStatus = iota
|
|
||||||
ConnStatusNormal
|
|
||||||
ConnStatusIgnore
|
|
||||||
)
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
package constant
|
|
||||||
|
|
||||||
const ()
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
package constants
|
||||||
|
|
||||||
|
type Caller string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CallerZltx Caller = "zltx" // 直连天下
|
||||||
|
CallerHyt Caller = "hyt" // 货易通
|
||||||
|
|
||||||
|
// 角色, 系统角色,用户角色
|
||||||
|
RoleSystem Caller = "system" // 系统角色
|
||||||
|
RoleUser Caller = "user" // 用户角色
|
||||||
|
RoleAssistant Caller = "assistant" // 助手角色
|
||||||
|
|
||||||
|
// 分页默认条数
|
||||||
|
ChatHistoryLimit = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c Caller) String() string {
|
||||||
|
return string(c)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
package constants
|
||||||
|
|
||||||
|
type ConnStatus int8
|
||||||
|
|
||||||
|
const (
|
||||||
|
ConnStatusClosed ConnStatus = iota
|
||||||
|
ConnStatusNormal
|
||||||
|
ConnStatusIgnore
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskType int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskTypeApi = 1
|
||||||
|
TaskTypeKnowle = 2
|
||||||
|
TaskTypeFunc = 3
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
package constants
|
||||||
|
|
||||||
|
const ()
|
||||||
|
|
@ -20,6 +20,7 @@ BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。
|
||||||
|
|
||||||
// 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题
|
// 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题
|
||||||
type PO interface {
|
type PO interface {
|
||||||
|
model.AiChatHi |
|
||||||
model.AiSy | model.AiSession | model.AiTask
|
model.AiSy | model.AiSession | model.AiTask
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -51,6 +52,7 @@ type BaseRepository[P PO] interface {
|
||||||
WithId(id interface{}) CondFunc // 查询id
|
WithId(id interface{}) CondFunc // 查询id
|
||||||
WithStatus(status int) CondFunc // 查询status
|
WithStatus(status int) CondFunc // 查询status
|
||||||
GetDb() *gorm.DB // 获取数据库连接
|
GetDb() *gorm.DB // 获取数据库连接
|
||||||
|
WithLimit(limit int) CondFunc // 限制返回条数
|
||||||
}
|
}
|
||||||
|
|
||||||
// PaginationResult 分页查询结果
|
// PaginationResult 分页查询结果
|
||||||
|
|
@ -206,3 +208,9 @@ func (this *BaseModel[P]) WithStatus(status int) CondFunc {
|
||||||
func (this *BaseModel[P]) GetDb() *gorm.DB {
|
func (this *BaseModel[P]) GetDb() *gorm.DB {
|
||||||
return this.Db
|
return this.Db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (this *BaseModel[P]) WithLimit(limit int) CondFunc {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Limit(limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
package impl
|
||||||
|
|
||||||
|
//import (
|
||||||
|
// "ai_scheduler/internal/data/model"
|
||||||
|
// "ai_scheduler/tmpl/dataTemp"
|
||||||
|
// "ai_scheduler/utils"
|
||||||
|
//)
|
||||||
|
|
||||||
|
//type ChatHisImpl struct {
|
||||||
|
// dataTemp.DataTemp
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func NewChatHisImpl(db *utils.Db) *ChatHisImpl {
|
||||||
|
// return &ChatHisImpl{*dataTemp.NewDataTemp(db, new(model.AiChatHi))}
|
||||||
|
//}
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
package impl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/model"
|
||||||
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
|
"ai_scheduler/utils"
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatImpl struct {
|
||||||
|
dataTemp.DataTemp
|
||||||
|
BaseRepository[model.AiChatHi]
|
||||||
|
chatChannel chan model.AiChatHi
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatImpl(db *utils.Db) *ChatImpl {
|
||||||
|
return &ChatImpl{
|
||||||
|
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)),
|
||||||
|
BaseRepository: NewBaseModel[model.AiChatHi](db.Client),
|
||||||
|
chatChannel: make(chan model.AiChatHi, 100),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSessionId 条件:会话ID
|
||||||
|
func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Where("session_id = ?", sessionId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步添加会话历史
|
||||||
|
func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
|
||||||
|
impl.chatChannel <- chat
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步处理会话历史
|
||||||
|
func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case chat := <-impl.chatChannel:
|
||||||
|
log.Infof("ChatHistoryAsyncProcess chat: %v", chat)
|
||||||
|
if err := impl.Create(&chat); err != nil {
|
||||||
|
log.Errorf("ChatHistoryAsyncProcess err: %v", err)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("ChatHistoryAsyncProcess ctx done")
|
||||||
|
return
|
||||||
|
// 定时打印通道大小
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
//log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -4,4 +4,4 @@ import (
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ProviderImpl = wire.NewSet(NewSessionImpl)
|
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl)
|
||||||
|
|
|
||||||
|
|
@ -10,13 +10,13 @@ import (
|
||||||
|
|
||||||
type SessionImpl struct {
|
type SessionImpl struct {
|
||||||
dataTemp.DataTemp
|
dataTemp.DataTemp
|
||||||
BaseModel[model.AiSession]
|
BaseRepository[model.AiSession]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSessionImpl(db *utils.Db) *SessionImpl {
|
func NewSessionImpl(db *utils.Db) *SessionImpl {
|
||||||
return &SessionImpl{
|
return &SessionImpl{
|
||||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSession)),
|
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSession)),
|
||||||
BaseModel: BaseModel[model.AiSession]{},
|
BaseRepository: NewBaseModel[model.AiSession](db.Client),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -40,3 +40,9 @@ func (s *SessionImpl) WithSysId(sysId interface{}) CondFunc {
|
||||||
return db.Where("sys_id = ?", sysId)
|
return db.Where("sys_id = ?", sysId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (impl *SessionImpl) WithSessionId(sessionId interface{}) CondFunc {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Where("session_id = ?", sessionId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||||
|
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||||
|
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||||
|
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const TableNameAiChatHi = "ai_chat_his"
|
||||||
|
|
||||||
|
// AiChatHi mapped from table <ai_chat_his>
|
||||||
|
type AiChatHi struct {
|
||||||
|
HisID int64 `gorm:"column:his_id;primaryKey" json:"his_id"`
|
||||||
|
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
|
||||||
|
Role string `gorm:"column:role;not null;comment:system系统输出,assistant助手输出,user用户输入" json:"role"` // system系统输出,assistant助手输出,user用户输入
|
||||||
|
Content string `gorm:"column:content;not null" json:"content"`
|
||||||
|
CreateAt *time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName AiChatHi's table name
|
||||||
|
func (*AiChatHi) TableName() string {
|
||||||
|
return TableNameAiChatHi
|
||||||
|
}
|
||||||
|
|
@ -12,14 +12,15 @@ const TableNameAiSession = "ai_session"
|
||||||
|
|
||||||
// AiSession mapped from table <ai_session>
|
// AiSession mapped from table <ai_session>
|
||||||
type AiSession struct {
|
type AiSession struct {
|
||||||
SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"`
|
|
||||||
SessionID string `gorm:"column:session_id;primaryKey" json:"session_id"`
|
SessionID string `gorm:"column:session_id;primaryKey" json:"session_id"`
|
||||||
|
SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"`
|
||||||
KnowlegeSessionID string `gorm:"column:knowlege_session_id;not null" json:"knowlege_session_id"`
|
KnowlegeSessionID string `gorm:"column:knowlege_session_id;not null" json:"knowlege_session_id"`
|
||||||
Title string `gorm:"column:title;not null" json:"title"`
|
Title string `gorm:"column:title;not null" json:"title"`
|
||||||
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
CreateAt *time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||||
UpdateAt time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"`
|
UpdateAt *time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"`
|
||||||
Status int32 `gorm:"column:status;not null" json:"status"`
|
Status int32 `gorm:"column:status;not null" json:"status"`
|
||||||
DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"`
|
DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"`
|
||||||
|
UserID string `gorm:"column:user_id;comment:用户id" json:"user_id"` // 用户id
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName AiSession's table name
|
// TableName AiSession's table name
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ type AiSy struct {
|
||||||
UpdateAt time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"`
|
UpdateAt time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"`
|
||||||
Status int32 `gorm:"column:status;not null" json:"status"`
|
Status int32 `gorm:"column:status;not null" json:"status"`
|
||||||
DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"`
|
DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"`
|
||||||
|
Prologue string `gorm:"column:prologue;comment:会话开场白" json:"prologue"` // 会话开场白
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName AiSy's table name
|
// TableName AiSy's table name
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,8 @@ type AiTask struct {
|
||||||
Name string `gorm:"column:name;not null" json:"name"`
|
Name string `gorm:"column:name;not null" json:"name"`
|
||||||
Index string `gorm:"column:index;not null" json:"index"`
|
Index string `gorm:"column:index;not null" json:"index"`
|
||||||
Desc string `gorm:"column:desc;not null" json:"desc"`
|
Desc string `gorm:"column:desc;not null" json:"desc"`
|
||||||
|
Type int32 `gorm:"column:type;not null;comment:类型,1:api,2:知识库" json:"type"` // 类型,1:api,2:知识库
|
||||||
|
Config string `gorm:"column:config" json:"config"`
|
||||||
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||||
UpdateAt time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"`
|
UpdateAt time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"`
|
||||||
Status int32 `gorm:"column:status;not null;default:1" json:"status"`
|
Status int32 `gorm:"column:status;not null;default:1" json:"status"`
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
package entitys
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatHistory struct {
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
Role constants.Caller `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
package entitys
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Response string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ResponseJson Response = "json"
|
||||||
|
ResponseLoading Response = "loading"
|
||||||
|
ResponseEnd Response = "end"
|
||||||
|
ResponseStream Response = "stream"
|
||||||
|
ResponseText Response = "txt"
|
||||||
|
ResponseImg Response = "img"
|
||||||
|
ResponseFile Response = "file"
|
||||||
|
ResponseErr Response = "error"
|
||||||
|
ResponseLog Response = "log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResponseData struct {
|
||||||
|
Done bool
|
||||||
|
Content string
|
||||||
|
Type Response
|
||||||
|
}
|
||||||
|
|
||||||
|
func MsgSet(msgType Response, msg string, done bool) []byte {
|
||||||
|
jsonByte, err := json.Marshal(ResponseData{
|
||||||
|
Done: done,
|
||||||
|
Content: msg,
|
||||||
|
|
||||||
|
Type: msgType,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return jsonByte
|
||||||
|
}
|
||||||
|
|
||||||
|
func MsgSend(c *websocket.Conn, msg ResponseData) error {
|
||||||
|
jsonByte, _ := json.Marshal(msg)
|
||||||
|
|
||||||
|
return c.WriteMessage(websocket.TextMessage, jsonByte)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MsgSendByte(c *websocket.Conn, msg []byte) {
|
||||||
|
|
||||||
|
_ = c.WriteMessage(websocket.TextMessage, msg)
|
||||||
|
}
|
||||||
|
|
@ -5,6 +5,10 @@ type SessionInitRequest struct {
|
||||||
UserId string `json:"user_id"`
|
UserId string `json:"user_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SessionInitResponse struct {
|
||||||
|
Chat []ChatHistory `json:"chat"`
|
||||||
|
}
|
||||||
|
|
||||||
type SessionListRequest struct {
|
type SessionListRequest struct {
|
||||||
SysId string `json:"sys_id"`
|
SysId string `json:"sys_id"`
|
||||||
UserId string `json:"user_id"`
|
UserId string `json:"user_id"`
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChatRequest 聊天请求
|
// ChatRequest 聊天请求
|
||||||
|
|
@ -65,12 +67,18 @@ type Tool interface {
|
||||||
Name() string
|
Name() string
|
||||||
Description() string
|
Description() string
|
||||||
Definition() ToolDefinition
|
Definition() ToolDefinition
|
||||||
Execute(ctx context.Context, args json.RawMessage) (interface{}, error)
|
Execute(channel chan ResponseData, c *websocket.Conn, args json.RawMessage) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// AIClient AI客户端接口
|
type ConfigDataHttp struct {
|
||||||
type AIClient interface {
|
Param map[string]interface{} `json:"param"`
|
||||||
Chat(ctx context.Context, messages []Message, tools []ToolDefinition) (*ChatResponse, error)
|
Request map[string]interface{} `json:"request"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigDataTool struct {
|
||||||
|
Param map[string]interface{} `json:"param"`
|
||||||
|
Request map[string]interface{} `json:"request"`
|
||||||
|
Tool string `json:"tool"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message 消息
|
// Message 消息
|
||||||
|
|
@ -79,6 +87,46 @@ type Message struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FuncApi struct {
|
||||||
|
Param interface{} `json:"param"`
|
||||||
|
Request l_request.Request `json:"request"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskConfig struct {
|
||||||
|
Param interface{} `json:"param"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskConfigDetail struct {
|
||||||
|
Param ConfigParam `json:"param"`
|
||||||
|
}
|
||||||
|
type ConfigParam struct {
|
||||||
|
Properties map[string]api.ToolProperty
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
type Match struct {
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
Index string `json:"index"`
|
||||||
|
IsMatch bool `json:"is_match"`
|
||||||
|
Parameters string `json:"parameters"`
|
||||||
|
Reasoning string `json:"reasoning"`
|
||||||
|
}
|
||||||
|
type ChatHis struct {
|
||||||
|
SessionId string `json:"session_id"`
|
||||||
|
Messages []HisMessage `json:"messages"`
|
||||||
|
Context HisContext `json:"context"`
|
||||||
|
}
|
||||||
|
type HisMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Timestamp string `json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HisContext struct {
|
||||||
|
UserLanguage string `json:"user_language"`
|
||||||
|
SystemMode string `json:"system_mode"`
|
||||||
|
}
|
||||||
|
|
||||||
// RouterService 路由服务接口
|
// RouterService 路由服务接口
|
||||||
type RouterService interface {
|
type RouterService interface {
|
||||||
Route(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
|
Route(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,104 @@
|
||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
ID string
|
||||||
|
SendFunc func(data []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Gateway struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
clients map[string]*Client // clientID -> Client
|
||||||
|
uidMap map[string][]string // uid -> []clientID
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGateway() *Gateway {
|
||||||
|
return &Gateway{
|
||||||
|
clients: make(map[string]*Client),
|
||||||
|
uidMap: make(map[string][]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) AddClient(c *Client) {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
g.clients[c.ID] = c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) RemoveClient(clientID string) {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
delete(g.clients, clientID)
|
||||||
|
for uid, list := range g.uidMap {
|
||||||
|
newList := []string{}
|
||||||
|
for _, cid := range list {
|
||||||
|
if cid != clientID {
|
||||||
|
newList = append(newList, cid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.uidMap[uid] = newList
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) SendToAll(msg []byte) {
|
||||||
|
g.mu.RLock()
|
||||||
|
defer g.mu.RUnlock()
|
||||||
|
for _, c := range g.clients {
|
||||||
|
_ = c.SendFunc(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) SendToClient(clientID string, msg []byte) error {
|
||||||
|
g.mu.RLock()
|
||||||
|
defer g.mu.RUnlock()
|
||||||
|
if c, ok := g.clients[clientID]; ok {
|
||||||
|
return c.SendFunc(msg)
|
||||||
|
}
|
||||||
|
return errors.New("client not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) BindUid(clientID, uid string) error {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
if _, ok := g.clients[clientID]; !ok {
|
||||||
|
return errors.New("client not found")
|
||||||
|
}
|
||||||
|
g.uidMap[uid] = append(g.uidMap[uid], clientID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) SendToUid(uid string, msg []byte) {
|
||||||
|
g.mu.RLock()
|
||||||
|
defer g.mu.RUnlock()
|
||||||
|
if list, ok := g.uidMap[uid]; ok {
|
||||||
|
for _, cid := range list {
|
||||||
|
if c, ok := g.clients[cid]; ok {
|
||||||
|
_ = c.SendFunc(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) ListClients() []string {
|
||||||
|
g.mu.RLock()
|
||||||
|
defer g.mu.RUnlock()
|
||||||
|
ids := make([]string, 0, len(g.clients))
|
||||||
|
for id := range g.clients {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) ListUids() map[string][]string {
|
||||||
|
g.mu.RLock()
|
||||||
|
defer g.mu.RUnlock()
|
||||||
|
result := make(map[string][]string, len(g.uidMap))
|
||||||
|
for uid, list := range g.uidMap {
|
||||||
|
result[uid] = append([]string(nil), list...)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
package pkg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SafeChannelPool struct {
|
||||||
|
pool chan chan entitys.ResponseData // 存储空闲 channel 的队列
|
||||||
|
bufSize int // channel 缓冲大小
|
||||||
|
mu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSafeChannelPool(c *config.Config) (*SafeChannelPool, func()) {
|
||||||
|
pool := &SafeChannelPool{
|
||||||
|
pool: make(chan chan entitys.ResponseData, c.Sys.ChannelPoolLen),
|
||||||
|
bufSize: c.Sys.ChannelPoolSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup := pool.Close
|
||||||
|
return pool, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从池中获取 channel(若无空闲则创建新 channel)
|
||||||
|
func (p *SafeChannelPool) Get() (chan entitys.ResponseData, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
if p.closed {
|
||||||
|
return nil, errors.New("pool is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case ch := <-p.pool: // 从池中取
|
||||||
|
return ch, nil
|
||||||
|
default: // 池为空,创建新 channel
|
||||||
|
return make(chan entitys.ResponseData, p.bufSize), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将 channel 放回池中(必须确保 channel 已清空!)
|
||||||
|
func (p *SafeChannelPool) Put(ch chan entitys.ResponseData) error {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
if p.closed {
|
||||||
|
return errors.New("pool is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清空 channel(防止复用时读取旧数据)
|
||||||
|
go func() {
|
||||||
|
for range ch {
|
||||||
|
// 丢弃所有数据(或根据业务需求处理)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p.pool <- ch: // 尝试放回池中
|
||||||
|
default: // 池已满,直接关闭 channel(避免泄漏)
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭池(释放所有资源)
|
||||||
|
func (p *SafeChannelPool) Close() {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
p.closed = true
|
||||||
|
close(p.pool) // 关闭池队列
|
||||||
|
// 需额外逻辑关闭所有内部 channel(此处简化)
|
||||||
|
}
|
||||||
|
|
@ -1 +1,8 @@
|
||||||
package pkg
|
package pkg
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
func JsonStringIgonErr(data interface{}) string {
|
||||||
|
dataByte, _ := json.Marshal(data)
|
||||||
|
return string(dataByte)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ go get -u gitea.cdlsxd.cn/rzy_tools/request/tags
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 正常使用
|
||||||
|
```go
|
||||||
|
req := request.Request{
|
||||||
|
Method: "POST",
|
||||||
|
Url: reqUrl,
|
||||||
|
Json: RequestBody,
|
||||||
|
Headers: header,
|
||||||
|
}
|
||||||
|
resp, _ := req.Send()
|
||||||
|
```
|
||||||
|
|
||||||
|
## 同时大量请求或者在协程中使用建议使用
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := RequestPools.Get()
|
||||||
|
defer RequestPools.ClearAndPut(r)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
package l_request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestPool struct {
|
||||||
|
pool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
var RequestPools = &RequestPool{
|
||||||
|
pool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return new(Request)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (re *RequestPool) Get() *Request {
|
||||||
|
return re.pool.Get().(*Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (re *RequestPool) Put(r *Request) {
|
||||||
|
re.pool.Put(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重置对象
|
||||||
|
func (re *RequestPool) Reset(r *Request) {
|
||||||
|
r.Method = ""
|
||||||
|
r.Url = ""
|
||||||
|
r.Params = nil
|
||||||
|
r.Headers = nil
|
||||||
|
r.Cookies = nil
|
||||||
|
r.Data = nil
|
||||||
|
r.Json = nil
|
||||||
|
r.Files = nil
|
||||||
|
r.Raw = ""
|
||||||
|
r.JsonByte = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (re *RequestPool) ClearAndPut(r *Request) {
|
||||||
|
re.Reset(r)
|
||||||
|
re.Put(r)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
package l_request
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestPool(t *testing.T) {
|
||||||
|
r := RequestPools.Get()
|
||||||
|
r.Url = "http://www.baidu.com"
|
||||||
|
RequestPools.ClearAndPut(r)
|
||||||
|
a := RequestPools.Get()
|
||||||
|
t.Log(a.Url)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,169 @@
|
||||||
|
package l_request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 请求结构体
|
||||||
|
type Request struct {
|
||||||
|
Method string `json:"method"` // 请求方法
|
||||||
|
Url string `json:"url"` // 请求url
|
||||||
|
Params map[string]string `json:"params"` // Query参数
|
||||||
|
Headers map[string]string `json:"headers"` // 请求头
|
||||||
|
Cookies map[string]string `json:"cookies"` // todo 处理 Cookies
|
||||||
|
Data map[string]string `json:"data"` // 表单格式请求数据
|
||||||
|
Json map[string]interface{} `json:"json"` // JSON格式请求数据 todo 多层 嵌套
|
||||||
|
Files map[string]string `json:"files"` // todo 处理 Files
|
||||||
|
Raw string `json:"raw"` // 原始请求数据
|
||||||
|
JsonByte []byte `json:"json_raw"` // JSON格式请求数据 todo 多层 嵌套
|
||||||
|
Xml []byte `json:"xml"` // xml
|
||||||
|
}
|
||||||
|
|
||||||
|
// 响应结构体
|
||||||
|
type Response struct {
|
||||||
|
StatusCode int `json:"status_code"` // 状态码
|
||||||
|
Reason string `json:"reason"` // 状态码说明
|
||||||
|
Elapsed float64 `json:"elapsed"` // 请求耗时(秒)
|
||||||
|
Content []byte `json:"content"` // 响应二进制内容
|
||||||
|
Text string `json:"text"` // 响应文本
|
||||||
|
Headers map[string]string `json:"headers"` // 响应头
|
||||||
|
Cookies map[string]string `json:"cookies"` // todo 添加响应Cookies
|
||||||
|
Request *Request `json:"request"` // 原始请求
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理请求方法
|
||||||
|
func (r *Request) getMethod() string {
|
||||||
|
return strings.ToUpper(r.Method) // 必须转为全部大写
|
||||||
|
}
|
||||||
|
|
||||||
|
// 组装URL
|
||||||
|
func (r *Request) getUrl() string {
|
||||||
|
if r.Params != nil {
|
||||||
|
urlValues := url.Values{}
|
||||||
|
Url, _ := url.Parse(r.Url) // todo 处理err
|
||||||
|
for key, value := range r.Params {
|
||||||
|
urlValues.Set(key, value)
|
||||||
|
}
|
||||||
|
Url.RawQuery = urlValues.Encode()
|
||||||
|
return Url.String()
|
||||||
|
}
|
||||||
|
return r.Url
|
||||||
|
}
|
||||||
|
|
||||||
|
// 组装请求数据
|
||||||
|
func (r *Request) getData() io.Reader {
|
||||||
|
var reqBody string
|
||||||
|
if r.Headers == nil {
|
||||||
|
r.Headers = make(map[string]string, 1)
|
||||||
|
}
|
||||||
|
if r.Raw != "" {
|
||||||
|
reqBody = r.Raw
|
||||||
|
} else if r.Data != nil {
|
||||||
|
urlValues := url.Values{}
|
||||||
|
for key, value := range r.Data {
|
||||||
|
urlValues.Add(key, value)
|
||||||
|
}
|
||||||
|
reqBody = urlValues.Encode()
|
||||||
|
if _, ex := r.Headers["Content-Type"]; !ex {
|
||||||
|
r.Headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if r.Json != nil {
|
||||||
|
bytesData, _ := json.Marshal(r.Json)
|
||||||
|
reqBody = string(bytesData)
|
||||||
|
if _, ex := r.Headers["Content-Type"]; !ex {
|
||||||
|
r.Headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
|
} else if r.JsonByte != nil {
|
||||||
|
reqBody = string(r.JsonByte)
|
||||||
|
if _, ex := r.Headers["Content-Type"]; !ex {
|
||||||
|
r.Headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
|
} else if r.Xml != nil {
|
||||||
|
reqBody = string(r.Xml)
|
||||||
|
if _, ex := r.Headers["Content-Type"]; !ex {
|
||||||
|
r.Headers["Content-Type"] = "application/xml"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.NewReader(reqBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加请求头-需要在getData后使用
|
||||||
|
func (r *Request) addHeaders(req *http.Request) {
|
||||||
|
if r.Headers != nil {
|
||||||
|
for key, value := range r.Headers {
|
||||||
|
req.Header.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备请求
|
||||||
|
func (r *Request) prepare() *http.Request {
|
||||||
|
Method := r.getMethod()
|
||||||
|
Url := r.getUrl()
|
||||||
|
Data := r.getData()
|
||||||
|
req, _ := http.NewRequest(Method, Url, Data)
|
||||||
|
r.addHeaders(req)
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// 组装响应对象
|
||||||
|
func (r *Request) packResponse(res *http.Response, elapsed float64) Response {
|
||||||
|
var resp Response
|
||||||
|
resBody, _ := io.ReadAll(res.Body)
|
||||||
|
resp.Content = resBody
|
||||||
|
resp.Text = string(resBody)
|
||||||
|
resp.StatusCode = res.StatusCode
|
||||||
|
resp.Reason = strings.Split(res.Status, " ")[1]
|
||||||
|
resp.Elapsed = elapsed
|
||||||
|
resp.Headers = map[string]string{}
|
||||||
|
for key, value := range res.Header {
|
||||||
|
resp.Headers[key] = strings.Join(value, ";")
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
func (r *Request) Send() (Response, error) {
|
||||||
|
req := r.prepare()
|
||||||
|
client := &http.Client{}
|
||||||
|
start := time.Now()
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
elapsed := time.Since(start).Seconds()
|
||||||
|
return r.packResponse(res, elapsed), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过证书发送请求
|
||||||
|
func (r *Request) SendWithoutSsl() (Response, error) {
|
||||||
|
tr := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
req := r.prepare()
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
start := time.Now()
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
elapsed := time.Since(start).Seconds()
|
||||||
|
return r.packResponse(res, elapsed), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
func (r *Request) SendNoParseResponse() (*http.Response, error) {
|
||||||
|
req := r.prepare()
|
||||||
|
client := &http.Client{}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
package pkg
|
package pkg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/pkg/ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
|
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ProviderSetClient = wire.NewSet(
|
var ProviderSetClient = wire.NewSet(
|
||||||
NewRdb,
|
NewRdb,
|
||||||
NewGormDb,
|
NewGormDb,
|
||||||
ollama.NewClient,
|
utils_ollama.NewUtilOllama,
|
||||||
|
utils_ollama.NewClient,
|
||||||
|
NewSafeChannelPool,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ func NewRdb(c *config.Config) *Rdb {
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildRdb 构建redis client
|
// buildRdb 构建redis client
|
||||||
func buildRdb(c *config.Redis) *redis.Client {
|
func buildRdb(c config.Redis) *redis.Client {
|
||||||
|
|
||||||
rdb := redis.NewClient(&redis.Options{
|
rdb := redis.NewClient(&redis.Options{
|
||||||
Addr: c.Host,
|
Addr: c.Host,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/pkg/l_request"
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KnowledgeBase struct {
|
||||||
|
session string
|
||||||
|
url string
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKnowledgeBase(url, apiKey, session string) *KnowledgeBase {
|
||||||
|
return &KnowledgeBase{
|
||||||
|
session: session,
|
||||||
|
url: url,
|
||||||
|
apiKey: apiKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 请求知识库聊天
|
||||||
|
func (this *KnowledgeBase) Chat(ctx context.Context, query string) (text string, err error) {
|
||||||
|
|
||||||
|
req := l_request.Request{
|
||||||
|
Method: "post",
|
||||||
|
Url: this.url + "/api/v1/knowledge-chat/" + this.session,
|
||||||
|
Params: nil,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-API-Key": this.apiKey,
|
||||||
|
},
|
||||||
|
Cookies: nil,
|
||||||
|
Data: nil,
|
||||||
|
Json: map[string]interface{}{
|
||||||
|
"query": query,
|
||||||
|
},
|
||||||
|
Files: nil,
|
||||||
|
Raw: "",
|
||||||
|
JsonByte: nil,
|
||||||
|
Xml: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
rsp, err := req.SendNoParseResponse()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rsp.Body.Close()
|
||||||
|
|
||||||
|
err = connectAndReadSSE(rsp)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message 表示解析后的 SSE 消息
|
||||||
|
type Message struct {
|
||||||
|
Event string // 事件类型(默认 "message")
|
||||||
|
Data string // 消息内容(可能多行)
|
||||||
|
ID string // 消息 ID(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 连接 SSE 并读取数据
|
||||||
|
func connectAndReadSSE(resp *http.Response) error {
|
||||||
|
|
||||||
|
// 验证响应状态和格式
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("非 200 状态码: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if !strings.Contains(contentType, "text/event-stream") {
|
||||||
|
return fmt.Errorf("不支持的 Content-Type: %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 逐行读取响应流
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
var currentMsg Message // 当前正在组装的消息
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if line == "" {
|
||||||
|
// 空行表示一条消息结束,处理当前消息
|
||||||
|
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
|
||||||
|
printMessage(currentMsg)
|
||||||
|
currentMsg = Message{} // 重置消息
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析字段(格式:"field: value")
|
||||||
|
parts := strings.SplitN(line, ":", 2)
|
||||||
|
if len(parts) < 2 {
|
||||||
|
continue // 无效行(无冒号),跳过
|
||||||
|
}
|
||||||
|
field := strings.TrimSpace(parts[0])
|
||||||
|
value := strings.TrimSpace(parts[1])
|
||||||
|
|
||||||
|
switch field {
|
||||||
|
case "event":
|
||||||
|
currentMsg.Event = value
|
||||||
|
case "data":
|
||||||
|
// data 可能多行,用换行符拼接(最后一条消息可能无结尾空行)
|
||||||
|
currentMsg.Data += value + ""
|
||||||
|
//case "id":
|
||||||
|
// currentMsg.ID = value
|
||||||
|
// 可选:处理 "retry" 字段(服务器建议的重连时间,单位秒)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查扫描错误(如连接断开)
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return fmt.Errorf("读取流失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理最后一条未结束的消息(无结尾空行)
|
||||||
|
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
|
||||||
|
printMessage(currentMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type MegContent struct {
|
||||||
|
Id string `json:"id"` // 消息 ID
|
||||||
|
ResponseType string `json:"response_type"` // 响应类型,answer 或 references
|
||||||
|
Content string `json:"content"` // 消息内容
|
||||||
|
Done bool `json:"done"` // 是否完成
|
||||||
|
KnowledgeReferences interface{} `json:"knowledge_references"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// printMessage 打印解析后的 SSE 消息
|
||||||
|
func printMessage(msg Message) {
|
||||||
|
|
||||||
|
//fmt.Printf("--- 收到 SSE 消息 ---")
|
||||||
|
//fmt.Printf("事件类型: %s,", msg.Event)
|
||||||
|
//fmt.Printf("消息 ID: %s,", msg.ID)
|
||||||
|
//fmt.Printf("内容:%s,", strings.TrimSpace(msg.Data)) // 去除末尾多余换行
|
||||||
|
|
||||||
|
var content MegContent
|
||||||
|
_ = json.Unmarshal([]byte(msg.Data), &content)
|
||||||
|
fmt.Println(msg.Data)
|
||||||
|
|
||||||
|
//if content.ResponseType == "answer" {
|
||||||
|
// //fmt.Printf("%s", content.Content)
|
||||||
|
// fmt.Println(content)
|
||||||
|
//} else {
|
||||||
|
// fmt.Printf("--- 收到 SSE 消息 ---")
|
||||||
|
// fmt.Printf("事件类型: %s,", msg.Event)
|
||||||
|
// fmt.Printf("消息 ID: %s,", msg.ID)
|
||||||
|
// fmt.Printf("内容:%s,", strings.TrimSpace(msg.Data)) // 去除末尾多余换行
|
||||||
|
//}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRetryAfter 从响应头获取重连时间(示例,需根据实际响应头调整)
|
||||||
|
func getRetryAfter(url string) int {
|
||||||
|
// 实际需重新请求并获取响应头(此处简化为固定值)
|
||||||
|
// 正确做法:在 connectAndReadSSE 中记录响应头的 Retry-After 字段
|
||||||
|
return 5 // 示例:等待 5 秒
|
||||||
|
}
|
||||||
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DBConn(c *config.DB) (*gorm.DB, func()) {
|
func DBConn(c config.DB) (*gorm.DB, func()) {
|
||||||
mysqlConn, err := sql.Open(c.Driver, c.Source)
|
mysqlConn, err := sql.Open(c.Driver, c.Source)
|
||||||
gormDB, err := gorm.Open(
|
gormDB, err := gorm.Open(
|
||||||
mysql.New(mysql.Config{Conn: mysqlConn}),
|
mysql.New(mysql.Config{Conn: mysqlConn}),
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
package ollama
|
package utils_ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"net/http"
|
||||||
"fmt"
|
"net/url"
|
||||||
"time"
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
@ -18,79 +19,76 @@ type Client struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient 创建新的Ollama客户端
|
// NewClient 创建新的Ollama客户端
|
||||||
func NewClient(config *config.Config) (entitys.AIClient, func(), error) {
|
func NewClient(config *config.Config) (client *Client, cleanFunc func(), err error) {
|
||||||
client, err := api.ClientFromEnvironment()
|
client = &Client{
|
||||||
|
config: &config.Ollama,
|
||||||
|
}
|
||||||
|
url, err := client.getUrl()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
client.client = api.NewClient(url, http.DefaultClient)
|
||||||
|
|
||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
if client != nil {
|
if client != nil {
|
||||||
client = nil
|
client = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, cleanup, fmt.Errorf("failed to create ollama client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{
|
return client, cleanup, nil
|
||||||
client: client,
|
|
||||||
config: &config.Ollama,
|
|
||||||
}, cleanup, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Chat 实现聊天功能
|
// ToolSelect 工具选择
|
||||||
func (c *Client) Chat(ctx context.Context, messages []entitys.Message, tools []entitys.ToolDefinition) (*entitys.ChatResponse, error) {
|
func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools []api.Tool) (res api.ChatResponse, err error) {
|
||||||
// 构建聊天请求
|
// 构建聊天请求
|
||||||
req := &api.ChatRequest{
|
req := &api.ChatRequest{
|
||||||
Model: c.config.Model,
|
Model: c.config.Model,
|
||||||
Messages: make([]api.Message, len(messages)),
|
Messages: messages,
|
||||||
Stream: new(bool), // 设置为false,不使用流式响应
|
Stream: new(bool), // 设置为false,不使用流式响应
|
||||||
Think: &api.ThinkValue{Value: true},
|
Think: &api.ThinkValue{Value: true},
|
||||||
|
Tools: tools,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转换消息格式
|
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||||
for i, msg := range messages {
|
res = resp
|
||||||
req.Messages[i] = api.Message{
|
|
||||||
Role: msg.Role,
|
|
||||||
Content: msg.Content,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加工具定义
|
|
||||||
if len(tools) > 0 {
|
|
||||||
req.Tools = make([]api.Tool, len(tools))
|
|
||||||
for i, tool := range tools {
|
|
||||||
toolData, _ := json.Marshal(tool)
|
|
||||||
var apiTool api.Tool
|
|
||||||
json.Unmarshal(toolData, &apiTool)
|
|
||||||
req.Tools[i] = apiTool
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
responseChan := make(chan api.ChatResponse)
|
|
||||||
errorChan := make(chan error)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
err := c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
|
||||||
responseChan <- resp
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorChan <- err
|
return
|
||||||
}
|
}
|
||||||
close(responseChan)
|
|
||||||
close(errorChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 等待响应
|
return
|
||||||
select {
|
}
|
||||||
case resp := <-responseChan:
|
|
||||||
return c.convertResponse(&resp), nil
|
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.ResponseData, messages []api.Message) (err error) {
|
||||||
case err := <-errorChan:
|
// 构建聊天请求
|
||||||
return nil, fmt.Errorf("chat request failed: %w", err)
|
req := &api.ChatRequest{
|
||||||
case <-ctx.Done():
|
Model: c.config.Model,
|
||||||
return nil, ctx.Err()
|
Messages: messages,
|
||||||
case <-time.After(c.config.Timeout):
|
Stream: nil,
|
||||||
return nil, fmt.Errorf("chat request timeout")
|
Think: &api.ThinkValue{Value: true},
|
||||||
}
|
}
|
||||||
|
var w sync.WaitGroup
|
||||||
|
w.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer w.Done()
|
||||||
|
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||||
|
if resp.Message.Content != "" {
|
||||||
|
ch <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: resp.Message.Content,
|
||||||
|
Type: entitys.ResponseStream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
w.Wait()
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertResponse 转换响应格式
|
// convertResponse 转换响应格式
|
||||||
|
|
@ -121,3 +119,13 @@ func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
||||||
//return result
|
//return result
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) getUrl() (*url.URL, error) {
|
||||||
|
baseURL := c.config.BaseURL
|
||||||
|
envURL := os.Getenv("OLLAMA_BASE_URL")
|
||||||
|
if envURL != "" {
|
||||||
|
baseURL = envURL
|
||||||
|
}
|
||||||
|
|
||||||
|
return url.Parse(baseURL)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
package utils_ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
|
"github.com/tmc/langchaingo/llms/ollama"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UtilOllama struct {
|
||||||
|
Llm *ollama.LLM
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
||||||
|
llm, err := ollama.New(
|
||||||
|
ollama.WithModel(c.Ollama.Model),
|
||||||
|
ollama.WithHTTPClient(http.DefaultClient),
|
||||||
|
ollama.WithServerURL(getUrl(c)),
|
||||||
|
ollama.WithKeepAlive("-1s"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UtilOllama{
|
||||||
|
Llm: llm,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//func (o *UtilOllama) a() {
|
||||||
|
// var agent agents.Agent
|
||||||
|
// agent = agents.NewOneShotAgent(llm, tools, opts...)
|
||||||
|
//
|
||||||
|
// agents.NewExecutor()
|
||||||
|
//}
|
||||||
|
|
||||||
|
func getUrl(c *config.Config) string {
|
||||||
|
baseURL := c.Ollama.BaseURL
|
||||||
|
envURL := os.Getenv("OLLAMA_BASE_URL")
|
||||||
|
if envURL != "" {
|
||||||
|
baseURL = envURL
|
||||||
|
}
|
||||||
|
return baseURL
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/gateway"
|
||||||
"ai_scheduler/internal/server/router"
|
"ai_scheduler/internal/server/router"
|
||||||
"ai_scheduler/internal/services"
|
"ai_scheduler/internal/services"
|
||||||
|
|
||||||
|
|
@ -12,10 +13,11 @@ import (
|
||||||
func NewHTTPServer(
|
func NewHTTPServer(
|
||||||
service *services.ChatService,
|
service *services.ChatService,
|
||||||
session *services.SessionService,
|
session *services.SessionService,
|
||||||
|
gateway *gateway.Gateway,
|
||||||
) *fiber.App {
|
) *fiber.App {
|
||||||
//构建 server
|
//构建 server
|
||||||
app := initRoute()
|
app := initRoute()
|
||||||
router.SetupRoutes(app, service, session)
|
router.SetupRoutes(app, service, session, gateway)
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,10 @@ package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
|
"ai_scheduler/internal/gateway"
|
||||||
"ai_scheduler/internal/services"
|
"ai_scheduler/internal/services"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -13,7 +15,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetupRoutes 设置路由
|
// SetupRoutes 设置路由
|
||||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService) {
|
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) {
|
||||||
app.Use(func(c *fiber.Ctx) error {
|
app.Use(func(c *fiber.Ctx) error {
|
||||||
// 设置 CORS 头
|
// 设置 CORS 头
|
||||||
c.Set("Access-Control-Allow-Origin", "*")
|
c.Set("Access-Control-Allow-Origin", "*")
|
||||||
|
|
@ -28,7 +30,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
||||||
// 继续处理后续中间件或路由
|
// 继续处理后续中间件或路由
|
||||||
return c.Next()
|
return c.Next()
|
||||||
})
|
})
|
||||||
routerHttp(app, sessionService)
|
routerHttp(app, sessionService, gateway)
|
||||||
routerSocket(app, ChatService)
|
routerSocket(app, ChatService)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -56,7 +58,7 @@ var bufferPool = &sync.Pool{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func routerHttp(app *fiber.App, sessionService *services.SessionService) {
|
func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) {
|
||||||
r := app.Group("api/v1/")
|
r := app.Group("api/v1/")
|
||||||
registerResponse(r)
|
registerResponse(r)
|
||||||
// 注册 CORS 中间件
|
// 注册 CORS 中间件
|
||||||
|
|
@ -65,9 +67,28 @@ func routerHttp(app *fiber.App, sessionService *services.SessionService) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Post("/session/init", sessionService.SessionInit)
|
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
|
||||||
r.Post("/session/list", sessionService.SessionList)
|
r.Post("/session/list", sessionService.SessionList)
|
||||||
|
//广播
|
||||||
|
r.Get("/broadcast", func(ctx *fiber.Ctx) error {
|
||||||
|
action := ctx.Query("action")
|
||||||
|
uid := ctx.Query("uid")
|
||||||
|
msg := ctx.Query("msg")
|
||||||
|
|
||||||
|
switch action {
|
||||||
|
case "sendToAll":
|
||||||
|
gateway.SendToAll([]byte(msg))
|
||||||
|
return ctx.SendString("sent to all")
|
||||||
|
case "sendToUid":
|
||||||
|
if uid == "" {
|
||||||
|
return ctx.Status(400).SendString("missing uid")
|
||||||
|
}
|
||||||
|
gateway.SendToUid(uid, []byte(msg))
|
||||||
|
return ctx.SendString(fmt.Sprintf("sent to uid %s", uid))
|
||||||
|
default:
|
||||||
|
return ctx.Status(400).SendString("unknown action")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerResponse(router fiber.Router) {
|
func registerResponse(router fiber.Router) {
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,33 @@
|
||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/data/constant"
|
"ai_scheduler/internal/biz"
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/gateway"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChatHandler 聊天处理器
|
// ChatHandler 聊天处理器
|
||||||
type ChatService struct {
|
type ChatService struct {
|
||||||
routerService entitys.RouterService
|
routerBiz *biz.AiRouterBiz
|
||||||
|
Gw *gateway.Gateway
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChatHandler 创建聊天处理器
|
// NewChatHandler 创建聊天处理器
|
||||||
func NewChatService(routerService entitys.RouterService) *ChatService {
|
func NewChatService(routerService *biz.AiRouterBiz, gw *gateway.Gateway) *ChatService {
|
||||||
return &ChatService{
|
return &ChatService{
|
||||||
routerService: routerService,
|
routerBiz: routerService,
|
||||||
|
Gw: gw,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -36,12 +46,34 @@ type FunctionCallResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
||||||
//c.WriteMessage(messageType, message)
|
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("发送错误:", err)
|
||||||
|
}
|
||||||
|
_ = c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateClientID() string {
|
||||||
|
// 使用时间戳+随机数确保唯一性
|
||||||
|
timestamp := time.Now().UnixNano()
|
||||||
|
randomBytes := make([]byte, 4)
|
||||||
|
rand.Read(randomBytes)
|
||||||
|
randomStr := hex.EncodeToString(randomBytes)
|
||||||
|
return fmt.Sprintf("%d%s", timestamp, randomStr)
|
||||||
|
}
|
||||||
func (h *ChatService) Chat(c *websocket.Conn) {
|
func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
|
h.mu.Lock()
|
||||||
|
clientID := generateClientID()
|
||||||
|
h.mu.Unlock()
|
||||||
|
client := &gateway.Client{
|
||||||
|
ID: clientID,
|
||||||
|
SendFunc: func(data []byte) error {
|
||||||
|
return c.WriteMessage(websocket.TextMessage, data)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h.Gw.AddClient(client)
|
||||||
|
log.Println("client connected:", clientID)
|
||||||
log.Println("客户端已连接")
|
log.Println("客户端已连接")
|
||||||
defer c.Close()
|
|
||||||
// 循环读取客户端消息
|
// 循环读取客户端消息
|
||||||
for {
|
for {
|
||||||
messageType, message, err := c.ReadMessage()
|
messageType, message, err := c.ReadMessage()
|
||||||
|
|
@ -49,11 +81,17 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
log.Println("读取错误:", err)
|
log.Println("读取错误:", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
//简单协议:bind:<uid>
|
||||||
|
if c.Headers("Sec-Websocket-Protocol") == "bind" && c.Headers("X-Session") != "" {
|
||||||
|
uid := c.Headers("X-Session")
|
||||||
|
_ = h.Gw.BindUid(clientID, uid)
|
||||||
|
log.Printf("bind %s -> uid:%s\n", clientID, uid)
|
||||||
|
}
|
||||||
msg, chatType := h.handleMessageToString(c, messageType, message)
|
msg, chatType := h.handleMessageToString(c, messageType, message)
|
||||||
if chatType == constant.ConnStatusClosed {
|
if chatType == constants.ConnStatusClosed {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if chatType == constant.ConnStatusIgnore {
|
if chatType == constants.ConnStatusIgnore {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -63,32 +101,34 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
log.Println("JSON parse error:", err)
|
log.Println("JSON parse error:", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = h.routerService.RouteWithSocket(c, &req)
|
err = h.routerBiz.RouteWithSocket(c, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("处理失败:", err)
|
log.Println("处理失败:", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Println("客户端已断开")
|
h.Gw.RemoveClient(clientID)
|
||||||
|
_ = c.Close()
|
||||||
|
log.Println("client disconnected:", clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constant.ConnStatus) {
|
func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) {
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case websocket.TextMessage:
|
case websocket.TextMessage:
|
||||||
return msg.([]byte), constant.ConnStatusNormal
|
return msg.([]byte), constants.ConnStatusNormal
|
||||||
case websocket.BinaryMessage:
|
case websocket.BinaryMessage:
|
||||||
return msg.([]byte), constant.ConnStatusNormal
|
return msg.([]byte), constants.ConnStatusNormal
|
||||||
case websocket.CloseMessage:
|
case websocket.CloseMessage:
|
||||||
|
|
||||||
return nil, constant.ConnStatusClosed
|
return nil, constants.ConnStatusClosed
|
||||||
case websocket.PingMessage:
|
case websocket.PingMessage:
|
||||||
// 可选:回复 Pong
|
// 可选:回复 Pong
|
||||||
c.WriteMessage(websocket.PongMessage, nil)
|
c.WriteMessage(websocket.PongMessage, nil)
|
||||||
return nil, constant.ConnStatusIgnore
|
return nil, constants.ConnStatusIgnore
|
||||||
case websocket.PongMessage:
|
case websocket.PongMessage:
|
||||||
return nil, constant.ConnStatusIgnore
|
return nil, constants.ConnStatusIgnore
|
||||||
default:
|
default:
|
||||||
return nil, constant.ConnStatusIgnore
|
return nil, constants.ConnStatusIgnore
|
||||||
}
|
}
|
||||||
return msg.([]byte), constant.ConnStatusIgnore
|
return msg.([]byte), constants.ConnStatusIgnore
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
package services
|
package services
|
||||||
|
|
||||||
import "github.com/google/wire"
|
import (
|
||||||
|
"ai_scheduler/internal/gateway"
|
||||||
|
"github.com/google/wire"
|
||||||
|
)
|
||||||
|
|
||||||
var ProviderSetServices = wire.NewSet(NewChatService)
|
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway)
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,13 @@ import (
|
||||||
|
|
||||||
type SessionService struct {
|
type SessionService struct {
|
||||||
sessionBiz *biz.SessionBiz
|
sessionBiz *biz.SessionBiz
|
||||||
|
chatBiz *biz.ChatHistoryBiz
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(sessionBiz *biz.SessionBiz) *SessionService {
|
func NewSessionService(sessionBiz *biz.SessionBiz, chatBiz *biz.ChatHistoryBiz) *SessionService {
|
||||||
return &SessionService{
|
return &SessionService{
|
||||||
sessionBiz: sessionBiz,
|
sessionBiz: sessionBiz,
|
||||||
|
chatBiz: chatBiz,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -23,11 +25,9 @@ func (s *SessionService) SessionInit(c *fiber.Ctx) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionId, err := s.sessionBiz.SessionInit(c.Context(), req)
|
result, err := s.sessionBiz.SessionInit(c.Context(), req)
|
||||||
|
|
||||||
return handRes(c, err, fiber.Map{
|
return handRes(c, err, result)
|
||||||
"session_id": sessionId,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionList 获取会话列表
|
// SessionList 获取会话列表
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg/mapstructure"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_task(t *testing.T) {
|
||||||
|
var c entitys.TaskConfig
|
||||||
|
config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
|
||||||
|
err := json.Unmarshal([]byte(config), &c)
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type configData struct {
|
||||||
|
Param map[string]interface{} `json:"param"`
|
||||||
|
Do map[string]interface{} `json:"do"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_task2(t *testing.T) {
|
||||||
|
var (
|
||||||
|
c l_request.Request
|
||||||
|
config configData
|
||||||
|
)
|
||||||
|
|
||||||
|
configJson := `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`
|
||||||
|
err := json.Unmarshal([]byte(configJson), &config)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
mapstructure.Decode(config.Do, &c)
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func producer(ch chan<- int) {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
ch <- i // 发送数据到通道
|
||||||
|
fmt.Printf("Sent: %d\n", i)
|
||||||
|
time.Sleep(500 * time.Millisecond) // 模拟生产延迟
|
||||||
|
}
|
||||||
|
close(ch) // 关闭通道,通知接收方数据发送完毕
|
||||||
|
}
|
||||||
|
|
||||||
|
func consumer(ch <-chan int) {
|
||||||
|
for v := range ch { // 阻塞等待数据,有数据立即处理
|
||||||
|
fmt.Printf("Received: %d\n", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_a(t *testing.T) {
|
||||||
|
ch := make(chan int, 3) // 有缓冲通道(可选)
|
||||||
|
|
||||||
|
go producer(ch)
|
||||||
|
consumer(ch) // 主线程阻塞,直到通道关闭
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,248 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg/l_request"
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 知识库工具
|
||||||
|
type KnowledgeBaseTool struct {
|
||||||
|
config config.ToolConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKnowledgeBaseTool 创建知识库工具
|
||||||
|
func NewKnowledgeBaseTool(config config.ToolConfig) *KnowledgeBaseTool {
|
||||||
|
return &KnowledgeBaseTool{config: config}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KnowledgeBaseTool) GetConfig() config.ToolConfig {
|
||||||
|
return k.config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name 返回工具名称
|
||||||
|
func (k *KnowledgeBaseTool) Name() string {
|
||||||
|
return "knowledgeBase"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Description 返回工具描述
|
||||||
|
func (k *KnowledgeBaseTool) Description() string {
|
||||||
|
return "请求知识库"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Definition 返回工具定义
|
||||||
|
func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
|
||||||
|
return entitys.ToolDefinition{
|
||||||
|
Type: "function",
|
||||||
|
Function: entitys.FunctionDef{
|
||||||
|
Name: k.Name(),
|
||||||
|
Description: k.Description(),
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"query": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "知识库查询条件",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"query"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute 执行知识库查询
|
||||||
|
func (k *KnowledgeBaseTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error {
|
||||||
|
|
||||||
|
var params KnowledgeBaseRequest
|
||||||
|
if err := json.Unmarshal(args, ¶ms); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal args failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return k.chat(channel, c, params)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
type KnowledgeBaseRequest struct {
|
||||||
|
Session string // 知识库会话id
|
||||||
|
ApiKey string // 知识库apiKey
|
||||||
|
Query string // 用户输入
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message 表示解析后的 SSE 消息
|
||||||
|
type Message struct {
|
||||||
|
Event string // 事件类型(默认 "message")
|
||||||
|
Data string // 消息内容(可能多行)
|
||||||
|
ID string // 消息 ID(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MegContent struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
ResponseType string `json:"response_type"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
KnowledgeReferences interface{} `json:"knowledge_references"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 请求知识库聊天
|
||||||
|
func (this *KnowledgeBaseTool) chat(channel chan entitys.ResponseData, c *websocket.Conn, param KnowledgeBaseRequest) (err error) {
|
||||||
|
|
||||||
|
req := l_request.Request{
|
||||||
|
Method: "post",
|
||||||
|
Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + param.Session,
|
||||||
|
Params: nil,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-API-Key": param.ApiKey,
|
||||||
|
},
|
||||||
|
Cookies: nil,
|
||||||
|
Data: nil,
|
||||||
|
Json: map[string]interface{}{
|
||||||
|
"query": param.Query,
|
||||||
|
},
|
||||||
|
Files: nil,
|
||||||
|
Raw: "",
|
||||||
|
JsonByte: nil,
|
||||||
|
Xml: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
rsp, err := req.SendNoParseResponse()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rsp.Body.Close()
|
||||||
|
|
||||||
|
err = connectAndReadSSE(rsp, channel)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 连接 SSE 并读取数据
|
||||||
|
func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) error {
|
||||||
|
|
||||||
|
// 验证响应状态和格式
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("非 200 状态码: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if !strings.Contains(contentType, "text/event-stream") {
|
||||||
|
return fmt.Errorf("不支持的 Content-Type: %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 逐行读取响应流
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
var currentMsg Message // 当前正在组装的消息
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if line == "" {
|
||||||
|
// 空行表示一条消息结束,处理当前消息
|
||||||
|
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
|
||||||
|
channel <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: currentMsg.Data,
|
||||||
|
Type: entitys.ResponseJson,
|
||||||
|
}
|
||||||
|
currentMsg = Message{} // 重置消息
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析字段(格式:"field: value")
|
||||||
|
parts := strings.SplitN(line, ":", 2)
|
||||||
|
if len(parts) < 2 {
|
||||||
|
continue // 无效行(无冒号),跳过
|
||||||
|
}
|
||||||
|
field := strings.TrimSpace(parts[0])
|
||||||
|
value := strings.TrimSpace(parts[1])
|
||||||
|
|
||||||
|
switch field {
|
||||||
|
case "event":
|
||||||
|
currentMsg.Event = value
|
||||||
|
case "data":
|
||||||
|
// data 可能多行,用换行符拼接(最后一条消息可能无结尾空行)
|
||||||
|
currentMsg.Data += value + ""
|
||||||
|
default:
|
||||||
|
// 忽略未知字段
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查扫描错误(如连接断开)
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return fmt.Errorf("读取流失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理最后一条未结束的消息(无结尾空行)
|
||||||
|
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
|
||||||
|
channel <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: currentMsg.Data,
|
||||||
|
Type: entitys.ResponseJson,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取知识库 session
|
||||||
|
func GetKnowledgeBaseSession(host, baseId, apiKey string) (string, error) {
|
||||||
|
req := l_request.Request{
|
||||||
|
Method: "post",
|
||||||
|
Url: host + "/api/v1/sessions",
|
||||||
|
Params: nil,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-API-Key": apiKey,
|
||||||
|
},
|
||||||
|
Cookies: nil,
|
||||||
|
Data: nil,
|
||||||
|
Json: map[string]interface{}{
|
||||||
|
"knowledge_base_id": baseId,
|
||||||
|
},
|
||||||
|
Files: nil,
|
||||||
|
Raw: "",
|
||||||
|
JsonByte: nil,
|
||||||
|
Xml: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
rsp, err := req.Send()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result sessionRsp
|
||||||
|
err = json.Unmarshal(rsp.Content, &result)
|
||||||
|
|
||||||
|
return result.Data.Id, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionRsp struct {
|
||||||
|
Data struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
TenantId int `json:"tenant_id"`
|
||||||
|
KnowledgeBaseId string `json:"knowledge_base_id"`
|
||||||
|
MaxRounds int `json:"max_rounds"`
|
||||||
|
EnableRewrite bool `json:"enable_rewrite"`
|
||||||
|
FallbackStrategy string `json:"fallback_strategy"`
|
||||||
|
FallbackResponse string `json:"fallback_response"`
|
||||||
|
EmbeddingTopK int `json:"embedding_top_k"`
|
||||||
|
KeywordThreshold float64 `json:"keyword_threshold"`
|
||||||
|
VectorThreshold float64 `json:"vector_threshold"`
|
||||||
|
RerankModelId string `json:"rerank_model_id"`
|
||||||
|
RerankTopK int `json:"rerank_top_k"`
|
||||||
|
RerankThreshold float64 `json:"rerank_threshold"`
|
||||||
|
SummaryModelId string `json:"summary_model_id"`
|
||||||
|
} `json:"data"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKnowledgeBaseTool_Execute(t *testing.T) {
|
||||||
|
|
||||||
|
kb := NewKnowledgeBaseTool(config.ToolConfig{})
|
||||||
|
channel := make(chan entitys.ResponseData)
|
||||||
|
err := kb.Execute(channel, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// session
|
||||||
|
func TestKnowledgeBaseTool_Submit(t *testing.T) {
|
||||||
|
|
||||||
|
apiKey := "sk-EfnUANKMj3DUOiEPJZ5xS8SGMsbO6be_qYAg9uZ8T3zyoFM-"
|
||||||
|
baseId := "kb-00000001"
|
||||||
|
host := "http://117.175.169.61:10000"
|
||||||
|
|
||||||
|
sessionId, err := GetKnowledgeBaseSession(host, baseId, apiKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("GetKnowledgeBaseSession() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("sessionId:", sessionId)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -2,36 +2,40 @@ package tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/constants"
|
"ai_scheduler/internal/data/constants"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager 工具管理器
|
// Manager 工具管理器
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
tools map[string]entitys.Tool
|
tools map[string]entitys.Tool
|
||||||
|
llm *utils_ollama.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager 创建工具管理器
|
// NewManager 创建工具管理器
|
||||||
func NewManager(config *config.Config) *Manager {
|
func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
tools: make(map[string]entitys.Tool),
|
tools: make(map[string]entitys.Tool),
|
||||||
|
llm: llm,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册天气工具
|
// 注册天气工具
|
||||||
if config.Tools.Weather.Enabled {
|
//if config.Tools.Weather.Enabled {
|
||||||
weatherTool := NewWeatherTool()
|
// weatherTool := NewWeatherTool()
|
||||||
m.tools[weatherTool.Name()] = weatherTool
|
// m.tools[weatherTool.Name()] = weatherTool
|
||||||
}
|
//}
|
||||||
|
//
|
||||||
// 注册计算器工具
|
//// 注册计算器工具
|
||||||
if config.Tools.Calculator.Enabled {
|
//if config.Tools.Calculator.Enabled {
|
||||||
calcTool := NewCalculatorTool()
|
// calcTool := NewCalculatorTool()
|
||||||
m.tools[calcTool.Name()] = calcTool
|
// m.tools[calcTool.Name()] = calcTool
|
||||||
}
|
//}
|
||||||
|
|
||||||
// 注册知识库工具
|
// 注册知识库工具
|
||||||
// if config.Knowledge.Enabled {
|
// if config.Knowledge.Enabled {
|
||||||
|
|
@ -41,16 +45,32 @@ func NewManager(config *config.Config) *Manager {
|
||||||
|
|
||||||
// 注册直连天下订单详情工具
|
// 注册直连天下订单详情工具
|
||||||
if config.Tools.ZltxOrderDetail.Enabled {
|
if config.Tools.ZltxOrderDetail.Enabled {
|
||||||
zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail)
|
zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm)
|
||||||
m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool
|
m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册直连天下订单日志工具
|
//注册直连天下订单日志工具
|
||||||
// if config.ZltxOrderLog.Enabled {
|
if config.Tools.ZltxOrderDirectLog.Enabled {
|
||||||
// zltxOrderLogTool := NewZltxOrderLogTool(config.ZltxOrderLog)
|
zltxOrderLogTool := NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog)
|
||||||
// m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool
|
m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool
|
||||||
// }
|
}
|
||||||
|
|
||||||
|
//注册直连天下商品工具
|
||||||
|
if config.Tools.ZltxProduct.Enabled {
|
||||||
|
zltxProductTool := NewZltxProductTool(config.Tools.ZltxProduct)
|
||||||
|
m.tools[zltxProductTool.Name()] = zltxProductTool
|
||||||
|
}
|
||||||
|
//注册直连天下订单统计工具
|
||||||
|
if config.Tools.ZltxOrderStatistics.Enabled {
|
||||||
|
zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics)
|
||||||
|
m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注册知识库工具
|
||||||
|
if config.Tools.Knowledge.Enabled {
|
||||||
|
knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge)
|
||||||
|
m.tools[knowledgeTool.Name()] = knowledgeTool
|
||||||
|
}
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -80,43 +100,43 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteTool 执行工具
|
// ExecuteTool 执行工具
|
||||||
func (m *Manager) ExecuteTool(ctx context.Context, name string, args json.RawMessage) (interface{}, error) {
|
func (m *Manager) ExecuteTool(channel chan entitys.ResponseData, c *websocket.Conn, name string, args json.RawMessage) error {
|
||||||
tool, exists := m.GetTool(name)
|
tool, exists := m.GetTool(name)
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, fmt.Errorf("tool not found: %s", name)
|
return fmt.Errorf("tool not found: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tool.Execute(ctx, args)
|
return tool.Execute(channel, c, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteToolCalls 执行多个工具调用
|
// ExecuteToolCalls 执行多个工具调用
|
||||||
func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
|
//func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
|
||||||
results := make([]entitys.ToolCall, len(toolCalls))
|
// results := make([]entitys.ToolCall, len(toolCalls))
|
||||||
|
//
|
||||||
for i, toolCall := range toolCalls {
|
// for i, toolCall := range toolCalls {
|
||||||
results[i] = toolCall
|
// results[i] = toolCall
|
||||||
|
//
|
||||||
// 执行工具
|
// // 执行工具
|
||||||
result, err := m.ExecuteTool(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
// err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
// 将错误信息作为结果返回
|
// // 将错误信息作为结果返回
|
||||||
errorResult := map[string]interface{}{
|
// errorResult := map[string]interface{}{
|
||||||
"error": err.Error(),
|
// "error": err.Error(),
|
||||||
}
|
// }
|
||||||
resultBytes, _ := json.Marshal(errorResult)
|
// resultBytes, _ := json.Marshal(errorResult)
|
||||||
results[i].Result = resultBytes
|
// results[i].Result = resultBytes
|
||||||
} else {
|
// } else {
|
||||||
// 将成功结果序列化
|
// // 将成功结果序列化
|
||||||
resultBytes, err := json.Marshal(result)
|
// resultBytes, err := json.Marshal(result)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
errorResult := map[string]interface{}{
|
// errorResult := map[string]interface{}{
|
||||||
"error": fmt.Sprintf("failed to serialize result: %v", err),
|
// "error": fmt.Sprintf("failed to serialize result: %v", err),
|
||||||
}
|
// }
|
||||||
resultBytes, _ = json.Marshal(errorResult)
|
// resultBytes, _ = json.Marshal(errorResult)
|
||||||
}
|
// }
|
||||||
results[i].Result = resultBytes
|
// results[i].Result = resultBytes
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
return results, nil
|
// return results, nil
|
||||||
}
|
//}
|
||||||
|
|
|
||||||
|
|
@ -3,21 +3,25 @@ package tools
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ZltxOrderDetailTool 直连天下订单详情工具
|
// ZltxOrderDetailTool 直连天下订单详情工具
|
||||||
type ZltxOrderDetailTool struct {
|
type ZltxOrderDetailTool struct {
|
||||||
config config.ToolConfig
|
config config.ToolConfig
|
||||||
|
llm *utils_ollama.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewZltxOrderDetailTool 创建直连天下订单详情工具
|
// NewZltxOrderDetailTool 创建直连天下订单详情工具
|
||||||
func NewZltxOrderDetailTool(config config.ToolConfig) *ZltxOrderDetailTool {
|
func NewZltxOrderDetailTool(config config.ToolConfig, llm *utils_ollama.Client) *ZltxOrderDetailTool {
|
||||||
return &ZltxOrderDetailTool{config: config}
|
return &ZltxOrderDetailTool{config: config, llm: llm}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name 返回工具名称
|
// Name 返回工具名称
|
||||||
|
|
@ -53,7 +57,7 @@ func (w *ZltxOrderDetailTool) Definition() entitys.ToolDefinition {
|
||||||
|
|
||||||
// ZltxOrderDetailRequest 直连天下订单详情请求参数
|
// ZltxOrderDetailRequest 直连天下订单详情请求参数
|
||||||
type ZltxOrderDetailRequest struct {
|
type ZltxOrderDetailRequest struct {
|
||||||
Number string `json:"number"`
|
OrderNumber string `json:"order_number"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ZltxOrderDetailResponse 直连天下订单详情响应
|
// ZltxOrderDetailResponse 直连天下订单详情响应
|
||||||
|
|
@ -61,6 +65,13 @@ type ZltxOrderDetailResponse struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
Data ZltxOrderDetailData `json:"data"`
|
Data ZltxOrderDetailData `json:"data"`
|
||||||
|
Mes string `json:"mes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZltxOrderLogResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
Data any `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ZltxOrderDetailData 直连天下订单详情数据
|
// ZltxOrderDetailData 直连天下订单详情数据
|
||||||
|
|
@ -70,37 +81,103 @@ type ZltxOrderDetailData struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute 执行直连天下订单详情查询
|
// Execute 执行直连天下订单详情查询
|
||||||
func (w *ZltxOrderDetailTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) {
|
func (w *ZltxOrderDetailTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error {
|
||||||
var req ZltxOrderDetailRequest
|
var req ZltxOrderDetailRequest
|
||||||
if err := json.Unmarshal(args, &req); err != nil {
|
if err := json.Unmarshal(args, &req); err != nil {
|
||||||
return nil, fmt.Errorf("invalid zltxOrderDetail request: %w", err)
|
return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Number == "" {
|
if req.OrderNumber == "" {
|
||||||
return nil, fmt.Errorf("number is required")
|
return fmt.Errorf("number is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 这里可以集成真实的直连天下订单详情API
|
// 这里可以集成真实的直连天下订单详情API
|
||||||
return w.getZltxOrderDetail(ctx, req.Number), nil
|
return w.getZltxOrderDetail(channel, c, req.OrderNumber)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
|
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
|
||||||
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ctx context.Context, number string) *ZltxOrderDetailResponse {
|
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.ResponseData, c *websocket.Conn, number string) (err error) {
|
||||||
url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number)
|
//查询订单详情
|
||||||
authorization := fmt.Sprintf("Bearer %s", w.config.APIKey)
|
var auth string
|
||||||
|
if c != nil {
|
||||||
// 发送http请求
|
auth = c.Headers("X-Authorization", "")
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
|
||||||
if err != nil {
|
|
||||||
return &ZltxOrderDetailResponse{}
|
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", authorization)
|
if len(auth) == 0 {
|
||||||
|
auth = w.config.APIKey
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return &ZltxOrderDetailResponse{}
|
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
req := l_request.Request{
|
||||||
|
Url: fmt.Sprintf("%szltx_api/admin/direct/ai/%s", w.config.BaseURL, number),
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
||||||
|
},
|
||||||
|
Method: "GET",
|
||||||
|
}
|
||||||
|
res, err := req.Send()
|
||||||
|
|
||||||
return &ZltxOrderDetailResponse{}
|
if err != nil {
|
||||||
|
return fmt.Errorf("订单查询失败:%s", err.Error())
|
||||||
|
}
|
||||||
|
var codeMap map[string]interface{}
|
||||||
|
if err = json.Unmarshal(res.Content, &codeMap); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if codeMap["code"].(float64) != 200 {
|
||||||
|
return fmt.Errorf("订单查询失败:%s", res.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resData ZltxOrderDetailResponse
|
||||||
|
if err = json.Unmarshal(res.Content, &resData); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: res.Text,
|
||||||
|
Type: entitys.ResponseJson,
|
||||||
|
}
|
||||||
|
if resData.Data.Direct != nil && resData.Data.Direct["needAi"].(bool) {
|
||||||
|
ch <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: "正在分析订单日志",
|
||||||
|
Type: entitys.ResponseLoading,
|
||||||
|
}
|
||||||
|
|
||||||
|
req = l_request.Request{
|
||||||
|
Url: fmt.Sprintf("%szltx_api/admin/direct/log/%s/%s", w.config.BaseURL, resData.Data.Direct["orderOrderNumber"].(string), resData.Data.Direct["serialNumber"].(string)),
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
||||||
|
},
|
||||||
|
Method: "GET",
|
||||||
|
}
|
||||||
|
res, err = req.Send()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var orderLog ZltxOrderLogResponse
|
||||||
|
if err = json.Unmarshal(res.Content, &orderLog); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if orderLog.Code != 200 {
|
||||||
|
return fmt.Errorf("订单日志查询失败:%s", orderLog.Error)
|
||||||
|
}
|
||||||
|
dataJson, err := json.Marshal(orderLog.Data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("订单日志解析失败:%s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = w.llm.ChatStream(context.TODO(), ch, []api.Message{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: "你是一个订单日志助手。用户可能会提供订单日志,你需要分析订单日志,提取出订单失败的原因。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: fmt.Sprintf("订单日志:%s", string(dataJson)),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("订单日志解析失败:%s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,126 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ZltxOrderLogTool struct {
|
||||||
|
config config.ToolConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ZltxOrderLogTool) Name() string {
|
||||||
|
return "zltxOrderDirectLog"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ZltxOrderLogTool) Description() string {
|
||||||
|
return "查询订单日志"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ZltxOrderLogTool) Definition() entitys.ToolDefinition {
|
||||||
|
return entitys.ToolDefinition{
|
||||||
|
Type: "function",
|
||||||
|
Function: entitys.FunctionDef{
|
||||||
|
Name: t.Name(),
|
||||||
|
Description: t.Description(),
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"order_number": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "订单编号",
|
||||||
|
},
|
||||||
|
"serial_number": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "流水号",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"order_number", "serial_number"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZltxOrderDetailRequest 直连天下订单详情请求参数
|
||||||
|
type ZltxOrderLogRequest struct {
|
||||||
|
OrderNumber string `json:"order_number"`
|
||||||
|
SerialNumber string `json:"serial_number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZltxOrderDetailResponse 直连天下订单详情响应
|
||||||
|
type ZltxOrderDirectLogResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
Data []ZltxOrderDirectLogData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZltxOrderLogData 直连天下订单详情数据
|
||||||
|
type ZltxOrderDirectLogData struct {
|
||||||
|
Datetime string `json:"datetime"`
|
||||||
|
ServerID string `json:"serverId"`
|
||||||
|
Mes string `json:"mes"`
|
||||||
|
Data map[string]interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ZltxOrderLogTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error {
|
||||||
|
var req ZltxOrderLogRequest
|
||||||
|
if err := json.Unmarshal(args, &req); err != nil {
|
||||||
|
return fmt.Errorf("invalid zltxOrderLog request: %w", err)
|
||||||
|
}
|
||||||
|
if req.OrderNumber == "" || req.SerialNumber == "" {
|
||||||
|
return fmt.Errorf("orderNumber and serialNumber is required")
|
||||||
|
}
|
||||||
|
return t.getZltxOrderLog(channel, c, req.OrderNumber, req.SerialNumber)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.ResponseData, c *websocket.Conn, orderNumber, serialNumber string) (err error) {
|
||||||
|
//查询订单详情
|
||||||
|
var auth string
|
||||||
|
if c != nil {
|
||||||
|
auth = c.Headers("X-Authorization", "")
|
||||||
|
}
|
||||||
|
if len(auth) == 0 {
|
||||||
|
auth = t.config.APIKey
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s%s/%s", t.config.BaseURL, orderNumber, serialNumber)
|
||||||
|
req := l_request.Request{
|
||||||
|
Url: url,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
||||||
|
},
|
||||||
|
Method: "GET",
|
||||||
|
}
|
||||||
|
res, err := req.Send()
|
||||||
|
var resData ZltxOrderDirectLogResponse
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resData.Code != 200 {
|
||||||
|
return fmt.Errorf("订单查询失败:%s", resData.Error)
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(res.Content, &resData); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c != nil {
|
||||||
|
_ = c.WriteMessage(websocket.TextMessage, res.Content)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
channel <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: res.Text,
|
||||||
|
Type: entitys.ResponseJson,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewZltxOrderLogTool(config config.ToolConfig) *ZltxOrderLogTool {
|
||||||
|
return &ZltxOrderLogTool{
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,259 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ZltxProductTool struct {
|
||||||
|
config config.ToolConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxProductTool) Name() string {
|
||||||
|
return "zltxProduct"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxProductTool) Description() string {
|
||||||
|
return "获取直连天下商品信息"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxProductTool) Definition() entitys.ToolDefinition {
|
||||||
|
return entitys.ToolDefinition{
|
||||||
|
Type: "function",
|
||||||
|
Function: entitys.FunctionDef{
|
||||||
|
Name: z.Name(),
|
||||||
|
Description: z.Description(),
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"id": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "商品ID",
|
||||||
|
},
|
||||||
|
"name": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "商品名称",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"id", "name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZltxProductRequest struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxProductTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error {
|
||||||
|
var req ZltxProductRequest
|
||||||
|
if err := json.Unmarshal(args, &req); err != nil {
|
||||||
|
return fmt.Errorf("invalid zltxProduct request: %w", err)
|
||||||
|
}
|
||||||
|
return z.getZltxProduct(channel, c, req.Id, req.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZltxProductResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data struct {
|
||||||
|
DataCount int `json:"dataCount"`
|
||||||
|
List []ZltxProductData `json:"list"`
|
||||||
|
} `json:"data"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZltxProductDataById struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data ZltxProductData `json:"data"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZltxProductData struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
OursProductCategoryID int `json:"ours_product_category_id"`
|
||||||
|
OfficialProductID int `json:"official_product_id"`
|
||||||
|
Tag string `json:"tag"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type int `json:"type"`
|
||||||
|
Discount string `json:"discount"`
|
||||||
|
Preview string `json:"preview"`
|
||||||
|
Describe string `json:"describe"`
|
||||||
|
Price string `json:"price"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
CreateTime string `json:"create_time"`
|
||||||
|
UpdateTime string `json:"update_time"`
|
||||||
|
Extend string `json:"extend"`
|
||||||
|
Wight int `json:"wight"`
|
||||||
|
Property int `json:"property"`
|
||||||
|
AuthProductInfo []any `json:"auth_product_info"`
|
||||||
|
AuthProductIds string `json:"auth_product_ids"`
|
||||||
|
Category struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
CreateTime string `json:"create_time"`
|
||||||
|
Pid int `json:"pid"`
|
||||||
|
} `json:"category"`
|
||||||
|
OfficialProduct struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
OfficialID int `json:"official_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Describe string `json:"describe"`
|
||||||
|
Preview string `json:"preview"`
|
||||||
|
Price float64 `json:"price"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
CreateTime string `json:"create_time"`
|
||||||
|
UpdateTime string `json:"update_time"`
|
||||||
|
Type int `json:"type"`
|
||||||
|
Daies int `json:"daies"`
|
||||||
|
PreviewURL string `json:"preview_url"`
|
||||||
|
Official struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Describe string `json:"describe"`
|
||||||
|
Num int `json:"num"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
WebURL string `json:"web_url"`
|
||||||
|
RechargeURL string `json:"recharge_url"`
|
||||||
|
CreateTime string `json:"create_time"`
|
||||||
|
UpdateTime string `json:"update_time"`
|
||||||
|
Type int `json:"type"`
|
||||||
|
Tag string `json:"tag"`
|
||||||
|
} `json:"official"`
|
||||||
|
} `json:"official_product"`
|
||||||
|
Statistics interface{} `json:"statistics"`
|
||||||
|
PlatformProductList interface{} `json:"platform_product_list"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxProductTool) getZltxProduct(channel chan entitys.ResponseData, c *websocket.Conn, id string, name string) error {
|
||||||
|
var auth string
|
||||||
|
if c != nil {
|
||||||
|
auth = c.Headers("X-Authorization", "")
|
||||||
|
}
|
||||||
|
if len(auth) == 0 {
|
||||||
|
auth = z.config.APIKey
|
||||||
|
}
|
||||||
|
var Url string
|
||||||
|
var params map[string]string
|
||||||
|
if id != "" {
|
||||||
|
Url = fmt.Sprintf("%s/%s", z.config.BaseURL, id)
|
||||||
|
} else {
|
||||||
|
Url = fmt.Sprintf("%s?keyword=%s&limit=10&page=1", z.config.BaseURL, name)
|
||||||
|
params = map[string]string{
|
||||||
|
"keyword": name,
|
||||||
|
"limit": "10",
|
||||||
|
"page": "1",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req := l_request.Request{
|
||||||
|
//get /admin/oursProduct/{product_id} 通过商品id获取我们的商品信息
|
||||||
|
//get /admin/oursProduct?keyword={name}&limit=10&page=1 通过商品name获取我们的商品列表
|
||||||
|
//根据商品ID或名称走不同的接口查询
|
||||||
|
Url: Url,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
||||||
|
},
|
||||||
|
Params: params,
|
||||||
|
Method: "GET",
|
||||||
|
}
|
||||||
|
res, err := req.Send()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var resp ZltxProductResponse
|
||||||
|
if err := json.Unmarshal(res.Content, &resp); err != nil {
|
||||||
|
return fmt.Errorf("解析商品数据失败:%w", err)
|
||||||
|
}
|
||||||
|
if resp.Code != 200 {
|
||||||
|
return fmt.Errorf("商品查询失败:%s", resp.Error)
|
||||||
|
}
|
||||||
|
if resp.Data.List == nil || len(resp.Data.List) == 0 {
|
||||||
|
var respData ZltxProductDataById
|
||||||
|
if err := json.Unmarshal(res.Content, &respData); err != nil {
|
||||||
|
return fmt.Errorf("解析商品数据失败:%w", err)
|
||||||
|
}
|
||||||
|
resp.Data.List = []ZltxProductData{respData.Data}
|
||||||
|
resp.Data.DataCount = 1
|
||||||
|
}
|
||||||
|
//调用 平台商品列表
|
||||||
|
if resp.Data.List != nil && len(resp.Data.List) > 0 {
|
||||||
|
for i := range resp.Data.List {
|
||||||
|
// 调用 平台商品列表
|
||||||
|
if resp.Data.List[i].AuthProductIds != "" {
|
||||||
|
platformProductList := z.ExecutePlatformProductList(auth, resp.Data.List[i].AuthProductIds)
|
||||||
|
resp.Data.List[i].PlatformProductList = platformProductList
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
marshal, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
channel <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: string(marshal),
|
||||||
|
Type: entitys.ResponseJson,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxProductTool) ExecutePlatformProductList(auth string, authProductIds string) []map[string]any {
|
||||||
|
if authProductIds == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//authProductIds 以逗号分割
|
||||||
|
authProductIdList := strings.Split(authProductIds, ",")
|
||||||
|
var result []map[string]any
|
||||||
|
|
||||||
|
for _, authProductId := range authProductIdList {
|
||||||
|
var platformProductResponse map[string]any
|
||||||
|
req := l_request.Request{
|
||||||
|
//https://gateway.dev.cdlsxd.cn/zltx_api/admin/platformProduct/{product_id}?id={product_id}
|
||||||
|
Url: fmt.Sprintf("%s/%s?id=%s", z.config.AddURL, authProductId, authProductId),
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
||||||
|
},
|
||||||
|
Method: "GET",
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := req.Send()
|
||||||
|
if err != nil {
|
||||||
|
// 可以考虑记录日志而不是直接跳过
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(res.Content, &platformProductResponse); err != nil {
|
||||||
|
// 可以考虑记录日志而不是直接跳过
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只提取 data 部分
|
||||||
|
if data, ok := platformProductResponse["data"]; ok {
|
||||||
|
if dataMap, ok := data.(map[string]any); ok {
|
||||||
|
result = append(result, dataMap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZltxProductPlatformProductList struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data map[string]any `json:"data"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewZltxProductTool(config config.ToolConfig) *ZltxProductTool {
|
||||||
|
return &ZltxProductTool{
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,119 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ZltxOrderStatisticsTool struct {
|
||||||
|
config config.ToolConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxOrderStatisticsTool) Name() string {
|
||||||
|
return "zltxOrderStatistics"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxOrderStatisticsTool) Description() string {
|
||||||
|
return "通过账号获取订单统计信息"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxOrderStatisticsTool) Definition() entitys.ToolDefinition {
|
||||||
|
return entitys.ToolDefinition{
|
||||||
|
Type: "function",
|
||||||
|
Function: entitys.FunctionDef{
|
||||||
|
Name: z.Name(),
|
||||||
|
Description: z.Description(),
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"number": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "账号或分销商号",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"number"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZltxOrderStatisticsRequest struct {
|
||||||
|
Number string `json:"number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxOrderStatisticsTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error {
|
||||||
|
var req ZltxOrderStatisticsRequest
|
||||||
|
if err := json.Unmarshal(args, &req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if req.Number == "" {
|
||||||
|
return fmt.Errorf("number is required")
|
||||||
|
}
|
||||||
|
return z.getZltxOrderStatistics(channel, c, req.Number)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZltxOrderStatisticsResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data []ZltxOrderStatisticsData `json:"data"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZltxOrderStatisticsData struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Number string `json:"number"`
|
||||||
|
Success int `json:"success"`
|
||||||
|
Fail int `json:"fail"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.ResponseData, c *websocket.Conn, number string) error {
|
||||||
|
//查询订单详情
|
||||||
|
var auth string
|
||||||
|
if c != nil {
|
||||||
|
auth = c.Headers("X-Authorization", "")
|
||||||
|
}
|
||||||
|
if len(auth) == 0 {
|
||||||
|
auth = z.config.APIKey
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s%s", z.config.BaseURL, number)
|
||||||
|
req := l_request.Request{
|
||||||
|
Url: url,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Authorization": fmt.Sprintf("Bearer %s", auth),
|
||||||
|
},
|
||||||
|
Method: "GET",
|
||||||
|
}
|
||||||
|
res, err := req.Send()
|
||||||
|
var resData ZltxOrderStatisticsResponse
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(res.Content, &resData); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resData.Code != 200 {
|
||||||
|
return fmt.Errorf("zltx order statistics error: %s", resData.Error)
|
||||||
|
}
|
||||||
|
if c != nil {
|
||||||
|
_ = c.WriteMessage(websocket.TextMessage, res.Content)
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
channel <- entitys.ResponseData{
|
||||||
|
Done: false,
|
||||||
|
Content: res.Text,
|
||||||
|
Type: entitys.ResponseJson,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewZltxOrderStatisticsTool(config config.ToolConfig) *ZltxOrderStatisticsTool {
|
||||||
|
return &ZltxOrderStatisticsTool{
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"ai_scheduler/utils"
|
"ai_scheduler/utils"
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
"github.com/go-kratos/kratos/v2/log"
|
"github.com/go-kratos/kratos/v2/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
@ -103,6 +105,65 @@ func (k DataTemp) GetOneBySearchToStrut(cond *builder.Cond, result interface{})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) {
|
||||||
|
var (
|
||||||
|
query, _ = builder.ToBoundSQL(*cond)
|
||||||
|
model = k.Db.Model(k.Model).Where(query)
|
||||||
|
total int64
|
||||||
|
)
|
||||||
|
if len(orderBy) == 0 {
|
||||||
|
orderBy = "updated_at desc"
|
||||||
|
}
|
||||||
|
// 1. 计算总数
|
||||||
|
if err = model.Count(&total).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 设置分页
|
||||||
|
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
|
||||||
|
|
||||||
|
// 3. 查询数据(确保 result 是指针,如 &[]User)
|
||||||
|
if err = model.Limit(pageBoIn.GetSize()).
|
||||||
|
Offset(pageBoIn.GetOffset()).
|
||||||
|
Order(orderBy).
|
||||||
|
Find(result).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 可选:用 reflect 处理结果(确保类型正确)
|
||||||
|
val := reflect.ValueOf(result)
|
||||||
|
if val.Kind() != reflect.Ptr {
|
||||||
|
return nil, fmt.Errorf("result must be a pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
val = val.Elem() // 解引用
|
||||||
|
if val.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < val.Len(); i++ {
|
||||||
|
elem := val.Index(i)
|
||||||
|
if elem.Kind() == reflect.Struct {
|
||||||
|
// 示例:打印某个字段
|
||||||
|
if field := elem.FieldByName("ID"); field.IsValid() {
|
||||||
|
fmt.Println("ID:", field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pageBoOut, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k DataTemp) GetListToStruct2(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}) (pageBoOut *RespPageBo, err error) {
|
||||||
|
var (
|
||||||
|
query, _ = builder.ToBoundSQL(*cond)
|
||||||
|
model = k.Db.Model(k.Model).Where(query)
|
||||||
|
total int64
|
||||||
|
)
|
||||||
|
model.Count(&total)
|
||||||
|
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
|
||||||
|
model.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order("updated_at desc").Find(&result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (k DataTemp) UpdateByCond(cond *builder.Cond, data interface{}) (err error) {
|
func (k DataTemp) UpdateByCond(cond *builder.Cond, data interface{}) (err error) {
|
||||||
var (
|
var (
|
||||||
query, _ = builder.ToBoundSQL(*cond)
|
query, _ = builder.ToBoundSQL(*cond)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue