ai_scheduler/internal/biz/do/handle.go

709 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package do
import (
"ai_scheduler/internal/biz/llm_service"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errorcode "ai_scheduler/internal/data/error"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/domain/tools/common/knowledge_base"
"ai_scheduler/internal/domain/workflow/runtime"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"ai_scheduler/internal/pkg/dingtalk"
"ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/rec_extra"
"ai_scheduler/internal/pkg/util"
"ai_scheduler/internal/tools"
"bufio"
errorsSpecial "errors"
"io"
"net/http"
"time"
"context"
"encoding/json"
"fmt"
"strings"
"github.com/coze-dev/coze-go"
"github.com/gofiber/fiber/v2/log"
"github.com/ollama/ollama/api"
"gorm.io/gorm/utils"
)
type Handle struct {
Ollama *llm_service.OllamaService
toolManager *tools.Manager
conf *config.Config
sessionImpl *impl.SessionImpl
workflowManager *runtime.Registry
dingtalkOldClient *dingtalk.OldClient
dingtalkContactClient *dingtalk.ContactClient
dingtalkNotableClient *dingtalk.NotableClient
}
func NewHandle(
Ollama *llm_service.OllamaService,
toolManager *tools.Manager,
conf *config.Config,
sessionImpl *impl.SessionImpl,
workflowManager *runtime.Registry,
dingtalkOldClient *dingtalk.OldClient,
dingtalkContactClient *dingtalk.ContactClient,
dingtalkNotableClient *dingtalk.NotableClient,
) *Handle {
return &Handle{
Ollama: Ollama,
toolManager: toolManager,
conf: conf,
sessionImpl: sessionImpl,
workflowManager: workflowManager,
dingtalkOldClient: dingtalkOldClient,
dingtalkContactClient: dingtalkContactClient,
dingtalkNotableClient: dingtalkNotableClient,
}
}
func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (err error) {
entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别")
prompt, err := promptProcessor.CreatePrompt(ctx, rec)
//意图识别
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{
Prompt: prompt,
Tools: rec.Tasks,
})
if err != nil {
return
}
entitys.ResLog(rec.Ch, "recognize", recognizeMsg)
entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束")
var match entitys.Match
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
err = errors.SysErrf("数据结构错误:%v", err.Error())
return
}
rec.Match = &match
return
}
// RewriteQuery 改写查询词,支持多轮对话
func (r *Handle) RewriteQuery(ctx context.Context, history []model.AiBotChatHi, currentQuery string) (string, error) {
if len(history) == 0 {
return currentQuery, nil
}
var histStr strings.Builder
for _, h := range history {
role := "用户"
if h.Role != "user" {
role = "助手"
}
histStr.WriteString(fmt.Sprintf("%s: %s\n", role, h.Content))
}
systemPrompt := `你是一个搜索查询改写专家。请结合用户的历史对话上下文,将用户当前的输入改写为一个独立的、语义完整的、适合知识库检索的中文查询词。
要求:
1. 保持原意,补全指代(如“它”、“刚才那个问题”)。
2. 只返回改写后的查询词,不要有任何解释。
3. 如果当前输入已经很完整,直接返回原句。`
userPrompt := fmt.Sprintf("### 历史对话:\n%s\n### 当前输入:\n%s\n### 改写后的查询词:", histStr.String(), currentQuery)
messages := []api.Message{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
}
return r.Ollama.Chat(ctx, messages)
}
type IssueClassification struct {
SysName string `json:"sys_name"`
IssueTypeName string `json:"issue_type_name"`
Summary string `json:"summary"`
Reason string `json:"reason"`
}
// ClassifyIssue 问题分类分析
func (r *Handle) ClassifyIssue(ctx context.Context, systems []string, issueTypes []string, userInput string) (*IssueClassification, error) {
systemPrompt := fmt.Sprintf(`你是一个技术支持路由专家。请分析用户的输入,并将其归类到最合适的系统和问题类型中。
可用系统列表: [%s]
可用问题类型: [%s]
请仅以 JSON 格式回复,包含以下字段:
- sys_name: 系统名称,若未提及系统关键词,则为"全局"
- issue_type_name: 问题类型名称
- summary: 15字以内的问题简述用于群命名
- reason: 分类判断理由;系统名称判断理由`, strings.Join(systems, ", "), strings.Join(issueTypes, ", "))
messages := []api.Message{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userInput},
}
resp, err := r.Ollama.Chat(ctx, messages)
if err != nil {
return nil, err
}
// 尝试清理 JSON 内容(有时模型会返回 markdown 块)
resp = strings.TrimPrefix(resp, "```json")
resp = strings.TrimSuffix(resp, "```")
resp = strings.TrimSpace(resp)
var result IssueClassification
if err := json.Unmarshal([]byte(resp), &result); err != nil {
return nil, fmt.Errorf("解析分类结果失败: %w, 原文: %s", err, resp)
}
return &result, nil
}
func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
return
}
func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, rec *entitys.Recognize, requireData *entitys.RequireData) (err error) {
if !rec.Match.IsMatch {
if len(rec.Match.Chat) != 0 {
entitys.ResText(rec.Ch, "", rec.Match.Chat)
} else {
entitys.ResText(rec.Ch, "", rec.Match.Reasoning)
}
return
}
var pointTask *model.AiTask
for _, task := range requireData.Tasks {
if task.Index == rec.Match.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.OtherTask(ctx, rec)
}
// 校验用户权限
if err = r.PermissionAuth(client, pointTask); err != nil {
log.Errorf("权限验证失败: %s", err.Error())
return
}
switch constants.TaskType(pointTask.Type) {
case constants.TaskTypeApi:
return r.handleApiTask(ctx, rec, pointTask)
case constants.TaskTypeKnowle:
return r.handleKnowleV2(ctx, rec, pointTask)
case constants.TaskTypeFunc:
return r.handleTask(ctx, rec, pointTask)
case constants.TaskTypeBot:
return r.HandleBot(ctx, rec, &entitys.Task{
Index: pointTask.Index,
})
case constants.TaskTypeEinoWorkflow:
return r.handleEinoWorkflow(ctx, rec, pointTask)
case constants.TaskTypeCozeWorkflow:
return r.handleCozeWorkflow(ctx, rec, pointTask)
default:
return r.handleOtherTask(ctx, requireData)
}
}
func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.Recognize) (err error) {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
return
}
func (r *Handle) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec)
if err != nil {
return
}
return
}
// 知识库
// func (r *Handle) handleKnowle(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
// var (
// configData entitys.ConfigDataTool
// sessionIdKnowledge string
// query string
// host string
// )
// err = json.Unmarshal([]byte(task.Config), &configData)
// if err != nil {
// return
// }
// ext, err := rec_extra.GetTaskRecExt(rec)
// if err != nil {
// return
// }
// // 通过session 找到知识库session
// var has bool
// if len(ext.Session) == 0 {
// return errors.SessionNotFound
// }
// ext.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(ext.Session))
// if err != nil {
// return
// } else if !has {
// return errors.SessionNotFound
// }
// // 找到知识库的host
// {
// tool, exists := r.toolManager.GetTool(configData.Tool)
// if !exists {
// return fmt.Errorf("tool not found: %s", configData.Tool)
// }
// if knowledgeTool, ok := tool.(*public.KnowledgeBaseTool); !ok {
// return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
// } else {
// host = knowledgeTool.GetConfig().BaseURL
// }
// }
// // 知识库的session为空请求知识库获取, 并绑定
// if ext.SessionInfo.KnowlegeSessionID == "" {
// // 请求知识库
// if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, ext.Sys.KnowlegeBaseID, ext.Sys.KnowlegeTenantKey); err != nil {
// return
// }
// // 绑定知识库session下次可以使用
// ext.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
// if err = r.sessionImpl.Update(&ext.SessionInfo, r.sessionImpl.WithSessionId(ext.SessionInfo.SessionID)); err != nil {
// return
// }
// }
// // 用户输入解析
// var ok bool
// input := make(map[string]string)
// if err = json.Unmarshal([]byte(rec.Match.Parameters), &input); err != nil {
// return
// }
// if query, ok = input["query"]; !ok {
// return fmt.Errorf("query不能为空")
// }
// ext.KnowledgeConf = entitys.KnowledgeBaseRequest{
// Session: ext.SessionInfo.KnowlegeSessionID,
// ApiKey: ext.Sys.KnowlegeTenantKey,
// Query: query,
// }
// rec.Ext = pkg.JsonByteIgonErr(ext)
// // 执行工具
// err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec)
// if err != nil {
// return
// }
// return
// }
// 知识库V2 - lightRAG自建
func (r *Handle) handleKnowleV2(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
// 获取用户session信息
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
// 获取租户ID 形式为 {biz-user} 比如 "zltx-platform"
tenantID := ext.Sys.KnowlegeTenantKey
// 请求知识库工具
knowledgeBase := knowledge_base.New(r.conf.KnowledgeConfig)
knowledgeResp, err := knowledgeBase.Query(&knowledge_base.QueryRequest{
TenantID: tenantID, // 后续动态接参
Query: rec.UserContent.Text,
Mode: constants.KnowledgeModeMix,
Stream: true,
Think: false,
OnlyRAG: true,
})
if err != nil {
return fmt.Errorf("请求知识库工具失败err: %v", err)
}
// 读取知识库SSE数据
err = r.readKnowledgeSSE(knowledgeResp, rec.Ch, false)
if err != nil {
return
}
return
}
// 读取知识库 SSE 数据
func (r *Handle) readKnowledgeSSE(resp io.ReadCloser, channel chan entitys.Response, useParagraphMode bool) (err error) {
scanner := bufio.NewScanner(resp)
var buffer strings.Builder
var taskIndex string = "knowledgeBase"
for scanner.Scan() {
line := scanner.Text()
delta, done, err := knowledge_base.ParseOpenAIStreamData(line)
if err != nil {
return fmt.Errorf("解析SSE数据失败: %w", err)
}
if done {
break
}
if delta == nil {
continue
}
// 推理内容
if delta.ReasoningContent != "" {
entitys.ResStream(channel, taskIndex, delta.ReasoningContent)
continue
}
// 输出内容 - 段落
if delta.Content != "" && useParagraphMode {
// 存入缓冲区
buffer.WriteString(delta.Content)
content := buffer.String()
// 检查是否有换行符,按段落输出
if idx := strings.LastIndex(content, "\n"); idx != -1 {
// 发送直到最后一个换行符的内容
toSend := content[:idx+1]
entitys.ResStream(channel, taskIndex, toSend)
// 重置缓冲区,保留剩余部分
remaining := content[idx+1:]
buffer.Reset()
buffer.WriteString(remaining)
}
}
// 输出内容 - 逐字
if delta.Content != "" && !useParagraphMode {
entitys.ResStream(channel, taskIndex, delta.Content)
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("读取SSE流中断: %w", err)
}
// 发送缓冲区剩余内容(仅在段落模式下需要)
if useParagraphMode && buffer.Len() > 0 {
entitys.ResStream(channel, taskIndex, buffer.String())
}
return nil
}
// bot 临时实现,后续转到 eino 工作流
func (r *Handle) HandleBot(ctx context.Context, rec *entitys.Recognize, task *entitys.Task) (err error) {
if task.Index == "bug_optimization_submit" {
var unionId string
entitys.ResLoading(rec.Ch, task.Index, "需求记录中...\n")
// 获取dingtalk accessToken
accessToken, _ := r.dingtalkOldClient.GetAccessToken()
// Ext 中获取 sessionId
taskExt := rec.GetTaskExt()
if taskExt == nil {
return errorcode.ParamErr("taskExt参数错误")
}
if len(taskExt.Session) > 0 {
// 获取创建者 dingtalk unionId
unionId = r.getUserDingtalkUnionId(ctx, accessToken, taskExt.Session)
} else if len(taskExt.UserName) > 0 {
unionId = r.getUserDingtalkUnionIdWithUserName(ctx, accessToken, taskExt.UserName)
} else {
return errorcode.ParamErr("taskExt参数错误,重要参数缺失")
}
// 附件url
var attachmentUrl string
for _, file := range rec.UserContent.File {
attachmentUrl = file.FileUrl
break
}
req := &dingtalk.InsertRecordReq{
BaseId: r.conf.Dingtalk.TableDemand.BaseId,
SheetIdOrName: r.conf.Dingtalk.TableDemand.SheetIdOrName,
// OperatorId: tool_callback.BotBugOptimizationSubmitAdminUnionId,
OperatorId: unionId,
CreatorUnionId: unionId,
Content: rec.UserContent.Text,
AttachmentUrl: attachmentUrl,
}
recordId, err := r.dingtalkNotableClient.InsertRecord(accessToken, req)
if err != nil {
errCode := r.dingtalkNotableClient.GetHTTPStatus(err)
// 权限不足
if errCode == 403 {
return errorcode.ForbiddenErr("您当前没有AI需求表编辑权限请联系管理员添加权限")
}
return err
}
if recordId == "" {
return errors.NewBusinessErr(422, "创建记录失败,请联系管理员")
}
var detailPage string
entitys.ResLog(rec.Ch, task.Index, "需求记录完成")
switch rec.OutPutScene {
case entitys.OutPutSceneDingTalk:
// 构建跳转链接
detailPage = "[去查看](" + r.conf.Dingtalk.TableDemand.Url + ")"
default:
// 构建跳转链接
detailPage = util.BuildJumpLink(r.conf.Dingtalk.TableDemand.Url, "去查看")
}
entitys.ResText(rec.Ch, task.Index, fmt.Sprintf("需求已记录,正在分配相关人员处理,请您耐心等待处理结果。点击查看工单进度:%s", detailPage))
return nil
}
return errors.NewBusinessErr(422, "bot 任务未实现")
}
// getUserDingtalkUnionId 获取用户的 dingtalk unionId
func (r *Handle) getUserDingtalkUnionId(ctx context.Context, accessToken, sessionID string) (unionId string) {
if len(sessionID) == 0 {
// 查询用户名
return ""
}
session, has, err := r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(sessionID))
if err != nil || !has {
log.Warnf("session not found: %s", sessionID)
return
}
return r.getUserDingtalkUnionIdWithUserName(ctx, accessToken, session.UserName)
}
func (r *Handle) getUserDingtalkUnionIdWithUserName(ctx context.Context, accessToken, userName string) (unionId string) {
// 获取创建者uid 用户名 -> dingtalk uid
creatorId, err := r.dingtalkContactClient.SearchUserOne(dingtalk.AppKey{AccessToken: accessToken}, userName)
if err != nil {
log.Warnf("search dingtalk user one failed: %v", err)
return
}
// 获取用户详情 dingtalk uid -> dingtalk unionId
userDetails, err := r.dingtalkOldClient.QueryUserDetails(ctx, creatorId)
if err != nil {
log.Warnf("query user dingtalk details failed: %v", err)
return
}
if userDetails == nil {
log.Warnf("user details not found: %s", creatorId)
return
}
unionId = userDetails.UnionID
return
}
func (r *Handle) handleApiTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
var (
request l_request.Request
requestParam map[string]interface{}
)
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
err = json.Unmarshal([]byte(rec.Match.Parameters), &requestParam)
if err != nil {
return
}
// request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
task.Config = strings.ReplaceAll(task.Config, "${authorization}", ext.Auth)
for k, v := range requestParam {
if vStr, ok := v.(string); ok {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr)
} else {
var jsonStr []byte
jsonStr, err = json.Marshal(v)
if err != nil {
return errors.NewBusinessErr(422, "请求参数解析失败")
}
task.Config = strings.ReplaceAll(task.Config, "\"${"+k+"}\"", string(jsonStr))
}
}
var configData entitys.ConfigDataHttp
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Request, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr(422, "api地址获取失败")
return
}
entitys.ResLoading(rec.Ch, task.Index, "正在请求数据")
res, err := request.Send()
if err != nil {
return
}
entitys.ResJson(rec.Ch, task.Index, res.Text)
return
}
// eino 工作流
func (r *Handle) handleEinoWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
// token 写入ctx
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
ctx = util.SetTokenToContext(ctx, ext.Auth)
entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流")
// 工作流内部输出
workflowId := task.Index
_, err = r.workflowManager.Invoke(ctx, workflowId, &runtime.WorkflowArgs{Recognize: rec})
if err != nil {
return err
}
return nil
}
func (r *Handle) handleCozeWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流(coze)")
customClient := &http.Client{
Timeout: time.Minute * 30,
}
authCli := coze.NewTokenAuth(r.conf.Coze.ApiSecret)
cozeCli := coze.NewCozeAPI(
authCli,
coze.WithBaseURL(r.conf.Coze.BaseURL),
coze.WithHttpClient(customClient),
)
// 从参数中获取workflowID
type requestParams struct {
Request l_request.Request `json:"request"`
}
var config requestParams
err = json.Unmarshal([]byte(task.Config), &config)
if err != nil {
return err
}
workflowId, ok := config.Request.Json["workflow_id"].(string)
if !ok {
return fmt.Errorf("workflow_id不能为空")
}
// 提取参数
var data map[string]interface{}
err = json.Unmarshal([]byte(rec.Match.Parameters), &data)
req := &coze.RunWorkflowsReq{
WorkflowID: workflowId,
Parameters: data,
// IsAsync: true,
}
stream := config.Request.Json["stream"].(bool)
entitys.ResLog(rec.Ch, task.Index, "工作流执行中...")
if stream {
streamResp, err := cozeCli.Workflows.Runs.Stream(ctx, req)
if err != nil {
return err
}
handleCozeWorkflowEvents(ctx, streamResp, cozeCli, workflowId, rec.Ch, task.Index)
} else {
resp, err := cozeCli.Workflows.Runs.Create(ctx, req)
if err != nil {
return err
}
entitys.ResJson(rec.Ch, task.Index, resp.Data)
}
return
}
// handleCozeWorkflowEvents 处理 coze 工作流事件
func handleCozeWorkflowEvents(ctx context.Context, resp coze.Stream[coze.WorkflowEvent], cozeCli coze.CozeAPI, workflowID string, ch chan entitys.Response, index string) {
defer resp.Close()
for {
event, err := resp.Recv()
if errorsSpecial.Is(err, io.EOF) {
fmt.Println("Stream finished")
break
}
if err != nil {
fmt.Println("Error receiving event:", err)
break
}
switch event.Event {
case coze.WorkflowEventTypeMessage:
entitys.ResStream(ch, index, event.Message.Content)
case coze.WorkflowEventTypeError:
entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %s", event.Error))
case coze.WorkflowEventTypeDone:
entitys.ResEnd(ch, index, "工作流执行完成")
case coze.WorkflowEventTypeInterrupt:
resumeReq := &coze.ResumeRunWorkflowsReq{
WorkflowID: workflowID,
EventID: event.Interrupt.InterruptData.EventID,
ResumeData: "your data",
InterruptType: event.Interrupt.InterruptData.Type,
}
newResp, err := cozeCli.Workflows.Runs.Resume(ctx, resumeReq)
if err != nil {
entitys.ResError(ch, index, fmt.Sprintf("工作流恢复执行错误: %s", err.Error()))
return
}
entitys.ResLog(ch, index, "工作流恢复执行中...")
handleCozeWorkflowEvents(ctx, newResp, cozeCli, workflowID, ch, index)
}
}
fmt.Printf("done, log:%s\n", resp.Response().LogID())
}
// 权限验证
func (r *Handle) PermissionAuth(client *gateway.Client, pointTask *model.AiTask) (err error) {
// 授权检查权限
if !utils.Contains(client.GetCodes(), pointTask.Index) {
return fmt.Errorf("用户权限不足: %s", pointTask.Name)
}
return nil
}