l-dingtalk-stream-sdk-go/client/client.go

400 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package client
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/handler"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/logger"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/payload"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/plugin"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/utils"
)
/**
* @Author linya.jj
* @Date 2023/3/22 14:23
*/
type StreamClient struct {
AppCredential *AppCredentialConfig
UserAgent *UserAgentConfig
AutoReconnect bool
subscriptions map[string]map[string]handler.IFrameHandler
conn *websocket.Conn
sessionId string
mutex sync.Mutex
}
func NewStreamClient(options ...ClientOption) *StreamClient {
cli := &StreamClient{}
defaultOptions := []ClientOption{
WithSubscription(utils.SubscriptionTypeKSystem, "disconnect", cli.OnDisconnect),
WithSubscription(utils.SubscriptionTypeKSystem, "ping", cli.OnPing),
WithUserAgent(NewDingtalkGoSDKUserAgent()),
WithAutoReconnect(true),
}
for _, option := range defaultOptions {
option(cli)
}
for _, option := range options {
if option == nil {
continue
}
option(cli)
}
return cli
}
func (cli *StreamClient) Start(ctx context.Context) error {
if cli.conn != nil {
return nil
}
cli.mutex.Lock()
defer cli.mutex.Unlock()
if cli.conn != nil {
return nil
}
endpoint, err := cli.GetConnectionEndpoint(ctx)
if err != nil {
return err
}
wssUrl := fmt.Sprintf("%s?ticket=%s", endpoint.Endpoint, endpoint.Ticket)
header := make(http.Header)
conn, resp, err := websocket.DefaultDialer.Dial(wssUrl, header)
if err != nil {
return err
}
// 建连失败
if resp.StatusCode >= http.StatusBadRequest {
return utils.ErrorFromHttpResponseBody(resp)
}
cli.conn = conn
cli.sessionId = endpoint.Ticket
logger.GetLogger().Infof("connect success, sessionId=[%s]", cli.sessionId)
go cli.processLoop()
return nil
}
func (cli *StreamClient) processLoop() {
defer func() {
if err := recover(); err != nil {
logger.GetLogger().Errorf("connection process panic due to unknown reason, error=[%s]", err)
}
if cli.AutoReconnect {
go cli.reconnect()
}
}()
for {
if cli.conn == nil {
logger.GetLogger().Errorf("connection process connect nil, maybe disconnected.")
return
}
messageType, message, err := cli.conn.ReadMessage()
if err != nil {
logger.GetLogger().Errorf("connection process read message error: messageType=[%d] message=[%s] error=[%s]", messageType, string(message), err)
return
}
logger.GetLogger().Debugf("ReadRawMessage : messageType=[%d] message=[%s]", messageType, string(message))
go cli.processDataFrame(message)
}
}
func (cli *StreamClient) processDataFrame(rawData []byte) {
defer func() {
if err := recover(); err != nil {
logger.GetLogger().Errorf("connection processDataFrame panic, error=[%s]", err)
}
}()
dataFrame, err := payload.DecodeDataFrame(rawData)
if err != nil {
logger.GetLogger().Errorf("connection process decode data frame error: length=[%d] error=[%s]", len(rawData), err)
return
}
if dataFrame == nil || dataFrame.Headers == nil {
logger.GetLogger().Errorf("connection processDataFrame dataFrame nil.")
return
}
var dataAck *payload.DataFrameResponse
frameHandler, err := cli.GetHandler(dataFrame.Type, dataFrame.GetTopic())
if err != nil || frameHandler == nil {
// 没有注册handler返回404
dataAck = payload.NewDataFrameResponse(payload.DataFrameResponseStatusCodeKHandlerNotFound)
} else {
dataAck, err = frameHandler(context.Background(), dataFrame)
if err != nil && dataAck == nil {
dataAck = payload.NewErrorDataFrameResponse(err)
}
}
if dataAck == nil {
dataAck = payload.NewSuccessDataFrameResponse()
}
if dataAck.GetHeader(payload.DataFrameHeaderKMessageId) == "" {
dataAck.SetHeader(payload.DataFrameHeaderKMessageId, dataFrame.GetMessageId())
}
if dataAck.GetHeader(payload.DataFrameHeaderKContentType) == "" {
dataAck.SetHeader(payload.DataFrameHeaderKContentType, payload.DataFrameContentTypeKJson)
}
errSend := cli.SendDataFrameResponse(context.Background(), dataAck)
logger.GetLogger().Debugf("SendFrameAck dataAck=[%v", dataAck)
if errSend != nil {
logger.GetLogger().Errorf("connection processDataFrame send response error: error=[%s]", errSend)
}
}
func (cli *StreamClient) Close() {
if cli.conn == nil {
return
}
cli.mutex.Lock()
defer cli.mutex.Unlock()
if cli.conn == nil {
return
}
if err := cli.conn.Close(); err != nil {
logger.GetLogger().Errorf("StreamClient close. error=[%s]", err)
}
cli.conn = nil
cli.sessionId = ""
}
func (cli *StreamClient) reconnect() {
defer func() {
if err := recover(); err != nil {
logger.GetLogger().Errorf("reconect panic due to unknown reason. error=[%s]", err)
}
}()
cli.Close()
for {
err := cli.Start(context.Background())
if err != nil {
logger.GetLogger().Errorf("StreamClient reconnect error. error=[%s]", err)
time.Sleep(time.Second * 3)
} else {
logger.GetLogger().Infof("StreamClient reconnect success")
return
}
}
}
func (cli *StreamClient) GetHandler(stype, stopic string) (handler.IFrameHandler, error) {
subs := cli.subscriptions[stype]
if subs == nil || subs[stopic] == nil {
return nil, errors.New("HandlerNotRegistedForTypeTopic_" + stype + "_" + stopic)
}
return subs[stopic], nil
}
func (cli *StreamClient) CheckConfigValid() error {
if err := cli.AppCredential.Valid(); err != nil {
return err
}
if err := cli.UserAgent.Valid(); err != nil {
return err
}
if cli.subscriptions == nil {
return errors.New("subscriptionsNil")
}
for ttype, subs := range cli.subscriptions {
if _, ok := utils.SubscriptionTypeSet[ttype]; !ok {
return errors.New("UnKnownSubscriptionType_" + ttype)
}
if len(subs) <= 0 {
return errors.New("NoHandlersRegistedForType_" + ttype)
}
for ttopic, h := range subs {
if h == nil {
return errors.New("HandlerNilForTypeTopic_" + ttype + "_" + ttopic)
}
}
}
return nil
}
func (cli *StreamClient) GetConnectionEndpoint(ctx context.Context) (*payload.ConnectionEndpointResponse, error) {
if err := cli.CheckConfigValid(); err != nil {
return nil, err
}
requestModel := payload.ConnectionEndpointRequest{
ClientId: cli.AppCredential.ClientId,
ClientSecret: cli.AppCredential.ClientSecret,
UserAgent: cli.UserAgent.UserAgent,
Subscriptions: make([]*payload.SubscriptionModel, 0),
}
if localIp, err := utils.GetFirstLanIP(); err == nil {
requestModel.LocalIP = localIp
}
for ttype, subs := range cli.subscriptions {
for ttopic := range subs {
requestModel.Subscriptions = append(requestModel.Subscriptions, &payload.SubscriptionModel{
Type: ttype,
Topic: ttopic,
})
}
}
requestJsonBody, _ := json.Marshal(requestModel)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, utils.GetConnectionEndpointAPIUrl, bytes.NewReader(requestJsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
httpClient := &http.Client{
Transport: http.DefaultTransport,
Timeout: 5 * time.Second, //设置超时包含connection时间、任意重定向时间、读取response body时间
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, utils.ErrorFromHttpResponseBody(resp)
}
defer resp.Body.Close()
responseJsonBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
endpoint := &payload.ConnectionEndpointResponse{}
if err := json.Unmarshal(responseJsonBody, endpoint); err != nil {
return nil, err
}
if err := endpoint.Valid(); err != nil {
return nil, err
}
return endpoint, nil
}
func (cli *StreamClient) OnDisconnect(ctx context.Context, df *payload.DataFrame) (*payload.DataFrameResponse, error) {
logger.GetLogger().Debugf("StreamClient.OnDisconnect")
cli.Close()
return nil, nil
}
func (cli *StreamClient) OnPing(ctx context.Context, df *payload.DataFrame) (*payload.DataFrameResponse, error) {
dfPong := payload.NewDataFrameAckPong(df.GetMessageId())
dfPong.Data = df.Data
return dfPong, nil
}
// 返回正常数据包
func (cli *StreamClient) SendDataFrameResponse(ctx context.Context, resp *payload.DataFrameResponse) error {
if resp == nil {
return errors.New("SendDataFrameResponseError_ResponseNil")
}
if cli.conn == nil {
logger.GetLogger().Errorf("SendDataFrameResponse error, conn nil, maybe disconnected.")
return errors.New("disconnected")
}
return cli.conn.WriteJSON(resp)
}
// 通用注册函数
func (cli *StreamClient) RegisterRouter(stype, stopic string, frameHandler handler.IFrameHandler) {
if cli.subscriptions == nil {
cli.subscriptions = make(map[string]map[string]handler.IFrameHandler)
}
if _, ok := cli.subscriptions[stype]; !ok {
cli.subscriptions[stype] = make(map[string]handler.IFrameHandler)
}
cli.subscriptions[stype][stopic] = frameHandler
}
// callback类型注册函数
func (cli *StreamClient) RegisterCallbackRouter(topic string, frameHandler handler.IFrameHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKCallback, topic, frameHandler)
}
// 聊天机器人的注册函数
func (cli *StreamClient) RegisterChatBotCallbackRouter(messageHandler chatbot.IChatBotMessageHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKCallback, payload.BotMessageCallbackTopic, chatbot.NewDefaultChatBotFrameHandler(messageHandler).OnEventReceived)
}
// AI插件的注册函数
func (cli *StreamClient) RegisterPluginCallbackRouter(messageHandler plugin.IPluginMessageHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKCallback, payload.PluginMessageCallbackTopic, plugin.NewDefaultPluginFrameHandler(messageHandler).OnEventReceived)
}
// 事件类型的注册函数
func (cli *StreamClient) RegisterEventRouter(topic string, frameHandler handler.IFrameHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKEvent, topic, frameHandler)
}
// 事件类型的注册函数
func (cli *StreamClient) RegisterAllEventRouter(frameHandler handler.IFrameHandler) {
cli.RegisterRouter(utils.SubscriptionTypeKEvent, "*", frameHandler)
}