l-dingtalk-stream-sdk-go/client/client.go

503 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package client
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/card"
"io"
"net/http"
"net/url"
"sync"
"time"
"github.com/gorilla/websocket"
"gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/chatbot"
"gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/handler"
"gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/logger"
"gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/payload"
"gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/plugin"
"gitea.cdlsxd.cn/self-tools/l-dingtalk-stream-sdk-go/utils"
)
/**
* @Author linya.jj
* @Date 2023/3/22 14:23
*/
type StreamClient struct {
AppCredential *AppCredentialConfig
UserAgent *UserAgentConfig
AutoReconnect bool
subscriptions map[string]map[string]handler.IFrameHandler
conn *websocket.Conn
sessionId string
mutex sync.Mutex
extras map[string]string
openApiHost string
proxy string
keepAliveIdle time.Duration
}
func NewStreamClient(options ...ClientOption) *StreamClient {
cli := &StreamClient{
keepAliveIdle: 120 * time.Second,
}
defaultOptions := []ClientOption{
WithSubscription(utils.SubscriptionTypeKSystem, "disconnect", cli.OnDisconnect),
WithSubscription(utils.SubscriptionTypeKSystem, "ping", cli.OnPing),
WithUserAgent(NewDingtalkGoSDKUserAgent()),
WithAutoReconnect(true),
}
for _, option := range defaultOptions {
option(cli)
}
for _, option := range options {
if option == nil {
continue
}
option(cli)
}
return cli
}
func (cli *StreamClient) Start(ctx context.Context) error {
if cli.conn != nil {
return nil
}
cli.mutex.Lock()
defer cli.mutex.Unlock()
if cli.conn != nil {
return nil
}
endpoint, err := cli.GetConnectionEndpoint(ctx)
if err != nil {
return err
}
wssUrl := fmt.Sprintf("%s?ticket=%s", endpoint.Endpoint, endpoint.Ticket)
header := make(http.Header)
var dialer *websocket.Dialer
if len(cli.proxy) == 0 {
dialer = websocket.DefaultDialer
} else {
proxyURL, err := url.Parse(cli.proxy)
if err != nil {
return err
}
dialer = &websocket.Dialer{
Proxy: http.ProxyURL(proxyURL),
}
}
conn, resp, err := dialer.Dial(wssUrl, header)
if err != nil {
return err
}
// 建连失败
if resp.StatusCode >= http.StatusBadRequest {
return utils.ErrorFromHttpResponseBody(resp)
}
cli.conn = conn
cli.sessionId = endpoint.Ticket
logger.GetLogger().Infof("connect success, sessionId=[%s]", cli.sessionId)
go cli.processLoop()
return nil
}
func (cli *StreamClient) processLoop() {
defer func() {
if err := recover(); err != nil {
logger.GetLogger().Errorf("connection process panic due to unknown reason, error=[%s]", err)
}
if cli.AutoReconnect {
go cli.reconnect()
}
}()
if cli.conn == nil {
logger.GetLogger().Errorf("connection process connect nil, maybe disconnected.")
return
}
readChan := make(chan []byte)
pongChan := make(chan struct{})
closeChan := make(chan struct{})
defer func() { close(closeChan) }()
defer func() { close(pongChan) }()
defer func() { close(readChan) }()
cli.conn.SetPongHandler(func(appData string) error {
pongChan <- struct{}{}
return nil
})
//开始启动协程读数据
go func() {
for {
messageType, message, err := cli.conn.ReadMessage()
if err != nil {
logger.GetLogger().Errorf("connection process read message error: messageType=[%d] message=[%s] error=[%s]", messageType, string(message), err)
closeChan <- struct{}{}
return
}
if messageType == websocket.TextMessage {
readChan <- message
}
}
}()
//循环处理事件
for {
timer := time.NewTimer(cli.keepAliveIdle)
select {
case msg, ok := <-readChan:
timer.Stop()
if ok {
go cli.processDataFrame(msg)
} else {
logger.GetLogger().Errorf("connection process is closed")
return
}
case <-timer.C:
e := cli.conn.WriteMessage(websocket.PingMessage, nil)
if e != nil {
logger.GetLogger().Errorf("connection write ping message error: error=[%s]", e)
return
}
go func() {
select {
case <-pongChan:
return
case <-time.After(5 * time.Second):
logger.GetLogger().Errorf("ping time out, connection is closing")
closeChan <- struct{}{}
return
}
}()
case <-closeChan:
return
}
}
}
func (cli *StreamClient) processDataFrame(rawData []byte) {
defer func() {
if err := recover(); err != nil {
logger.GetLogger().Errorf("connection processDataFrame panic, error=[%s]", err)
}
}()
dataFrame, err := payload.DecodeDataFrame(rawData)
if err != nil {
logger.GetLogger().Errorf("connection process decode data frame error: length=[%d] error=[%s]", len(rawData), err)
return
}
if dataFrame == nil || dataFrame.Headers == nil {
logger.GetLogger().Errorf("connection processDataFrame dataFrame nil.")
return
}
var dataAck *payload.DataFrameResponse
frameHandler, err := cli.GetHandler(dataFrame.Type, dataFrame.GetTopic())
if err != nil || frameHandler == nil {
// 没有注册handler返回404
dataAck = payload.NewDataFrameResponse(payload.DataFrameResponseStatusCodeKHandlerNotFound)
} else {
dataAck, err = frameHandler(context.Background(), dataFrame)
if err != nil && dataAck == nil {
dataAck = payload.NewErrorDataFrameResponse(err)
}
}
if dataAck == nil {
dataAck = payload.NewSuccessDataFrameResponse()
}
if dataAck.GetHeader(payload.DataFrameHeaderKMessageId) == "" {
dataAck.SetHeader(payload.DataFrameHeaderKMessageId, dataFrame.GetMessageId())
}
if dataAck.GetHeader(payload.DataFrameHeaderKContentType) == "" {
dataAck.SetHeader(payload.DataFrameHeaderKContentType, payload.DataFrameContentTypeKJson)
}
errSend := cli.SendDataFrameResponse(context.Background(), dataAck)
sentBytes, _ := json.Marshal(dataAck)
logger.GetLogger().Debugf("[wire] [websocket] local => remote:\n%s", string(sentBytes))
if errSend != nil {
logger.GetLogger().Errorf("connection processDataFrame send response error: error=[%s]", errSend)
}
}
func (cli *StreamClient) Close() {
if cli.conn == nil {
return
}
cli.mutex.Lock()
defer cli.mutex.Unlock()
if cli.conn == nil {
return
}
if err := cli.conn.Close(); err != nil {
logger.GetLogger().Errorf("StreamClient close. error=[%s]", err)
}
cli.conn = nil
cli.sessionId = ""
}
func (cli *StreamClient) reconnect() {
defer func() {
if err := recover(); err != nil {
logger.GetLogger().Errorf("reconect panic due to unknown reason. error=[%s]", err)
}
}()
cli.Close()
for {
err := cli.Start(context.Background())
if err != nil {
logger.GetLogger().Errorf("StreamClient reconnect error. error=[%s]", err)
time.Sleep(time.Second * 3)
} else {
logger.GetLogger().Infof("StreamClient reconnect success")
return
}
}
}
func (cli *StreamClient) GetHandler(stype, stopic string) (handler.IFrameHandler, error) {
subs := cli.subscriptions[stype]
if subs == nil || subs[stopic] == nil {
return nil, errors.New("HandlerNotRegistedForTypeTopic_" + stype + "_" + stopic)
}
return subs[stopic], nil
}
func (cli *StreamClient) CheckConfigValid() error {
if err := cli.AppCredential.Valid(); err != nil {
return err
}
if err := cli.UserAgent.Valid(); err != nil {
return err
}
if cli.subscriptions == nil {
return errors.New("subscriptionsNil")
}
for ttype, subs := range cli.subscriptions {
if _, ok := utils.SubscriptionTypeSet[ttype]; !ok {
return errors.New("UnKnownSubscriptionType_" + ttype)
}
if len(subs) <= 0 {
return errors.New("NoHandlersRegistedForType_" + ttype)
}
for ttopic, h := range subs {
if h == nil {
return errors.New("HandlerNilForTypeTopic_" + ttype + "_" + ttopic)
}
}
}
return nil
}
func (cli *StreamClient) GetConnectionEndpoint(ctx context.Context) (*payload.ConnectionEndpointResponse, error) {
if err := cli.CheckConfigValid(); err != nil {
return nil, err
}
requestModel := payload.ConnectionEndpointRequest{
ClientId: cli.AppCredential.ClientId,
ClientSecret: cli.AppCredential.ClientSecret,
UserAgent: cli.UserAgent.UserAgent,
Subscriptions: make([]*payload.SubscriptionModel, 0),
Extras: cli.extras,
}
if localIp, err := utils.GetFirstLanIP(); err == nil {
requestModel.LocalIP = localIp
}
for ttype, subs := range cli.subscriptions {
for ttopic := range subs {
requestModel.Subscriptions = append(requestModel.Subscriptions, &payload.SubscriptionModel{
Type: ttype,
Topic: ttopic,
})
}
}
requestJsonBody, _ := json.Marshal(requestModel)
var targetHost string
if len(cli.openApiHost) == 0 {
targetHost = utils.DefaultOpenApiHost
} else {
targetHost = cli.openApiHost
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetHost+utils.GetConnectionEndpointAPIUrl, bytes.NewReader(requestJsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
var transport http.RoundTripper
if len(cli.proxy) == 0 {
transport = http.DefaultTransport
} else {
proxyURL, err := url.Parse(cli.proxy)
if err != nil {
return nil, err
}
transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
httpClient := &http.Client{
Transport: transport,
Timeout: 5 * time.Second, //设置超时包含connection时间、任意重定向时间、读取response body时间
}
logger.GetLogger().Debugf("[wire] [http] local => remote:\n%s %s %s\nHost: %s\n%s\n\n%s",
req.Method, req.URL.RequestURI(), req.Proto, req.Host,
utils.DumpHeaders(req.Header), requestJsonBody)
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, utils.ErrorFromHttpResponseBody(resp)
}
defer resp.Body.Close()
responseJsonBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
logger.GetLogger().Debugf("[wire] [http] remote => localhost:\n%s %s\n%s\n\n%s",
resp.Proto, resp.Status,
utils.DumpHeaders(resp.Header), responseJsonBody)
endpoint := &payload.ConnectionEndpointResponse{}
if err := json.Unmarshal(responseJsonBody, endpoint); err != nil {
return nil, err
}
if err := endpoint.Valid(); err != nil {
return nil, err
}
return endpoint, nil
}
func (cli *StreamClient) OnDisconnect(ctx context.Context, df *payload.DataFrame) (*payload.DataFrameResponse, error) {
logger.GetLogger().Debugf("StreamClient.OnDisconnect")
cli.Close()
return nil, nil
}
func (cli *StreamClient) OnPing(ctx context.Context, df *payload.DataFrame) (*payload.DataFrameResponse, error) {
dfPong := payload.NewDataFrameAckPong(df.GetMessageId())
dfPong.Data = df.Data
return dfPong, nil
}
// 返回正常数据包
func (cli *StreamClient) SendDataFrameResponse(ctx context.Context, resp *payload.DataFrameResponse) error {
if resp == nil {
return errors.New("SendDataFrameResponseError_ResponseNil")
}
if cli.conn == nil {
logger.GetLogger().Errorf("SendDataFrameResponse error, conn nil, maybe disconnected.")
return errors.New("disconnected")
}
return cli.conn.WriteJSON(resp)
}
// 通用注册函数
func (cli *StreamClient) RegisterRouter(stype, stopic string, frameHandler handler.IFrameHandler) {
if cli.subscriptions == nil {
cli.subscriptions = make(map[string]map[string]handler.IFrameHandler)
}
if _, ok := cli.subscriptions[stype]; !ok {
cli.subscriptions[stype] = make(map[string]handler.IFrameHandler)
}
cli.subscriptions[stype][stopic] = frameHandler
}
// callback类型注册函数
func (cli *StreamClient) RegisterCallbackRouter(topic string, frameHandler handler.IFrameHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKCallback, topic, frameHandler)
}
// 聊天机器人的注册函数
func (cli *StreamClient) RegisterChatBotCallbackRouter(messageHandler chatbot.IChatBotMessageHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKCallback, payload.BotMessageCallbackTopic, chatbot.NewDefaultChatBotFrameHandler(messageHandler).OnEventReceived)
}
// AI插件的注册函数
func (cli *StreamClient) RegisterPluginCallbackRouter(messageHandler plugin.IPluginMessageHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKCallback, payload.PluginMessageCallbackTopic, plugin.NewDefaultPluginFrameHandler(messageHandler).OnEventReceived)
}
// 互动卡片的注册函数
func (cli *StreamClient) RegisterCardCallbackRouter(messageHandler card.ICardCallbackHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKCallback, payload.CardInstanceCallbackTopic, card.NewDefaultPluginFrameHandler(messageHandler).OnEventReceived)
}
// 事件类型的注册函数
func (cli *StreamClient) RegisterEventRouter(topic string, frameHandler handler.IFrameHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKEvent, topic, frameHandler)
}
// 所有事件的注册函数
func (cli *StreamClient) RegisterAllEventRouter(frameHandler handler.IFrameHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKEvent, "*", frameHandler)
}