diff --git a/internal/pkg/lsxd/login.go b/internal/pkg/lsxd/login.go index e2c5509..04dc821 100644 --- a/internal/pkg/lsxd/login.go +++ b/internal/pkg/lsxd/login.go @@ -3,10 +3,16 @@ package lsxd import ( "ai_scheduler/internal/config" "ai_scheduler/internal/data/constants" - "ai_scheduler/internal/pkg/l_request" "ai_scheduler/utils" + "bytes" "context" "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "sync" + "time" "github.com/go-kratos/kratos/v2/log" "github.com/redis/go-redis/v9" @@ -15,6 +21,7 @@ import ( type Login struct { config *config.Config redisCli *redis.Client + mu sync.Mutex } func NewLogin(config *config.Config, rdb *utils.Rdb) *Login { @@ -25,96 +32,183 @@ func NewLogin(config *config.Config, rdb *utils.Rdb) *Login { } func (l *Login) GetToken() string { - ctx := context.Background() - // 1.取缓存 - token, err := l.redisCli.Get(ctx, constants.CACHE_KEY_LSXD_TOKEN).Result() + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + token, err := l.getCachedToken(ctx) if err != nil { log.Errorf("lsxd get token from redis failed, err: %v", err) } - - // 2.缓存存在 - if token != "" { - // 3.缓存有效直接输出 - if l.checkTokenValid(token) { - return token - } + if token != "" && l.checkTokenValid(ctx, token) { + return token } - // 4.缓存不存在或缓存无效,调用登录接口获取token - token, err = l.login() + l.mu.Lock() + defer l.mu.Unlock() + + token, err = l.getCachedToken(ctx) + if err != nil { + log.Errorf("lsxd get token from redis failed, err: %v", err) + } + if token != "" && l.checkTokenValid(ctx, token) { + return token + } + + token, err = l.login(ctx) if err != nil { log.Errorf("lsxd login failed, err: %v", err) return "" } - // 5.缓存token - l.redisCli.Set(ctx, constants.CACHE_KEY_LSXD_TOKEN, token, constants.EXPIRE_LSXD_TOKEN) + if token == "" { + log.Errorf("lsxd login failed, token is empty") + return "" + } + + if err := l.cacheToken(ctx, token); err != nil { + log.Errorf("lsxd cache token failed, err: %v", err) + } return token } // 校验token是否有效 -func (l *Login) checkTokenValid(token string) bool { +func (l *Login) checkTokenValid(ctx context.Context, token string) bool { // 欢迎页校验token有效 checkTokenURL := l.config.LSXD.CheckTokenURL - // 调用欢迎页校验token有效 - r := l_request.Request{ - Url: checkTokenURL, - Headers: map[string]string{ - "Content-Type": "application/json", - "Authorization": token, - }, - Method: "GET", + if checkTokenURL == "" { + return token != "" } - res, err := r.Send() + + status, err := l.doRequest(ctx, http.MethodGet, checkTokenURL, token, nil) if err != nil { log.Errorf("lsxd check token valid failed, err: %v", err) + return true + } + if status == http.StatusOK { + return true + } + if status == http.StatusUnauthorized || status == http.StatusForbidden { return false } - if res.StatusCode != 200 { - log.Errorf("lsxd check token valid failed, status code: %d", res.StatusCode) - return false - } - - log.Info("lsxd check token valid success") - return true } // 调用登录接口获取token -func (l *Login) login() (string, error) { +func (l *Login) login(ctx context.Context) (string, error) { // 1.获取配置 loginURL := l.config.LSXD.LoginURL phone := l.config.LSXD.Phone password := l.config.LSXD.Password // 2.调用登录接口获取token - r := l_request.Request{ - Url: loginURL, - Headers: map[string]string{ - "Content-Type": "application/json", - }, - Method: "POST", - Json: map[string]any{ - "phone": phone, - "password": password, - "code": "123456", - }, + if loginURL == "" { + return "", errors.New("login url is empty") } - res, err := r.Send() + if phone == "" || password == "" { + return "", errors.New("phone or password is empty") + } + + reqBody := map[string]any{ + "phone": phone, + "password": password, + "code": "123456", + } + bodyBytes, err := json.Marshal(reqBody) if err != nil { return "", err } - // 3.解析token - var resp struct { - Token string `json:"accessToken"` - } - err = json.Unmarshal(res.Content, &resp) + status, respBody, err := l.doRequestWithBody(ctx, http.MethodPost, loginURL, "", "application/json", bodyBytes) if err != nil { return "", err } - token := resp.Token + if status != http.StatusOK { + return "", fmt.Errorf("login status code: %d", status) + } + + type loginResp struct { + Code int `json:"code"` + Msg string `json:"msg"` + Message string `json:"message"` + AccessToken string `json:"accessToken"` + Data struct { + AccessToken string `json:"accessToken"` + Token string `json:"token"` + } `json:"data"` + } + + var resp loginResp + if err := json.Unmarshal(respBody, &resp); err != nil { + return "", err + } + + token := resp.AccessToken + if token == "" { + token = resp.Data.AccessToken + } + if token == "" { + token = resp.Data.Token + } + + if token == "" { + return "", errors.New("token is empty") + } - // 4.返回token return token, nil } + +func (l *Login) getCachedToken(ctx context.Context) (string, error) { + token, err := l.redisCli.Get(ctx, constants.CACHE_KEY_LSXD_TOKEN).Result() + if err == nil { + return token, nil + } + if errors.Is(err, redis.Nil) { + return "", nil + } + return "", err +} + +func (l *Login) cacheToken(ctx context.Context, token string) error { + if token == "" { + return errors.New("token is empty") + } + return l.redisCli.Set(ctx, constants.CACHE_KEY_LSXD_TOKEN, token, constants.EXPIRE_LSXD_TOKEN).Err() +} + +func (l *Login) doRequest(ctx context.Context, method string, url string, authorization string, body []byte) (int, error) { + status, _, err := l.doRequestWithBody(ctx, method, url, authorization, "", body) + return status, err +} + +func (l *Login) doRequestWithBody(ctx context.Context, method string, url string, authorization string, contentType string, body []byte) (int, []byte, error) { + reqCtx, cancel := context.WithTimeout(ctx, 6*time.Second) + defer cancel() + + var reader io.Reader + if len(body) > 0 { + reader = bytes.NewReader(body) + } + + req, err := http.NewRequestWithContext(reqCtx, method, url, reader) + if err != nil { + return 0, nil, err + } + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + if authorization != "" { + req.Header.Set("Authorization", authorization) + } + + resp, err := (&http.Client{}).Do(req) + if err != nil { + return 0, nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, nil, err + } + return resp.StatusCode, respBody, nil +}