feat: 优化任务模型与权限控制

This commit is contained in:
renzhiyuan 2025-12-04 10:14:57 +08:00
parent ec5623bc07
commit 4339b6eee8
8 changed files with 114 additions and 62 deletions

View File

@ -3,11 +3,12 @@ server:
port: 8090
host: "0.0.0.0"
ollama:
base_url: "http://127.0.0.1:11434"
model: "qwen3-coder:480b-cloud"
generate_model: "qwen3-coder:480b-cloud"
vl_model: "qwen2.5vl:7b"
vl_model: "gemini-3-pro-preview"
timeout: "120s"
level: "info"
format: "json"
@ -19,6 +20,7 @@ sys:
channel_pool_len: 100
channel_pool_size: 32
llm_pool_len: 5
heartbeat_interval: 30
redis:
host: 47.97.27.195:6379
type: node

View File

@ -139,6 +139,7 @@ func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entit
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
Images: requireData.ImgByte,
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
//Think: &api.ThinkValue{Value: false},
})
if err != nil {
return

View File

@ -36,10 +36,11 @@ type LLM struct {
// SysConfig 系统配置
type SysConfig struct {
SessionLen int `mapstructure:"session_len"`
ChannelPoolLen int `mapstructure:"channel_pool_len"`
ChannelPoolSize int `mapstructure:"channel_pool_size"`
LlmPoolLen int `mapstructure:"llm_pool_len"`
SessionLen int `mapstructure:"session_len"`
ChannelPoolLen int `mapstructure:"channel_pool_len"`
ChannelPoolSize int `mapstructure:"channel_pool_size"`
LlmPoolLen int `mapstructure:"llm_pool_len"`
HeartbeatInterval int `mapstructure:"heartbeat_interval"`
}
// ServerConfig 服务器配置

View File

@ -3,33 +3,44 @@ package gateway
import (
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/model"
"encoding/hex"
"fmt"
"github.com/gofiber/websocket/v2"
"ai_scheduler/internal/pkg"
"context"
"encoding/binary"
"log"
"math/rand"
"time"
"github.com/gofiber/websocket/v2"
)
var (
ErrConnClosed = errors.SysErr("连接不存在或已关闭")
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
idBuf = make([]byte, 20)
)
type Client struct {
id string // 客户端唯一ID
conn *websocket.Conn // WebSocket 连接
session string // 会话ID
key string // 应用密钥
auth string // 用户凭证token
codes []string // 用户权限code
sysInfo *model.AiSy // 系统信息
tasks []model.AiTask // 任务列表
sysCode string // 系统编码
id string // 客户端唯一ID
conn *websocket.Conn // WebSocket 连接
session string // 会话ID
key string // 应用密钥
auth string // 用户凭证token
codes []string // 用户权限code
sysInfo *model.AiSy // 系统信息
tasks []model.AiTask // 任务列表
sysCode string // 系统编码
Ctx context.Context
Cancel context.CancelFunc
LastActive time.Time
}
func NewClient(conn *websocket.Conn) *Client {
func NewClient(conn *websocket.Conn, ctx context.Context, cancel context.CancelFunc) *Client {
return &Client{
id: generateClientID(),
conn: conn,
id: generateClientID(),
conn: conn,
Ctx: ctx,
Cancel: cancel,
}
}
@ -103,12 +114,16 @@ func (c *Client) SendFunc(msg []byte) error {
// 生成唯一的客户端ID
func generateClientID() string {
// 使用时间戳+随机数确保唯一性
// 1. 时间戳
timestamp := time.Now().UnixNano()
randomBytes := make([]byte, 4)
rand.Read(randomBytes)
randomStr := hex.EncodeToString(randomBytes)
return fmt.Sprintf("%d%s", timestamp, randomStr)
binary.BigEndian.PutUint64(idBuf[:8], uint64(timestamp))
// 2. 随机数4字节
binary.BigEndian.PutUint32(idBuf[8:12], rng.Uint32())
// 3. 十六进制编码
n := pkg.HexEncode(idBuf[:12], idBuf[12:])
return string(idBuf[12 : 12+n])
}
// 连接数据验证和收集
@ -136,3 +151,22 @@ func (c *Client) DataAuth() (err error) {
}
return
}
func (c *Client) InitHeartbeat(timeoutSecond time.Duration) {
ticker := time.NewTicker(timeoutSecond * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
//*2是防止丢包连续丢包两次再加5s网络延迟容错
if time.Since(c.LastActive) > (timeoutSecond*2*time.Second + 5) { // 5秒容错
log.Println("Heartbeat timeout", "id", c.id)
c.conn.WriteMessage(websocket.CloseMessage, []byte("Heartbeat timeout"))
c.conn.Close()
return
}
case <-c.Ctx.Done():
return
}
}
}

View File

@ -2,7 +2,9 @@ package gateway
import (
"errors"
"log"
"sync"
"time"
)
type Gateway struct {
@ -20,14 +22,27 @@ func NewGateway() *Gateway {
func (g *Gateway) AddClient(c *Client) {
g.mu.Lock()
defer g.mu.Unlock()
defer func() {
g.mu.Unlock()
//心跳开始计时
c.LastActive = time.Now()
log.Println("client connected:", c.GetID())
log.Println("客户端已连接")
}()
g.clients[c.GetID()] = c
}
func (g *Gateway) RemoveClient(clientID string) {
func (g *Gateway) Cleanup(clientID string) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.clients, clientID)
defer func() {
if c, ex := g.clients[clientID]; ex {
delete(g.clients, clientID)
_ = c.conn.Close()
c.Cancel()
}
g.mu.Unlock()
log.Println("client disconnected:", clientID)
}()
for uid, list := range g.uidMap {
newList := []string{}
for _, cid := range list {
@ -37,6 +52,7 @@ func (g *Gateway) RemoveClient(clientID string) {
}
g.uidMap[uid] = newList
}
}
func (g *Gateway) SendToAll(msg []byte) {
@ -63,6 +79,7 @@ func (g *Gateway) BindUid(clientID, uid string) error {
return errors.New("client not found")
}
g.uidMap[uid] = append(g.uidMap[uid], clientID)
log.Printf("bind %s -> uid:%s\n", clientID, uid)
return nil
}

View File

@ -55,3 +55,13 @@ func ValidateImageURL(rawURL string) error {
return nil
}
// hexEncode 将 src 的二进制数据编码为十六进制字符串,写入 dst返回写入长度
func HexEncode(src, dst []byte) int {
const hextable = "0123456789abcdef"
for i := 0; i < len(src); i++ {
dst[i*2] = hextable[src[i]>>4]
dst[i*2+1] = hextable[src[i]&0xf]
}
return len(src) * 2
}

View File

@ -82,10 +82,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
func routerSocket(app *fiber.App, chatService *services.ChatService) {
ws := app.Group("ws/v1/")
// WebSocket 路由配置
ws.Get("/chat", websocket.New(func(c *websocket.Conn) {
// 可以在这里添加握手前的中间件逻辑(如头校验)
chatService.Chat(c) // 调用实际的 Chat 处理函数
}, websocket.Config{
ws.Get("/chat", websocket.New(chatService.Chat, websocket.Config{
// 可选配置:跨域检查、最大负载大小等
HandshakeTimeout: 10 * time.Second,
//Subprotocols: []string{"json", "msgpack"},

View File

@ -6,9 +6,11 @@ import (
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"context"
"encoding/json"
"log"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
@ -38,20 +40,6 @@ func NewChatService(
}
}
// ToolCallResponse 工具调用响应
type ToolCallResponse struct {
ID string `json:"id" example:"call_1"`
Type string `json:"type" example:"function"`
Function FunctionCallResponse `json:"function"`
Result interface{} `json:"result,omitempty"`
}
// FunctionCallResponse 函数调用响应
type FunctionCallResponse struct {
Name string `json:"name" example:"get_weather"`
Arguments interface{} `json:"arguments"`
}
func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
err := c.WriteMessage(websocket.TextMessage, []byte(content))
if err != nil {
@ -63,15 +51,18 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) {
// Chat 处理WebSocket聊天连接
// 这是WebSocket处理的主入口函数
func (h *ChatService) Chat(c *websocket.Conn) {
// 创建新的客户端实例
h.mu.Lock()
client := gateway.NewClient(c)
h.mu.Unlock()
ctx, cancel := context.WithCancel(context.Background())
// 创建新的客户端实例
client := gateway.NewClient(c, ctx, cancel)
// 心跳检测
go client.InitHeartbeat(time.Duration(h.cfg.Sys.HeartbeatInterval))
// 将客户端添加到网关管理
h.Gw.AddClient(client)
log.Println("client connected:", client.GetID())
log.Println("客户端已连接")
// 确保在函数返回时移除客户端并关闭连接
defer func() {
h.Gw.Cleanup(client.GetID())
}()
// 绑定会话ID
uid := c.Query("x-session")
@ -79,7 +70,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
if err := h.Gw.BindUid(client.GetID(), uid); err != nil {
log.Println("绑定UID错误:", err)
}
log.Printf("bind %s -> uid:%s\n", client.GetID(), uid)
}
// 验证并收集连接数据,后续对话中会使用
@ -89,13 +79,6 @@ func (h *ChatService) Chat(c *websocket.Conn) {
return
}
// 确保在函数返回时移除客户端并关闭连接
defer func() {
h.Gw.RemoveClient(client.GetID())
_ = c.Close()
log.Println("client disconnected:", client.GetID())
}()
// 循环读取客户端消息
for {
// 读取消息
@ -104,7 +87,14 @@ func (h *ChatService) Chat(c *websocket.Conn) {
log.Println("读取错误:", err)
break
}
//if string(message) == `{"type":"ping"}` {
// client.LastActive = time.Now()
// if err := c.WriteMessage(websocket.TextMessage, []byte(`{"type":"pong"}`)); err != nil {
// log.Println("Heartbeat response failed", "id", client.GetID(), "err", err)
// return
// }
// continue
//}
// 处理消息
msg, chatType := h.handleMessageToString(c, messageType, message)
if chatType == constants.ConnStatusClosed {