diff --git a/config/config.yaml b/config/config.yaml index c66ba00..06ff171 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -53,7 +53,8 @@ tools: enabled: true DingTalkBot: enabled: true - + api_key: "dingwxaswioywe5dyw7u" + api_secret: "WZtr20zOMlQBEIkOdh2-K1no_spVWYNzD8LJZm1fPGuSSQQT1lu0iqTIqnJhCG0Q" default_prompt: img_recognize: diff --git a/internal/biz/session.go b/internal/biz/session.go index e1dd17e..19b3e8a 100644 --- a/internal/biz/session.go +++ b/internal/biz/session.go @@ -59,10 +59,11 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe } else if !has { // 不存在,创建一个 session = model.AiSession{ - SysID: sysConfig.SysID, - SessionID: utils.UUID(), - UserID: req.UserId, - UserName: req.UserName, + SysID: sysConfig.SysID, + SessionID: utils.UUID(), + UserID: req.UserId, + UserName: req.UserName, + DingUserId: req.DingUserId, } err = s.sessionRepo.Create(&session) if err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index f0c80da..d39d80a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -93,9 +93,10 @@ type ToolsConfig struct { // ToolConfig 单个工具配置 type ToolConfig struct { - Enabled bool `mapstructure:"enabled"` - BaseURL string `mapstructure:"base_url"` - APIKey string `mapstructure:"api_key"` + Enabled bool `mapstructure:"enabled"` + BaseURL string `mapstructure:"base_url"` + APIKey string `mapstructure:"api_key"` + APISecret string `mapstructure:"api_secret"` //附加地址 AddURL string `mapstructure:"add_url"` } diff --git a/internal/data/model/ai_session.gen.go b/internal/data/model/ai_session.gen.go index d872adb..7a2ce0a 100644 --- a/internal/data/model/ai_session.gen.go +++ b/internal/data/model/ai_session.gen.go @@ -20,8 +20,9 @@ type AiSession struct { UpdateAt *time.Time `gorm:"column:update_at;default:CURRENT_TIMESTAMP" json:"update_at"` Status int32 `gorm:"column:status;not null" json:"status"` DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` - UserID string `gorm:"column:user_id;comment:用户id" json:"user_id"` // 用户id - UserName string `gorm:"column:user_name;comment:用户名称" json:"user_name"` // 用户id + UserID string `gorm:"column:user_id;comment:用户id" json:"user_id"` // 用户id + UserName string `gorm:"column:user_name;comment:用户名称" json:"user_name"` // 用户名称 + DingUserId string `gorm:"column:ding_user_id;comment:钉钉用户id" json:"ding_user_id"` // 钉钉用户id } // TableName AiSession's table name diff --git a/internal/entitys/session.go b/internal/entitys/session.go index 74a517a..e02ddb9 100644 --- a/internal/entitys/session.go +++ b/internal/entitys/session.go @@ -1,9 +1,10 @@ package entitys type SessionInitRequest struct { - SysId string `json:"sys_id"` - UserId string `json:"user_id"` - UserName string `json:"user_name"` + SysId string `json:"sys_id"` + UserId string `json:"user_id"` + UserName string `json:"user_name"` + DingUserId string `json:"ding_user_id"` } type SessionInitResponse struct { diff --git a/internal/pkg/dingtalk/client.go b/internal/pkg/dingtalk/client.go new file mode 100644 index 0000000..bc3ebd0 --- /dev/null +++ b/internal/pkg/dingtalk/client.go @@ -0,0 +1,153 @@ +package dingtalk + +import ( + "ai_scheduler/internal/config" + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/url" +) + +// Client 使用官方接口获取 AccessToken,并通过 HTTP 直接调用 TopAPI +type Client struct { + cfg *config.Config +} + +// NewDingTalkClient 基于配置创建钉钉客户端(无 fastwego 依赖) +func NewDingTalkClient(cfg *config.Config) *Client { return &Client{cfg: cfg} } + +type UserDetail struct { + UserID string `json:"userid"` + Name string `json:"name"` + UnionID string `json:"unionid"` +} + +// getAccessToken 通过官方接口获取应用 AccessToken(简单实现,每次请求刷新,不做缓存) +func (c *Client) getAccessToken(ctx context.Context) (string, error) { + // 调用旧版 OAPI 获取 access_token,供 /topapi 使用 + appKey := c.cfg.Tools.DingTalkBot.APIKey + appSecret := c.cfg.Tools.DingTalkBot.APISecret + params := url.Values{} + params.Add("appkey", appKey) + params.Add("appsecret", appSecret) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://oapi.dingtalk.com/gettoken?"+params.Encode(), nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + var r struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + } + if err := json.Unmarshal(b, &r); err != nil { + return "", err + } + if r.ErrCode != 0 || r.AccessToken == "" { + if r.ErrMsg == "" { + r.ErrMsg = "gettoken failed" + } + return "", errors.New(r.ErrMsg) + } + return r.AccessToken, nil +} + +// QueryUserDetails 按 userId 查询用户详情 +func (c *Client) QueryUserDetails(ctx context.Context, userId string) (*UserDetail, error) { + accessToken, err := c.getAccessToken(ctx) + if err != nil { + return nil, err + } + body := struct { + UserId string `json:"userid"` + Language string `json:"language,omitempty"` + }{UserId: userId, Language: "zh_CN"} + b, _ := json.Marshal(body) + // 直接调用 TopAPI(oapi.dingtalk.com),通过 query 参数携带 access_token + // 官方 SDK 无 topapi 封装,此处保留原接口以保证行为一致 + params := url.Values{} + params.Add("access_token", accessToken) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oapi.dingtalk.com/topapi/v2/user/get?"+params.Encode(), bytes.NewReader(b)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + httpResp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + res, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, err + } + + var resp struct { + Code int `json:"errcode"` + Msg string `json:"errmsg"` + Result UserDetail `json:"result"` + } + if err := json.Unmarshal(res, &resp); err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, errors.New(resp.Msg) + } + + return &resp.Result, nil +} + +// QueryUserDetailsByMobile 按手机号查询用户详情 +func (c *Client) QueryUserDetailsByMobile(ctx context.Context, mobile string) (*UserDetail, error) { + accessToken, err := c.getAccessToken(ctx) + if err != nil { + return nil, err + } + body := struct { + Mobile string `json:"mobile"` + }{Mobile: mobile} + b, _ := json.Marshal(body) + // 直接调用 TopAPI(oapi.dingtalk.com),通过 query 参数携带 access_token + params := url.Values{} + params.Add("access_token", accessToken) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oapi.dingtalk.com/topapi/v2/user/getbymobile?"+params.Encode(), bytes.NewReader(b)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + httpResp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + res, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, err + } + + var resp struct { + Code int `json:"errcode"` + Msg string `json:"errmsg"` + Result UserDetail `json:"result"` + } + if err := json.Unmarshal(res, &resp); err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, errors.New(resp.Msg) + } + + return &resp.Result, nil +} diff --git a/internal/pkg/provider_set.go b/internal/pkg/provider_set.go index e20b9c2..e7674b7 100644 --- a/internal/pkg/provider_set.go +++ b/internal/pkg/provider_set.go @@ -1,6 +1,7 @@ package pkg import ( + "ai_scheduler/internal/pkg/dingtalk" "ai_scheduler/internal/pkg/utils_langchain" "ai_scheduler/internal/pkg/utils_ollama" @@ -13,4 +14,5 @@ var ProviderSetClient = wire.NewSet( utils_langchain.NewUtilLangChain, utils_ollama.NewClient, NewSafeChannelPool, + dingtalk.NewDingTalkClient, ) diff --git a/internal/server/http.go b/internal/server/http.go index d4df3a1..4cdc393 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -11,22 +11,24 @@ import ( ) type HTTPServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + callback *services.CallbackService } func NewHTTPServer( - service *services.ChatService, - session *services.SessionService, - task *services.TaskService, - gateway *gateway.Gateway, + service *services.ChatService, + session *services.SessionService, + task *services.TaskService, + gateway *gateway.Gateway, + callback *services.CallbackService, ) *fiber.App { - //构建 server - app := initRoute() - router.SetupRoutes(app, service, session, task, gateway) - return app + //构建 server + app := initRoute() + router.SetupRoutes(app, service, session, task, gateway, callback) + return app } func initRoute() *fiber.App { diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 04fbc7c..a2bbd81 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -22,7 +22,7 @@ type RouterServer struct { } // SetupRoutes 设置路由 -func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway) { +func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 c.Set("Access-Control-Allow-Origin", "*") @@ -51,6 +51,8 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史 r.Post("/session/list", sessionService.SessionList) r.Post("/sys/tasks", task.Tasks) + // 回调 + r.Post("/callback", callbackService.Callback) //广播 r.Get("/broadcast", func(ctx *fiber.Ctx) error { action := ctx.Query("action") diff --git a/internal/services/callback.go b/internal/services/callback.go new file mode 100644 index 0000000..6450ebd --- /dev/null +++ b/internal/services/callback.go @@ -0,0 +1,164 @@ +package services + +import ( + "ai_scheduler/internal/config" + errorcode "ai_scheduler/internal/data/error" + "ai_scheduler/internal/gateway" + "ai_scheduler/internal/pkg/dingtalk" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/gofiber/fiber/v2" +) + +// CallbackService 统一回调入口 +type CallbackService struct { + cfg *config.Config + gateway *gateway.Gateway + taskMap map[string]string // task_id -> session_id + dingtalkClient *dingtalk.Client +} + +func NewCallbackService(cfg *config.Config, gateway *gateway.Gateway, dingtalkClient *dingtalk.Client) *CallbackService { + return &CallbackService{ + cfg: cfg, + gateway: gateway, + taskMap: map[string]string{}, + dingtalkClient: dingtalkClient, + } +} + +// Envelope 回调统一请求体 +type Envelope struct { + Action string `json:"action"` + TaskID string `json:"task_id"` + Data map[string]string `json:"data"` +} + +// SetTaskMapping 设置 task_id 到 session_id 的映射(内存版)。 +// 注意:生产环境建议使用 Redis/DB + TTL,确保幂等与过期清理。 +func (s *CallbackService) SetTaskMapping(taskID, sessionID string) { + if taskID == "" || sessionID == "" { + return + } + s.taskMap[taskID] = sessionID +} + +// GetSessionByTaskID 读取映射 +func (s *CallbackService) GetSessionByTaskID(taskID string) (string, bool) { + return "bf0c6873-1df2-4d46-aa3c-8f7456e7efca", true + v, ok := s.taskMap[taskID] + return v, ok +} + +// Callback 统一回调处理 +// 头部:X-Source-Key / X-Timestamp +func (s *CallbackService) Callback(c *fiber.Ctx) error { + // 读取头 + sourceKey := strings.TrimSpace(c.Get("X-Source-Key")) + ts := strings.TrimSpace(c.Get("X-Timestamp")) + + // 时间窗口(如果提供了 ts 则校验,否则跳过),窗口 5 分钟 + if ts != "" && !validateTimestamp(ts, 300*time.Minute) { + return errorcode.AuthNotFound + } + + // 解析 Envelope + var env Envelope + if err := json.Unmarshal(c.Body(), &env); err != nil { + return errorcode.ParamErr("invalid json: %v", err) + } + if env.Action == "" || env.TaskID == "" { + return errorcode.ParamErr("missing action/task_id") + } + if len(env.Data) == 0 { + return errorcode.ParamErr("missing data") + } + + switch sourceKey { + case "dingtalk": + return s.handleDingTalkCallback(c, env) + default: + return errorcode.AuthNotFound + } +} + +func validateTimestamp(ts string, window time.Duration) bool { + // 期望毫秒时间戳或秒级,简单容错 + // 尝试解析为整数 + var n int64 + for _, base := range []int64{1, 1000} { // 秒或毫秒 + if v, ok := parseInt64(ts); ok { + n = v + // 归一为毫秒 + if base == 1 && len(ts) <= 10 { + n = n * 1000 + } + now := time.Now().UnixMilli() + diff := now - n + if diff < 0 { + diff = -diff + } + if diff <= window.Milliseconds() { + return true + } + } + } + return false +} + +func parseInt64(s string) (int64, bool) { + var n int64 + for _, ch := range s { + if ch < '0' || ch > '9' { + return 0, false + } + n = n*10 + int64(ch-'0') + } + return n, true +} + +func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) error { + switch env.Action { + // bug/优化完成回调 + case "bug_optimization_submit_done": + // 获取 session_id + sessionID, ok := s.GetSessionByTaskID(env.TaskID) + if !ok { + return errorcode.ParamErr("missing session_id for task_id: %s", env.TaskID) + } + + // 获取接收者姓名 + receivers := env.Data["receivers"] + if receivers == "" { + return errorcode.ParamErr("missing receivers") + } + var receiverIds []string + if err := json.Unmarshal([]byte(receivers), &receiverIds); err != nil { + return errorcode.ParamErr("invalid receivers: %v", err) + } + if len(receiverIds) == 0 { + return errorcode.ParamErr("empty receivers") + } + + aaa, _ := s.dingtalkClient.QueryUserDetailsByMobile(c.Context(), "13126622913") + + userDetails, err := s.dingtalkClient.QueryUserDetails(c.Context(), aaa.UserID) + // userDetails, err := s.dingtalkClient.QueryUserDetails(c.Context(), receiverIds[0]) + if err != nil { + return errorcode.ParamErr("query user details failed: %v", err) + } + if userDetails == nil { + return errorcode.ParamErr("user details is nil") + } + + msg := fmt.Sprintf(env.Data["msg"], userDetails.Name) + + s.gateway.SendToUid(sessionID, []byte(msg)) + return c.JSON(fiber.Map{"code": 0, "message": "ok"}) + default: + return errorcode.ParamErr("unknown action: %s", env.Action) + } +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 70987fa..0e1284b 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -6,4 +6,4 @@ import ( "github.com/google/wire" ) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService) +var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService) diff --git a/internal/tools_bot/dtalk_bot.go b/internal/tools_bot/dtalk_bot.go index 99270b0..80755c8 100644 --- a/internal/tools_bot/dtalk_bot.go +++ b/internal/tools_bot/dtalk_bot.go @@ -49,6 +49,17 @@ func (w *BotTool) Execute(ctx context.Context, toolName string, requireData *ent return } +const ( + // 工单QA + BotBugOptimizationSubmitQA = "温子新" + BotBugOptimizationSubmitPM = "贺泽琨" +) + +// 现存问题: +// 1. 回调时 session 直接传入不安全 todo +// 2. 创建人无法指定[钉钉用户],影响后续状态变化时通知 +// 3. 回调接口,[接收人]、[文档地址不能]动态配置 +// 4. 测试环境与线上环境,使用的不是同一个钉钉主体 func (w *BotTool) BugOptimizationSubmit(ctx context.Context, requireData *entitys.RequireData) (err error) { // 获取用户信息 cond := builder.NewCond() @@ -85,14 +96,14 @@ func (w *BotTool) BugOptimizationSubmit(ctx context.Context, requireData *entity data := make(map[string]bool) if err = json.Unmarshal(res.Content, &data); err != nil { - return fmt.Errorf("解析商品数据失败:%w", err) + return fmt.Errorf("解析工单回调失败:%w", err) } if data["success"] { - entitys.ResLoading(requireData.Ch, requireData.Match.Index, "问题信息记录中...") + entitys.ResLoading(requireData.Ch, requireData.Match.Index, "问题内容记录中...") return } - entitys.ResJson(requireData.Ch, requireData.Match.Index, "bug问题请咨询 @温子新 ,优化建议请咨询 @贺泽琨 。") + entitys.ResJson(requireData.Ch, requireData.Match.Index, fmt.Sprintf("bug问题请咨询 @%s ,优化建议请咨询 @%s 。", BotBugOptimizationSubmitQA, BotBugOptimizationSubmitPM)) return }