diff --git a/internal/biz/router.go b/internal/biz/router.go index 1471009..5888c76 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -2,18 +2,24 @@ package biz import ( "ai_scheduler/internal/config" + "ai_scheduler/internal/data/constant" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" + "ai_scheduler/internal/pkg/mapstructure" "ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/tools" "ai_scheduler/tmpl/dataTemp" "context" "encoding/json" - "log" + "fmt" + "net/http" + "strings" + "gitea.cdlsxd.cn/self-tools/l_request" + "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" "github.com/tmc/langchaingo/llms" "xorm.io/builder" @@ -41,6 +47,7 @@ func NewAiRouterBiz( hisImpl *impl.ChatImpl, conf *config.Config, utilAgent *utils_ollama.UtilOllama, + ) entitys.RouterService { return &AiRouterService{ //aiClient: aiClient, @@ -91,26 +98,27 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo return errors.SystemError } - toolDefinitions := r.registerTools(task) + //toolDefinitions := r.registerTools(task) //prompt := r.getPrompt(sysInfo, history, req.Text) //意图预测 - //msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt), - // llms.WithTools(toolDefinitions), - // //llms.WithToolChoice(llms.FunctionCallBehaviorAuto), - // llms.WithJSONMode(), - //) - prompt := r.getPromptLLM(sysInfo, history, req.Text) - msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, - llms.WithTools(toolDefinitions), - llms.WithToolChoice("tool_name"), + prompt := r.getPromptLLM(sysInfo, history, req.Text, task) + match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, + //llms.WithTools(toolDefinitions), + //llms.WithToolChoice("tool_name"), llms.WithJSONMode(), ) if err != nil { return errors.SystemError } - - c.WriteMessage(1, []byte(msg.Choices[0].Content)) + log.Info(match) + var matchJson entitys.Match + err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson) + if err != nil { + return errors.SystemError + } + return r.handleMatch(c, &matchJson, task) + //c.WriteMessage(1, []byte(msg.Choices[0].Content)) // 构建消息 //messages := []entitys.Message{ // { @@ -193,6 +201,76 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo return nil } +func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) { + defer func() { + c.WriteMessage(1, []byte("EOF")) + }() + if !matchJson.IsMatch { + c.WriteMessage(1, []byte(matchJson.Reasoning)) + return + } + var pointTask *model.AiTask + for _, task := range tasks { + if task.Index == matchJson.Index { + pointTask = &task + break + } + } + if pointTask == nil || pointTask.Index == "other" { + return r.handleOtherTask(c, matchJson) + } + var res []byte + switch pointTask.Type { + case constant.TaskTypeApi: + res, err = r.handleApiTask(c, matchJson, pointTask) + default: + return r.handleOtherTask(c, matchJson) + } + + return +} + +func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) { + + c.WriteMessage(1, []byte(matchJson.Reasoning)) + + return +} + +func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (resByte []byte, err error) { + var ( + request l_request.Request + auth = c.Headers("X-Authorization", "") + requestParam map[string]interface{} + ) + err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam) + if err != nil { + return + } + request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth) + for k, v := range requestParam { + task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v)) + } + var configData entitys.ConfigData + err = json.Unmarshal([]byte(task.Config), &configData) + if err != nil { + return + } + err = mapstructure.Decode(configData.Do, &request) + if err != nil { + return + } + if len(request.Url) == 0 { + err = errors.NewBusinessErr("00022", "api地址获取失败") + return + } + res, err := request.Send() + if err != nil { + return + } + return res.Content, nil +} + func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { cond := builder.NewCond() @@ -261,7 +339,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi return prompt } -func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent { +func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { var ( prompt = make([]llms.MessageContent, 0) ) @@ -275,6 +353,11 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha Parts: []llms.ContentPart{ llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))), + }, }, llms.MessageContent{ Role: llms.ChatMessageTypeHuman, Parts: []llms.ContentPart{ diff --git a/internal/data/constant/const.go b/internal/data/constant/const.go index d611d70..19ae774 100644 --- a/internal/data/constant/const.go +++ b/internal/data/constant/const.go @@ -11,6 +11,6 @@ const ( type TaskType int32 const ( - TaskTypeApi ConnStatus = iota + 1 - TaskTypeKnowle + TaskTypeApi = 1 + TaskTypeKnowle = 2 ) diff --git a/internal/entitys/types.go b/internal/entitys/types.go index 13a1f9f..0a3373f 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -69,9 +69,9 @@ type Tool interface { Execute(ctx context.Context, args json.RawMessage) (interface{}, error) } -// AIClient AI客户端接口 -type AIClient interface { - Chat(ctx context.Context, messages []Message, tools []ToolDefinition) (*ChatResponse, error) +type ConfigData struct { + Param map[string]interface{} `json:"param"` + Do map[string]interface{} `json:"do"` } // Message 消息 @@ -88,7 +88,13 @@ type FuncApi struct { type TaskConfig struct { Param interface{} `json:"param"` } - +type Match struct { + Confidence float64 `json:"confidence"` + Index string `json:"index"` + IsMatch bool `json:"is_match"` + Parameters string `json:"parameters"` + Reasoning string `json:"reasoning"` +} type ChatHis struct { SessionId string `json:"session_id"` Messages []HisMessage `json:"messages"` diff --git a/internal/pkg/utils_ollama/ollama.go b/internal/pkg/utils_ollama/ollama.go index f5546bb..65388a3 100644 --- a/internal/pkg/utils_ollama/ollama.go +++ b/internal/pkg/utils_ollama/ollama.go @@ -18,7 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { ollama.WithModel(c.Ollama.Model), ollama.WithHTTPClient(http.DefaultClient), ollama.WithServerURL(getUrl(c)), - ollama.WithKeepAlive("1h"), + ollama.WithKeepAlive("-1s"), ) if err != nil { logger.Fatal(err) diff --git a/internal/test/chat_test.go b/internal/test/chat_test.go index ebd1bda..489eb4f 100644 --- a/internal/test/chat_test.go +++ b/internal/test/chat_test.go @@ -2,8 +2,11 @@ package test import ( "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/mapstructure" "encoding/json" "testing" + + "gitea.cdlsxd.cn/self-tools/l_request" ) func Test_task(t *testing.T) { @@ -12,3 +15,23 @@ func Test_task(t *testing.T) { err := json.Unmarshal([]byte(config), &c) t.Log(err) } + +type configData struct { + Param map[string]interface{} `json:"param"` + Do map[string]interface{} `json:"do"` +} + +func Test_task2(t *testing.T) { + var ( + c l_request.Request + config configData + ) + + configJson := `{"param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}, "do": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}` + err := json.Unmarshal([]byte(configJson), &config) + if err != nil { + panic(err) + } + mapstructure.Decode(config.Do, &c) + t.Log(err) +}