diff --git a/cmd/server/wire.go b/cmd/server/wire.go index e34c611..79f3667 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -5,6 +5,7 @@ package main import ( "ai_scheduler/internal/biz" + "ai_scheduler/internal/biz/handle/dingtalk" "ai_scheduler/internal/config" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/domain/workflow" @@ -31,6 +32,7 @@ func InitializeApp(*config.Config, log.AllLogger) (*server.Servers, func(), erro impl.ProviderImpl, utils.ProviderUtils, tools_bot.ProviderSetBotTools, + dingtalk.ProviderSetDingTalk, )) } diff --git a/config/config.yaml b/config/config.yaml index c1e8c08..4b568b4 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -23,7 +23,7 @@ redis: host: 47.97.27.195:6379 type: node pass: lansexiongdi@666 - key: report-api-test + key: report-api pollSize: 5 #连接池大小,不配置,或配置为0表示不启用连接池 minIdleConns: 2 #最小空闲连接数 maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭 diff --git a/config/config_test.yaml b/config/config_test.yaml index be2695d..cbeb64f 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -71,6 +71,20 @@ tools: zltxOrderAfterSaleResellerBatch: enabled: true base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/afterSales/reseller_pre_ai" + weather: + enabled: true + base_url: "https://restapi.amap.com/v3/weather/weatherInfo" + api_key: "12afbde5ab78cb7e575ff76bd0bdef2b" + cozeExpress: + enabled: true + base_url: "https://api.coze.cn" + api_key: "7582477438102552616" + api_secret: "pat_eEN0BdLNDughEtABjJJRYTW71olvDU0qUbfQUeaPc2NnYWO8HeyNoui5aR9z0sSZ" + cozeCompany: + enabled: true + base_url: "https://api.coze.cn" + api_key: "7583905168607100978" + api_secret: "pat_eEN0BdLNDughEtABjJJRYTW71olvDU0qUbfQUeaPc2NnYWO8HeyNoui5aR9z0sSZ" diff --git a/go.mod b/go.mod index 483c766..521dbdb 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,7 @@ require ( github.com/clbanning/mxj/v2 v2.5.5 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2 // indirect + github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -61,8 +62,9 @@ require ( github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/goph/emperror v0.17.2 // indirect - github.com/gorilla/websocket v1.5.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect diff --git a/go.sum b/go.sum index 6fcccec..6a18537 100644 --- a/go.sum +++ b/go.sum @@ -141,6 +141,8 @@ github.com/cloudwego/eino-ext/libs/acl/openai v0.1.2/go.mod h1:S4OkvglPY9hsm9tXe github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20 h1:m6P88V9lLrxZsE7uj9otq7l7nqDuCSAJ86KhzRlWf0M= +github.com/coze-dev/coze-go v0.0.0-20251029161603-312b7fd62d20/go.mod h1:wdT5CFt/sFsWz9hna2Z7DWzUra9spx0SoX1PUZyoSB0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -204,6 +206,8 @@ github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPArei github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w= github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -273,6 +277,8 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= diff --git a/internal/biz/ding_talk_bot.go b/internal/biz/ding_talk_bot.go index f2abea1..c67c8e5 100644 --- a/internal/biz/ding_talk_bot.go +++ b/internal/biz/ding_talk_bot.go @@ -2,12 +2,13 @@ package biz import ( "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/biz/handle/dingtalk" "ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" - "ai_scheduler/internal/pkg/mapstructure" "context" + "encoding/json" "fmt" "github.com/gofiber/fiber/v2/log" @@ -22,6 +23,7 @@ type DingTalkBotBiz struct { botConfigImpl *impl.BotConfigImpl replier *chatbot.ChatbotReplier log log.Logger + dingTalkUser *dingtalk.User } // NewDingTalkBotBiz @@ -29,12 +31,15 @@ func NewDingTalkBotBiz( do *do.Do, handle *do.Handle, botConfigImpl *impl.BotConfigImpl, + dingTalkUser *dingtalk.User, + ) *DingTalkBotBiz { return &DingTalkBotBiz{ do: do, handle: handle, botConfigImpl: botConfigImpl, replier: chatbot.NewChatbotReplier(), + dingTalkUser: dingTalkUser, } } @@ -47,10 +52,11 @@ func (d *DingTalkBotBiz) GetDingTalkBotCfgList() (dingBotList []entitys.DingTalk err = d.botConfigImpl.GetRangeToMapStruct(&cond, &botConfig) for _, v := range botConfig { var config entitys.DingTalkBot - err = mapstructure.Decode(v, &config) + err = json.Unmarshal([]byte(v.BotConfig), &config) if err != nil { d.log.Info("初始化“%s”失败:%s", v.BotName, err.Error()) } + config.BotIndex = v.BotIndex dingBotList = append(dingBotList, config) } return @@ -64,8 +70,8 @@ func (d *DingTalkBotBiz) InitRequire(ctx context.Context, data *chatbot.BotCallb } entitys.ResLog(requireData.Ch, "recognize_start", "收到消息,正在处理中,请稍等") - requireData.Sys, err = d.do.GetSysInfoForDingTalkBot(requireData) - requireData.Tasks, err = d.do.GetTasks(requireData.Sys.SysID) + requireData.UserInfo, err = d.dingTalkUser.GetUserInfoFromBot(ctx, data.SenderStaffId, dingtalk.WithId(1)) + return } diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index d252c47..80deb73 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -131,7 +131,8 @@ func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, require // 获取任务列表的辅助函数 func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if taskInfo := client.GetTasks(); len(taskInfo) == 0 { - tasks, err := d.GetTasks(requireData.Sys.SysID) + // 从数据库获取任务列表, 0 表示获取公共的任务 + tasks, err := d.GetTasks(requireData.Sys.SysID, 0) if err != nil { return err } @@ -140,6 +141,7 @@ func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireDa } else { requireData.Tasks = taskInfo } + return nil } @@ -225,7 +227,7 @@ func (d *Do) GetSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, e func (d *Do) GetSysInfoForDingTalkBot(requireData *entitys.RequireDataDingTalkBot) (sysInfo model.AiSy, err error) { cond := builder.NewCond() - cond = cond.And(builder.Eq{"app_key": requireData.Key}) + cond = cond.And(builder.Eq{"app_key": requireData.Auth}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) @@ -242,12 +244,12 @@ func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.Ai return } -func (d *Do) GetTasks(sysId int32) (tasks []model.AiTask, err error) { +func (d *Do) GetTasks(sysId ...int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() - cond = cond.And(builder.Eq{"sys_id": sysId}) + //cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.IsNull{"delete_at"}) - cond = cond.And(builder.Eq{"status": 1}) + cond = cond.And(builder.Eq{"status": 1}.And(builder.In("sys_id", sysId))) _, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "") return diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index e5b30b4..8a9162d 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -14,6 +14,7 @@ import ( "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/util" "ai_scheduler/internal/tools" + "ai_scheduler/internal/tools/public" "ai_scheduler/internal/tools_bot" "context" "encoding/json" @@ -52,6 +53,7 @@ func NewHandle( func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (match entitys.Match, err error) { entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别") + prompt, err := promptProcessor.CreatePrompt(ctx, rec) //意图识别 recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{ @@ -189,7 +191,7 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD return fmt.Errorf("tool not found: %s", configData.Tool) } - if knowledgeTool, ok := tool.(*tools.KnowledgeBaseTool); !ok { + if knowledgeTool, ok := tool.(*public.KnowledgeBaseTool); !ok { return fmt.Errorf("未找到知识库Tool: %s", configData.Tool) } else { host = knowledgeTool.GetConfig().BaseURL @@ -200,7 +202,7 @@ func (r *Handle) handleKnowle(ctx context.Context, requireData *entitys.RequireD // 知识库的session为空,请求知识库获取, 并绑定 if requireData.SessionInfo.KnowlegeSessionID == "" { // 请求知识库 - if sessionIdKnowledge, err = tools.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { + if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, requireData.Sys.KnowlegeBaseID, requireData.Sys.KnowlegeTenantKey); err != nil { return } diff --git a/internal/biz/handle/dingtalk/auth.go b/internal/biz/handle/dingtalk/auth.go new file mode 100644 index 0000000..2b359a9 --- /dev/null +++ b/internal/biz/handle/dingtalk/auth.go @@ -0,0 +1,118 @@ +package dingtalk + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" + "ai_scheduler/utils" + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/gofiber/fiber/v2/log" + "github.com/redis/go-redis/v9" + "xorm.io/builder" +) + +type Auth struct { + redis *redis.Client + cfg *config.Config + botConfigImpl *impl.BotConfigImpl +} + +func NewAuth(cfg *config.Config, redis *utils.Rdb, botConfigImpl *impl.BotConfigImpl) *Auth { + return &Auth{ + redis: redis.Rdb, + cfg: cfg, + botConfigImpl: botConfigImpl, + } +} + +func (a *Auth) GetAccessToken(ctx context.Context, clientId string, clientSecret string) (authInfo *AuthInfo, err error) { + if clientId == "" { + return nil, errors.New("clientId is empty") + } + accessToken := a.redis.Get(ctx, a.getKey(clientId)).Val() + if accessToken == "" { + dingTalkAuthRes, _err := a.getNewAccessToken(ctx, clientId, clientSecret) + if _err != nil { + return nil, _err + } + err = a.redis.SetEx(ctx, a.getKey(clientId), dingTalkAuthRes.AccessToken, time.Duration(dingTalkAuthRes.ExpireIn-3600)*time.Second).Err() + if err != nil { + return + } + accessToken = dingTalkAuthRes.AccessToken + } + return &AuthInfo{ + ClientId: clientId, + ClientSecret: clientSecret, + AccessToken: accessToken, + }, nil +} + +func (a *Auth) getKey(clientId string) string { + return a.cfg.Redis.Key + ":" + constants.DingTalkAuthBaseKeyPrefix + ":" + clientId +} + +func (a *Auth) getNewAccessToken(ctx context.Context, clientId string, clientSecret string) (auth DingTalkAuthIRes, err error) { + if clientId == "" || clientSecret == "" { + err = errors.New("clientId or clientSecret is empty") + return + } + + req := l_request.Request{ + Method: http.MethodPost, + Url: "https://api.dingtalk.com/v1.0/oauth2/accessToken", + Json: map[string]interface{}{ + "appKey": clientId, + "appSecret": clientSecret, + }, + } + res, err := req.Send() + if err != nil { + return + } + err = json.Unmarshal(res.Content, &auth) + + return +} + +func (a *Auth) GetTokenFromBotOption(ctx context.Context, botOption ...BotOption) (token *AuthInfo, err error) { + botInfo := &Bot{} + for _, option := range botOption { + option(botInfo) + } + + if botInfo.id == 0 && botInfo.botConfig == nil { + err = errors.New("botInfo is nil") + return + } + if botInfo.botConfig == nil { + var botConfigDo model.AiBotConfig + cond := builder.NewCond() + cond = cond.And(builder.Eq{"bot_id": botInfo.id}) + err = a.botConfigImpl.GetOneBySearchToStrut(&cond, &botConfigDo) + if err != nil { + return + } + if botConfigDo.BotID == 0 { + err = errors.New("未找到机器人服务配置") + return + } + botInfo.botConfig = &botConfigDo + } + var botConfig entitys.DingTalkBot + err = json.Unmarshal([]byte(botInfo.botConfig.BotConfig), &botConfig) + if err != nil { + log.Infof("初始化“%s”失败:%s", botInfo.botConfig.BotName, err.Error()) + return + } + return a.GetAccessToken(ctx, botConfig.ClientId, botConfig.ClientSecret) + +} diff --git a/internal/biz/handle/dingtalk/dept.go b/internal/biz/handle/dingtalk/dept.go new file mode 100644 index 0000000..0a6bd15 --- /dev/null +++ b/internal/biz/handle/dingtalk/dept.go @@ -0,0 +1,107 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + "context" + "encoding/json" + "fmt" + "net/http" + + "xorm.io/builder" +) + +type Dept struct { + dingDeptImpl *impl.BotDeptImpl + auth *Auth +} + +func NewDept(dingDeptImpl *impl.BotDeptImpl, auth *Auth) *Dept { + return &Dept{ + dingDeptImpl: dingDeptImpl, + auth: auth, + } +} + +func (d *Dept) GetDeptInfoByDeptIds(ctx context.Context, deptIds []int, authInfo *AuthInfo) (depts []*entitys.Dept, err error) { + if len(deptIds) == 0 || authInfo == nil { + return + } + var deptsInfo []model.AiBotDept + cond := builder.NewCond() + cond = cond.And(builder.Eq{"dingtalk_dept_id": deptIds}) + err = d.dingDeptImpl.GetRangeToMapStruct(&cond, &deptsInfo) + if err != nil { + return + } + var existDept = make([]int, len(deptsInfo), 0) + for _, dept := range deptsInfo { + depts = append(depts, &entitys.Dept{ + DeptId: int(dept.DeptID), + Name: dept.Name, + }) + existDept = append(existDept, int(dept.DeptID)) + } + diff := pkg.Difference(deptIds, existDept) + if len(diff) > 0 { + deptDo := make([]model.AiBotDept, 0) + for _, deptId := range diff { + deptInfo, _err := d.GetDeptInfoFromDingTalk(ctx, deptId, authInfo.AccessToken) + if _err != nil { + return nil, _err + } + depts = append(depts, &entitys.Dept{ + DeptId: deptInfo.DeptId, + Name: deptInfo.Name, + }) + deptDo = append(deptDo, model.AiBotDept{ + DingtalkDeptID: int32(deptInfo.DeptId), + Name: deptInfo.Name, + }) + } + if len(deptDo) > 0 { + _, err = d.dingDeptImpl.Add(deptDo) + if err != nil { + return nil, err + } + } + } + + return +} + +func (d *Dept) GetDeptInfoFromDingTalk(ctx context.Context, deptId int, token string) (depts DeptResResult, err error) { + if deptId == 0 || len(token) == 0 { + return + } + + req := l_request.Request{ + Method: http.MethodPost, + Url: constants.GetDingTalkRequestUrl(constants.RequestUrlGetDeptGet, map[string]string{ + "access_token": token, + }), + Json: map[string]interface{}{ + "dept_id": deptId, + }, + } + res, _err := req.Send() + if _err != nil { + err = _err + return + } + var deptInfo DeptRes + + err = json.Unmarshal(res.Content, &deptInfo) + if err != nil { + return + } + if deptInfo.Errcode != 0 { + fmt.Errorf("钉钉请求报错:%s", deptInfo.Errmsg) + } + return deptInfo.DeptResResult, err + +} diff --git a/internal/biz/handle/dingtalk/option.go b/internal/biz/handle/dingtalk/option.go new file mode 100644 index 0000000..fb473c7 --- /dev/null +++ b/internal/biz/handle/dingtalk/option.go @@ -0,0 +1,21 @@ +package dingtalk + +import "ai_scheduler/internal/data/model" + +type Bot struct { + id int + botConfig *model.AiBotConfig +} +type BotOption func(*Bot) + +func WithId(id int) BotOption { + return func(b *Bot) { + b.id = id + } +} + +func WithBootConfig(BotConfig *model.AiBotConfig) BotOption { + return func(bot *Bot) { + bot.botConfig = BotConfig + } +} diff --git a/internal/biz/handle/dingtalk/provider_set.go b/internal/biz/handle/dingtalk/provider_set.go new file mode 100644 index 0000000..579d464 --- /dev/null +++ b/internal/biz/handle/dingtalk/provider_set.go @@ -0,0 +1,11 @@ +package dingtalk + +import ( + "github.com/google/wire" +) + +var ProviderSetDingTalk = wire.NewSet( + NewUser, + NewAuth, + NewDept, +) diff --git a/internal/biz/handle/dingtalk/types.go b/internal/biz/handle/dingtalk/types.go new file mode 100644 index 0000000..a36ea76 --- /dev/null +++ b/internal/biz/handle/dingtalk/types.go @@ -0,0 +1,84 @@ +package dingtalk + +import "time" + +type DingTalkAuthIRes struct { + AccessToken string `json:"accessToken"` + ExpireIn int64 `json:"expireIn"` +} + +type UserInfoRes struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + Result UserInfoResResult `json:"result"` + RequestId string `json:"request_id"` +} + +type UserInfoResResult struct { + Active bool `json:"active"` + Admin bool `json:"admin"` + Avatar string `json:"avatar"` + Boss bool `json:"boss"` + CreateTime time.Time `json:"create_time"` + DeptIdList []int `json:"dept_id_list"` + DeptOrderList []struct { + DeptId int `json:"dept_id"` + Order int64 `json:"order"` + } `json:"dept_order_list"` + ExclusiveAccount bool `json:"exclusive_account"` + HideMobile bool `json:"hide_mobile"` + HiredDate int64 `json:"hired_date"` + JobNumber string `json:"job_number"` + LeaderInDept []struct { + DeptId int `json:"dept_id"` + Leader bool `json:"leader"` + } `json:"leader_in_dept"` + ManagerUserid string `json:"manager_userid"` + Name string `json:"name"` + RealAuthed bool `json:"real_authed"` + RoleList []struct { + GroupName string `json:"group_name"` + Id int `json:"id"` + Name string `json:"name"` + } `json:"role_list"` + Senior bool `json:"senior"` + Title string `json:"title"` + Unionid string `json:"unionid"` + Userid string `json:"userid"` +} + +type DeptRes struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + DeptResResult DeptResResult `json:"result"` + RequestId string `json:"request_id"` +} + +type DeptResResult struct { + DeptPermits []int `json:"dept_permits"` + OuterPermitUsers []string `json:"outer_permit_users"` + DeptManagerUseridList []string `json:"dept_manager_userid_list"` + OrgDeptOwner string `json:"org_dept_owner"` + OuterDept bool `json:"outer_dept"` + DeptGroupChatId string `json:"dept_group_chat_id"` + GroupContainSubDept bool `json:"group_contain_sub_dept"` + AutoAddUser bool `json:"auto_add_user"` + HideDept bool `json:"hide_dept"` + Name string `json:"name"` + OuterPermitDepts []int `json:"outer_permit_depts"` + UserPermits []interface{} `json:"user_permits"` + DeptId int `json:"dept_id"` + CreateDeptGroup bool `json:"create_dept_group"` + Order int `json:"order"` + Code string `json:"code"` + UnionDeptExt struct { + CorpId string `json:"corp_id"` + DeptId int `json:"dept_id"` + } `json:"union_dept_ext"` +} + +type AuthInfo struct { + ClientId string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + AccessToken string `json:"accessToken"` +} diff --git a/internal/biz/handle/dingtalk/user.go b/internal/biz/handle/dingtalk/user.go new file mode 100644 index 0000000..900a492 --- /dev/null +++ b/internal/biz/handle/dingtalk/user.go @@ -0,0 +1,126 @@ +package dingtalk + +import ( + "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/l_request" + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +type User struct { + dingUserImpl *impl.BotUserImpl + botConfigImpl *impl.BotConfigImpl + auth *Auth + dept *Dept +} + +func NewUser( + dingUserImpl *impl.BotUserImpl, + auth *Auth, + dept *Dept, +) *User { + return &User{ + dingUserImpl: dingUserImpl, + auth: auth, + dept: dept, + } +} + +func (u *User) GetUserInfoFromBot(ctx context.Context, staffId string, botOption ...BotOption) (userInfo *entitys.DingTalkUserInfo, err error) { + if len(staffId) == 0 { + return + } + user, err := u.dingUserImpl.GetByStaffId(staffId) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return + } + } + authInfo, err := u.auth.GetTokenFromBotOption(ctx, botOption...) + if err != nil || authInfo == nil { + return + } + //如果没有找到,则新增 + if user == nil { + DingUserInfo, _err := u.getUserInfoFromDingTalk(ctx, authInfo.AccessToken, staffId) + if _err != nil { + return nil, _err + } + user = &model.AiBotUser{ + StaffID: DingUserInfo.Userid, + Name: DingUserInfo.Name, + Title: DingUserInfo.Title, + //Extension: DingUserInfo.Extension, + DeptIDList: strings.Join(pkg.SliceIntToString(DingUserInfo.DeptIdList), ","), + IsBoss: int32(pkg.Ter(DingUserInfo.Boss, constants.IsBossTrue, constants.IsBossFalse)), + IsSenior: int32(pkg.Ter(DingUserInfo.Senior, constants.IsSeniorTrue, constants.IsSeniorFalse)), + HiredDate: time.UnixMilli(DingUserInfo.HiredDate), + } + + _, err = u.dingUserImpl.Add(user) + if err != nil { + return + } + } + userInfo = &entitys.DingTalkUserInfo{ + UserId: int(user.UserID), + StaffId: user.StaffID, + Name: user.Name, + IsBoss: constants.IsBoss(user.IsBoss), + IsSenior: constants.IsSenior(user.IsSenior), + HiredDate: user.HiredDate, + Extension: user.Extension, + } + if len(user.DeptIDList) > 0 { + deptIdList := pkg.SliceStringToInt(strings.Split(user.DeptIDList, ",")) + depts, _err := u.dept.GetDeptInfoByDeptIds(ctx, deptIdList, authInfo) + if _err != nil { + return nil, err + } + for _, dept := range depts { + userInfo.Dept = append(userInfo.Dept, dept) + } + } + + return userInfo, nil +} + +func (u *User) getUserInfoFromDingTalk(ctx context.Context, token string, staffId string) (user UserInfoResResult, err error) { + if token == "" && staffId == "" { + err = errors.New("获取钉钉用户信息的必要参数不足") + return + } + + req := l_request.Request{ + Method: http.MethodPost, + Url: constants.GetDingTalkRequestUrl(constants.RequestUrlGetUserGet, map[string]string{ + "access_token": token, + }), + Data: map[string]string{ + "userid": staffId, + }, + } + res, err := req.Send() + if err != nil { + return + } + var userInfoRes UserInfoRes + err = json.Unmarshal(res.Content, &userInfoRes) + if err != nil { + return + } + if userInfoRes.Errcode != 0 { + fmt.Errorf("钉钉请求报错:%s", userInfoRes.Errmsg) + } + return userInfoRes.Result, err +} diff --git a/internal/biz/router.go b/internal/biz/router.go index a711445..5127a54 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -7,6 +7,7 @@ import ( "ai_scheduler/internal/gateway" "context" "encoding/json" + "strings" "time" "ai_scheduler/internal/entitys" @@ -66,6 +67,7 @@ func (r *AiRouterBiz) RouteWithSocket(client *gateway.Client, req *entitys.ChatS return } //意图识别 + requireData.Match, err = r.handle.Recognize(ctx, &rec, sys) if err != nil { log.Errorf("意图识别失败: %s", err.Error()) @@ -136,9 +138,13 @@ func (r *AiRouterBiz) buildUserContent(requireData *entitys.RequireData) (*entit // 处理文件和图片 fileUrls := []string{requireData.Req.File, requireData.Req.Img} - for _, url := range fileUrls { - if url != "" { - files = append(files, &entitys.RecognizeFile{FileUrl: url}) + for _, item := range fileUrls { + // 处理逗号分隔的多个URL + urlList := strings.Split(item, ",") + for _, url := range urlList { + if url != "" { + files = append(files, &entitys.RecognizeFile{FileUrl: url}) + } } } diff --git a/internal/config/config.go b/internal/config/config.go index 5f9a5d7..91115cb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -129,6 +129,10 @@ type ToolsConfig struct { ZltxOrderAfterSaleReseller ToolConfig `mapstructure:"zltxOrderAfterSaleReseller"` // 下游批充订单售后 ZltxOrderAfterSaleResellerBatch ToolConfig `mapstructure:"zltxOrderAfterSaleResellerBatch"` + // Coze 快递查询工具 + CozeExpress ToolConfig `mapstructure:"cozeExpress"` + // Coze 公司查询工具 + CozeCompany ToolConfig `mapstructure:"cozeCompany"` } // ToolConfig 单个工具配置 diff --git a/internal/data/constants/bot.go b/internal/data/constants/bot.go index 2b24cfe..446519a 100644 --- a/internal/data/constants/bot.go +++ b/internal/data/constants/bot.go @@ -31,3 +31,5 @@ type BotType int const ( BotTypeDingTalk BotType = 1 // 系统的bug/优化建议 ) + +const DingTalkAuthBaseKeyPrefix = "dingTalk_auth" diff --git a/internal/data/constants/const.go b/internal/data/constants/const.go index 151d259..604848e 100644 --- a/internal/data/constants/const.go +++ b/internal/data/constants/const.go @@ -31,3 +31,5 @@ var UseFulMap = map[UseFul]string{ UseFulNotUnclear: "回答不明确", UseFulNotError: "理解错误", } + +type BaseBool int32 diff --git a/internal/data/constants/dingtalk.go b/internal/data/constants/dingtalk.go new file mode 100644 index 0000000..5c9de89 --- /dev/null +++ b/internal/data/constants/dingtalk.go @@ -0,0 +1,38 @@ +package constants + +import "net/url" + +const DingTalkBseUrl = "https://oapi.dingtalk.com" + +type RequestUrl string + +const ( + RequestUrlGetUserGet RequestUrl = "/topapi/v2/user/get" + RequestUrlGetDeptGet RequestUrl = "/topapi/v2/department/get" +) + +func GetDingTalkRequestUrl(path RequestUrl, query map[string]string) string { + u, _ := url.Parse(DingTalkBseUrl + string(path)) + q := u.Query() + for key, val := range query { + q.Add(key, val) + } + u.RawQuery = q.Encode() + return u.String() +} + +// IsBoss 是否是老板 +type IsBoss int + +const ( + IsBossTrue IsBoss = 1 + IsBossFalse IsBoss = 0 +) + +// IsSenior 是否是老板 +type IsSenior int + +const ( + IsSeniorTrue IsSenior = 1 + IsSeniorFalse IsSenior = 0 +) diff --git a/internal/data/constants/user_test.go b/internal/data/constants/user_test.go new file mode 100644 index 0000000..eee1b17 --- /dev/null +++ b/internal/data/constants/user_test.go @@ -0,0 +1,91 @@ +package constants + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/biz/do" + "ai_scheduler/internal/biz/llm_service" + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/gateway" + "ai_scheduler/internal/pkg/dingtalk" + "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/pkg/utils_vllm" + "ai_scheduler/internal/server" + "ai_scheduler/internal/services" + "ai_scheduler/internal/tools" + "ai_scheduler/internal/tools_bot" + "ai_scheduler/utils" + "os" + "testing" +) + +const + +func TestMain(m *testing.M) { + bootstrap := initialize.LoadConfigWithTest() + businessLogger := log2.NewBusinessLogger(bootstrap.Logs, id, Name, Version) + helper := pkg.NewLogHelper(businessLogger, bootstrap) + + db, cleanup := utils.NewGormDb(configConfig) + sysImpl := impl.NewSysImpl(db) + taskImpl := impl.NewTaskImpl(db) + chatHisImpl := impl.NewChatHisImpl(db) + doDo := do.NewDo(sysImpl, taskImpl, chatHisImpl, configConfig) + client, cleanup2, err := utils_ollama.NewClient(configConfig) + if err != nil { + cleanup() + return nil, nil, err + } + utils_vllmClient, cleanup3, err := utils_vllm.NewClient(configConfig) + if err != nil { + cleanup2() + cleanup() + return nil, nil, err + } + ollamaService := llm_service.NewOllamaGenerate(client, utils_vllmClient, configConfig, chatHisImpl) + manager := tools.NewManager(configConfig, client) + sessionImpl := impl.NewSessionImpl(db) + botTool := tools_bot.NewBotTool(configConfig, client, sessionImpl) + handle := do.NewHandle(ollamaService, manager, configConfig, sessionImpl, botTool) + aiRouterBiz := biz.NewAiRouterBiz(doDo, handle) + chatHistoryBiz := biz.NewChatHistoryBiz(chatHisImpl, taskImpl) + gatewayGateway := gateway.NewGateway() + chatService := services.NewChatService(aiRouterBiz, chatHistoryBiz, gatewayGateway, configConfig) + sessionBiz := biz.NewSessionBiz(configConfig, sessionImpl, sysImpl, chatHisImpl) + sessionService := services.NewSessionService(sessionBiz, chatHistoryBiz) + taskBiz := biz.NewTaskBiz(configConfig, taskImpl) + taskService := services.NewTaskService(sessionBiz, taskBiz) + oldClient := dingtalk.NewOldClient(configConfig) + contactClient, err := dingtalk.NewContactClient(configConfig) + if err != nil { + cleanup3() + cleanup2() + cleanup() + return nil, nil, err + } + notableClient, err := dingtalk.NewNotableClient(configConfig) + if err != nil { + cleanup3() + cleanup2() + cleanup() + return nil, nil, err + } + callbackService := services.NewCallbackService(configConfig, gatewayGateway, oldClient, contactClient, notableClient, botTool) + historyService := services.NewHistoryService(chatHistoryBiz) + app := server.NewHTTPServer(chatService, sessionService, taskService, gatewayGateway, callbackService, historyService) + botConfigImpl := impl.NewBotConfigImpl(db) + dingTalkBotBiz := biz.NewDingTalkBotBiz(doDo, handle, botConfigImpl) + dingBotService := services.NewDingBotService(configConfig, dingTalkBotBiz) + v := server.ProvideAllDingBotServices(dingBotService) + dingTalkBotServer := server.NewDingTalkBotServer(v) + servers := server.NewServers(configConfig, app, dingTalkBotServer) + code := m.Run() + os.Exit(code) +} + +func Test_GetUserInfo(t *testing.T) { + var c entitys.TaskConfig + config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` + err := json.Unmarshal([]byte(config), &c) + t.Log(err) +} diff --git a/internal/data/impl/bot_dept.go b/internal/data/impl/bot_dept.go new file mode 100644 index 0000000..8ae0e4b --- /dev/null +++ b/internal/data/impl/bot_dept.go @@ -0,0 +1,17 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" +) + +type BotDeptImpl struct { + dataTemp.DataTemp +} + +func NewBotDeptImpl(db *utils.Db) *BotDeptImpl { + return &BotDeptImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotDept)), + } +} diff --git a/internal/data/impl/bot_user.go b/internal/data/impl/bot_user.go new file mode 100644 index 0000000..46d8442 --- /dev/null +++ b/internal/data/impl/bot_user.go @@ -0,0 +1,27 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "ai_scheduler/tmpl/dataTemp" + "ai_scheduler/utils" + "database/sql" +) + +type BotUserImpl struct { + dataTemp.DataTemp +} + +func NewBotUserImpl(db *utils.Db) *BotUserImpl { + return &BotUserImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiBotUser)), + } +} + +func (k BotUserImpl) GetByStaffId(staffId string) (data *model.AiBotUser, err error) { + + err = k.Db.Model(k.Model).Where("staff_id = ?", staffId).Find(data).Error + if data == nil { + err = sql.ErrNoRows + } + return +} diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index 53d389c..1563258 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -10,4 +10,7 @@ var ProviderImpl = wire.NewSet( NewTaskImpl, NewChatHisImpl, NewBotConfigImpl, + NewBotDeptImpl, + NewBotUserImpl, + NewBotChatHisImpl, ) diff --git a/internal/data/model/ai_bot_config.gen.go b/internal/data/model/ai_bot_config.gen.go index 5c4f27a..6885e81 100644 --- a/internal/data/model/ai_bot_config.gen.go +++ b/internal/data/model/ai_bot_config.gen.go @@ -14,9 +14,10 @@ const TableNameAiBotConfig = "ai_bot_config" type AiBotConfig struct { BotID int32 `gorm:"column:bot_id;primaryKey;autoIncrement:true" json:"bot_id"` SysID int32 `gorm:"column:sys_id;not null" json:"sys_id"` - BotType int32 `gorm:"column:bot_type;not null;default:1" json:"bot_type"` - BotName string `gorm:"column:bot_name;not null" json:"bot_name"` - BotConfig string `gorm:"column:bot_config;not null" json:"bot_config"` + BotType int32 `gorm:"column:bot_type;not null;default:1;comment:类型,1为钉钉机器人" json:"bot_type"` // 类型,1为钉钉机器人 + BotName string `gorm:"column:bot_name;not null;comment:名字" json:"bot_name"` // 名字 + BotConfig string `gorm:"column:bot_config;not null;comment:配置" json:"bot_config"` // 配置 + BotIndex string `gorm:"column:bot_index;not null;comment:索引" json:"bot_index"` // 索引 CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` Status int32 `gorm:"column:status;not null" json:"status"` diff --git a/internal/data/model/ai_bot_dept.gen.go b/internal/data/model/ai_bot_dept.gen.go new file mode 100644 index 0000000..db8ba60 --- /dev/null +++ b/internal/data/model/ai_bot_dept.gen.go @@ -0,0 +1,26 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotDept = "ai_bot_dept" + +// AiBotDept mapped from table +type AiBotDept struct { + DeptID int32 `gorm:"column:dept_id;primaryKey" json:"dept_id"` + DingtalkDeptID int32 `gorm:"column:dingtalk_dept_id;not null;comment:标记部门的唯一id,钉钉:钉钉侧提供的dept_id" json:"dingtalk_dept_id"` // 标记部门的唯一id,钉钉:钉钉侧提供的dept_id + Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` +} + +// TableName AiBotDept's table name +func (*AiBotDept) TableName() string { + return TableNameAiBotDept +} diff --git a/internal/data/model/ai_bot_user.gen.go b/internal/data/model/ai_bot_user.gen.go new file mode 100644 index 0000000..5783e51 --- /dev/null +++ b/internal/data/model/ai_bot_user.gen.go @@ -0,0 +1,33 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameAiBotUser = "ai_bot_user" + +// AiBotUser mapped from table +type AiBotUser struct { + UserID int32 `gorm:"column:user_id;primaryKey" json:"user_id"` + StaffID string `gorm:"column:staff_id;not null;comment:标记用户用的唯一id,钉钉:钉钉侧提供的user_id" json:"staff_id"` // 标记用户用的唯一id,钉钉:钉钉侧提供的user_id + Name string `gorm:"column:name;not null;comment:用户名称" json:"name"` // 用户名称 + Title string `gorm:"column:title;not null;comment:职位" json:"title"` // 职位 + Extension string `gorm:"column:extension;not null;default:1;comment:信息面板" json:"extension"` // 信息面板 + RoleList string `gorm:"column:role_list;not null;comment:角色列表。" json:"role_list"` // 角色列表。 + DeptIDList string `gorm:"column:dept_id_list;not null;comment:所在部门id列表" json:"dept_id_list"` // 所在部门id列表 + IsBoss int32 `gorm:"column:is_boss;not null;comment:是否是老板" json:"is_boss"` // 是否是老板 + IsSenior int32 `gorm:"column:is_senior;not null;comment:是否是高管" json:"is_senior"` // 是否是高管 + HiredDate time.Time `gorm:"column:hired_date;not null;default:CURRENT_TIMESTAMP;comment:入职时间" json:"hired_date"` // 入职时间 + Status int32 `gorm:"column:status;not null" json:"status"` + DeleteAt *time.Time `gorm:"column:delete_at" json:"delete_at"` + CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` +} + +// TableName AiBotUser's table name +func (*AiBotUser) TableName() string { + return TableNameAiBotUser +} diff --git a/internal/entitys/bot.go b/internal/entitys/bot.go index 568822c..af92ba8 100644 --- a/internal/entitys/bot.go +++ b/internal/entitys/bot.go @@ -8,11 +8,8 @@ import ( ) type RequireDataDingTalkBot struct { - Session string - Key string - Sys model.AiSy Histories []model.AiChatHi - SessionInfo model.AiSession + UserInfo *DingTalkUserInfo Tasks []model.AiTask Match *Match Req *chatbot.BotCallbackDataModel @@ -24,7 +21,7 @@ type RequireDataDingTalkBot struct { } type DingTalkBot struct { - BotIndex string - ClientId string - ClientSecret string + BotIndex string `json:"bot_index"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` } diff --git a/internal/entitys/dingtalk.go b/internal/entitys/dingtalk.go new file mode 100644 index 0000000..3595bf3 --- /dev/null +++ b/internal/entitys/dingtalk.go @@ -0,0 +1,22 @@ +package entitys + +import ( + "ai_scheduler/internal/data/constants" + "time" +) + +type DingTalkUserInfo struct { + UserId int `json:"user_id"` + StaffId string `json:"staff_id"` + Name string `json:"name"` + Dept []*Dept `json:"dept"` + IsBoss constants.IsBoss `json:"is_boss"` + IsSenior constants.IsSenior `json:"is_senior"` + HiredDate time.Time `json:"hired_date"` + Extension string `json:"extension"` +} + +type Dept struct { + Name string `json:"name"` + DeptId int `json:"dept_id"` +} diff --git a/internal/entitys/recognize.go b/internal/entitys/recognize.go index 45d9fcc..3d68be7 100644 --- a/internal/entitys/recognize.go +++ b/internal/entitys/recognize.go @@ -1,8 +1,6 @@ package entitys -import ( - "ai_scheduler/internal/data/constants" -) +import "ai_scheduler/internal/data/constants" type Recognize struct { SystemPrompt string // 系统提示内容 @@ -28,7 +26,7 @@ type RecognizeUserContent struct { type FileData []byte type RecognizeFile struct { - File []FileData // 文件数据(二进制格式) - FileUrl string // 文件下载链接 + FileData FileData // 文件数据(二进制格式) FileType constants.Caller // 文件类型(文件类型,能填最好填,可以跳过一层判断) + FileUrl string // 文件下载链接 } diff --git a/internal/entitys/types.go b/internal/entitys/types.go index 72ecd6c..b8f8caa 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -30,7 +30,7 @@ type FirstSockRequest struct { type ChatSockRequest struct { Text string `json:"text" binding:"required"` - Img string `json:"img" binding:"required"` + Img string `json:"img" binding:"required"` // 多图片使用 英文, 分割 File string `json:"file" binding:"required"` Tags string `json:"tags" binding:"required"` MarkHis int64 `json:"mark_his" ` diff --git a/internal/pkg/func.go b/internal/pkg/func.go index f6006ac..57321bd 100644 --- a/internal/pkg/func.go +++ b/internal/pkg/func.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/url" + "strconv" "strings" ) @@ -65,3 +66,70 @@ func HexEncode(src, dst []byte) int { } return len(src) * 2 } + +// Ter 三目运算 Ter(true, 1, 2) +func Ter[T any](cond bool, a, b T) T { + if cond { + return a + } + return b +} + +// StringToSlice [num,num]转slice +func StringToSlice(s string) ([]int, error) { + // 1. 去掉两端的方括号 + trimmed := strings.Trim(s, "[]") + + // 2. 按逗号分割 + parts := strings.Split(trimmed, ",") + + // 3. 转换为 []int + result := make([]int, 0, len(parts)) + for _, part := range parts { + num, err := strconv.Atoi(strings.TrimSpace(part)) + if err != nil { + return nil, err + } + result = append(result, num) + } + return result, nil +} + +// Difference 差集 +func Difference[T comparable](a, b []T) []T { + // 创建 b 的映射(T 必须是可比较的类型) + bMap := make(map[T]struct{}, len(b)) + for _, item := range b { + bMap[item] = struct{}{} + } + + var diff []T // 修正为 []T 而非 []int + for _, item := range a { + if _, found := bMap[item]; !found { + diff = append(diff, item) + } + } + return diff +} + +// SliceStringToInt []string=>[]int +func SliceStringToInt(strSlice []string) []int { + numSlice := make([]int, len(strSlice)) + for i, str := range strSlice { + num, err := strconv.Atoi(str) + if err != nil { + return nil + } + numSlice[i] = num + } + return numSlice +} + +// SliceIntToString []int=>[]string +func SliceIntToString(slice []int) []string { + strSlice := make([]string, len(slice)) // len=cap=len(slice) + for i, num := range slice { + strSlice[i] = strconv.Itoa(num) // 直接赋值,无 append + } + return strSlice +} diff --git a/internal/server/ding_talk_bot.go b/internal/server/ding_talk_bot.go index 9a63812..f68f543 100644 --- a/internal/server/ding_talk_bot.go +++ b/internal/server/ding_talk_bot.go @@ -62,10 +62,10 @@ func (d *DingTalkBotServer) Run(ctx context.Context, botIndex string) { } err := cli.Start(ctx) if err != nil { - log.Info("%s启动失败", name) + log.Infof("%s启动失败", name) continue } - log.Info("%s启动成功", name) + log.Infof("%s启动成功", name) } } func DingBotServerInit(clientId string, clientSecret string, service DingBotServiceInterface) (cli *client.StreamClient) { diff --git a/internal/services/dtalk_bot.go b/internal/services/dtalk_bot.go index 864f5d7..e4e9e8d 100644 --- a/internal/services/dtalk_bot.go +++ b/internal/services/dtalk_bot.go @@ -31,7 +31,7 @@ func (d *DingBotService) OnChatBotMessageReceived(ctx context.Context, data *cha return } go func() { - defer close(requireData.Ch) + //defer close(requireData.Ch) //if match, _err := d.dingTalkBotBiz.Recognize(ctx, data); _err != nil { // requireData.Ch <- entitys.Response{ // Type: entitys.ResponseEnd, diff --git a/internal/tools/calculator.go b/internal/tools/calculator.go deleted file mode 100644 index fecde64..0000000 --- a/internal/tools/calculator.go +++ /dev/null @@ -1,121 +0,0 @@ -package tools - -import ( - "ai_scheduler/internal/entitys" - - "context" - "encoding/json" - "fmt" - "math" -) - -// CalculatorTool 计算器工具 -type CalculatorTool struct{} - -// NewCalculatorTool 创建计算器工具 -func NewCalculatorTool() *CalculatorTool { - return &CalculatorTool{} -} - -// Name 返回工具名称 -func (c *CalculatorTool) Name() string { - return "calculate" -} - -// Description 返回工具描述 -func (c *CalculatorTool) Description() string { - return "执行基本的数学运算,支持加减乘除和幂运算" -} - -// Definition 返回工具定义 -func (c *CalculatorTool) Definition() entitys.ToolDefinition { - return entitys.ToolDefinition{ - Type: "function", - Function: entitys.FunctionDef{ - Name: c.Name(), - Description: c.Description(), - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "operation": map[string]interface{}{ - "type": "string", - "description": "运算类型", - "enum": []string{"add", "subtract", "multiply", "divide", "power"}, - }, - "a": map[string]interface{}{ - "type": "number", - "description": "第一个数字", - }, - "b": map[string]interface{}{ - "type": "number", - "description": "第二个数字", - }, - }, - "required": []string{"operation", "a", "b"}, - }, - }, - } -} - -// CalculateRequest 计算请求参数 -type CalculateRequest struct { - Operation string `json:"operation"` - A float64 `json:"a"` - B float64 `json:"b"` -} - -// CalculateResponse 计算响应 -type CalculateResponse struct { - Operation string `json:"operation"` - A float64 `json:"a"` - B float64 `json:"b"` - Result float64 `json:"result"` - Expression string `json:"expression"` -} - -// Execute 执行计算 -func (c *CalculatorTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { - var req CalculateRequest - if err := json.Unmarshal(args, &req); err != nil { - return nil, fmt.Errorf("invalid calculate request: %w", err) - } - - var result float64 - var expression string - - switch req.Operation { - case "add": - result = req.A + req.B - expression = fmt.Sprintf("%.2f + %.2f = %.2f", req.A, req.B, result) - case "subtract": - result = req.A - req.B - expression = fmt.Sprintf("%.2f - %.2f = %.2f", req.A, req.B, result) - case "multiply": - result = req.A * req.B - expression = fmt.Sprintf("%.2f × %.2f = %.2f", req.A, req.B, result) - case "divide": - if req.B == 0 { - return nil, fmt.Errorf("division by zero is not allowed") - } - result = req.A / req.B - expression = fmt.Sprintf("%.2f ÷ %.2f = %.2f", req.A, req.B, result) - case "power": - result = math.Pow(req.A, req.B) - expression = fmt.Sprintf("%.2f ^ %.2f = %.2f", req.A, req.B, result) - default: - return nil, fmt.Errorf("unsupported operation: %s", req.Operation) - } - - // 检查结果是否有效 - if math.IsNaN(result) || math.IsInf(result, 0) { - return nil, fmt.Errorf("calculation resulted in invalid number") - } - - return &CalculateResponse{ - Operation: req.Operation, - A: req.A, - B: req.B, - Result: result, - Expression: expression, - }, nil -} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 406d11c..cd1940d 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -4,6 +4,7 @@ import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/utils_ollama" + "ai_scheduler/internal/tools/public" zltxtool "ai_scheduler/internal/tools/zltx" "context" @@ -43,30 +44,30 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { // 注册直连天下订单详情工具 if config.Tools.ZltxOrderDetail.Enabled { - zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) + zltxOrderDetailTool := zltxtool.NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm) m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool } //注册直连天下订单日志工具 if config.Tools.ZltxOrderDirectLog.Enabled { - zltxOrderLogTool := NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog) + zltxOrderLogTool := zltxtool.NewZltxOrderLogTool(config.Tools.ZltxOrderDirectLog) m.tools[zltxOrderLogTool.Name()] = zltxOrderLogTool } //注册直连天下商品工具 if config.Tools.ZltxProduct.Enabled { - zltxProductTool := NewZltxProductTool(config.Tools.ZltxProduct) + zltxProductTool := zltxtool.NewZltxProductTool(config.Tools.ZltxProduct) m.tools[zltxProductTool.Name()] = zltxProductTool } //注册直连天下订单统计工具 if config.Tools.ZltxOrderStatistics.Enabled { - zltxOrderStatisticsTool := NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics) + zltxOrderStatisticsTool := zltxtool.NewZltxOrderStatisticsTool(config.Tools.ZltxOrderStatistics) m.tools[zltxOrderStatisticsTool.Name()] = zltxOrderStatisticsTool } // 注册知识库工具 if config.Tools.Knowledge.Enabled { - knowledgeTool := NewKnowledgeBaseTool(config.Tools.Knowledge) + knowledgeTool := public.NewKnowledgeBaseTool(config.Tools.Knowledge) m.tools[knowledgeTool.Name()] = knowledgeTool } @@ -85,9 +86,24 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { zltxOrderAfterSaleResellerBatchTool := zltxtool.NewOrderAfterSaleResellerBatchTool(config.Tools.ZltxOrderAfterSaleResellerBatch) m.tools[zltxOrderAfterSaleResellerBatchTool.Name()] = zltxOrderAfterSaleResellerBatchTool } + // 注册天气工具 + if config.Tools.Weather.Enabled { + weatherTool := public.NewWeatherTool(config.Tools.Weather) + m.tools[weatherTool.Name()] = weatherTool + } + // 注册 Coze 快递查询工具 + if config.Tools.CozeExpress.Enabled { + cozeTool := public.NewCozeExpress(config.Tools.CozeExpress, m.llm) + m.tools[cozeTool.Name()] = cozeTool + } + // 注册 Coze 公司查询工具 + if config.Tools.CozeCompany.Enabled { + cozeTool := public.NewCozeCompany(config.Tools.CozeCompany, m.llm) + m.tools[cozeTool.Name()] = cozeTool + } // 普通对话 - chat := NewNormalChatTool(m.llm, config) + chat := public.NewNormalChatTool(m.llm, config) m.tools[chat.Name()] = chat return m diff --git a/internal/tools/public/coze_company.go b/internal/tools/public/coze_company.go new file mode 100644 index 0000000..3f10638 --- /dev/null +++ b/internal/tools/public/coze_company.go @@ -0,0 +1,264 @@ +package public + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "encoding/json" + "fmt" + "github.com/ollama/ollama/api" + "net/http" + "time" + + "github.com/coze-dev/coze-go" +) + +type CozeCompany struct { + cozeApi coze.CozeAPI + config config.ToolConfig + llm *utils_ollama.Client +} + +// NewCoze 创建 Coze 实例 +func NewCozeCompany(config config.ToolConfig, llm *utils_ollama.Client) *CozeCompany { + return &CozeCompany{ + cozeApi: newCozeApi(config), + config: config, + llm: llm, + } +} + +// newCozeClient 创建 Coze 客户端 +func newCozeApi(config config.ToolConfig) coze.CozeAPI { + authCli := coze.NewTokenAuth(config.APISecret) + cozeApi := coze.NewCozeAPI(authCli, coze.WithBaseURL(config.BaseURL), coze.WithHttpClient(&http.Client{ + Timeout: time.Second * 120, + })) + return cozeApi +} + +// Name 返回工具名称 +func (c *CozeCompany) Name() string { + return "coze_company" +} + +// Description 返回工具描述 +func (c *CozeCompany) Description() string { + return "查询企业信息" +} + +// Definition 返回工具定义 +func (c *CozeCompany) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: c.Name(), + Description: c.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "company_name": map[string]interface{}{ + "type": "string", + "description": "企业名称", + }, + }, + "required": []string{"company_name"}, + }, + }, + } +} + +// Execute 执行查询 +func (c *CozeCompany) Execute(ctx context.Context, requireData *entitys.RequireData) error { + var req map[string]interface{} + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + return fmt.Errorf("invalid express request: %w", err) + } + + if req["company_name"] == "" { + return fmt.Errorf("company_name is required") + } + + // 调用 Coze 工作流 + rsp, err := c.callWorkflow(ctx, req) + if err != nil { + return fmt.Errorf("failed to get real weather: %w", err) + } + + companyInfo := CompanyInfo{} + err = json.Unmarshal([]byte(rsp.Data), &companyInfo) + if err != nil { + return fmt.Errorf("failed to unmarshal company info: %w", err) + } + + // 调用 LLM 模型 + err = c.llm.ChatStream(ctx, requireData.Ch, []api.Message{ + { + Role: "system", + Content: `# Role: 企业信息分析与经营诊断专家: +请基于以下12项指定数据字段(无需补充未提供的信息),完成目标企业的全维度分析总结,要求每部分结论必须100%锚定对应数据,拒绝主观推测,突出“风险可见性”与“关键信息关联性”: +一、输入数据清单(需逐一对应分析) +行政处罚:公司是否有行政处罚(含处罚事由、处罚机关、处罚日期、处罚文号) +清算信息:公司的清算信息(含清算原因、清算组构成、清算进展状态) +变更记录:公司的变更记录(含变更事项:注册资本/股东/法定代表人/经营范围、变更时间、变更前后内容) +主要成员:公司名搜索公司的主要成员(含姓名、职位、任职时间、核心履历关键词) +企业详情:公司名称,搜索公司详细信息(含成立时间、注册资本、经营范围、行业分类、注册地址) +经营异常:公司是否有经营异常(含列入原因、列入日期、移出状态) +破产重组:公司破产重组的信息(含申请法院、受理时间、重组方案核心内容) +严重违法:是否有严重违法信息(含违法事由、认定机关、公示期限) +司法信息:司法信息(含案件类型、原被告身份、判决结果/进展) +被执行信息:公司的被执行信息(含执行法院、执行标的、未履行金额、失信状态) +企业手机号:公司名称查询企业手机号(需标注是否为公开备案号) +股东信息:公司对应的股东(含股东名称、出资额、出资比例、股东类型:自然人/企业/机构) +二、分析总结框架(严格对应数据) +1. 企业基础画像(锚定公司名称,搜索公司详细信息) +核心属性:成立时间、注册资本(实缴/认缴)、行业分类(如“批发和零售业”“软件和信息技术服务业”)、主营业务(从经营范围中提炼1-2个核心赛道,如“专注于智能仓储设备研发与销售”); +注册地址:是否与主要经营地一致(若公司名称,搜索公司详细信息有披露)。 +2. 股权结构与股东特征(锚定公司对应的股东) +股权集中度:前三大股东出资占比之和(如“第一大股东持股60%,为绝对控股股东”); +股东类型分布:自然人股东、企业股东、机构股东的占比(如“70%为企业股东,含1家行业头部企业”); +关键股东亮点:若有知名企业/机构股东,需明确标注(如“股东含XX产业投资基金,具备产业链资源协同潜力”)。 +3. 经营稳定性与合规风险(锚定公司是否有行政处罚/公司是否有经营异常/是否有严重违法信息/公司的清算信息/公司破产重组的信息) +行政处罚风险:是否存在行政处罚? +若有:列清「处罚事由(如“虚假宣传”“税务逾期申报”)、处罚机关、处罚日期」,并判断“是否属于高频违规”(如1年内≥2次同类处罚→高风险); +若无:标注“无公开行政处罚记录”。 +经营异常风险:是否存在经营异常? +若有:说明「列入原因(如“通过登记的住所无法联系”“未公示年度报告”)、是否已移出」,并评估对业务的影响(如“地址失联可能导致客户信任下降”); +若无:标注“无经营异常记录”。 +严重违法红线:是否存在严重违法? +若有:明确「违法事由(如“拒不履行生效法律文书”“欺诈消费者”)、认定机关、公示期限」,并标注“触及监管红线,需重点核查整改情况”; +若无:标注“无严重违法记录”。 +极端状态预警:是否存在清算或破产重组? +若有清算:说明「清算原因(如“股东会决议解散”“营业期限届满”)、清算进展(如“清算组已完成资产清查”)」; +若有破产重组:说明「申请法院、受理时间、重组方案核心内容(如“拟引入战略投资者注资5000万”)」; +若无:标注“无清算或破产重组记录”。 +4. 司法与执行风险(锚定司法信息/公司的被执行信息) +涉诉情况:司法案件的核心特征(如“80%为买卖合同纠纷,20%为劳动争议仲裁”;“作为被告的案件占比75%”;“判决胜诉率约60%”); +被执行压力:是否存在被执行信息? +若有:列清「执行法院、执行标的金额、未履行金额、是否纳入失信被执行人名单」,并计算“未履行金额占注册资本的比例”(如“未履行800万,占注册资本的20%”); +若无:标注“无公开被执行记录”。 +5. 管理团队与联系信息(锚定公司名搜索公司的主要成员/公司名称查询企业手机号) +核心团队稳定性:主要成员的任职时间分布(如“CEO任职4年,CFO任职2年,核心团队平均任职3年”);是否有频繁变动(如“1年内2位高管离职”→需提示“管理稳定性风险”); +关键岗位资质:核心成员(如法定代表人、CEO、CFO)的过往履历亮点(如“CEO曾任职XX上市公司,主导过亿元级项目落地”); +联系信息可信度:企业手机号是否为公开备案(如“与工商登记预留电话一致→可信度高”;“为非备案号→需提示‘联系信息真实性存疑’”)。 +6. 综合结论与行动建议(整合所有数据) +整体风险评级:基于数据密度,给出“低风险/中风险/高风险”定性(示例:“无行政处罚、无被执行、仅1条普通合同纠纷→低风险”;“有失信被执行+严重违法→高风险”); +Top3核心风险:按“严重违法>被执行>破产重组>行政处罚>经营异常>司法纠纷”排序,列出最需关注的3个问题(如“1. 未履行金额占注册资本20%,存在债务违约风险;2. 1年内2次地址异常,经营稳定性弱;3. 自然人股东占比过高,决策易受个人因素影响”); + actionable 建议:针对每个核心风险,给出可执行的核查/应对动作(如“核查未履行金额对应的案件进展,评估企业偿债能力;要求企业提供近1年的地址证明,确认经营场所稳定性;穿透核查自然人股东的资产状况,降低决策风险”)。 +三、输出规则(强制遵守) +数据溯源:每句结论必须标注对应数据字段(如“根据公司是否有行政处罚,企业2023年因‘税务逾期申报’被区税务局处罚”); +量化优先:拒绝模糊表述(如不说“很多案件”,要说“近1年涉及5起买卖合同纠纷”); +风险分级:用“★”标注风险等级(★越多越严重,如“严重违法★★★”“被执行★★”“经营异常★”); +语言风格:专业简洁,避免冗余,适合风控/投资/合作前的快速决策阅读。 + +【示例输出片段】 +3. 经营稳定性与合规风险 + +行政处罚:根据公司是否有行政处罚,企业2022年11月因“发布虚假广告”被区市场监管局处以3万元罚款(文号:X市监罚字〔2022〕456号),无后续同类处罚→风险等级★; + +经营异常:根据公司是否有经营异常,企业2023年6月因“通过登记的住所无法联系”被列入异常名录,2023年12月已移出→风险等级★; + +严重违法:根据是否有严重违法信息,无公开严重违法记录→风险等级无; + +极端状态:根据公司的清算信息/公司破产重组的信息,无清算或破产重组记录→风险等级无。 + +6. 综合结论与行动建议 + +整体评级:低风险(仅1次轻微行政处罚,无重大合规瑕疵); + +Top3核心风险:1. 根据司法信息,近1年作为被告的合同纠纷占比75%,需警惕应收账款回收风险★★;2. 根据公司名搜索公司的主要成员,1年内2位销售总监离职,管理稳定性弱★;3. 根据公司名称查询企业手机号,企业手机号为非备案号,联系信息真实性存疑★; + +建议:核查合同纠纷案件的原告身份及回款情况,评估坏账概率;要求企业提供离职人员的交接说明,确认业务连续性;索要企业备案的联系方式,验证沟通有效性。 +`, + }, + + { + Role: "assistant", + Content: fmt.Sprintf(`请分析企业:%s, +公司是否有行政处罚:%v, +公司的清算信息:%v, +公司的变更记录:%v, +公司名搜索公司的主要成员:%v, +公司名称,搜索公司详细信息:%v, +公司是否有经营异常:%v, +公司破产重组的信息:%v, +是否有严重违法信息:%v, +司法信息:%v, +公司的被执行信息:%v, +公司名称查询企业手机号:%v, +公司对应的股东:%v`, req["company_name"], companyInfo.Xzcf, companyInfo.Clears, companyInfo.Changes, companyInfo.Employees, companyInfo.Searchdata, companyInfo.Operations, companyInfo.BankruptcyPublicList, companyInfo.Illegals, companyInfo.JudicialList, companyInfo.Executes, companyInfo.Phone, companyInfo.Partners), + }, + { + Role: "user", + Content: requireData.Req.Text, + }, + }, + c.Name(), "") + if err != nil { + return fmt.Errorf("failed to get express info: %w", err) + } + + //entitys.ResText(requireData.Ch, "", rsp.Data) + + return nil +} + +// CallWorkflow 调用 Coze 工作流 +// 参数: +// - ctx: 上下文,用于控制超时和取消 +// - workflowId: 工作流 ID +// - params: 工作流参数 +// 返回: +// - interface{}: 工作流执行结果 +// - error: 错误信息 +func (c *CozeCompany) callWorkflow(ctx context.Context, params map[string]interface{}) (*coze.RunWorkflowsResp, error) { + // 准备工作流请求参数 + workflowReq := &coze.RunWorkflowsReq{ + WorkflowID: c.config.APIKey, + Parameters: params, + } + + // 调用工作流 + resp, err := c.cozeApi.Workflows.Runs.Create(ctx, workflowReq) + if err != nil { + return nil, fmt.Errorf("工作流调用失败: %w", err) + } + + // 处理工作流响应 + if resp == nil { + return nil, fmt.Errorf("工作流响应为空") + } + + // 返回工作流执行结果 + return resp, nil +} + +type CompanyInfo struct { + BankruptcyPublicList interface{} `json:"bankruptcy_public_list"` // 破产公示列表 + Changes interface{} `json:"changes"` // 变更记录 + Clears interface{} `json:"clears"` // 清算记录 + Employees interface{} `json:"employees"` // 员工列表 + Executes interface{} `json:"executes"` // 执行记录 + Illegals interface{} `json:"illegals"` // 违法记录 + JudicialList interface{} `json:"judicial_list"` // 司法记录 + Operations interface{} `json:"operations"` // 经营记录 + Partners interface{} `json:"partners"` // 合伙人列表 + Phone string `json:"phone"` // 联系电话 + // 搜索数据 + Searchdata struct { + Authority interface{} `json:"authority"` + BusinessScope interface{} `json:"business_scope"` + Capital interface{} `json:"capital"` + CompanyAddress interface{} `json:"company_address"` + CompanyName string `json:"company_name"` + CompanyStatus interface{} `json:"company_status"` + CompanyType interface{} `json:"company_type"` + CreditNo interface{} `json:"credit_no"` + EstablishDate interface{} `json:"establish_date"` + Industry interface{} `json:"industry"` + LegalPerson interface{} `json:"legal_person"` + Province interface{} `json:"province"` + } `json:"searchdata"` + Xzcf interface{} `json:"xzcf"` // 行政处罚 +} diff --git a/internal/tools/public/coze_express.go b/internal/tools/public/coze_express.go new file mode 100644 index 0000000..de316b5 --- /dev/null +++ b/internal/tools/public/coze_express.go @@ -0,0 +1,140 @@ +package public + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/utils_ollama" + "context" + "encoding/json" + "fmt" + "github.com/ollama/ollama/api" + + "github.com/coze-dev/coze-go" +) + +type CozeExpress struct { + cozeApi coze.CozeAPI + config config.ToolConfig + llm *utils_ollama.Client +} + +// NewCozeExpress 创建 CozeExpress 实例 +func NewCozeExpress(config config.ToolConfig, llm *utils_ollama.Client) *CozeExpress { + return &CozeExpress{ + cozeApi: newCozeApi(config), + config: config, + llm: llm, + } +} + +// newCozeExpressClient 创建 CozeExpress 客户端 +func newCozeExpressApi(config config.ToolConfig) coze.CozeAPI { + authCli := coze.NewTokenAuth(config.APISecret) + cozeApi := coze.NewCozeAPI(authCli, coze.WithBaseURL(config.BaseURL)) + return cozeApi +} + +// Name 返回工具名称 +func (c *CozeExpress) Name() string { + return "coze_express" +} + +// Description 返回工具描述 +func (c *CozeExpress) Description() string { + return "查询快递物流信息" +} + +// Definition 返回工具定义 +func (c *CozeExpress) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: c.Name(), + Description: c.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "express_id": map[string]interface{}{ + "type": "string", + "description": "快递单号", + }, + }, + "required": []string{"express_id"}, + }, + }, + } +} + +// Execute 执行查询 +func (c *CozeExpress) Execute(ctx context.Context, requireData *entitys.RequireData) error { + var req map[string]interface{} + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + return fmt.Errorf("invalid express request: %w", err) + } + + if req["express_id"] == "" { + return fmt.Errorf("express_id is required") + } + + // 调用 Coze 工作流查询快递物流信息 + rsp, err := c.callWorkflow(ctx, req) + if err != nil { + return fmt.Errorf("failed to get real weather: %w", err) + } + err = c.llm.ChatStream(ctx, requireData.Ch, []api.Message{ + { + Role: "system", + Content: "你是一个快递查询助手。用户可能会提供快递单号,你需要分析快递单号,根据快递单号查询物流信息并反馈给我", + }, + { + Role: "assistant", + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(requireData.Histories)), + }, + { + Role: "assistant", + Content: fmt.Sprintf("需要分析的快递单号:%s", rsp.Data), + }, + { + Role: "user", + Content: requireData.Req.Text, + }, + }, c.Name(), "") + if err != nil { + return fmt.Errorf("failed to get express info: %w", err) + } + + //entitys.ResText(requireData.Ch, "", rsp.Data) + + return nil +} + +// CallWorkflow 调用 Coze 工作流 +// 参数: +// - ctx: 上下文,用于控制超时和取消 +// - workflowId: 工作流 ID +// - params: 工作流参数 +// 返回: +// - interface{}: 工作流执行结果 +// - error: 错误信息 +func (c *CozeExpress) callWorkflow(ctx context.Context, params map[string]interface{}) (*coze.RunWorkflowsResp, error) { + // 准备工作流请求参数 + workflowReq := &coze.RunWorkflowsReq{ + WorkflowID: c.config.APIKey, + Parameters: params, + } + + // 调用工作流 + resp, err := c.cozeApi.Workflows.Runs.Create(ctx, workflowReq) + if err != nil { + return nil, fmt.Errorf("工作流调用失败: %w", err) + } + + // 处理工作流响应 + if resp == nil { + return nil, fmt.Errorf("工作流响应为空") + } + + // 返回工作流执行结果 + return resp, nil +} diff --git a/internal/tools/konwledge_base.go b/internal/tools/public/konwledge_base.go similarity index 99% rename from internal/tools/konwledge_base.go rename to internal/tools/public/konwledge_base.go index dbbb5ce..505a349 100644 --- a/internal/tools/konwledge_base.go +++ b/internal/tools/public/konwledge_base.go @@ -1,4 +1,4 @@ -package tools +package public import ( "ai_scheduler/internal/config" diff --git a/internal/tools/konwledge_base_test.go b/internal/tools/public/konwledge_base_test.go similarity index 97% rename from internal/tools/konwledge_base_test.go rename to internal/tools/public/konwledge_base_test.go index de1ab3f..0838587 100644 --- a/internal/tools/konwledge_base_test.go +++ b/internal/tools/public/konwledge_base_test.go @@ -1,4 +1,4 @@ -package tools +package public import ( "testing" diff --git a/internal/tools/normal_chat.go b/internal/tools/public/normal_chat.go similarity index 99% rename from internal/tools/normal_chat.go rename to internal/tools/public/normal_chat.go index 56010fc..a4c96e0 100644 --- a/internal/tools/normal_chat.go +++ b/internal/tools/public/normal_chat.go @@ -1,4 +1,4 @@ -package tools +package public import ( "ai_scheduler/internal/config" diff --git a/internal/tools/public/weather.go b/internal/tools/public/weather.go new file mode 100644 index 0000000..0292fa4 --- /dev/null +++ b/internal/tools/public/weather.go @@ -0,0 +1,345 @@ +package public + +import ( + "ai_scheduler/internal/config" + "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/l_request" + "log" + "strconv" + + "context" + "encoding/json" + "fmt" + "time" +) + +// WeatherTool 天气查询工具 +type WeatherTool struct { + mockData bool + config config.ToolConfig +} + +// NewWeatherTool 创建天气工具 +func NewWeatherTool(config config.ToolConfig) *WeatherTool { + return &WeatherTool{config: config} +} + +// Name 返回工具名称 +func (w *WeatherTool) Name() string { + return "get_weather" +} + +// Description 返回工具描述 +func (w *WeatherTool) Description() string { + return "获取指定城市的天气信息" +} + +// Definition 返回工具定义 +func (w *WeatherTool) Definition() entitys.ToolDefinition { + return entitys.ToolDefinition{ + Type: "function", + Function: entitys.FunctionDef{ + Name: w.Name(), + Description: w.Description(), + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{ + "type": "string", + "description": "城市名称,如:北京、上海、广州", + }, + "unit": map[string]interface{}{ + "type": "string", + "description": "温度单位,celsius(摄氏度)或fahrenheit(华氏度)", + "enum": []string{"celsius", "fahrenheit"}, + "default": "celsius", + }, + "extensions": map[string]interface{}{ + "type": "string", + "description": "扩展参数,base/all base:返回实况天气 all:返回预报天气", + "enum": []string{"base", "all"}, + "default": "base", + }, + }, + "required": []string{"city"}, + }, + }, + } +} + +// WeatherRequest 天气请求参数 +type WeatherRequest struct { + City string `json:"city"` + Extensions string `json:"extensions"` // 扩展参数,base/all base:返回实况天气 all:返回预报天气 + Unit string `json:"unit,omitempty"` +} + +// WeatherResponse 天气响应 +type WeatherResponse struct { + City string `json:"city"` + Unit string `json:"unit"` + Timestamp string `json:"timestamp"` + LiveWeather *LiveWeather `json:"live_weather,omitempty"` // 实时天气 + Forecasts []ForecastWeather `json:"forecasts,omitempty"` // 预报天气 +} + +// ForecastWeather 预报天气 +type ForecastWeather struct { + Date string `json:"date"` + Week string `json:"week"` + DayWeather string `json:"day_weather"` + NightWeather string `json:"night_weather"` + DayTemp float64 `json:"day_temp"` + NightTemp float64 `json:"night_temp"` + DayWind string `json:"day_wind"` + NightWind string `json:"night_wind"` + DayWindPower string `json:"day_wind_power"` + NightWindPower string `json:"night_wind_power"` +} + +// LiveWeather 实时天气 +type LiveWeather struct { + Temperature float64 `json:"temperature"` + Condition string `json:"condition"` + Humidity int `json:"humidity"` + WindSpeed float64 `json:"wind_speed"` + WindDirection string `json:"wind_direction"` +} + +// Execute 执行天气查询 +func (w *WeatherTool) Execute(ctx context.Context, requireData *entitys.RequireData) error { + var req WeatherRequest + if err := json.Unmarshal([]byte(requireData.Match.Parameters), &req); err != nil { + return fmt.Errorf("invalid weather request: %w", err) + } + + if req.City == "" { + return fmt.Errorf("city is required") + } + + if req.Unit == "" { + req.Unit = "celsius" + } + + // 设置默认获取实时天气信息 + if req.Extensions == "" { + req.Extensions = "base" + } + + // 这里可以集成真实的天气API + responseMsg, err := w.getRealWeather(req) + if err != nil { + return fmt.Errorf("failed to get real weather: %w", err) + } + + // 根据 extensions 参数返回不同的天气信息 + if req.Extensions == "base" { + entitys.ResText(requireData.Ch, "", fmt.Sprintf("%s实时天气:%s,温度:%.1f℃,湿度:%d%%,风速:%.1fkm/h,风向:%s", + req.City, + responseMsg.LiveWeather.Condition, + responseMsg.LiveWeather.Temperature, + responseMsg.LiveWeather.Humidity, + responseMsg.LiveWeather.WindSpeed, + responseMsg.LiveWeather.WindDirection)) + } else { + + rspStr := fmt.Sprintf("%s天气预报:\n", req.City) + for _, forecast := range responseMsg.Forecasts { + rspStr += fmt.Sprintf("%s 温度:%.1f℃/%.1f℃ 风力:%s %s\n", + forecast.Date, forecast.DayTemp, forecast.NightTemp, forecast.DayWind, forecast.NightWind) + } + + entitys.ResText(requireData.Ch, "", rspStr) + } + return nil +} + +//// getMockWeather 获取模拟天气数据 +//func (w *WeatherTool) getMockWeather(city, unit string) *WeatherResponse { +// rand.Seed(time.Now().UnixNano()) +// +// // 模拟不同城市的基础温度 +// baseTemp := map[string]float64{ +// "北京": 15.0, +// "上海": 18.0, +// "广州": 25.0, +// "深圳": 26.0, +// "杭州": 17.0, +// "成都": 16.0, +// } +// +// temp := baseTemp[city] +// if temp == 0 { +// temp = 20.0 // 默认温度 +// } +// +// // 添加随机变化 +// temp += (rand.Float64() - 0.5) * 10 +// +// // 转换温度单位 +// if unit == "fahrenheit" { +// temp = temp*9/5 + 32 +// } +// +// conditions := []string{"晴朗", "多云", "阴天", "小雨", "中雨"} +// condition := conditions[rand.Intn(len(conditions))] +// +// return &WeatherResponse{ +// City: city, +// Temperature: float64(int(temp*10)) / 10, // 保留一位小数 +// Unit: unit, +// Condition: condition, +// Humidity: rand.Intn(40) + 40, // 40-80% +// WindSpeed: float64(rand.Intn(20)) + 1.0, +// Timestamp: time.Now().Format("2006-01-02 15:04:05"), +// } +//} + +// getRealWeather 调用高德天气API +func (w *WeatherTool) getRealWeather(request WeatherRequest) (*WeatherResponse, error) { + // 构建请求URL + req := l_request.Request{ + Url: w.config.BaseURL, + Headers: map[string]string{}, + Params: map[string]string{ + "city": request.City, // 城市名称 + "key": w.config.APIKey, // API密钥 + // extensions: 基础天气数据 可选值:base/all base:返回实况天气 all:返回预报天气 + "extensions": request.Extensions, // 基础天气数据 + "output": "JSON", // JSON格式返回 + }, + Method: "GET", + } + res, err := req.Send() + if err != nil { + return nil, err + } + + // 解析API响应 + var apiResp struct { + Status string `json:"status"` + Count string `json:"count"` + Info string `json:"info"` + Infocode string `json:"infocode"` + + // 预报天气信息数据 + Forecasts []struct { + City string `json:"city"` + Adcode string `json:"adcode"` + Province string `json:"province"` + Reporttime string `json:"reporttime"` + Casts []struct { + Date string `json:"date"` + Week string `json:"week"` + Dayweather string `json:"dayweather"` + Nightweather string `json:"nightweather"` + Daytemp string `json:"daytemp"` + Nighttemp string `json:"nighttemp"` + Daywind string `json:"daywind"` + Nightwind string `json:"nightwind"` + Daypower string `json:"daypower"` + Nightpower string `json:"nightpower"` + DaytempFloat string `json:"daytemp_float"` + NighttempFloat string `json:"nighttemp_float"` + } `json:"casts"` + } `json:"forecasts"` + // 实况天气信息数据 + Lives []struct { + Province string `json:"province"` + City string `json:"city"` + Adcode string `json:"adcode"` + Weather string `json:"weather"` + Temperature string `json:"temperature"` + Winddirection string `json:"winddirection"` + Windpower string `json:"windpower"` + Humidity string `json:"humidity"` + Reporttime string `json:"reporttime"` + TemperatureFloat string `json:"temperature_float"` + HumidityFloat string `json:"humidity_float"` + } `json:"lives"` + } + + log.Printf("weather API response: %s", string(res.Content)) + + if err = json.Unmarshal(res.Content, &apiResp); err != nil { + return nil, fmt.Errorf("parse weather API response failed: %w", err) + } + + // 检查API返回状态 + if apiResp.Status != "1" { + return nil, fmt.Errorf("weather API returned error: %s, info: %s", apiResp.Status, apiResp.Info) + } + + // 获取城市名称 + cityName := "" + if len(apiResp.Lives) > 0 { + cityName = apiResp.Lives[0].City + } else if len(apiResp.Forecasts) > 0 { + cityName = apiResp.Forecasts[0].City + } else { + return nil, fmt.Errorf("no weather data found") + } + + // 构建响应 + response := &WeatherResponse{ + City: cityName, + Unit: request.Unit, + Timestamp: time.Now().Format("2006-01-02 15:04:05"), + } + // 处理实时天气 + if len(apiResp.Lives) > 0 { + liveData := apiResp.Lives[0] + + // 转换温度 + temp, _ := strconv.ParseFloat(liveData.Temperature, 64) + if request.Unit == "fahrenheit" { + temp = temp*9/5 + 32 + } + + // 转换湿度和风速 + humidity, _ := strconv.Atoi(liveData.Humidity) + windSpeed, _ := strconv.ParseFloat(liveData.Windpower, 64) + + response.LiveWeather = &LiveWeather{ + Temperature: temp, + Condition: liveData.Weather, + Humidity: humidity, + WindSpeed: windSpeed, + WindDirection: liveData.Winddirection, + } + } + + // 处理预报天气 + if len(apiResp.Forecasts) > 0 && len(apiResp.Forecasts[0].Casts) > 0 { + response.Forecasts = make([]ForecastWeather, 0, len(apiResp.Forecasts[0].Casts)) + + for _, cast := range apiResp.Forecasts[0].Casts { + // 转换温度 + dayTemp, _ := strconv.ParseFloat(cast.Daytemp, 64) + nightTemp, _ := strconv.ParseFloat(cast.Nighttemp, 64) + + if request.Unit == "fahrenheit" { + dayTemp = dayTemp*9/5 + 32 + nightTemp = nightTemp*9/5 + 32 + } + + forecast := ForecastWeather{ + Date: cast.Date, + Week: cast.Week, + DayWeather: cast.Dayweather, + NightWeather: cast.Nightweather, + DayTemp: dayTemp, + NightTemp: nightTemp, + DayWind: cast.Daywind, + NightWind: cast.Nightwind, + DayWindPower: cast.Daypower, + NightWindPower: cast.Nightpower, + } + + response.Forecasts = append(response.Forecasts, forecast) + } + } + + return response, nil + +} diff --git a/internal/tools/weather.go b/internal/tools/weather.go deleted file mode 100644 index 744216f..0000000 --- a/internal/tools/weather.go +++ /dev/null @@ -1,139 +0,0 @@ -package tools - -import ( - "ai_scheduler/internal/entitys" - - "context" - "encoding/json" - "fmt" - "math/rand" - "time" -) - -// WeatherTool 天气查询工具 -type WeatherTool struct { - mockData bool -} - -// NewWeatherTool 创建天气工具 -func NewWeatherTool() *WeatherTool { - return &WeatherTool{} -} - -// Name 返回工具名称 -func (w *WeatherTool) Name() string { - return "get_weather" -} - -// Description 返回工具描述 -func (w *WeatherTool) Description() string { - return "获取指定城市的天气信息" -} - -// Definition 返回工具定义 -func (w *WeatherTool) Definition() entitys.ToolDefinition { - return entitys.ToolDefinition{ - Type: "function", - Function: entitys.FunctionDef{ - Name: w.Name(), - Description: w.Description(), - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "city": map[string]interface{}{ - "type": "string", - "description": "城市名称,如:北京、上海、广州", - }, - "unit": map[string]interface{}{ - "type": "string", - "description": "温度单位,celsius(摄氏度)或fahrenheit(华氏度)", - "enum": []string{"celsius", "fahrenheit"}, - "default": "celsius", - }, - }, - "required": []string{"city"}, - }, - }, - } -} - -// WeatherRequest 天气请求参数 -type WeatherRequest struct { - City string `json:"city"` - Unit string `json:"unit,omitempty"` -} - -// WeatherResponse 天气响应 -type WeatherResponse struct { - City string `json:"city"` - Temperature float64 `json:"temperature"` - Unit string `json:"unit"` - Condition string `json:"condition"` - Humidity int `json:"humidity"` - WindSpeed float64 `json:"wind_speed"` - Timestamp string `json:"timestamp"` -} - -// Execute 执行天气查询 -func (w *WeatherTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { - var req WeatherRequest - if err := json.Unmarshal(args, &req); err != nil { - return nil, fmt.Errorf("invalid weather request: %w", err) - } - - if req.City == "" { - return nil, fmt.Errorf("city is required") - } - - if req.Unit == "" { - req.Unit = "celsius" - } - - if w.mockData { - return w.getMockWeather(req.City, req.Unit), nil - } - - // 这里可以集成真实的天气API - return w.getMockWeather(req.City, req.Unit), nil -} - -// getMockWeather 获取模拟天气数据 -func (w *WeatherTool) getMockWeather(city, unit string) *WeatherResponse { - rand.Seed(time.Now().UnixNano()) - - // 模拟不同城市的基础温度 - baseTemp := map[string]float64{ - "北京": 15.0, - "上海": 18.0, - "广州": 25.0, - "深圳": 26.0, - "杭州": 17.0, - "成都": 16.0, - } - - temp := baseTemp[city] - if temp == 0 { - temp = 20.0 // 默认温度 - } - - // 添加随机变化 - temp += (rand.Float64() - 0.5) * 10 - - // 转换温度单位 - if unit == "fahrenheit" { - temp = temp*9/5 + 32 - } - - conditions := []string{"晴朗", "多云", "阴天", "小雨", "中雨"} - condition := conditions[rand.Intn(len(conditions))] - - return &WeatherResponse{ - City: city, - Temperature: float64(int(temp*10)) / 10, // 保留一位小数 - Unit: unit, - Condition: condition, - Humidity: rand.Intn(40) + 40, // 40-80% - WindSpeed: float64(rand.Intn(20)) + 1.0, - Timestamp: time.Now().Format("2006-01-02 15:04:05"), - } -} diff --git a/internal/tools/zltx_after_direct.go b/internal/tools/zltx/zltx_after_direct.go similarity index 99% rename from internal/tools/zltx_after_direct.go rename to internal/tools/zltx/zltx_after_direct.go index 5f1e20e..01a6282 100644 --- a/internal/tools/zltx_after_direct.go +++ b/internal/tools/zltx/zltx_after_direct.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_after_pre.go b/internal/tools/zltx/zltx_after_pre.go similarity index 99% rename from internal/tools/zltx_after_pre.go rename to internal/tools/zltx/zltx_after_pre.go index b44812b..5bbf809 100644 --- a/internal/tools/zltx_after_pre.go +++ b/internal/tools/zltx/zltx_after_pre.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx/zltx_order_detail.go similarity index 99% rename from internal/tools/zltx_order_detail.go rename to internal/tools/zltx/zltx_order_detail.go index f253f59..7dae9f1 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx/zltx_order_detail.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_order_direct_log.go b/internal/tools/zltx/zltx_order_direct_log.go similarity index 99% rename from internal/tools/zltx_order_direct_log.go rename to internal/tools/zltx/zltx_order_direct_log.go index bf82408..b1b1483 100644 --- a/internal/tools/zltx_order_direct_log.go +++ b/internal/tools/zltx/zltx_order_direct_log.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_product.go b/internal/tools/zltx/zltx_product.go similarity index 99% rename from internal/tools/zltx_product.go rename to internal/tools/zltx/zltx_product.go index 61b8bb1..de24236 100644 --- a/internal/tools/zltx_product.go +++ b/internal/tools/zltx/zltx_product.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/internal/tools/zltx_statistics.go b/internal/tools/zltx/zltx_statistics.go similarity index 99% rename from internal/tools/zltx_statistics.go rename to internal/tools/zltx/zltx_statistics.go index f53d4b1..4a051a8 100644 --- a/internal/tools/zltx_statistics.go +++ b/internal/tools/zltx/zltx_statistics.go @@ -1,4 +1,4 @@ -package tools +package zltx import ( "ai_scheduler/internal/config" diff --git a/utils/rds.go b/utils/rds.go index 57b89f1..874eb3e 100644 --- a/utils/rds.go +++ b/utils/rds.go @@ -13,10 +13,10 @@ type Rdb struct { var rdb *Rdb -func NewRdb(c *config.Redis) *Rdb { +func NewRdb(c *config.Config) *Rdb { if rdb == nil { //构建 redis - rdbBuild := buildRdb(c) + rdbBuild := buildRdb(&c.Redis) //退出时清理资源 rdb = &Rdb{Rdb: rdbBuild} }