diff --git a/internal/biz/handle/qywx/auth.go b/internal/biz/handle/qywx/auth.go new file mode 100644 index 0000000..c0c9737 --- /dev/null +++ b/internal/biz/handle/qywx/auth.go @@ -0,0 +1,82 @@ +package qywx + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/utils" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/redis/go-redis/v9" +) + +type Auth struct { + redis *redis.Client + cfg *config.Config +} + +func NewAuth(cfg *config.Config, redis *utils.Rdb) *Auth { + return &Auth{ + redis: redis.Rdb, + cfg: cfg, + } +} + +func (a *Auth) GetAccessToken(ctx context.Context, corpid string, corpsecret string) (authInfo *AuthInfo, err error) { + if corpid == "" { + return nil, errors.New("corpid is empty") + } + accessToken := a.redis.Get(ctx, a.getKey(corpsecret)).Val() + var expire time.Duration + if accessToken == "" { + authRes, _err := a.getNewAccessToken(ctx, corpid, corpsecret) + if _err != nil { + return nil, _err + } + expire = time.Duration(authRes.ExpiresIn-60) * time.Second + err = a.redis.SetEx(ctx, a.getKey(corpsecret), authRes.AccessToken, expire).Err() + if err != nil { + return + } + accessToken = authRes.AccessToken + } else { + expire, _ = a.redis.TTL(ctx, a.getKey(corpsecret)).Result() + } + return &AuthInfo{ + Corpid: corpid, + Corpsecret: corpsecret, + AccessToken: accessToken, + Expire: expire, + }, nil +} + +func (a *Auth) getKey(corpsecret string) string { + return a.cfg.Redis.Key + ":" + constants.QywxAuthBaseKeyPrefix + ":" + corpsecret +} + +func (a *Auth) getNewAccessToken(ctx context.Context, corpid string, corpsecret string) (auth AuthRes, err error) { + if corpid == "" || corpsecret == "" { + err = errors.New("corpid or corpsecret is empty") + return + } + + req := l_request.Request{ + Method: http.MethodGet, + Url: "https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=" + corpid + "&corpsecret=" + corpsecret, + } + res, err := req.Send() + if err != nil { + return + } + err = json.Unmarshal(res.Content, &auth) + if auth.Errcode != 0 { + err = fmt.Errorf("请求失败:%s", auth.Errmsg) + return + } + return +} diff --git a/internal/biz/handle/qywx/group.go b/internal/biz/handle/qywx/group.go new file mode 100644 index 0000000..21fe52f --- /dev/null +++ b/internal/biz/handle/qywx/group.go @@ -0,0 +1,15 @@ +package qywx + +import "ai_scheduler/internal/data/impl" + +type Group struct { + groupImpl *impl.BotGroupQywxImpl + auth *Auth +} + +func NewGroup(groupImpl *impl.BotGroupQywxImpl, auth *Auth) *Group { + return &Group{ + groupImpl: groupImpl, + auth: auth, + } +} diff --git a/internal/biz/handle/qywx/types.go b/internal/biz/handle/qywx/types.go new file mode 100644 index 0000000..09fdf67 --- /dev/null +++ b/internal/biz/handle/qywx/types.go @@ -0,0 +1,17 @@ +package qywx + +import "time" + +type AuthRes struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` +} + +type AuthInfo struct { + Corpid string `json:"corpid"` + Corpsecret string `json:"corpsecret"` + AccessToken string `json:"accessToken"` + Expire time.Duration `json:"expireIn"` +} diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index d0ca85c..8516626 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -36,6 +36,8 @@ const DingTalkAuthBaseKeyPrefix = "dingTalk_auth" const DingTalkAuthBaseKeyBotPrefix = "dingTalk_auth_bot" +const QywxAuthBaseKeyPrefix = "qywx_auth" + // PermissionType 工具使用权限 type PermissionType int32 diff --git a/internal/data/impl/bot_group_qywx.go b/internal/data/impl/bot_group_qywx.go new file mode 100644 index 0000000..87b5ce5 --- /dev/null +++ b/internal/data/impl/bot_group_qywx.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotGroupQywxImpl struct { + dataTemp.DataTemp +} + +func NewBotGroupQywxImpl(db *utils.Db) *BotGroupImpl { + return &BotGroupImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotGroupQywx)), + } +} diff --git a/internal/services/callback.go b/internal/services/callback.go index c68e1c3..b45ef41 100644 --- a/internal/services/callback.go +++ b/internal/services/callback.go @@ -103,40 +103,36 @@ func (s *CallbackService) Callback(c *fiber.Ctx) error { } } -// func validateTimestamp(ts string, window time.Duration) bool { -// // 期望毫秒时间戳或秒级,简单容错 -// // 尝试解析为整数 -// var n int64 -// for _, base := range []int64{1, 1000} { // 秒或毫秒 -// if v, ok := parseInt64(ts); ok { -// n = v -// // 归一为毫秒 -// if base == 1 && len(ts) <= 10 { -// n = n * 1000 -// } -// now := time.Now().UnixMilli() -// diff := now - n -// if diff < 0 { -// diff = -diff -// } -// if diff <= window.Milliseconds() { -// return true -// } -// } -// } -// return false -// } +func (s *CallbackService) CallbackQr(c *fiber.Ctx) error { + // 读取头 + sourceKey := strings.TrimSpace(c.Get("X-Source-Key")) + ts := strings.TrimSpace(c.Get("X-Timestamp")) -// func parseInt64(s string) (int64, bool) { -// var n int64 -// for _, ch := range s { -// if ch < '0' || ch > '9' { -// return 0, false -// } -// n = n*10 + int64(ch-'0') -// } -// return n, true -// } + // 时间窗口(如果提供了 ts 则校验,否则跳过),窗口 5 分钟 + // if ts != "" && !validateTimestamp(ts, 5*time.Minute) { + if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) { + return errorcode.AuthNotFound + } + + // 解析 Envelope + var env Envelope + if err := json.Unmarshal(c.Body(), &env); err != nil { + return errorcode.ParamErrf("invalid json: %v", err) + } + if env.Action == "" || env.TaskID == "" { + return errorcode.ParamErrf("missing action/task_id") + } + if env.Data == nil { + return errorcode.ParamErrf("missing data") + } + + switch sourceKey { + case "dingtalk": + return s.handleDingTalkCallback(c, env) + default: + return errorcode.AuthNotFound + } +} func (s *CallbackService) handleDingTalkCallback(c *fiber.Ctx, env Envelope) error { // 校验taskId