ai_scheduler/internal/biz/do/handle.go

420 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package do
import (
"ai_scheduler/internal/biz/llm_service"
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/constants"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/domain/workflow/runtime"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/l_request"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/rec_extra"
"ai_scheduler/internal/pkg/util"
"ai_scheduler/internal/tools"
"ai_scheduler/internal/tools/public"
errorsSpecial "errors"
"io"
"net/http"
"time"
"context"
"encoding/json"
"fmt"
"strings"
"github.com/coze-dev/coze-go"
"gorm.io/gorm/utils"
)
type Handle struct {
Ollama *llm_service.OllamaService
toolManager *tools.Manager
conf *config.Config
sessionImpl *impl.SessionImpl
workflowManager *runtime.Registry
}
func NewHandle(
Ollama *llm_service.OllamaService,
toolManager *tools.Manager,
conf *config.Config,
sessionImpl *impl.SessionImpl,
workflowManager *runtime.Registry,
) *Handle {
return &Handle{
Ollama: Ollama,
toolManager: toolManager,
conf: conf,
sessionImpl: sessionImpl,
workflowManager: workflowManager,
}
}
func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (err error) {
entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别")
prompt, err := promptProcessor.CreatePrompt(ctx, rec)
//意图识别
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{
Prompt: prompt,
Tools: rec.Tasks,
})
if err != nil {
return
}
entitys.ResLog(rec.Ch, "recognize", recognizeMsg)
entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束")
var match entitys.Match
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
err = errors.SysErr("数据结构错误:%v", err.Error())
return
}
rec.Match = &match
return
}
func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
return
}
func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, rec *entitys.Recognize, requireData *entitys.RequireData) (err error) {
if !rec.Match.IsMatch {
if len(rec.Match.Chat) != 0 {
entitys.ResText(rec.Ch, "", rec.Match.Chat)
} else {
entitys.ResText(rec.Ch, "", rec.Match.Reasoning)
}
return
}
var pointTask *model.AiTask
for _, task := range requireData.Tasks {
if task.Index == rec.Match.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.OtherTask(ctx, rec)
}
// 校验用户权限
// if err = r.PermissionAuth(client, pointTask); err != nil {
// log.Errorf("权限验证失败: %s", err.Error())
// return
// }
switch constants.TaskType(pointTask.Type) {
case constants.TaskTypeApi:
return r.handleApiTask(ctx, rec, pointTask)
case constants.TaskTypeFunc:
return r.handleTask(ctx, rec, pointTask)
case constants.TaskTypeKnowle:
return r.handleKnowle(ctx, rec, pointTask)
case constants.TaskTypeEinoWorkflow:
return r.handleEinoWorkflow(ctx, rec, pointTask)
case constants.TaskTypeCozeWorkflow:
return r.handleCozeWorkflow(ctx, rec, pointTask)
default:
return r.handleOtherTask(ctx, requireData)
}
}
func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.Recognize) (err error) {
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
return
}
func (r *Handle) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec)
if err != nil {
return
}
return
}
// 知识库
func (r *Handle) handleKnowle(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
var (
configData entitys.ConfigDataTool
sessionIdKnowledge string
query string
host string
)
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
// 通过session 找到知识库session
var has bool
if len(ext.Session) == 0 {
return errors.SessionNotFound
}
ext.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(ext.Session))
if err != nil {
return
} else if !has {
return errors.SessionNotFound
}
// 找到知识库的host
{
tool, exists := r.toolManager.GetTool(configData.Tool)
if !exists {
return fmt.Errorf("tool not found: %s", configData.Tool)
}
if knowledgeTool, ok := tool.(*public.KnowledgeBaseTool); !ok {
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
} else {
host = knowledgeTool.GetConfig().BaseURL
}
}
// 知识库的session为空请求知识库获取, 并绑定
if ext.SessionInfo.KnowlegeSessionID == "" {
// 请求知识库
if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, ext.Sys.KnowlegeBaseID, ext.Sys.KnowlegeTenantKey); err != nil {
return
}
// 绑定知识库session下次可以使用
ext.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
if err = r.sessionImpl.Update(&ext.SessionInfo, r.sessionImpl.WithSessionId(ext.SessionInfo.SessionID)); err != nil {
return
}
}
// 用户输入解析
var ok bool
input := make(map[string]string)
if err = json.Unmarshal([]byte(rec.Match.Parameters), &input); err != nil {
return
}
if query, ok = input["query"]; !ok {
return fmt.Errorf("query不能为空")
}
ext.KnowledgeConf = entitys.KnowledgeBaseRequest{
Session: ext.SessionInfo.KnowlegeSessionID,
ApiKey: ext.Sys.KnowlegeTenantKey,
Query: query,
}
rec.Ext = pkg.JsonByteIgonErr(ext)
// 执行工具
err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec)
if err != nil {
return
}
return
}
func (r *Handle) handleApiTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
var (
request l_request.Request
requestParam map[string]interface{}
)
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
err = json.Unmarshal([]byte(rec.Match.Parameters), &requestParam)
if err != nil {
return
}
// request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
task.Config = strings.ReplaceAll(task.Config, "${authorization}", ext.Auth)
for k, v := range requestParam {
if vStr, ok := v.(string); ok {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr)
} else {
var jsonStr []byte
jsonStr, err = json.Marshal(v)
if err != nil {
return errors.NewBusinessErr(422, "请求参数解析失败")
}
task.Config = strings.ReplaceAll(task.Config, "\"${"+k+"}\"", string(jsonStr))
}
}
var configData entitys.ConfigDataHttp
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Request, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr(422, "api地址获取失败")
return
}
entitys.ResLoading(rec.Ch, task.Index, "正在请求数据")
res, err := request.Send()
if err != nil {
return
}
entitys.ResJson(rec.Ch, task.Index, res.Text)
return
}
// eino 工作流
func (r *Handle) handleEinoWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
// token 写入ctx
ext, err := rec_extra.GetTaskRecExt(rec)
if err != nil {
return
}
ctx = util.SetTokenToContext(ctx, ext.Auth)
entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流")
// 工作流内部输出
workflowId := task.Index
_, err = r.workflowManager.Invoke(ctx, workflowId, rec)
if err != nil {
return err
}
return nil
}
func (r *Handle) handleCozeWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流(coze)")
customClient := &http.Client{
Timeout: time.Minute * 30,
}
authCli := coze.NewTokenAuth(r.conf.Coze.ApiSecret)
cozeCli := coze.NewCozeAPI(
authCli,
coze.WithBaseURL(r.conf.Coze.BaseURL),
coze.WithHttpClient(customClient),
)
// 从参数中获取workflowID
type requestParams struct {
Request l_request.Request `json:"request"`
}
var config requestParams
err = json.Unmarshal([]byte(task.Config), &config)
if err != nil {
return err
}
workflowId, ok := config.Request.Json["workflow_id"].(string)
if !ok {
return fmt.Errorf("workflow_id不能为空")
}
// 提取参数
var data map[string]interface{}
err = json.Unmarshal([]byte(rec.Match.Parameters), &data)
req := &coze.RunWorkflowsReq{
WorkflowID: workflowId,
Parameters: data,
// IsAsync: true,
}
stream := config.Request.Json["stream"].(bool)
entitys.ResLog(rec.Ch, task.Index, "工作流执行中...")
if stream {
streamResp, err := cozeCli.Workflows.Runs.Stream(ctx, req)
if err != nil {
return err
}
handleCozeWorkflowEvents(ctx, streamResp, cozeCli, workflowId, rec.Ch, task.Index)
} else {
resp, err := cozeCli.Workflows.Runs.Create(ctx, req)
if err != nil {
return err
}
entitys.ResJson(rec.Ch, task.Index, resp.Data)
}
return
}
// handleCozeWorkflowEvents 处理 coze 工作流事件
func handleCozeWorkflowEvents(ctx context.Context, resp coze.Stream[coze.WorkflowEvent], cozeCli coze.CozeAPI, workflowID string, ch chan entitys.Response, index string) {
defer resp.Close()
for {
event, err := resp.Recv()
if errorsSpecial.Is(err, io.EOF) {
fmt.Println("Stream finished")
break
}
if err != nil {
fmt.Println("Error receiving event:", err)
break
}
switch event.Event {
case coze.WorkflowEventTypeMessage:
entitys.ResStream(ch, index, event.Message.Content)
case coze.WorkflowEventTypeError:
entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %s", event.Error))
case coze.WorkflowEventTypeDone:
entitys.ResEnd(ch, index, "工作流执行完成")
case coze.WorkflowEventTypeInterrupt:
resumeReq := &coze.ResumeRunWorkflowsReq{
WorkflowID: workflowID,
EventID: event.Interrupt.InterruptData.EventID,
ResumeData: "your data",
InterruptType: event.Interrupt.InterruptData.Type,
}
newResp, err := cozeCli.Workflows.Runs.Resume(ctx, resumeReq)
if err != nil {
entitys.ResError(ch, index, fmt.Sprintf("工作流恢复执行错误: %s", err.Error()))
return
}
entitys.ResLog(ch, index, "工作流恢复执行中...")
handleCozeWorkflowEvents(ctx, newResp, cozeCli, workflowID, ch, index)
}
}
fmt.Printf("done, log:%s\n", resp.Response().LogID())
}
// 权限验证
func (r *Handle) PermissionAuth(client *gateway.Client, pointTask *model.AiTask) (err error) {
// 授权检查权限
if !utils.Contains(client.GetCodes(), pointTask.Index) {
return fmt.Errorf("用户权限不足: %s", pointTask.Name)
}
return nil
}