From 3de460658d28d42a76ed61b68d7a1660a773388b Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 21 Nov 2025 09:31:55 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat:=E6=8E=A5=E5=85=A5=E6=9D=83=E9=99=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 10 ++++++- internal/biz/do/handle.go | 61 +++++++++++++++++++++++++++++++++++++++ internal/config/config.go | 26 ++++++++++++----- 3 files changed, 88 insertions(+), 9 deletions(-) diff --git a/config/config_test.yaml b/config/config_test.yaml index ff2d2d8..973afa0 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -64,4 +64,12 @@ default_prompt: 提取出图片中对用户可能有用的关键信息(例如金额、日期、标题、编号、联系信息、商品名称等)。 若图片为文档类(如合同、发票、收据),请结构化输出关键字段(如客户名称、金额、开票日期等)。 ' - user_prompt: '识别图片内容' \ No newline at end of file + user_prompt: '识别图片内容' +# 权限配置 +permissionConfig: + # 统一登录平台基础URL + unified_login_platform_base_url: "https://api.test.user.1688sup.com" + # 白名单接口 + white_list: + - "chat" # 聊天接口 + - "bug_optimization_submit" # 优化建议提交接口 diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index 4c2803d..bfbd76f 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -16,6 +16,9 @@ import ( "context" "encoding/json" "fmt" + "github.com/gofiber/fiber/v2/log" + "gorm.io/gorm/utils" + "net/http" "strings" ) @@ -90,6 +93,14 @@ func (r *Handle) HandleMatch(ctx context.Context, requireData *entitys.RequireDa if pointTask == nil || pointTask.Index == "other" { return r.OtherTask(ctx, requireData) } + + // 校验用户权限 + if err = r.PermissionAuth(requireData, pointTask); err != nil { + log.Errorf("权限验证失败: %s", err.Error()) + entitys.ResLog(requireData.Ch, "", "权限验证失败:"+err.Error()) + return + } + switch constants.TaskType(pointTask.Type) { case constants.TaskTypeApi: return r.handleApiTask(ctx, requireData, pointTask) @@ -252,3 +263,53 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require return } + +// 权限验证 +func (r *Handle) PermissionAuth(requireData *entitys.RequireData, pointTask *model.AiTask) (err error) { + // 白名单接口不要校验权限 + if utils.Contains(r.conf.PermissionConfig.WhiteList, pointTask.Index) { + return nil + } + + // 查询用户权限 + var ( + request l_request.Request + ) + + request.Url = r.conf.PermissionConfig.UnifiedLoginPlatformBaseURL + + request.Method = "GET" + request.Headers = map[string]string{ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + "Accept": "application/json, text/plain, */*", + "Authorization": "Bearer " + requireData.Auth, + } + + // 发送请求 + res, err := request.Send() + if err != nil { + return err + } + + // 检查响应状态码 + if res.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + type resp struct { + Codes []string `json:"codes"` + } + // 解析响应体 + var respBody resp + err = json.Unmarshal([]byte(res.Text), &respBody) + if err != nil { + return err + } + + // 检查权限 + if !utils.Contains(respBody.Codes, pointTask.Index) { + return fmt.Errorf("用户权限不足: %s", pointTask.Name) + } + + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index d39d80a..b2350f7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,14 +9,15 @@ import ( // Config 应用配置 type Config struct { - Server ServerConfig `mapstructure:"server"` - Ollama OllamaConfig `mapstructure:"ollama"` - Sys SysConfig `mapstructure:"sys"` - Tools ToolsConfig `mapstructure:"tools"` - Logging LoggingConfig `mapstructure:"logging"` - Redis Redis `mapstructure:"redis"` - DB DB `mapstructure:"db"` - DefaultPrompt SysPrompt `mapstructure:"default_prompt"` + Server ServerConfig `mapstructure:"server"` + Ollama OllamaConfig `mapstructure:"ollama"` + Sys SysConfig `mapstructure:"sys"` + Tools ToolsConfig `mapstructure:"tools"` + Logging LoggingConfig `mapstructure:"logging"` + Redis Redis `mapstructure:"redis"` + DB DB `mapstructure:"db"` + DefaultPrompt SysPrompt `mapstructure:"default_prompt"` + PermissionConfig PermissionConfig `mapstructure:"permissionConfig"` // LLM *LLM `mapstructure:"llm"` } @@ -107,6 +108,15 @@ type LoggingConfig struct { Format string `mapstructure:"format"` } +// PermissionConfig 权限校验配置 +type PermissionConfig struct { + UnifiedLoginPlatformBaseURL string `mapstructure:"unified_login_platform_base_url"` // 统一登录平台基础URL + // 白名单任务 + WhiteList []string `mapstructure:"white_list"` // 白名单任务列表 +} + +// 权限校验配置 + // LoadConfig 加载配置 func LoadConfig(configPath string) (*Config, error) { viper.SetConfigFile(configPath) From 10db75a7aa57c1e2368003b2cbd4a034bbb8a62b Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Fri, 21 Nov 2025 11:57:18 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E6=8F=90=E7=A4=BA=E8=AF=8D=E5=A2=9E=E5=8A=A0=E6=97=B6?= =?UTF-8?q?=E9=97=B4=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/llm_service/common.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/biz/llm_service/common.go b/internal/biz/llm_service/common.go index f10406b..1d62ed7 100644 --- a/internal/biz/llm_service/common.go +++ b/internal/biz/llm_service/common.go @@ -15,7 +15,7 @@ type LlmService interface { // buildSystemPrompt 构建系统提示词 func buildSystemPrompt(prompt string) string { if len(prompt) == 0 { - prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容" + prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON:{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式,禁用markdown格式返回\n3.只返回json字符串,不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容,\n当前时间是:" + time.Now().Format(time.DateTime) } return prompt From c71c72038e65eea1b3386d46d1f70528a64ee477 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 21 Nov 2025 15:17:47 +0800 Subject: [PATCH 3/4] =?UTF-8?q?feat:=E8=B0=83=E6=95=B4chat=E9=83=A8?= =?UTF-8?q?=E5=88=86=E4=BB=A3=E7=A0=81=EF=BC=8C=E7=94=A8=E6=88=B7=E6=9D=83?= =?UTF-8?q?=E9=99=90=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config_test.yaml | 11 +- internal/biz/do/ctx.go | 167 +++++++++++++++++++----------- internal/biz/do/handle.go | 53 +++------- internal/biz/router.go | 38 +++++-- internal/config/config.go | 15 ++- internal/data/error/error_code.go | 1 + internal/entitys/response.go | 1 - internal/gateway/client.go | 138 ++++++++++++++++++++++++ internal/gateway/gateway.go | 7 +- internal/services/chat.go | 158 ++++++++++++++++++++-------- 10 files changed, 419 insertions(+), 170 deletions(-) create mode 100644 internal/gateway/client.go diff --git a/config/config_test.yaml b/config/config_test.yaml index 973afa0..f45fff4 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -67,9 +67,14 @@ default_prompt: user_prompt: '识别图片内容' # 权限配置 permissionConfig: - # 统一登录平台基础URL - unified_login_platform_base_url: "https://api.test.user.1688sup.com" - # 白名单接口 + # 不同系统的权限校验配置 + sys_permission: + # 直连天下系统 + zltx: + permission_url: "https://gateway.dev.cdlsxd.cn/zltx_api/test/v1/menu/myCodes?systemCode=zltx" + white_list: + - "knowledge_qa" # 知识问答 + # 通用的白名单接口 white_list: - "chat" # 聊天接口 - "bug_optimization_submit" # 优化建议提交接口 diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index a160c8c..b0e7299 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -6,6 +6,7 @@ import ( "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg" "ai_scheduler/tmpl/dataTemp" "context" @@ -22,7 +23,7 @@ import ( ) type Do struct { - Ctx *entitys.RequireData + //Ctx *entitys.RequireData sessionImpl *impl.SessionImpl sysImpl *impl.SysImpl taskImpl *impl.TaskImpl @@ -44,78 +45,120 @@ func NewDo( } } -func (d *Do) InitCtx(req *entitys.ChatSockRequest) *Do { - d.Ctx = &entitys.RequireData{ - Req: req, +func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { + // 1. 验证客户端数据 + if err = d.validateClientData(client, requireData); err != nil { + return err } - return d + + // 2. 加载系统信息 + if err = d.loadSystemInfo(ctx, client, requireData); err != nil { + return fmt.Errorf("获取系统信息失败: %w", err) + } + + // 3. 加载任务列表 + if err = d.loadTaskList(ctx, client, requireData); err != nil { + return fmt.Errorf("获取任务列表失败: %w", err) + } + + // 4. 加载聊天历史 + if err = d.loadChatHistory(ctx, requireData); err != nil { + return fmt.Errorf("获取历史记录失败: %w", err) + } + + // 5. 加载图片数据 + if err = d.getImgData(requireData); err != nil { + return err + } + + return nil } -func (d *Do) DataAuth(c *websocket.Conn) (err error) { - d.Ctx.Session = c.Query("x-session", "") - if len(d.Ctx.Session) == 0 { - err = errors.SessionNotFound - return - } - d.Ctx.Auth = c.Query("x-authorization", "") - if len(d.Ctx.Auth) == 0 { - err = errors.AuthNotFound - return - } - d.Ctx.Key = c.Query("x-app-key", "") - if len(d.Ctx.Key) == 0 { - err = errors.KeyNotFound - return +// 提取数据验证为单独函数 +func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error { + requireData.Session = client.GetSession() + if len(requireData.Session) == 0 { + return errors.SessionNotFound } - d.Ctx.Sys, err = d.getSysInfo() - if err != nil { - err = errors.SysErr("获取系统信息失败:%v", err.Error()) - return - } - d.Ctx.Histories, err = d.getSessionChatHis() - if err != nil { - err = errors.SysErr("获取历史记录失败:%v", err.Error()) - return + requireData.Auth = client.GetAuth() + if len(requireData.Auth) == 0 { + return errors.AuthNotFound } - d.Ctx.Tasks, err = d.getTasks(d.Ctx.Sys.SysID) - if err != nil { - err = errors.SysErr("获取任务列表失败:%v", err.Error()) - return - } - if err = d.getImgData(); err != nil { - return + requireData.Key = client.GetKey() + if len(requireData.Key) == 0 { + return errors.KeyNotFound } - return + return nil } -func (d *Do) MakeCh(c *websocket.Conn) (ctx context.Context, deferFunc func()) { - d.Ctx.Ch = make(chan entitys.Response) +// 获取系统信息的辅助函数 +func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { + if sysInfo := client.GetSysInfo(); sysInfo == nil { + sys, err := d.getSysInfo(requireData) + if err != nil { + return err + } + client.SetSysInfo(&sys) + requireData.Sys = sys + } else { + requireData.Sys = *sysInfo + } + return nil +} + +// 获取任务列表的辅助函数 +func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { + if taskInfo := client.GetTasks(); len(taskInfo) == 0 { + tasks, err := d.getTasks(requireData.Sys.SysID) + if err != nil { + return err + } + requireData.Tasks = tasks + client.SetTasks(tasks) + } else { + requireData.Tasks = taskInfo + } + return nil +} + +// 获取历史记录的辅助函数 +func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireData) error { + histories, err := d.getSessionChatHis(requireData) + if err != nil { + return err + } + requireData.Histories = histories + return nil +} + +func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) { + requireData.Ch = make(chan entitys.Response) ctx, cancel := context.WithCancel(context.Background()) - done := d.startMessageHandler(ctx, c) + done := d.startMessageHandler(ctx, c, requireData) return ctx, func() { - close(d.Ctx.Ch) //关闭主通道 - <-done // 等待消息处理完成 + close(requireData.Ch) //关闭主通道 + <-done // 等待消息处理完成 cancel() } } -func (d *Do) getImgData() (err error) { - if len(d.Ctx.Req.Img) == 0 { +func (d *Do) getImgData(requireData *entitys.RequireData) (err error) { + if len(requireData.Req.Img) == 0 { return } - imgs := strings.Split(d.Ctx.Req.Img, ",") + imgs := strings.Split(requireData.Req.Img, ",") if len(imgs) == 0 { return } for k, img := range imgs { baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:" - entitys.ResLog(d.Ctx.Ch, "img_get_start", "正在获取第"+strconv.Itoa(k+1)+"张图片") + entitys.ResLog(requireData.Ch, "img_get_start", "正在获取第"+strconv.Itoa(k+1)+"张图片") if err = pkg.ValidateImageURL(img); err != nil { - entitys.ResLog(d.Ctx.Ch, "", baseErr+":expected image content") + entitys.ResLog(requireData.Ch, "", baseErr+":expected image content") continue } req := l_request.Request{ @@ -128,20 +171,20 @@ func (d *Do) getImgData() (err error) { } res, _err := req.Send() if _err != nil { - entitys.ResLog(d.Ctx.Ch, "", baseErr+_err.Error()) + entitys.ResLog(requireData.Ch, "", baseErr+_err.Error()) continue } if _, ex := res.Headers["Content-Type"]; !ex { - entitys.ResLog(d.Ctx.Ch, "", baseErr+":Content-Type不存在") + entitys.ResLog(requireData.Ch, "", baseErr+":Content-Type不存在") continue } if !strings.HasPrefix(res.Headers["Content-Type"], "image/") { - entitys.ResLog(d.Ctx.Ch, "", baseErr+":expected image content") + entitys.ResLog(requireData.Ch, "", baseErr+":expected image content") continue } - d.Ctx.ImgByte = append(d.Ctx.ImgByte, res.Content) - d.Ctx.ImgUrls = append(d.Ctx.ImgUrls, img) - entitys.ResLog(d.Ctx.Ch, "img_get_end", "第"+strconv.Itoa(k+1)+"张图片获取成功") + requireData.ImgByte = append(requireData.ImgByte, res.Content) + requireData.ImgUrls = append(requireData.ImgUrls, img) + entitys.ResLog(requireData.Ch, "img_get_end", "第"+strconv.Itoa(k+1)+"张图片获取成功") } return @@ -152,19 +195,19 @@ func (d *Do) getRequireData() (err error) { return } -func (d *Do) getSysInfo() (sysInfo model.AiSy, err error) { +func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) { cond := builder.NewCond() - cond = cond.And(builder.Eq{"app_key": d.Ctx.Key}) + cond = cond.And(builder.Eq{"app_key": requireData.Key}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) return } -func (d *Do) getSessionChatHis() (his []model.AiChatHi, err error) { +func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.AiChatHi, err error) { cond := builder.NewCond() - cond = cond.And(builder.Eq{"session_id": d.Ctx.Session}) + cond = cond.And(builder.Eq{"session_id": requireData.Session}) _, err = d.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc") @@ -186,7 +229,7 @@ func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) { func (d *Do) startMessageHandler( ctx context.Context, c *websocket.Conn, - + requireData *entitys.RequireData, ) <-chan struct{} { done := make(chan struct{}) var chat []string @@ -200,10 +243,10 @@ func (d *Do) startMessageHandler( ) if len(chat) > 0 { AiRes := &model.AiChatHi{ - SessionID: d.Ctx.Session, - Ques: d.Ctx.Req.Text, + SessionID: requireData.Session, + Ques: requireData.Req.Text, Ans: strings.Join(chat, ""), - Files: d.Ctx.Req.Img, + Files: requireData.Req.Img, } d.hisImpl.AddWithData(AiRes) hisLog.HisId = AiRes.HisID @@ -216,7 +259,7 @@ func (d *Do) startMessageHandler( }() - for v := range d.Ctx.Ch { // 自动检测通道关闭 + for v := range requireData.Ch { // 自动检测通道关闭 if err := sendWithTimeout(c, v, 2*time.Second); err != nil { log.Errorf("Send error: %v", err) return diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index bfbd76f..828e443 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -8,6 +8,7 @@ import ( "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/l_request" "ai_scheduler/internal/pkg/mapstructure" @@ -18,7 +19,6 @@ import ( "fmt" "github.com/gofiber/fiber/v2/log" "gorm.io/gorm/utils" - "net/http" "strings" ) @@ -71,7 +71,7 @@ func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.Requi return } -func (r *Handle) HandleMatch(ctx context.Context, requireData *entitys.RequireData) (err error) { +func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { if !requireData.Match.IsMatch { if len(requireData.Match.Chat) != 0 { @@ -95,9 +95,8 @@ func (r *Handle) HandleMatch(ctx context.Context, requireData *entitys.RequireDa } // 校验用户权限 - if err = r.PermissionAuth(requireData, pointTask); err != nil { + if err = r.PermissionAuth(client, pointTask); err != nil { log.Errorf("权限验证失败: %s", err.Error()) - entitys.ResLog(requireData.Ch, "", "权限验证失败:"+err.Error()) return } @@ -265,49 +264,21 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require } // 权限验证 -func (r *Handle) PermissionAuth(requireData *entitys.RequireData, pointTask *model.AiTask) (err error) { - // 白名单接口不要校验权限 +func (r *Handle) PermissionAuth(client *gateway.Client, pointTask *model.AiTask) (err error) { + // 通用权限校验 if utils.Contains(r.conf.PermissionConfig.WhiteList, pointTask.Index) { return nil } - // 查询用户权限 - var ( - request l_request.Request - ) - - request.Url = r.conf.PermissionConfig.UnifiedLoginPlatformBaseURL - - request.Method = "GET" - request.Headers = map[string]string{ - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", - "Accept": "application/json, text/plain, */*", - "Authorization": "Bearer " + requireData.Auth, + // 系统权限校验 + if v, ok := r.conf.PermissionConfig.SysPermission[client.GetSysCode()]; !ok { + return fmt.Errorf("未配置系统权限校验: %s", client.GetSysCode()) + } else if utils.Contains(v.WhiteList, pointTask.Index) { + return nil } - // 发送请求 - res, err := request.Send() - if err != nil { - return err - } - - // 检查响应状态码 - if res.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - type resp struct { - Codes []string `json:"codes"` - } - // 解析响应体 - var respBody resp - err = json.Unmarshal([]byte(res.Text), &respBody) - if err != nil { - return err - } - - // 检查权限 - if !utils.Contains(respBody.Codes, pointTask.Index) { + // 授权检查权限 + if !utils.Contains(client.GetCodes(), pointTask.Index) { return fmt.Errorf("用户权限不足: %s", pointTask.Name) } diff --git a/internal/biz/router.go b/internal/biz/router.go index 287d87b..70a069e 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -2,11 +2,11 @@ package biz import ( "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/gateway" "ai_scheduler/internal/entitys" "github.com/gofiber/fiber/v2/log" - "github.com/gofiber/websocket/v2" ) // AiRouterBiz 智能路由服务 @@ -26,28 +26,48 @@ func NewAiRouterBiz( } } -func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { - //必要数据验证和获取 - dos := r.do.InitCtx(req) +// 路由处理WebSocket请求 +// +// 参数: +// - client: 网关客户端 +// - req: 聊天请求结构体 +// +// 返回: +// - err: 错误信息 +func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatSockRequest) (err error) { + // 创建请求上下文数据 + requireData := &entitys.RequireData{ + Req: req, + } + // 获取WebSocket连接 + conn := client.GetConn() //初始化通道/上下文 - ctx, clearFunc := dos.MakeCh(c) - defer clearFunc() + ctx, clearFunc := r.do.MakeCh(conn, requireData) + defer func() { + if err != nil { + requireData.Ch <- entitys.Response{ + Content: err.Error(), + Type: entitys.ResponseErr, + } + } + clearFunc() + }() //数据验证和收集 - if err = dos.DataAuth(c); err != nil { + if err = r.do.DataAuth(ctx, client, requireData); err != nil { log.Errorf("数据验证和收集失败: %s", err.Error()) return } //意图识别 - if err = r.handle.Recognize(ctx, dos.Ctx); err != nil { + if err = r.handle.Recognize(ctx, requireData); err != nil { log.Errorf("意图识别失败: %s", err.Error()) return } //向下传递 - if err = r.handle.HandleMatch(ctx, dos.Ctx); err != nil { + if err = r.handle.HandleMatch(ctx, client, requireData); err != nil { log.Errorf("任务处理失败: %s", err.Error()) return } diff --git a/internal/config/config.go b/internal/config/config.go index b2350f7..27d46e4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -110,9 +110,18 @@ type LoggingConfig struct { // PermissionConfig 权限校验配置 type PermissionConfig struct { - UnifiedLoginPlatformBaseURL string `mapstructure:"unified_login_platform_base_url"` // 统一登录平台基础URL - // 白名单任务 - WhiteList []string `mapstructure:"white_list"` // 白名单任务列表 + // 不同系统的权限校验配置 + SysPermission map[string]SysWhiteList `mapstructure:"sys_permission"` + // 通用的白名单任务列表,不需要权限校验 + WhiteList []string `mapstructure:"white_list"` +} + +// 细分系统的白名单任务列表 +type SysWhiteList struct { + // 获取权限的地址 + PermissionURL string `mapstructure:"permission_url"` + // 系统的白名单任务列表 + WhiteList []string `mapstructure:"white_list"` } // 权限校验配置 diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index e9cc67d..9c28865 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -13,6 +13,7 @@ var ( AuthNotFound = &BusinessErr{code: 408, message: "身份验证失败"} KeyNotFound = &BusinessErr{code: 409, message: "身份验证失败"} SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"} + SysCodeNotFound = &BusinessErr{code: 411, message: "未找到系统编码"} InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} ) diff --git a/internal/entitys/response.go b/internal/entitys/response.go index 73bd581..5262102 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -2,7 +2,6 @@ package entitys import ( "encoding/json" - "github.com/gofiber/websocket/v2" ) diff --git a/internal/gateway/client.go b/internal/gateway/client.go new file mode 100644 index 0000000..b49bf45 --- /dev/null +++ b/internal/gateway/client.go @@ -0,0 +1,138 @@ +package gateway + +import ( + errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/data/model" + "encoding/hex" + "fmt" + "github.com/gofiber/websocket/v2" + "math/rand" + "time" +) + +var ( + ErrConnClosed = errors.SysErr("连接不存在或已关闭") +) + +type Client struct { + id string // 客户端唯一ID + conn *websocket.Conn // WebSocket 连接 + session string // 会话ID + key string // 应用密钥 + auth string // 用户凭证token + codes []string // 用户权限code + sysInfo *model.AiSy // 系统信息 + tasks []model.AiTask // 任务列表 + sysCode string // 系统编码 +} + +func NewClient(conn *websocket.Conn) *Client { + return &Client{ + id: generateClientID(), + conn: conn, + } +} + +// GetID 获取客户端的唯一ID +func (c *Client) GetID() string { + return c.id +} + +// GetConn 获取客户端的 WebSocket 连接 +func (c *Client) GetConn() *websocket.Conn { + return c.conn +} + +// GetSession 获取会话ID +func (c *Client) GetSession() string { + return c.session +} + +// GetKey 获取应用密钥 +func (c *Client) GetKey() string { + return c.key +} + +// GetAuth 获取用户凭证token +func (c *Client) GetAuth() string { + return c.auth +} + +// GetCodes 获取用户权限code +func (c *Client) GetCodes() []string { + return c.codes +} + +// GetSysCode 获取系统编码 +func (c *Client) GetSysCode() string { + return c.sysCode +} + +// GetSysInfo 获取系统信息 +func (c *Client) GetSysInfo() *model.AiSy { + return c.sysInfo +} + +// SetSysInfo 设置系统信息 +func (c *Client) SetSysInfo(sysInfo *model.AiSy) { + c.sysInfo = sysInfo +} + +// GetTasks 获取任务列表 +func (c *Client) GetTasks() []model.AiTask { + return c.tasks +} + +// SetTasks 设置任务列表 +func (c *Client) SetTasks(tasks []model.AiTask) { + c.tasks = tasks +} + +// 设置用户权限code +func (c *Client) SetCodes(codes []string) { + c.codes = codes +} + +// SendFunc 发送消息到客户端 +func (c *Client) SendFunc(msg []byte) error { + if c.conn != nil { + return c.conn.WriteMessage(websocket.TextMessage, msg) + } + return ErrConnClosed +} + +// 生成唯一的客户端ID +func generateClientID() string { + // 使用时间戳+随机数确保唯一性 + timestamp := time.Now().UnixNano() + randomBytes := make([]byte, 4) + rand.Read(randomBytes) + randomStr := hex.EncodeToString(randomBytes) + return fmt.Sprintf("%d%s", timestamp, randomStr) +} + +// 连接数据验证和收集 +func (c *Client) DataAuth() (err error) { + c.session = c.conn.Query("x-session", "") + if len(c.session) == 0 { + err = errors.SessionNotFound + return + } + c.auth = c.conn.Query("x-authorization", "") + if len(c.auth) == 0 { + err = errors.AuthNotFound + return + } + c.key = c.conn.Query("x-app-key", "") + if len(c.key) == 0 { + err = errors.KeyNotFound + return + } + // 系统编码 + c.sysCode = c.conn.Query("x-sys-code", "") + if len(c.sysCode) == 0 { + err = errors.SysCodeNotFound + return + } + return +} diff --git a/internal/gateway/gateway.go b/internal/gateway/gateway.go index f04ed3a..0f0e6f3 100644 --- a/internal/gateway/gateway.go +++ b/internal/gateway/gateway.go @@ -5,11 +5,6 @@ import ( "sync" ) -type Client struct { - ID string - SendFunc func(data []byte) error -} - type Gateway struct { mu sync.RWMutex clients map[string]*Client // clientID -> Client @@ -26,7 +21,7 @@ func NewGateway() *Gateway { func (g *Gateway) AddClient(c *Client) { g.mu.Lock() defer g.mu.Unlock() - g.clients[c.ID] = c + g.clients[c.GetID()] = c } func (g *Gateway) RemoveClient(clientID string) { diff --git a/internal/services/chat.go b/internal/services/chat.go index 95cd130..bb28412 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -2,19 +2,18 @@ package services import ( "ai_scheduler/internal/biz" + "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" + errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" - "encoding/hex" + "ai_scheduler/internal/pkg/l_request" "encoding/json" - "fmt" - "log" - "math/rand" - "sync" - "time" - "github.com/gofiber/fiber/v2" "github.com/gofiber/websocket/v2" + "log" + "net/http" + "sync" ) // ChatHandler 聊天处理器 @@ -23,6 +22,7 @@ type ChatService struct { Gw *gateway.Gateway mu sync.Mutex ChatHis *biz.ChatHistoryBiz + cfg *config.Config } // NewChatHandler 创建聊天处理器 @@ -30,11 +30,13 @@ func NewChatService( routerService *biz.AiRouterBiz, chatHis *biz.ChatHistoryBiz, gw *gateway.Gateway, + cfg *config.Config, ) *ChatService { return &ChatService{ routerBiz: routerService, Gw: gw, ChatHis: chatHis, + cfg: cfg, } } @@ -60,36 +62,61 @@ func (h *ChatService) ChatFail(c *websocket.Conn, content string) { _ = c.Close() } -func generateClientID() string { - // 使用时间戳+随机数确保唯一性 - timestamp := time.Now().UnixNano() - randomBytes := make([]byte, 4) - rand.Read(randomBytes) - randomStr := hex.EncodeToString(randomBytes) - return fmt.Sprintf("%d%s", timestamp, randomStr) -} +// Chat 处理WebSocket聊天连接 +// 这是WebSocket处理的主入口函数 func (h *ChatService) Chat(c *websocket.Conn) { + // 创建新的客户端实例 h.mu.Lock() - clientID := generateClientID() + client := gateway.NewClient(c) h.mu.Unlock() - client := &gateway.Client{ - ID: clientID, - SendFunc: func(data []byte) error { - return c.WriteMessage(websocket.TextMessage, data) - }, - } + + // 将客户端添加到网关管理 h.Gw.AddClient(client) - log.Println("client connected:", clientID) + log.Println("client connected:", client.GetID()) log.Println("客户端已连接") + // 绑定会话ID + uid := c.Query("x-session") + if uid != "" { + if err := h.Gw.BindUid(client.GetID(), uid); err != nil { + log.Println("绑定UID错误:", err) + } + log.Printf("bind %s -> uid:%s\n", client.GetID(), uid) + } + + // 验证并收集连接数据,后续对话中会使用 + if err := client.DataAuth(); err != nil { + log.Println("数据验证错误:", err) + h.ChatFail(c, err.Error()) + return + } + + // 获取用户权限 + codes, err := h.GetUserPermission(client) + if err != nil { + log.Println("获取用户权限错误:", err) + h.ChatFail(c, err.Error()) + return + } + client.SetCodes(codes) + + // 确保在函数返回时移除客户端并关闭连接 + defer func() { + h.Gw.RemoveClient(client.GetID()) + _ = c.Close() + log.Println("client disconnected:", client.GetID()) + }() + // 循环读取客户端消息 for { + // 读取消息 messageType, message, err := c.ReadMessage() if err != nil { log.Println("读取错误:", err) break } + // 处理消息 msg, chatType := h.handleMessageToString(c, messageType, message) if chatType == constants.ConnStatusClosed { break @@ -99,39 +126,31 @@ func (h *ChatService) Chat(c *websocket.Conn) { } log.Printf("收到消息: %s", string(msg)) + + // 解析请求 var req entitys.ChatSockRequest - if err := json.Unmarshal(msg, &req); err != nil { + if err = json.Unmarshal(msg, &req); err != nil { log.Println("JSON parse error:", err) continue } - //简单协议:bind: - // if c.Headers("Sec-Websocket-Protocol") == "bind" && req.SessionID != "" { - // uid := c.Query("x-session") - // _ = h.Gw.BindUid(clientID, req.SessionID) - // log.Printf("bind %s -> uid:%s\n", clientID, uid) - // } - uid := c.Query("x-session") - if uid != "" { - _ = h.Gw.BindUid(clientID, uid) - log.Printf("bind %s -> uid:%s\n", clientID, uid) - } - - err = h.routerBiz.RouteWithSocket(c, &req) + // 路由处理请求 + err = h.routerBiz.RouteWithSocket(client, &req) if err != nil { log.Println("处理失败:", err) - entitys.MsgSend(c, entitys.Response{ - Content: err.Error(), - Type: entitys.ResponseText, - }) } - } - h.Gw.RemoveClient(clientID) - _ = c.Close() - log.Println("client disconnected:", clientID) } +// handleMessageToString 处理不同类型的WebSocket消息 +// 参数: +// - c: WebSocket连接 +// - msgType: 消息类型 +// - msg: 消息内容 +// +// 返回: +// - text: 处理后的文本内容 +// - chatType: 连接状态 func (h *ChatService) handleMessageToString(c *websocket.Conn, msgType int, msg any) (text []byte, chatType constants.ConnStatus) { switch msgType { case websocket.TextMessage: @@ -172,3 +191,52 @@ func (s *ChatService) UsefulList(c *fiber.Ctx) error { return c.JSON(constants.UseFulMap) } + +// 从统一登录平台获取用户权限 +func (s *ChatService) GetUserPermission(client *gateway.Client) (codes []string, err error) { + var ( + request l_request.Request + ) + + // 系统编码 + systemCode := client.GetSysCode() + + // 检查系统编码是否配置 + if v, ok := s.cfg.PermissionConfig.SysPermission[systemCode]; !ok { + err = errors.SysErr("系统编码 %s 未配置", systemCode) + return + } else { + request.Url = v.PermissionURL + } + + request.Method = "GET" + request.Headers = map[string]string{ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + "Accept": "application/json, text/plain, */*", + "Authorization": "Bearer " + client.GetAuth(), + } + + // 发送请求 + res, err := request.Send() + if err != nil { + return + } + + // 检查响应状态码 + if res.StatusCode != http.StatusOK { + err = errors.SysErr("获取用户权限失败") + return + } + + type resp struct { + Codes []string `json:"codes"` + } + // 解析响应体 + var respBody resp + err = json.Unmarshal([]byte(res.Text), &respBody) + if err != nil { + return + } + + return respBody.Codes, nil +} From 2cd19e5fdfbd381ddc06db2c74aadff3034fc1c9 Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 21 Nov 2025 15:25:01 +0800 Subject: [PATCH 4/4] feat: chat --- internal/biz/router.go | 5 +---- internal/entitys/response.go | 7 +++++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/internal/biz/router.go b/internal/biz/router.go index 70a069e..6dcc233 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -46,10 +46,7 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS ctx, clearFunc := r.do.MakeCh(conn, requireData) defer func() { if err != nil { - requireData.Ch <- entitys.Response{ - Content: err.Error(), - Type: entitys.ResponseErr, - } + entitys.ResError(requireData.Ch, "", err.Error()) } clearFunc() }() diff --git a/internal/entitys/response.go b/internal/entitys/response.go index 5262102..cdadc98 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -67,6 +67,13 @@ func ResLoading(ch chan Response, index string, content string) { Type: ResponseLoading, } } +func ResError(ch chan Response, index string, content string) { + ch <- Response{ + Index: index, + Content: content, + Type: ResponseErr, + } +} type ResponseData struct { Done bool