feat: 优化任务模型与权限控制

This commit is contained in:
renzhiyuan 2025-12-04 10:14:57 +08:00
parent ec5623bc07
commit 4339b6eee8
8 changed files with 114 additions and 62 deletions

View File

@ -3,11 +3,12 @@ server:
port: 8090 port: 8090
host: "0.0.0.0" host: "0.0.0.0"
ollama: ollama:
base_url: "http://127.0.0.1:11434" base_url: "http://127.0.0.1:11434"
model: "qwen3-coder:480b-cloud" model: "qwen3-coder:480b-cloud"
generate_model: "qwen3-coder:480b-cloud" generate_model: "qwen3-coder:480b-cloud"
vl_model: "qwen2.5vl:7b" vl_model: "gemini-3-pro-preview"
timeout: "120s" timeout: "120s"
level: "info" level: "info"
format: "json" format: "json"
@ -19,6 +20,7 @@ sys:
channel_pool_len: 100 channel_pool_len: 100
channel_pool_size: 32 channel_pool_size: 32
llm_pool_len: 5 llm_pool_len: 5
heartbeat_interval: 30
redis: redis:
host: 47.97.27.195:6379 host: 47.97.27.195:6379
type: node type: node

View File

@ -139,6 +139,7 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt, Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
Images: requireData.ImgByte, Images: requireData.ImgByte,
KeepAlive: &api.Duration{Duration: 3600 * time.Second}, KeepAlive: &api.Duration{Duration: 3600 * time.Second},
//Think: &api.ThinkValue{Value: false},
}) })
if err != nil { if err != nil {
return return

View File

@ -36,10 +36,11 @@ type LLM struct {
// SysConfig 系统配置 // SysConfig 系统配置
type SysConfig struct { type SysConfig struct {
SessionLen int `mapstructure:"session_len"` SessionLen int `mapstructure:"session_len"`
ChannelPoolLen int `mapstructure:"channel_pool_len"` ChannelPoolLen int `mapstructure:"channel_pool_len"`
ChannelPoolSize int `mapstructure:"channel_pool_size"` ChannelPoolSize int `mapstructure:"channel_pool_size"`
LlmPoolLen int `mapstructure:"llm_pool_len"` LlmPoolLen int `mapstructure:"llm_pool_len"`
HeartbeatInterval int `mapstructure:"heartbeat_interval"`
} }
// ServerConfig 服务器配置 // ServerConfig 服务器配置

View File

@ -3,33 +3,44 @@ package gateway
import ( import (
errors "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"encoding/hex" "ai_scheduler/internal/pkg"
"fmt" "context"
"github.com/gofiber/websocket/v2" "encoding/binary"
"log"
"math/rand" "math/rand"
"time" "time"
"github.com/gofiber/websocket/v2"
) )
var ( var (
ErrConnClosed = errors.SysErr("连接不存在或已关闭") ErrConnClosed = errors.SysErr("连接不存在或已关闭")
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
idBuf = make([]byte, 20)
) )
type Client struct { type Client struct {
id string // 客户端唯一ID id string // 客户端唯一ID
conn *websocket.Conn // WebSocket 连接 conn *websocket.Conn // WebSocket 连接
session string // 会话ID session string // 会话ID
key string // 应用密钥 key string // 应用密钥
auth string // 用户凭证token auth string // 用户凭证token
codes []string // 用户权限code codes []string // 用户权限code
sysInfo *model.AiSy // 系统信息 sysInfo *model.AiSy // 系统信息
tasks []model.AiTask // 任务列表 tasks []model.AiTask // 任务列表
sysCode string // 系统编码 sysCode string // 系统编码
Ctx context.Context
Cancel context.CancelFunc
LastActive time.Time
} }
func NewClient(conn *websocket.Conn) *Client { func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client {
return &Client{ return &Client{
id: generateClientID(), id: generateClientID(),
conn: conn, conn: conn,
Ctx: ctx,
Cancel: cancel,
} }
} }
@ -103,12 +114,16 @@ func (c *Client) SendFunc(msg []byte) error {
// 生成唯一的客户端ID // 生成唯一的客户端ID
func generateClientID() string { func generateClientID() string {
// 使用时间戳+随机数确保唯一性 // 1. 时间戳
timestamp := time.Now().UnixNano() timestamp := time.Now().UnixNano()
randomBytes := make([]byte, 4) binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
rand.Read(randomBytes)
randomStr := hex.EncodeToString(randomBytes) // 2. 随机数4字节
return fmt.Sprintf("%d%s", timestamp, randomStr) binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
// 3. 十六进制编码
n := pkg.HexEncode(idBuf[:12], idBuf[12:])
return string(idBuf[12 : 12+n])
} }
// 连接数据验证和收集 // 连接数据验证和收集
@ -136,3 +151,22 @@ func (c *Client) DataAuth() (err error) {
} }
return return
} }
func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
ticker := time.NewTicker(timeoutSecond * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
//*2是防止丢包连续丢包两次再加5s网络延迟容错
if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错
log.Println("Heartbeat timeout", "id", c.id)
c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
c.conn.Close()
return
}
case <-c.Ctx.Done():
return
}
}
}

View File

@ -2,7 +2,9 @@ package gateway
import ( import (
"errors" "errors"
"log"
"sync" "sync"
"time"
) )
type Gateway struct { type Gateway struct {
@ -20,14 +22,27 @@ func NewGateway() *Gateway {
func (g *Gateway) AddClient(c *Client) { func (g *Gateway) AddClient(c *Client) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer func() {
g.mu.Unlock()
//心跳开始计时
c.LastActive = time.Now()
log.Println("client connected:", c.GetID())
log.Println("客户端已连接")
}()
g.clients[c.GetID()] = c g.clients[c.GetID()] = c
} }
func (g *Gateway) RemoveClient(clientID string) { func (g *Gateway) Cleanup(clientID string) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer func() {
delete(g.clients, clientID) if c, ex := g.clients[clientID]; ex {
delete(g.clients, clientID)
_ = c.conn.Close()
c.Cancel()
}
g.mu.Unlock()
log.Println("client disconnected:", clientID)
}()
for uid, list := range g.uidMap { for uid, list := range g.uidMap {
newList := []string{} newList := []string{}
for _, cid := range list { for _, cid := range list {
@ -37,6 +52,7 @@ func (g *Gateway) RemoveClient(clientID string) {
} }
g.uidMap[uid] = newList g.uidMap[uid] = newList
} }
} }
func (g *Gateway) SendToAll(msg []byte) { func (g *Gateway) SendToAll(msg []byte) {
@ -63,6 +79,7 @@ func (g *Gateway) BindUid(clientID, uid string) error {
return errors.New("client not found") return errors.New("client not found")
} }
g.uidMap[uid] = append(g.uidMap[uid], clientID) g.uidMap[uid] = append(g.uidMap[uid], clientID)
log.Printf("bind %s -> uid:%s\n", clientID, uid)
return nil return nil
} }

View File

@ -55,3 +55,13 @@ func ValidateImageURL(rawURL string) error {
return nil return nil
} }
// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst返回写入长度
func HexEncode(src, dst []byte) int {
const hextable = "0123456789abcdef"
for i := 0; i < len(src); i++ {
dst[i*2] = hextable[src[i]>>4]
dst[i*2+1] = hextable[src[i]&0xf]
}
return len(src) * 2
}

View File

@ -82,10 +82,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
func routerSocket(app *fiber.App, chatService *services.ChatService) { func routerSocket(app *fiber.App, chatService *services.ChatService) {
ws := app.Group("ws/v1/") ws := app.Group("ws/v1/")
// WebSocket 路由配置 // WebSocket 路由配置
ws.Get("/chat", websocket.New(func(c *websocket.Conn) { ws.Get("/chat", websocket.New(chatService.Chat, websocket.Config{
// 可以在这里添加握手前的中间件逻辑(如头校验)
chatService.Chat(c) // 调用实际的 Chat 处理函数
}, websocket.Config{
// 可选配置:跨域检查、最大负载大小等 // 可选配置:跨域检查、最大负载大小等
HandshakeTimeout: 10 * time.Second, HandshakeTimeout: 10 * time.Second,
//Subprotocols: []string{"json", "msgpack"}, //Subprotocols: []string{"json", "msgpack"},

View File

@ -6,9 +6,11 @@ import (
"ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway" "ai_scheduler/internal/gateway"
"context"
"encoding/json" "encoding/json"
"log" "log"
"sync" "sync"
"time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
@ -38,20 +40,6 @@ func NewChatService(
} }
} }
// ToolCallResponse 工具调用响应
type ToolCallResponse struct {
ID string `json:"id" example:"call_1"`
Type string `json:"type" example:"function"`
Function FunctionCallResponse `json:"function"`
Result interface{} `json:"result,omitempty"`
}
// FunctionCallResponse 函数调用响应
type FunctionCallResponse struct {
Name string `json:"name" example:"get_weather"`
Arguments interface{} `json:"arguments"`
}
func (h *ChatService) ChatFail(c *websocket.Conn, content string) { func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
err := c.WriteMessage(websocket.TextMessage, []byte(content)) err := c.WriteMessage(websocket.TextMessage, []byte(content))
if err != nil { if err != nil {
@ -63,15 +51,18 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
// Chat 处理WebSocket聊天连接 // Chat 处理WebSocket聊天连接
// 这是WebSocket处理的主入口函数 // 这是WebSocket处理的主入口函数
func (h *ChatService) Chat(c *websocket.Conn) { func (h *ChatService) Chat(c *websocket.Conn) {
// 创建新的客户端实例 ctx, cancel := context.WithCancel(context.Background())
h.mu.Lock()
client := gateway.NewClient(c)
h.mu.Unlock()
// 创建新的客户端实例
client := gateway.NewClient(c, ctx, cancel)
// 心跳检测
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
// 将客户端添加到网关管理 // 将客户端添加到网关管理
h.Gw.AddClient(client) h.Gw.AddClient(client)
log.Println("client connected:", client.GetID()) // 确保在函数返回时移除客户端并关闭连接
log.Println("客户端已连接") defer func() {
h.Gw.Cleanup(client.GetID())
}()
// 绑定会话ID // 绑定会话ID
uid := c.Query("x-session") uid := c.Query("x-session")
@ -79,7 +70,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
if err := h.Gw.BindUid(client.GetID(), uid); err != nil { if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
log.Println("绑定UID错误:", err) log.Println("绑定UID错误:", err)
} }
log.Printf("bind %s -> uid:%s\n", client.GetID(), uid)
} }
// 验证并收集连接数据,后续对话中会使用 // 验证并收集连接数据,后续对话中会使用
@ -89,13 +79,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
return return
} }
// 确保在函数返回时移除客户端并关闭连接
defer func() {
h.Gw.RemoveClient(client.GetID())
_ = c.Close()
log.Println("client disconnected:", client.GetID())
}()
// 循环读取客户端消息 // 循环读取客户端消息
for { for {
// 读取消息 // 读取消息
@ -104,7 +87,14 @@ func (h *ChatService) Chat(c *websocket.Conn) {
log.Println("读取错误:", err) log.Println("读取错误:", err)
break break
} }
//if string(message) == `{"type":"ping"}` {
// client.LastActive = time.Now()
// if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil {
// log.Println("Heartbeat response failed", "id", client.GetID(), "err", err)
// return
// }
// continue
//}
// 处理消息 // 处理消息
msg, chatType := h.handleMessageToString(c, messageType, message) msg, chatType := h.handleMessageToString(c, messageType, message)
if chatType == constants.ConnStatusClosed { if chatType == constants.ConnStatusClosed {