feat: 优化任务模型与权限控制
This commit is contained in:
parent
ec5623bc07
commit
4339b6eee8
|
|
@ -3,11 +3,12 @@ server:
|
||||||
port: 8090
|
port: 8090
|
||||||
host: "0.0.0.0"
|
host: "0.0.0.0"
|
||||||
|
|
||||||
|
|
||||||
ollama:
|
ollama:
|
||||||
base_url: "http://127.0.0.1:11434"
|
base_url: "http://127.0.0.1:11434"
|
||||||
model: "qwen3-coder:480b-cloud"
|
model: "qwen3-coder:480b-cloud"
|
||||||
generate_model: "qwen3-coder:480b-cloud"
|
generate_model: "qwen3-coder:480b-cloud"
|
||||||
vl_model: "qwen2.5vl:7b"
|
vl_model: "gemini-3-pro-preview"
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
format: "json"
|
format: "json"
|
||||||
|
|
@ -19,6 +20,7 @@ sys:
|
||||||
channel_pool_len: 100
|
channel_pool_len: 100
|
||||||
channel_pool_size: 32
|
channel_pool_size: 32
|
||||||
llm_pool_len: 5
|
llm_pool_len: 5
|
||||||
|
heartbeat_interval: 30
|
||||||
redis:
|
redis:
|
||||||
host: 47.97.27.195:6379
|
host: 47.97.27.195:6379
|
||||||
type: node
|
type: node
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,7 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit
|
||||||
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||||
Images: requireData.ImgByte,
|
Images: requireData.ImgByte,
|
||||||
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
||||||
|
//Think: &api.ThinkValue{Value: false},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -36,10 +36,11 @@ type LLM struct {
|
||||||
|
|
||||||
// SysConfig 系统配置
|
// SysConfig 系统配置
|
||||||
type SysConfig struct {
|
type SysConfig struct {
|
||||||
SessionLen int `mapstructure:"session_len"`
|
SessionLen int `mapstructure:"session_len"`
|
||||||
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
ChannelPoolLen int `mapstructure:"channel_pool_len"`
|
||||||
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
ChannelPoolSize int `mapstructure:"channel_pool_size"`
|
||||||
LlmPoolLen int `mapstructure:"llm_pool_len"`
|
LlmPoolLen int `mapstructure:"llm_pool_len"`
|
||||||
|
HeartbeatInterval int `mapstructure:"heartbeat_interval"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig 服务器配置
|
// ServerConfig 服务器配置
|
||||||
|
|
|
||||||
|
|
@ -3,33 +3,44 @@ package gateway
|
||||||
import (
|
import (
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"encoding/hex"
|
"ai_scheduler/internal/pkg"
|
||||||
"fmt"
|
"context"
|
||||||
"github.com/gofiber/websocket/v2"
|
"encoding/binary"
|
||||||
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrConnClosed = errors.SysErr("连接不存在或已关闭")
|
ErrConnClosed = errors.SysErr("连接不存在或已关闭")
|
||||||
|
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
idBuf = make([]byte, 20)
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
id string // 客户端唯一ID
|
id string // 客户端唯一ID
|
||||||
conn *websocket.Conn // WebSocket 连接
|
conn *websocket.Conn // WebSocket 连接
|
||||||
session string // 会话ID
|
session string // 会话ID
|
||||||
key string // 应用密钥
|
key string // 应用密钥
|
||||||
auth string // 用户凭证token
|
auth string // 用户凭证token
|
||||||
codes []string // 用户权限code
|
codes []string // 用户权限code
|
||||||
sysInfo *model.AiSy // 系统信息
|
sysInfo *model.AiSy // 系统信息
|
||||||
tasks []model.AiTask // 任务列表
|
tasks []model.AiTask // 任务列表
|
||||||
sysCode string // 系统编码
|
sysCode string // 系统编码
|
||||||
|
Ctx context.Context
|
||||||
|
Cancel context.CancelFunc
|
||||||
|
LastActive time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(conn *websocket.Conn) *Client {
|
func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client {
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
id: generateClientID(),
|
id: generateClientID(),
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
Ctx: ctx,
|
||||||
|
Cancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -103,12 +114,16 @@ func (c *Client) SendFunc(msg []byte) error {
|
||||||
|
|
||||||
// 生成唯一的客户端ID
|
// 生成唯一的客户端ID
|
||||||
func generateClientID() string {
|
func generateClientID() string {
|
||||||
// 使用时间戳+随机数确保唯一性
|
// 1. 时间戳
|
||||||
timestamp := time.Now().UnixNano()
|
timestamp := time.Now().UnixNano()
|
||||||
randomBytes := make([]byte, 4)
|
binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
|
||||||
rand.Read(randomBytes)
|
|
||||||
randomStr := hex.EncodeToString(randomBytes)
|
// 2. 随机数(4字节)
|
||||||
return fmt.Sprintf("%d%s", timestamp, randomStr)
|
binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
|
||||||
|
|
||||||
|
// 3. 十六进制编码
|
||||||
|
n := pkg.HexEncode(idBuf[:12], idBuf[12:])
|
||||||
|
return string(idBuf[12 : 12+n])
|
||||||
}
|
}
|
||||||
|
|
||||||
// 连接数据验证和收集
|
// 连接数据验证和收集
|
||||||
|
|
@ -136,3 +151,22 @@ func (c *Client) DataAuth() (err error) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
|
||||||
|
ticker := time.NewTicker(timeoutSecond * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
//*2是防止丢包,连续丢包两次,再加5s网络延迟容错
|
||||||
|
if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错
|
||||||
|
log.Println("Heartbeat timeout", "id", c.id)
|
||||||
|
c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
|
||||||
|
c.conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-c.Ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Gateway struct {
|
type Gateway struct {
|
||||||
|
|
@ -20,14 +22,27 @@ func NewGateway() *Gateway {
|
||||||
|
|
||||||
func (g *Gateway) AddClient(c *Client) {
|
func (g *Gateway) AddClient(c *Client) {
|
||||||
g.mu.Lock()
|
g.mu.Lock()
|
||||||
defer g.mu.Unlock()
|
defer func() {
|
||||||
|
g.mu.Unlock()
|
||||||
|
//心跳开始计时
|
||||||
|
c.LastActive = time.Now()
|
||||||
|
log.Println("client connected:", c.GetID())
|
||||||
|
log.Println("客户端已连接")
|
||||||
|
}()
|
||||||
g.clients[c.GetID()] = c
|
g.clients[c.GetID()] = c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) RemoveClient(clientID string) {
|
func (g *Gateway) Cleanup(clientID string) {
|
||||||
g.mu.Lock()
|
g.mu.Lock()
|
||||||
defer g.mu.Unlock()
|
defer func() {
|
||||||
delete(g.clients, clientID)
|
if c, ex := g.clients[clientID]; ex {
|
||||||
|
delete(g.clients, clientID)
|
||||||
|
_ = c.conn.Close()
|
||||||
|
c.Cancel()
|
||||||
|
}
|
||||||
|
g.mu.Unlock()
|
||||||
|
log.Println("client disconnected:", clientID)
|
||||||
|
}()
|
||||||
for uid, list := range g.uidMap {
|
for uid, list := range g.uidMap {
|
||||||
newList := []string{}
|
newList := []string{}
|
||||||
for _, cid := range list {
|
for _, cid := range list {
|
||||||
|
|
@ -37,6 +52,7 @@ func (g *Gateway) RemoveClient(clientID string) {
|
||||||
}
|
}
|
||||||
g.uidMap[uid] = newList
|
g.uidMap[uid] = newList
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) SendToAll(msg []byte) {
|
func (g *Gateway) SendToAll(msg []byte) {
|
||||||
|
|
@ -63,6 +79,7 @@ func (g *Gateway) BindUid(clientID, uid string) error {
|
||||||
return errors.New("client not found")
|
return errors.New("client not found")
|
||||||
}
|
}
|
||||||
g.uidMap[uid] = append(g.uidMap[uid], clientID)
|
g.uidMap[uid] = append(g.uidMap[uid], clientID)
|
||||||
|
log.Printf("bind %s -> uid:%s\n", clientID, uid)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,3 +55,13 @@ func ValidateImageURL(rawURL string) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst,返回写入长度
|
||||||
|
func HexEncode(src, dst []byte) int {
|
||||||
|
const hextable = "0123456789abcdef"
|
||||||
|
for i := 0; i < len(src); i++ {
|
||||||
|
dst[i*2] = hextable[src[i]>>4]
|
||||||
|
dst[i*2+1] = hextable[src[i]&0xf]
|
||||||
|
}
|
||||||
|
return len(src) * 2
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -82,10 +82,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
||||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||||
ws := app.Group("ws/v1/")
|
ws := app.Group("ws/v1/")
|
||||||
// WebSocket 路由配置
|
// WebSocket 路由配置
|
||||||
ws.Get("/chat", websocket.New(func(c *websocket.Conn) {
|
ws.Get("/chat", websocket.New(chatService.Chat, websocket.Config{
|
||||||
// 可以在这里添加握手前的中间件逻辑(如头校验)
|
|
||||||
chatService.Chat(c) // 调用实际的 Chat 处理函数
|
|
||||||
}, websocket.Config{
|
|
||||||
// 可选配置:跨域检查、最大负载大小等
|
// 可选配置:跨域检查、最大负载大小等
|
||||||
HandshakeTimeout: 10 * time.Second,
|
HandshakeTimeout: 10 * time.Second,
|
||||||
//Subprotocols: []string{"json", "msgpack"},
|
//Subprotocols: []string{"json", "msgpack"},
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,11 @@ import (
|
||||||
"ai_scheduler/internal/data/constants"
|
"ai_scheduler/internal/data/constants"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/gateway"
|
"ai_scheduler/internal/gateway"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
|
|
@ -38,20 +40,6 @@ func NewChatService(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolCallResponse 工具调用响应
|
|
||||||
type ToolCallResponse struct {
|
|
||||||
ID string `json:"id" example:"call_1"`
|
|
||||||
Type string `json:"type" example:"function"`
|
|
||||||
Function FunctionCallResponse `json:"function"`
|
|
||||||
Result interface{} `json:"result,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FunctionCallResponse 函数调用响应
|
|
||||||
type FunctionCallResponse struct {
|
|
||||||
Name string `json:"name" example:"get_weather"`
|
|
||||||
Arguments interface{} `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
||||||
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
err := c.WriteMessage(websocket.TextMessage, []byte(content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -63,15 +51,18 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
|
||||||
// Chat 处理WebSocket聊天连接
|
// Chat 处理WebSocket聊天连接
|
||||||
// 这是WebSocket处理的主入口函数
|
// 这是WebSocket处理的主入口函数
|
||||||
func (h *ChatService) Chat(c *websocket.Conn) {
|
func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
// 创建新的客户端实例
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
h.mu.Lock()
|
|
||||||
client := gateway.NewClient(c)
|
|
||||||
h.mu.Unlock()
|
|
||||||
|
|
||||||
|
// 创建新的客户端实例
|
||||||
|
client := gateway.NewClient(c, ctx, cancel)
|
||||||
|
// 心跳检测
|
||||||
|
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
|
||||||
// 将客户端添加到网关管理
|
// 将客户端添加到网关管理
|
||||||
h.Gw.AddClient(client)
|
h.Gw.AddClient(client)
|
||||||
log.Println("client connected:", client.GetID())
|
// 确保在函数返回时移除客户端并关闭连接
|
||||||
log.Println("客户端已连接")
|
defer func() {
|
||||||
|
h.Gw.Cleanup(client.GetID())
|
||||||
|
}()
|
||||||
|
|
||||||
// 绑定会话ID
|
// 绑定会话ID
|
||||||
uid := c.Query("x-session")
|
uid := c.Query("x-session")
|
||||||
|
|
@ -79,7 +70,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
|
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
|
||||||
log.Println("绑定UID错误:", err)
|
log.Println("绑定UID错误:", err)
|
||||||
}
|
}
|
||||||
log.Printf("bind %s -> uid:%s\n", client.GetID(), uid)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证并收集连接数据,后续对话中会使用
|
// 验证并收集连接数据,后续对话中会使用
|
||||||
|
|
@ -89,13 +79,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 确保在函数返回时移除客户端并关闭连接
|
|
||||||
defer func() {
|
|
||||||
h.Gw.RemoveClient(client.GetID())
|
|
||||||
_ = c.Close()
|
|
||||||
log.Println("client disconnected:", client.GetID())
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 循环读取客户端消息
|
// 循环读取客户端消息
|
||||||
for {
|
for {
|
||||||
// 读取消息
|
// 读取消息
|
||||||
|
|
@ -104,7 +87,14 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
log.Println("读取错误:", err)
|
log.Println("读取错误:", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
//if string(message) == `{"type":"ping"}` {
|
||||||
|
// client.LastActive = time.Now()
|
||||||
|
// if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil {
|
||||||
|
// log.Println("Heartbeat response failed", "id", client.GetID(), "err", err)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// continue
|
||||||
|
//}
|
||||||
// 处理消息
|
// 处理消息
|
||||||
msg, chatType := h.handleMessageToString(c, messageType, message)
|
msg, chatType := h.handleMessageToString(c, messageType, message)
|
||||||
if chatType == constants.ConnStatusClosed {
|
if chatType == constants.ConnStatusClosed {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue