diff --git a/config/config_test.yaml b/config/config_test.yaml index c5c8451..21e44fc 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -73,14 +73,4 @@ default_prompt: user_prompt: '识别图片内容' # 权限配置 permissionConfig: - # 不同系统的权限校验配置 - sys_permission: - # 直连天下系统 - zltx: - permission_url: "https://gateway.dev.cdlsxd.cn/zltx_api/test/v1/menu/myCodes?systemCode=zltx" - white_list: - - "knowledge_qa" # 知识问答 - # 通用的白名单接口 - white_list: - - "chat" # 聊天接口 - - "bug_optimization_submit" # 优化建议提交接口 + permission_url: "http://api.test.user.1688sup.cn:8001/v1/menu/myCodes?systemId=" diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index b0e7299..ed8749d 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -10,7 +10,9 @@ import ( "ai_scheduler/internal/pkg" "ai_scheduler/tmpl/dataTemp" "context" + "encoding/json" "fmt" + "net/http" "strconv" "strings" "time" @@ -71,6 +73,11 @@ func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData * return err } + // 6. 加载用户权限 + if _, err = d.LoadUserPermission(client, requireData); err != nil { + return fmt.Errorf("获取用户权限失败: %w", err) + } + return nil } @@ -298,3 +305,51 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura return sendCtx.Err() } } + +// 从统一登录平台获取用户权限 +func (d *Do) LoadUserPermission(client *gateway.Client, requireData *entitys.RequireData) (codes []string, err error) { + if len(client.GetCodes()) > 0 { + return client.GetCodes(), nil + } + + var ( + request l_request.Request + ) + + // 构建请求URL + request.Url = d.conf.PermissionConfig.PermissionURL + strconv.Itoa(int(requireData.Sys.SysID)) + + request.Method = "GET" + request.Headers = map[string]string{ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + "Accept": "application/json, text/plain, */*", + "Authorization": "Bearer " + client.GetAuth(), + } + + // 发送请求 + res, err := request.Send() + if err != nil { + return + } + + // 检查响应状态码 + if res.StatusCode != http.StatusOK { + err = errors.SysErr("获取用户权限失败") + return + } + + type resp struct { + Codes []string `json:"codes"` + } + // 解析响应体 + var respBody resp + err = json.Unmarshal([]byte(res.Text), &respBody) + if err != nil { + return + } + + // 设置客户端权限 + client.SetCodes(respBody.Codes) + + return respBody.Codes, nil +} diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index e3fe794..e014d25 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -17,9 +17,8 @@ import ( "context" "encoding/json" "fmt" - "strings" - "gorm.io/gorm/utils" + "strings" ) type Handle struct { @@ -265,18 +264,6 @@ func (r *Handle) handleApiTask(ctx context.Context, requireData *entitys.Require // 权限验证 func (r *Handle) PermissionAuth(client *gateway.Client, pointTask *model.AiTask) (err error) { - // 通用权限校验 - if utils.Contains(r.conf.PermissionConfig.WhiteList, pointTask.Index) { - return nil - } - - // 系统权限校验 - if v, ok := r.conf.PermissionConfig.SysPermission[client.GetSysCode()]; !ok { - return fmt.Errorf("未配置系统权限校验: %s", client.GetSysCode()) - } else if utils.Contains(v.WhiteList, pointTask.Index) { - return nil - } - // 授权检查权限 if !utils.Contains(client.GetCodes(), pointTask.Index) { return fmt.Errorf("用户权限不足: %s", pointTask.Name) diff --git a/internal/biz/task.go b/internal/biz/task.go index 2c67591..e1a8689 100644 --- a/internal/biz/task.go +++ b/internal/biz/task.go @@ -1,10 +1,16 @@ package biz import ( + errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" "context" + "encoding/json" + "gorm.io/gorm/utils" + "net/http" + "strconv" "xorm.io/builder" @@ -24,12 +30,66 @@ func NewTaskBiz(conf *config.Config, taskRepo *impl.TaskImpl) *TaskBiz { } // taskList 功能列表 -func (t *TaskBiz) TaskList(ctx context.Context, req *entitys.TaskRequest) (list []model.AiTask, err error) { +func (t *TaskBiz) TaskList(ctx context.Context, req *entitys.TaskRequest, auth string) (list []model.AiTask, err error) { + tasks := make([]model.AiTask, 0) cond := builder.NewCond() cond = cond.And(builder.Eq{"status": 1}) cond = cond.And(builder.Eq{"sys_id": req.SysId}) - err = t.taskRepo.GetRangeToMapStruct(&cond, &list) + err = t.taskRepo.GetRangeToMapStruct(&cond, &tasks) + + codes, err := t.GetUserPermission(req, auth) + if err != nil { + return + } + + // 检查用户是否有权限 + for _, task := range tasks { + if utils.Contains(codes, task.Index) { + list = append(list, task) + } + } return } + +// 从统一登录平台获取用户权限 +func (t *TaskBiz) GetUserPermission(req *entitys.TaskRequest, auth string) (codes []string, err error) { + var ( + request l_request.Request + ) + + // 构建请求URL + request.Url = t.conf.PermissionConfig.PermissionURL + strconv.Itoa(int(req.SysId)) + + request.Method = "GET" + request.Headers = map[string]string{ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + "Accept": "application/json, text/plain, */*", + "Authorization": auth, + } + + // 发送请求 + res, err := request.Send() + if err != nil { + return + } + + // 检查响应状态码 + if res.StatusCode != http.StatusOK { + err = errors.SysErr("获取用户权限失败") + return + } + + type resp struct { + Codes []string `json:"codes"` + } + // 解析响应体 + var respBody resp + err = json.Unmarshal([]byte(res.Text), &respBody) + if err != nil { + return + } + + return respBody.Codes, nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 6002a36..32710ff 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -118,22 +118,10 @@ type LoggingConfig struct { // PermissionConfig 权限校验配置 type PermissionConfig struct { - // 不同系统的权限校验配置 - SysPermission map[string]SysWhiteList `mapstructure:"sys_permission"` - // 通用的白名单任务列表,不需要权限校验 - WhiteList []string `mapstructure:"white_list"` -} - -// 细分系统的白名单任务列表 -type SysWhiteList struct { // 获取权限的地址 PermissionURL string `mapstructure:"permission_url"` - // 系统的白名单任务列表 - WhiteList []string `mapstructure:"white_list"` } -// 权限校验配置 - // LoadConfig 加载配置 func LoadConfig(configPath string) (*Config, error) { viper.SetConfigFile(configPath) diff --git a/internal/services/chat.go b/internal/services/chat.go index 5dc0cf5..ff75bee 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -4,13 +4,10 @@ import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" - errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" - "ai_scheduler/internal/pkg/l_request" "encoding/json" "log" - "net/http" "sync" "github.com/gofiber/fiber/v2" @@ -92,15 +89,6 @@ func (h *ChatService) Chat(c *websocket.Conn) { return } - // 获取用户权限 - codes, err := h.GetUserPermission(client) - if err != nil { - log.Println("获取用户权限错误:", err) - h.ChatFail(c, err.Error()) - return - } - client.SetCodes(codes) - // 确保在函数返回时移除客户端并关闭连接 defer func() { h.Gw.RemoveClient(client.GetID()) @@ -192,52 +180,3 @@ func (s *ChatService) UsefulList(c *fiber.Ctx) error { return c.JSON(constants.UseFulMap) } - -// 从统一登录平台获取用户权限 -func (s *ChatService) GetUserPermission(client *gateway.Client) (codes []string, err error) { - var ( - request l_request.Request - ) - - // 系统编码 - systemCode := client.GetSysCode() - - // 检查系统编码是否配置 - if v, ok := s.cfg.PermissionConfig.SysPermission[systemCode]; !ok { - err = errors.SysErr("系统编码 %s 未配置", systemCode) - return - } else { - request.Url = v.PermissionURL - } - - request.Method = "GET" - request.Headers = map[string]string{ - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", - "Accept": "application/json, text/plain, */*", - "Authorization": "Bearer " + client.GetAuth(), - } - - // 发送请求 - res, err := request.Send() - if err != nil { - return - } - - // 检查响应状态码 - if res.StatusCode != http.StatusOK { - err = errors.SysErr("获取用户权限失败") - return - } - - type resp struct { - Codes []string `json:"codes"` - } - // 解析响应体 - var respBody resp - err = json.Unmarshal([]byte(res.Text), &respBody) - if err != nil { - return - } - - return respBody.Codes, nil -} diff --git a/internal/services/task.go b/internal/services/task.go index 380f027..3ace0d5 100644 --- a/internal/services/task.go +++ b/internal/services/task.go @@ -3,7 +3,6 @@ package services import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/entitys" - "github.com/gofiber/fiber/v2" ) @@ -25,7 +24,15 @@ func (s *TaskService) Tasks(c *fiber.Ctx) error { return err } - result, err := s.taskBiz.TaskList(c.Context(), req) + auth := "" + if auths := c.GetReqHeaders()["Authorization"]; len(auths) > 0 { + auth = auths[0] + } + if auth == "" { + return fiber.ErrUnauthorized + } + + result, err := s.taskBiz.TaskList(c.Context(), req, auth) if err != nil { return err