diff --git a/internal/biz/router.go b/internal/biz/router.go index 32ef390..be1358d 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -38,6 +38,7 @@ type AiRouterBiz struct { utilAgent *utils_ollama.UtilOllama ollama *utils_ollama.Client channelPool *pkg.SafeChannelPool + rds *pkg.Rdb } // NewRouterService 创建路由服务 @@ -52,6 +53,7 @@ func NewAiRouterBiz( utilAgent *utils_ollama.UtilOllama, channelPool *pkg.SafeChannelPool, ollama *utils_ollama.Client, + ) *AiRouterBiz { return &AiRouterBiz{ //aiClient: aiClient, @@ -75,15 +77,15 @@ func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*ent func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) (err error) { - session := c.Headers("X-Session", "") + session := c.Query("x-session", "") if len(session) == 0 { return errors.SessionNotFound } - auth := c.Headers("X-Authorization", "") + auth := c.Query("x-authorization", "") if len(auth) == 0 { return errors.AuthNotFound } - key := c.Headers("X-App-Key", "") + key := c.Query("x-app-key", "") if len(key) == 0 { return errors.KeyNotFound } @@ -367,7 +369,7 @@ func (r *AiRouterBiz) handleKnowle(channel chan entitys.Response, c *websocket.C func (r *AiRouterBiz) handleApiTask(channels chan entitys.Response, c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) { var ( request l_request.Request - auth = c.Headers("X-Authorization", "") + auth = c.Query("x-authorization", "") requestParam map[string]interface{} ) err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam) diff --git a/internal/data/error/error_code.go b/internal/data/error/error_code.go index edd0a55..a1e1f4a 100644 --- a/internal/data/error/error_code.go +++ b/internal/data/error/error_code.go @@ -6,12 +6,12 @@ var ( NotFoundError = &BusinessErr{code: 404, message: "请求地址未找到"} SystemError = &BusinessErr{code: 405, message: "系统错误"} - SupplierNotFound = &BusinessErr{code: 406, message: "供应商不存在"} - SessionNotFound = &BusinessErr{code: 407, message: "未找到会话信息"} - AuthNotFound = &BusinessErr{code: 408, message: "身份验证失败"} - KeyNotFound = &BusinessErr{code: 409, message: "身份验证失败"} - SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"} - InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} + ClientNotFound = &BusinessErr{code: 406, message: "未找到client_id"} + SessionNotFound = &BusinessErr{code: 407, message: "未找到会话信息"} + AuthNotFound = &BusinessErr{code: 408, message: "身份验证失败"} + KeyNotFound = &BusinessErr{code: 409, message: "身份验证失败"} + SysNotFound = &BusinessErr{code: 410, message: "未找到系统信息"} + InvalidParam = &BusinessErr{code: InvalidParamCode, message: "无效参数"} ) const ( diff --git a/internal/entitys/response.go b/internal/entitys/response.go index c4e65ee..0031d5e 100644 --- a/internal/entitys/response.go +++ b/internal/entitys/response.go @@ -18,6 +18,7 @@ const ( ResponseFile ResponseType = "file" ResponseErr ResponseType = "error" ResponseLog ResponseType = "log" + ResponseAuth ResponseType = "auth" ) type ResponseData struct { diff --git a/internal/entitys/types.go b/internal/entitys/types.go index ad12791..51030f2 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -20,6 +20,12 @@ type ChatRequestMeta struct { Authorization string `json:"authorization"` } +type FirstSockRequest struct { + Authorization string `json:"authorization"` + SessionID string `json:"session_id"` + AppKey string `json:"app_key"` +} + type ChatSockRequest struct { Text string `json:"text" binding:"required"` Img string `json:"img" binding:"required"` diff --git a/internal/services/chat.go b/internal/services/chat.go index e29d82b..36c89e0 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -82,12 +82,7 @@ func (h *ChatService) Chat(c *websocket.Conn) { log.Println("读取错误:", err) break } - //简单协议:bind: - if c.Headers("Sec-Websocket-Protocol") == "bind" && c.Headers("X-Session") != "" { - uid := c.Headers("X-Session") - _ = h.Gw.BindUid(clientID, uid) - log.Printf("bind %s -> uid:%s\n", clientID, uid) - } + msg, chatType := h.handleMessageToString(c, messageType, message) if chatType == constants.ConnStatusClosed { break @@ -102,6 +97,14 @@ func (h *ChatService) Chat(c *websocket.Conn) { log.Println("JSON parse error:", err) continue } + + //简单协议:bind: + if c.Headers("Sec-Websocket-Protocol") == "bind" && req.SessionID != "" { + uid := c.Headers("X-Session") + _ = h.Gw.BindUid(clientID, req.SessionID) + log.Printf("bind %s -> uid:%s\n", clientID, uid) + } + err = h.routerBiz.RouteWithSocket(c, &req) if err != nil { log.Println("处理失败:", err) diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index 79f56f7..cc55115 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -100,7 +100,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(ch chan entitys.Response, c *we //查询订单详情 var auth string if c != nil { - auth = c.Headers("X-Authorization", "") + auth = c.Query("x-authorization", "") } if len(auth) == 0 { auth = w.config.APIKey diff --git a/internal/tools/zltx_order_direct_log.go b/internal/tools/zltx_order_direct_log.go index c6a0a95..ce193f0 100644 --- a/internal/tools/zltx_order_direct_log.go +++ b/internal/tools/zltx_order_direct_log.go @@ -82,7 +82,7 @@ func (t *ZltxOrderLogTool) getZltxOrderLog(channel chan entitys.Response, c *web //查询订单详情 var auth string if c != nil { - auth = c.Headers("X-Authorization", "") + auth = c.Query("x-authorization", "") } if len(auth) == 0 { auth = t.config.APIKey diff --git a/internal/tools/zltx_product.go b/internal/tools/zltx_product.go index 71a7daf..915e651 100644 --- a/internal/tools/zltx_product.go +++ b/internal/tools/zltx_product.go @@ -135,7 +135,7 @@ type ZltxProductData struct { func (z ZltxProductTool) getZltxProduct(channel chan entitys.Response, c *websocket.Conn, id string, name string) error { var auth string if c != nil { - auth = c.Headers("X-Authorization", "") + auth = c.Query("x-authorization", "") } if len(auth) == 0 { auth = z.config.APIKey diff --git a/internal/tools/zltx_statistics.go b/internal/tools/zltx_statistics.go index 64bc80b..83c2a40 100644 --- a/internal/tools/zltx_statistics.go +++ b/internal/tools/zltx_statistics.go @@ -78,7 +78,7 @@ func (z ZltxOrderStatisticsTool) getZltxOrderStatistics(channel chan entitys.Res //查询订单详情 var auth string if c != nil { - auth = c.Headers("X-Authorization", "") + auth = c.Query("x-authorization", "") } if len(auth) == 0 { auth = z.config.APIKey