ai_scheduler/internal/biz/group_config.go

637 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package biz
import (
"ai_scheduler/internal/biz/tools_regis"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/domain/workflow/recharge"
"ai_scheduler/internal/domain/workflow/runtime"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/dingtalk"
"ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/lsxd"
"ai_scheduler/internal/pkg/utils_oss"
"ai_scheduler/internal/tools"
"ai_scheduler/internal/tools/bbxt"
"ai_scheduler/utils"
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/coze-dev/coze-go"
"github.com/gofiber/fiber/v2/log"
"xorm.io/builder"
)
// AiRouterBiz 智能路由服务
type GroupConfigBiz struct {
botGroupConfigImpl *impl.BotGroupConfigImpl
reportDailyCacheImpl *impl.ReportDailyCacheImpl
ossClient *utils_oss.Client
workflowManager *runtime.Registry
botTools []model.AiBotTool
toolManager *tools.Manager
conf *config.Config
rdb *utils.Rdb
dingtalkOauth2Client *dingtalk.Oauth2Client
dingtalkRobotClient *dingtalk.RobotClient
}
// NewDingTalkBotBiz
func NewGroupConfigBiz(
tools *tools_regis.ToolRegis,
ossClient *utils_oss.Client,
botGroupConfigImpl *impl.BotGroupConfigImpl,
workflowManager *runtime.Registry,
conf *config.Config,
reportDailyCacheImpl *impl.ReportDailyCacheImpl,
rdb *utils.Rdb,
toolManager *tools.Manager,
dingtalkOauth2Client *dingtalk.Oauth2Client,
dingtalkRobotClient *dingtalk.RobotClient,
) *GroupConfigBiz {
return &GroupConfigBiz{
botTools: tools.BootTools,
ossClient: ossClient,
botGroupConfigImpl: botGroupConfigImpl,
workflowManager: workflowManager,
conf: conf,
reportDailyCacheImpl: reportDailyCacheImpl,
rdb: rdb,
toolManager: toolManager,
dingtalkOauth2Client: dingtalkOauth2Client,
dingtalkRobotClient: dingtalkRobotClient,
}
}
func (g *GroupConfigBiz) GetGroupConfig(ctx context.Context, configId int32) (*model.AiBotGroupConfig, error) {
var groupConfig model.AiBotGroupConfig
cond := builder.NewCond()
cond = cond.And(builder.Eq{"config_id": configId})
err := g.botGroupConfigImpl.GetOneBySearchToStrut(&cond, &groupConfig)
return &groupConfig, err
}
func (g *GroupConfigBiz) GetReportLists(ctx context.Context, groupConfig *model.AiBotGroupConfig) (reports []*bbxt.ReportRes, err error) {
if groupConfig == nil {
return
}
var product []string
if groupConfig.ProductName != "" {
product = strings.Split(groupConfig.ProductName, ",")
}
reportList, err := bbxt.NewBbxtTools(g.conf, lsxd.NewLogin(g.conf, g.rdb))
if err != nil {
return
}
reports, err = reportList.DailyReport(ctx, time.Now(), bbxt.DownWardValue, product, bbxt.SumFilter, g.ossClient, g.GetReportCache)
if err != nil {
return
}
//追加电商充值系统统计 - 返回统一使用[]*bbxt.ReportRes
rechargeReports, err := g.rechargeDailyReport(ctx, time.Now(), nil, g.ossClient)
if err != nil || len(rechargeReports) == 0 {
return
}
reports = append(reports, rechargeReports...)
return
}
// rechargeDailyReport 获取电商充值系统统计报告
func (g *GroupConfigBiz) rechargeDailyReport(ctx context.Context, now time.Time, productNames []string, ossClient *utils_oss.Client) (reports []*bbxt.ReportRes, err error) {
defer func() {
if err := recover(); err != nil {
log.Error(err)
}
}()
workflowId := recharge.WorkflowIDStatisticsOursProduct
args := &runtime.WorkflowArgs{
Args: map[string]any{
"product_names": productNames,
"now": now,
},
}
res, err := g.workflowManager.Invoke(ctx, workflowId, args)
if err != nil {
return
}
log.Infof("imgUrl: %s", res["url"].(string))
reports = []*bbxt.ReportRes{
{
ReportName: "我们的商品统计(电商充值系统)",
Title: res["title"].(string),
Path: res["path"].(string),
Url: res["url"].(string),
Data: res["data"].([][]string),
},
}
return
}
func (g *GroupConfigBiz) handleReport(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool, groupConfig *model.AiBotGroupConfig) error {
var configData entitys.ConfigDataReport
err := json.Unmarshal([]byte(rec.Match.Parameters), &configData)
if err != nil {
return err
}
t, err := time.Parse(time.DateTime, configData.Time)
if err != nil {
t, err = time.Parse("2006-01-02 15:04", configData.Time)
if err != nil {
t, err = time.Parse("2006-01-02", configData.Time)
if err != nil {
log.Infof("时间识别失败:%s", configData.Time)
entitys.ResText(rec.Ch, "", "时间识别失败了可以给我一份比较具体的时间吗例如“2025-12-31 12:00,抱歉抱歉😀")
}
}
}
rep, err := bbxt.NewBbxtTools(g.conf, lsxd.NewLogin(g.conf, g.rdb))
uploader := bbxt.NewUploader(g.ossClient, g.conf)
if err != nil {
return err
}
var reports []*bbxt.ReportRes
switch rec.Match.Index {
case "report_loss_analysis":
repo, _err := rep.StatisOursProductLossSum(ctx, t, g.GetReportCache)
if _err != nil {
return _err
}
reports = append(reports, repo...)
case "report_sales_analysis":
product := strings.Split(groupConfig.ProductName, ",")
repo, _err := rep.GetStatisOfficialProductSum(t, product)
if _err != nil {
return _err
}
reports = append(reports, repo)
case "report_ranking_of_distributors":
repo, _err := rep.GetProfitRankingSum(t)
if _err != nil {
return _err
}
reports = append(reports, repo)
case "report_daily":
product := strings.Split(groupConfig.ProductName, ",")
repo, _err := rep.DailyReport(ctx, t, bbxt.DownWardValue, product, bbxt.SumFilter, nil, g.GetReportCache)
if _err != nil {
return _err
}
rechargeReport, _err := g.rechargeDailyReport(ctx, t, product, nil)
if _err != nil {
return _err
}
reports = append(reports, repo...)
reports = append(reports, rechargeReport...)
case "report_daily_recharge":
product := strings.Split(groupConfig.ProductName, ",")
repo, _err := g.rechargeDailyReport(ctx, t, product, nil)
if _err != nil || len(repo) == 0 {
return _err
}
reports = append(reports, repo...)
case "report_sale_down_analysis":
product := strings.Split(groupConfig.ProductName, ",")
repo, _err := rep.GetStatisOfficialProductSumDecline(t, bbxt.DownWardValue, product, bbxt.SumFilter)
if _err != nil {
return _err
}
reports = append(reports, repo)
default:
return fmt.Errorf("未找到的报表:%s", rec.Match.Index)
}
for _, report := range reports {
if report == nil {
continue
}
err = uploader.Run(report)
if err != nil {
log.Error(err)
continue
}
entitys.ResText(rec.Ch, "", fmt.Sprintf("%s![图片](%s)", report.Title, report.Url))
}
return nil
}
func (g *GroupConfigBiz) handleMatch(ctx context.Context, rec *entitys.Recognize, groupConfig *model.AiBotGroupConfig) (err error) {
if !rec.Match.IsMatch {
if len(rec.Match.Chat) != 0 {
entitys.ResText(rec.Ch, "", rec.Match.Chat)
} else {
entitys.ResText(rec.Ch, "", rec.Match.Reasoning)
}
return
}
var pointTask *model.AiBotTool
for _, task := range g.botTools {
if task.Index == rec.Match.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return g.otherTask(ctx, rec)
}
switch constants.TaskType(pointTask.Type) {
case constants.TaskTypeFunc:
return g.handleTask(ctx, rec, pointTask)
case constants.TaskTypeReport:
return g.handleReport(ctx, rec, pointTask, groupConfig)
case constants.TaskTypeCozeWorkflow:
return g.handleCozeWorkflow(ctx, rec, pointTask)
case constants.TaskTypeKnowle: // 知识库V2版本
return g.handleKnowledgeV2(ctx, rec, pointTask)
// return g.handleKnowledgeV3(ctx, rec, pointTask)
default:
return g.otherTask(ctx, rec)
}
return
}
func (g *GroupConfigBiz) getGroupTools(ctx context.Context, groupConfig *model.AiBotGroupConfig) (tools []model.AiBotTool, err error) {
if len(g.botTools) == 0 {
return
}
var (
groupRegisTools = make(map[int]struct{})
)
if groupConfig.ToolList != "" {
groupToolList := strings.Split(groupConfig.ToolList, ",")
for _, tool := range groupToolList {
if tool == "" {
continue
}
num, _err := strconv.Atoi(tool)
if _err != nil {
continue
}
groupRegisTools[num] = struct{}{}
}
}
for _, v := range g.botTools {
if v.PermissionType == constants.PermissionTypeNone {
tools = append(tools, v)
continue
}
if _, ex := groupRegisTools[int(v.ToolID)]; ex {
tools = append(tools, v)
}
}
return
}
func (q *GroupConfigBiz) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = q.toolManager.ExecuteTool(ctx, configData.Tool, rec)
if err != nil {
return
}
return
}
func (g *GroupConfigBiz) handleCozeWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiBotTool) (err error) {
entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流(coze)\n")
customClient := &http.Client{
Timeout: time.Minute * 30,
}
authCli := coze.NewTokenAuth(g.conf.Coze.ApiSecret)
cozeCli := coze.NewCozeAPI(
authCli,
coze.WithBaseURL(g.conf.Coze.BaseURL),
coze.WithHttpClient(customClient),
)
// 从参数中获取workflowID
type requestParams struct {
Request l_request.Request `json:"request"`
}
var config requestParams
err = json.Unmarshal([]byte(task.Config), &config)
if err != nil {
return err
}
workflowId, ok := config.Request.Json["workflow_id"].(string)
if !ok {
return fmt.Errorf("workflow_id不能为空")
}
// 提取参数
var data map[string]interface{}
err = json.Unmarshal([]byte(rec.Match.Parameters), &data)
req := &coze.RunWorkflowsReq{
WorkflowID: workflowId,
Parameters: data,
// IsAsync: true,
}
stream := config.Request.Json["stream"].(bool)
entitys.ResLog(rec.Ch, task.Index, "工作流执行中...")
if stream {
streamResp, err := cozeCli.Workflows.Runs.Stream(ctx, req)
if err != nil {
return err
}
g.handleCozeWorkflowEvents(ctx, streamResp, cozeCli, workflowId, rec.Ch, task.Index)
} else {
resp, err := cozeCli.Workflows.Runs.Create(ctx, req)
if err != nil {
return err
}
entitys.ResJson(rec.Ch, task.Index, resp.Data)
}
return
}
// handleCozeWorkflowEvents 处理 coze 工作流事件
func (g *GroupConfigBiz) handleCozeWorkflowEvents(ctx context.Context, resp coze.Stream[coze.WorkflowEvent], cozeCli coze.CozeAPI, workflowID string, ch chan entitys.Response, index string) {
defer resp.Close()
for {
event, err := resp.Recv()
if errors.Is(err, io.EOF) {
fmt.Println("Stream finished")
break
}
if err != nil {
fmt.Println("Error receiving event:", err)
break
}
switch event.Event {
case coze.WorkflowEventTypeMessage:
entitys.ResStream(ch, index, event.Message.Content)
case coze.WorkflowEventTypeError:
entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %v", event.Error))
case coze.WorkflowEventTypeDone:
entitys.ResEnd(ch, index, "工作流执行完成")
case coze.WorkflowEventTypeInterrupt:
resumeReq := &coze.ResumeRunWorkflowsReq{
WorkflowID: workflowID,
EventID: event.Interrupt.InterruptData.EventID,
ResumeData: "your data",
InterruptType: event.Interrupt.InterruptData.Type,
}
newResp, err := cozeCli.Workflows.Runs.Resume(ctx, resumeReq)
if err != nil {
entitys.ResError(ch, index, fmt.Sprintf("工作流恢复执行错误: %s", err.Error()))
return
}
entitys.ResLog(ch, index, "工作流恢复执行中...")
g.handleCozeWorkflowEvents(ctx, newResp, cozeCli, workflowID, ch, index)
}
}
fmt.Printf("done, log:%s\n", resp.Response().LogID())
}
func (g *GroupConfigBiz) otherTask(ctx context.Context, rec *entitys.Recognize) (err error) {
entitys.ResText(rec.Ch, "", rec.Match.Reasoning)
return
}
func (g *GroupConfigBiz) GetReportCache(ctx context.Context, day time.Time, totalDetail []*bbxt.ResellerLoss, bbxtObj *bbxt.BbxtTools) error {
var ResellerProductRelation map[int32]*bbxt.ResellerLossSumProductRelation
dayDate := day.Format(time.DateOnly)
cond := builder.NewCond()
cond = cond.And(builder.Eq{"cache_index": bbxt.IndexLossSumDetail})
cond = cond.And(builder.Eq{"cache_key": dayDate})
var cache model.AiReportDailyCache
err := g.reportDailyCacheImpl.GetOneBySearchToStrut(&cond, &cache)
if err != nil {
return err
}
if cache.ID == 0 {
ResellerProductRelation, err = bbxtObj.GetResellerLossMannagerAndLossReasonFromApi(ctx, totalDetail)
if err != nil {
return err
}
cache = model.AiReportDailyCache{
CacheKey: dayDate,
CacheIndex: bbxt.IndexLossSumDetail,
Value: pkg.JsonStringIgonErr(ResellerProductRelation),
}
_, err = g.reportDailyCacheImpl.Add(&cache)
} else {
err = json.Unmarshal([]byte(cache.Value), &ResellerProductRelation)
}
if err != nil {
return err
}
for _, v := range totalDetail {
if _, ex := ResellerProductRelation[v.ResellerId]; !ex {
continue
}
v.Manager = ResellerProductRelation[v.ResellerId].AfterSaleName
for _, vv := range v.ProductLoss {
if _, ex := ResellerProductRelation[v.ResellerId].Products[vv.ProductId]; !ex {
continue
}
vv.LossReason = ResellerProductRelation[v.ResellerId].Products[vv.ProductId].LossReason
}
}
return nil
}
// handleKnowledgeV2 处理知识库V2版本
func (g *GroupConfigBiz) handleKnowledgeV2(ctx context.Context, rec *entitys.Recognize, pointTask *model.AiBotTool) (err error) {
req := l_request.Request{
Method: "POST",
Url: "http://127.0.0.1:9600/query",
Headers: map[string]string{
"Content-Type": "application/json",
"X-Tenant-ID": "default",
},
Json: map[string]interface{}{
"query": rec.UserContent.Text,
"mode": "naive",
"stream": true,
"think": false,
},
}
resp, err := req.SendNoParseResponse()
if err != nil {
return fmt.Errorf("请求失败err: %v", err)
}
defer resp.Body.Close()
isRetrieved, err := g.connectAndReadSSE(resp, rec.Ch, true)
if err != nil {
return
}
// 未检索到匹配信息,询问是否拉群
if !isRetrieved {
// 获取dingtalk accessToken
accessToken, _ := g.dingtalkOauth2Client.GetAccessToken()
// 发送钉钉卡片
_, err = g.dingtalkRobotClient.SendGroupMessages(accessToken, rec.UserContent.Text)
if err != nil {
return fmt.Errorf("发送钉钉卡片失败err: %v", err)
}
// entitys.ResStream(rec.Ch, "", fmt.Sprintf("已发送卡片查询ID: %s", queryKey))
return
}
return
}
// 连接 SSE 并读取数据
// event: thinking
// data: {"text": "1. 上下文检索中...\n"}
// event: answer
// data: {"text": "根据"}
func (g *GroupConfigBiz) connectAndReadSSE(resp *http.Response, channel chan entitys.Response, useParagraphMode bool) (isRetrieved bool, err error) {
scanner := bufio.NewScanner(resp.Body)
var buffer strings.Builder
for scanner.Scan() {
line := scanner.Text()
// 解析event行
if strings.HasPrefix(line, "event:") {
eventStr := strings.TrimSpace(strings.TrimPrefix(line, "event:"))
if eventStr == "" {
continue
}
// thinking不输出
if eventStr == "thinking" {
continue
}
// system 事件输出
if eventStr == "system" {
// 未检索到,直接返回
dataStr := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if dataStr != "retrieved" {
return false, nil
}
continue
}
continue
}
// 解析 data 行
if strings.HasPrefix(line, "data:") {
dataStr := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if dataStr == "" {
continue
}
var data struct {
Text string `json:"text"`
}
if err := json.Unmarshal([]byte(dataStr), &data); err != nil {
log.Errorf("SSE数据解析失败: %v body: %s", err, dataStr)
continue
}
if data.Text != "" {
if useParagraphMode {
// 存入缓冲区
buffer.WriteString(data.Text)
content := buffer.String()
// 检查是否有换行符,按段落输出
if idx := strings.LastIndex(content, "\n"); idx != -1 {
// 发送直到最后一个换行符的内容
toSend := content[:idx+1]
entitys.ResStream(channel, "", toSend)
// 重置缓冲区,保留剩余部分
remaining := content[idx+1:]
buffer.Reset()
buffer.WriteString(remaining)
}
} else {
// 逐字输出模式:直接发送
entitys.ResStream(channel, "", data.Text)
}
}
}
}
if err := scanner.Err(); err != nil {
return true, fmt.Errorf("读取SSE流中断: %w", err)
}
// 发送缓冲区剩余内容(仅在段落模式下需要)
if useParagraphMode && buffer.Len() > 0 {
entitys.ResStream(channel, "", buffer.String())
}
return true, nil
}
// handleKnowledgeV3 处理知识库V3同步版本
func (g *GroupConfigBiz) handleKnowledgeV3(ctx context.Context, rec *entitys.Recognize, pointTask *model.AiBotTool) (err error) {
req := l_request.Request{
Method: "POST",
Url: "http://127.0.0.1:9600/query",
Headers: map[string]string{
"Content-Type": "application/json",
"X-Tenant-ID": "default",
},
Json: map[string]interface{}{
"query": rec.UserContent.Text,
"mode": "naive",
"stream": false,
"think": false,
},
}
resp, err := req.Send()
if err != nil {
return fmt.Errorf("请求失败err: %v", err)
}
obj := make(map[string]string)
if err := json.Unmarshal([]byte(resp.Text), &obj); err != nil {
return fmt.Errorf("解析响应失败err: %v", err)
}
entitys.ResText(rec.Ch, "", obj["response"])
return
}