420 lines
11 KiB
Go
420 lines
11 KiB
Go
package do
|
||
|
||
import (
|
||
"ai_scheduler/internal/biz/llm_service"
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/data/constants"
|
||
errors "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/data/impl"
|
||
"ai_scheduler/internal/data/model"
|
||
"ai_scheduler/internal/domain/workflow/runtime"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/gateway"
|
||
"ai_scheduler/internal/pkg"
|
||
"ai_scheduler/internal/pkg/l_request"
|
||
"ai_scheduler/internal/pkg/mapstructure"
|
||
"ai_scheduler/internal/pkg/rec_extra"
|
||
"ai_scheduler/internal/pkg/util"
|
||
"ai_scheduler/internal/tools"
|
||
"ai_scheduler/internal/tools/public"
|
||
errorsSpecial "errors"
|
||
"io"
|
||
"net/http"
|
||
"time"
|
||
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"github.com/coze-dev/coze-go"
|
||
"gorm.io/gorm/utils"
|
||
)
|
||
|
||
type Handle struct {
|
||
Ollama *llm_service.OllamaService
|
||
toolManager *tools.Manager
|
||
|
||
conf *config.Config
|
||
sessionImpl *impl.SessionImpl
|
||
workflowManager *runtime.Registry
|
||
}
|
||
|
||
func NewHandle(
|
||
Ollama *llm_service.OllamaService,
|
||
toolManager *tools.Manager,
|
||
conf *config.Config,
|
||
sessionImpl *impl.SessionImpl,
|
||
|
||
workflowManager *runtime.Registry,
|
||
) *Handle {
|
||
return &Handle{
|
||
Ollama: Ollama,
|
||
toolManager: toolManager,
|
||
conf: conf,
|
||
sessionImpl: sessionImpl,
|
||
|
||
workflowManager: workflowManager,
|
||
}
|
||
}
|
||
|
||
func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptProcessor PromptOption) (err error) {
|
||
entitys.ResLog(rec.Ch, "recognize_start", "准备意图识别")
|
||
|
||
prompt, err := promptProcessor.CreatePrompt(ctx, rec)
|
||
//意图识别
|
||
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{
|
||
Prompt: prompt,
|
||
Tools: rec.Tasks,
|
||
})
|
||
if err != nil {
|
||
return
|
||
}
|
||
entitys.ResLog(rec.Ch, "recognize", recognizeMsg)
|
||
entitys.ResLog(rec.Ch, "recognize_end", "意图识别结束")
|
||
var match entitys.Match
|
||
if err = json.Unmarshal([]byte(recognizeMsg), &match); err != nil {
|
||
err = errors.SysErrf("数据结构错误:%v", err.Error())
|
||
return
|
||
}
|
||
rec.Match = &match
|
||
|
||
return
|
||
}
|
||
|
||
func (r *Handle) handleOtherTask(ctx context.Context, requireData *entitys.RequireData) (err error) {
|
||
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
||
return
|
||
}
|
||
|
||
func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, rec *entitys.Recognize, requireData *entitys.RequireData) (err error) {
|
||
|
||
if !rec.Match.IsMatch {
|
||
if len(rec.Match.Chat) != 0 {
|
||
entitys.ResText(rec.Ch, "", rec.Match.Chat)
|
||
} else {
|
||
entitys.ResText(rec.Ch, "", rec.Match.Reasoning)
|
||
}
|
||
|
||
return
|
||
}
|
||
var pointTask *model.AiTask
|
||
for _, task := range requireData.Tasks {
|
||
if task.Index == rec.Match.Index {
|
||
pointTask = &task
|
||
break
|
||
}
|
||
}
|
||
|
||
if pointTask == nil || pointTask.Index == "other" {
|
||
return r.OtherTask(ctx, rec)
|
||
}
|
||
|
||
// 校验用户权限
|
||
// if err = r.PermissionAuth(client, pointTask); err != nil {
|
||
// log.Errorf("权限验证失败: %s", err.Error())
|
||
// return
|
||
// }
|
||
|
||
switch constants.TaskType(pointTask.Type) {
|
||
case constants.TaskTypeApi:
|
||
return r.handleApiTask(ctx, rec, pointTask)
|
||
case constants.TaskTypeFunc:
|
||
return r.handleTask(ctx, rec, pointTask)
|
||
case constants.TaskTypeKnowle:
|
||
return r.handleKnowle(ctx, rec, pointTask)
|
||
case constants.TaskTypeEinoWorkflow:
|
||
return r.handleEinoWorkflow(ctx, rec, pointTask)
|
||
case constants.TaskTypeCozeWorkflow:
|
||
return r.handleCozeWorkflow(ctx, rec, pointTask)
|
||
default:
|
||
return r.handleOtherTask(ctx, requireData)
|
||
}
|
||
}
|
||
|
||
func (r *Handle) OtherTask(ctx context.Context, requireData *entitys.Recognize) (err error) {
|
||
entitys.ResText(requireData.Ch, "", requireData.Match.Reasoning)
|
||
return
|
||
}
|
||
|
||
func (r *Handle) handleTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
|
||
var configData entitys.ConfigDataTool
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// 知识库
|
||
func (r *Handle) handleKnowle(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
|
||
|
||
var (
|
||
configData entitys.ConfigDataTool
|
||
sessionIdKnowledge string
|
||
query string
|
||
host string
|
||
)
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
ext, err := rec_extra.GetTaskRecExt(rec)
|
||
if err != nil {
|
||
return
|
||
}
|
||
// 通过session 找到知识库session
|
||
var has bool
|
||
if len(ext.Session) == 0 {
|
||
return errors.SessionNotFound
|
||
}
|
||
ext.SessionInfo, has, err = r.sessionImpl.FindOne(r.sessionImpl.WithSessionId(ext.Session))
|
||
if err != nil {
|
||
return
|
||
} else if !has {
|
||
return errors.SessionNotFound
|
||
}
|
||
|
||
// 找到知识库的host
|
||
{
|
||
tool, exists := r.toolManager.GetTool(configData.Tool)
|
||
if !exists {
|
||
return fmt.Errorf("tool not found: %s", configData.Tool)
|
||
}
|
||
|
||
if knowledgeTool, ok := tool.(*public.KnowledgeBaseTool); !ok {
|
||
return fmt.Errorf("未找到知识库Tool: %s", configData.Tool)
|
||
} else {
|
||
host = knowledgeTool.GetConfig().BaseURL
|
||
}
|
||
|
||
}
|
||
|
||
// 知识库的session为空,请求知识库获取, 并绑定
|
||
if ext.SessionInfo.KnowlegeSessionID == "" {
|
||
// 请求知识库
|
||
if sessionIdKnowledge, err = public.GetKnowledgeBaseSession(host, ext.Sys.KnowlegeBaseID, ext.Sys.KnowlegeTenantKey); err != nil {
|
||
return
|
||
}
|
||
|
||
// 绑定知识库session,下次可以使用
|
||
ext.SessionInfo.KnowlegeSessionID = sessionIdKnowledge
|
||
if err = r.sessionImpl.Update(&ext.SessionInfo, r.sessionImpl.WithSessionId(ext.SessionInfo.SessionID)); err != nil {
|
||
return
|
||
}
|
||
}
|
||
|
||
// 用户输入解析
|
||
var ok bool
|
||
input := make(map[string]string)
|
||
if err = json.Unmarshal([]byte(rec.Match.Parameters), &input); err != nil {
|
||
return
|
||
}
|
||
if query, ok = input["query"]; !ok {
|
||
return fmt.Errorf("query不能为空")
|
||
}
|
||
|
||
ext.KnowledgeConf = entitys.KnowledgeBaseRequest{
|
||
Session: ext.SessionInfo.KnowlegeSessionID,
|
||
ApiKey: ext.Sys.KnowlegeTenantKey,
|
||
Query: query,
|
||
}
|
||
rec.Ext = pkg.JsonByteIgonErr(ext)
|
||
// 执行工具
|
||
err = r.toolManager.ExecuteTool(ctx, configData.Tool, rec)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (r *Handle) handleApiTask(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
|
||
var (
|
||
request l_request.Request
|
||
requestParam map[string]interface{}
|
||
)
|
||
ext, err := rec_extra.GetTaskRecExt(rec)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = json.Unmarshal([]byte(rec.Match.Parameters), &requestParam)
|
||
if err != nil {
|
||
return
|
||
}
|
||
// request.Url = strings.ReplaceAll(task.Config, "${authorization}", requireData.Auth)
|
||
task.Config = strings.ReplaceAll(task.Config, "${authorization}", ext.Auth)
|
||
for k, v := range requestParam {
|
||
if vStr, ok := v.(string); ok {
|
||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", vStr)
|
||
} else {
|
||
var jsonStr []byte
|
||
jsonStr, err = json.Marshal(v)
|
||
if err != nil {
|
||
return errors.NewBusinessErr(422, "请求参数解析失败")
|
||
}
|
||
task.Config = strings.ReplaceAll(task.Config, "\"${"+k+"}\"", string(jsonStr))
|
||
}
|
||
}
|
||
var configData entitys.ConfigDataHttp
|
||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = mapstructure.Decode(configData.Request, &request)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if len(request.Url) == 0 {
|
||
err = errors.NewBusinessErr(422, "api地址获取失败")
|
||
return
|
||
}
|
||
|
||
entitys.ResLoading(rec.Ch, task.Index, "正在请求数据")
|
||
|
||
res, err := request.Send()
|
||
if err != nil {
|
||
return
|
||
}
|
||
entitys.ResJson(rec.Ch, task.Index, res.Text)
|
||
|
||
return
|
||
}
|
||
|
||
// eino 工作流
|
||
func (r *Handle) handleEinoWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
|
||
// token 写入ctx
|
||
ext, err := rec_extra.GetTaskRecExt(rec)
|
||
if err != nil {
|
||
return
|
||
}
|
||
ctx = util.SetTokenToContext(ctx, ext.Auth)
|
||
|
||
entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流")
|
||
|
||
// 工作流内部输出
|
||
workflowId := task.Index
|
||
_, err = r.workflowManager.Invoke(ctx, workflowId, rec)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Handle) handleCozeWorkflow(ctx context.Context, rec *entitys.Recognize, task *model.AiTask) (err error) {
|
||
entitys.ResLoading(rec.Ch, task.Index, "正在执行工作流(coze)")
|
||
|
||
customClient := &http.Client{
|
||
Timeout: time.Minute * 30,
|
||
}
|
||
|
||
authCli := coze.NewTokenAuth(r.conf.Coze.ApiSecret)
|
||
cozeCli := coze.NewCozeAPI(
|
||
authCli,
|
||
coze.WithBaseURL(r.conf.Coze.BaseURL),
|
||
coze.WithHttpClient(customClient),
|
||
)
|
||
|
||
// 从参数中获取workflowID
|
||
type requestParams struct {
|
||
Request l_request.Request `json:"request"`
|
||
}
|
||
var config requestParams
|
||
err = json.Unmarshal([]byte(task.Config), &config)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
workflowId, ok := config.Request.Json["workflow_id"].(string)
|
||
if !ok {
|
||
return fmt.Errorf("workflow_id不能为空")
|
||
}
|
||
// 提取参数
|
||
var data map[string]interface{}
|
||
err = json.Unmarshal([]byte(rec.Match.Parameters), &data)
|
||
|
||
req := &coze.RunWorkflowsReq{
|
||
WorkflowID: workflowId,
|
||
Parameters: data,
|
||
// IsAsync: true,
|
||
}
|
||
|
||
stream := config.Request.Json["stream"].(bool)
|
||
|
||
entitys.ResLog(rec.Ch, task.Index, "工作流执行中...")
|
||
|
||
if stream {
|
||
streamResp, err := cozeCli.Workflows.Runs.Stream(ctx, req)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
handleCozeWorkflowEvents(ctx, streamResp, cozeCli, workflowId, rec.Ch, task.Index)
|
||
} else {
|
||
resp, err := cozeCli.Workflows.Runs.Create(ctx, req)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
entitys.ResJson(rec.Ch, task.Index, resp.Data)
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// handleCozeWorkflowEvents 处理 coze 工作流事件
|
||
func handleCozeWorkflowEvents(ctx context.Context, resp coze.Stream[coze.WorkflowEvent], cozeCli coze.CozeAPI, workflowID string, ch chan entitys.Response, index string) {
|
||
defer resp.Close()
|
||
for {
|
||
event, err := resp.Recv()
|
||
if errorsSpecial.Is(err, io.EOF) {
|
||
fmt.Println("Stream finished")
|
||
break
|
||
}
|
||
if err != nil {
|
||
fmt.Println("Error receiving event:", err)
|
||
break
|
||
}
|
||
|
||
switch event.Event {
|
||
case coze.WorkflowEventTypeMessage:
|
||
entitys.ResStream(ch, index, event.Message.Content)
|
||
case coze.WorkflowEventTypeError:
|
||
entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %s", event.Error))
|
||
case coze.WorkflowEventTypeDone:
|
||
entitys.ResEnd(ch, index, "工作流执行完成")
|
||
case coze.WorkflowEventTypeInterrupt:
|
||
resumeReq := &coze.ResumeRunWorkflowsReq{
|
||
WorkflowID: workflowID,
|
||
EventID: event.Interrupt.InterruptData.EventID,
|
||
ResumeData: "your data",
|
||
InterruptType: event.Interrupt.InterruptData.Type,
|
||
}
|
||
newResp, err := cozeCli.Workflows.Runs.Resume(ctx, resumeReq)
|
||
if err != nil {
|
||
entitys.ResError(ch, index, fmt.Sprintf("工作流恢复执行错误: %s", err.Error()))
|
||
return
|
||
}
|
||
entitys.ResLog(ch, index, "工作流恢复执行中...")
|
||
handleCozeWorkflowEvents(ctx, newResp, cozeCli, workflowID, ch, index)
|
||
}
|
||
}
|
||
fmt.Printf("done, log:%s\n", resp.Response().LogID())
|
||
}
|
||
|
||
// 权限验证
|
||
func (r *Handle) PermissionAuth(client *gateway.Client, pointTask *model.AiTask) (err error) {
|
||
// 授权检查权限
|
||
if !utils.Contains(client.GetCodes(), pointTask.Index) {
|
||
return fmt.Errorf("用户权限不足: %s", pointTask.Name)
|
||
}
|
||
|
||
return nil
|
||
}
|