结构修改
This commit is contained in:
parent
8a2411016c
commit
03f6130a6e
|
@ -2,18 +2,24 @@ package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/data/constant"
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/data/impl"
|
"ai_scheduler/internal/data/impl"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg"
|
"ai_scheduler/internal/pkg"
|
||||||
|
"ai_scheduler/internal/pkg/mapstructure"
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
"ai_scheduler/internal/tools"
|
"ai_scheduler/internal/tools"
|
||||||
"ai_scheduler/tmpl/dataTemp"
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
"github.com/tmc/langchaingo/llms"
|
"github.com/tmc/langchaingo/llms"
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
|
@ -41,6 +47,7 @@ func NewAiRouterBiz(
|
||||||
hisImpl *impl.ChatImpl,
|
hisImpl *impl.ChatImpl,
|
||||||
conf *config.Config,
|
conf *config.Config,
|
||||||
utilAgent *utils_ollama.UtilOllama,
|
utilAgent *utils_ollama.UtilOllama,
|
||||||
|
|
||||||
) entitys.RouterService {
|
) entitys.RouterService {
|
||||||
return &AiRouterService{
|
return &AiRouterService{
|
||||||
//aiClient: aiClient,
|
//aiClient: aiClient,
|
||||||
|
@ -91,26 +98,27 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
|
|
||||||
toolDefinitions := r.registerTools(task)
|
//toolDefinitions := r.registerTools(task)
|
||||||
//prompt := r.getPrompt(sysInfo, history, req.Text)
|
//prompt := r.getPrompt(sysInfo, history, req.Text)
|
||||||
|
|
||||||
//意图预测
|
//意图预测
|
||||||
//msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
|
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
||||||
|
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
||||||
//llms.WithTools(toolDefinitions),
|
//llms.WithTools(toolDefinitions),
|
||||||
// //llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
|
//llms.WithToolChoice("tool_name"),
|
||||||
// llms.WithJSONMode(),
|
|
||||||
//)
|
|
||||||
prompt := r.getPromptLLM(sysInfo, history, req.Text)
|
|
||||||
msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
|
||||||
llms.WithTools(toolDefinitions),
|
|
||||||
llms.WithToolChoice("tool_name"),
|
|
||||||
llms.WithJSONMode(),
|
llms.WithJSONMode(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
|
log.Info(match)
|
||||||
c.WriteMessage(1, []byte(msg.Choices[0].Content))
|
var matchJson entitys.Match
|
||||||
|
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
|
||||||
|
if err != nil {
|
||||||
|
return errors.SystemError
|
||||||
|
}
|
||||||
|
return r.handleMatch(c, &matchJson, task)
|
||||||
|
//c.WriteMessage(1, []byte(msg.Choices[0].Content))
|
||||||
// 构建消息
|
// 构建消息
|
||||||
//messages := []entitys.Message{
|
//messages := []entitys.Message{
|
||||||
// {
|
// {
|
||||||
|
@ -193,6 +201,76 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
|
||||||
|
defer func() {
|
||||||
|
c.WriteMessage(1, []byte("EOF"))
|
||||||
|
}()
|
||||||
|
if !matchJson.IsMatch {
|
||||||
|
c.WriteMessage(1, []byte(matchJson.Reasoning))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var pointTask *model.AiTask
|
||||||
|
for _, task := range tasks {
|
||||||
|
if task.Index == matchJson.Index {
|
||||||
|
pointTask = &task
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pointTask == nil || pointTask.Index == "other" {
|
||||||
|
return r.handleOtherTask(c, matchJson)
|
||||||
|
}
|
||||||
|
var res []byte
|
||||||
|
switch pointTask.Type {
|
||||||
|
case constant.TaskTypeApi:
|
||||||
|
res, err = r.handleApiTask(c, matchJson, pointTask)
|
||||||
|
default:
|
||||||
|
return r.handleOtherTask(c, matchJson)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
|
||||||
|
|
||||||
|
c.WriteMessage(1, []byte(matchJson.Reasoning))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (resByte []byte, err error) {
|
||||||
|
var (
|
||||||
|
request l_request.Request
|
||||||
|
auth = c.Headers("X-Authorization", "")
|
||||||
|
requestParam map[string]interface{}
|
||||||
|
)
|
||||||
|
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
|
||||||
|
for k, v := range requestParam {
|
||||||
|
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
var configData entitys.ConfigData
|
||||||
|
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = mapstructure.Decode(configData.Do, &request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(request.Url) == 0 {
|
||||||
|
err = errors.NewBusinessErr("00022", "api地址获取失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res, err := request.Send()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return res.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||||||
|
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
|
@ -261,7 +339,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
|
||||||
return prompt
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent {
|
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||||
var (
|
var (
|
||||||
prompt = make([]llms.MessageContent, 0)
|
prompt = make([]llms.MessageContent, 0)
|
||||||
)
|
)
|
||||||
|
@ -275,6 +353,11 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha
|
||||||
Parts: []llms.ContentPart{
|
Parts: []llms.ContentPart{
|
||||||
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
||||||
},
|
},
|
||||||
|
}, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeTool,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||||
|
},
|
||||||
}, llms.MessageContent{
|
}, llms.MessageContent{
|
||||||
Role: llms.ChatMessageTypeHuman,
|
Role: llms.ChatMessageTypeHuman,
|
||||||
Parts: []llms.ContentPart{
|
Parts: []llms.ContentPart{
|
||||||
|
|
|
@ -11,6 +11,6 @@ const (
|
||||||
type TaskType int32
|
type TaskType int32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TaskTypeApi ConnStatus = iota + 1
|
TaskTypeApi = 1
|
||||||
TaskTypeKnowle
|
TaskTypeKnowle = 2
|
||||||
)
|
)
|
||||||
|
|
|
@ -69,9 +69,9 @@ type Tool interface {
|
||||||
Execute(ctx context.Context, args json.RawMessage) (interface{}, error)
|
Execute(ctx context.Context, args json.RawMessage) (interface{}, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AIClient AI客户端接口
|
type ConfigData struct {
|
||||||
type AIClient interface {
|
Param map[string]interface{} `json:"param"`
|
||||||
Chat(ctx context.Context, messages []Message, tools []ToolDefinition) (*ChatResponse, error)
|
Do map[string]interface{} `json:"do"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message 消息
|
// Message 消息
|
||||||
|
@ -88,7 +88,13 @@ type FuncApi struct {
|
||||||
type TaskConfig struct {
|
type TaskConfig struct {
|
||||||
Param interface{} `json:"param"`
|
Param interface{} `json:"param"`
|
||||||
}
|
}
|
||||||
|
type Match struct {
|
||||||
|
Confidence float64 `json:"confidence"`
|
||||||
|
Index string `json:"index"`
|
||||||
|
IsMatch bool `json:"is_match"`
|
||||||
|
Parameters string `json:"parameters"`
|
||||||
|
Reasoning string `json:"reasoning"`
|
||||||
|
}
|
||||||
type ChatHis struct {
|
type ChatHis struct {
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Messages []HisMessage `json:"messages"`
|
Messages []HisMessage `json:"messages"`
|
||||||
|
|
|
@ -18,7 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
||||||
ollama.WithModel(c.Ollama.Model),
|
ollama.WithModel(c.Ollama.Model),
|
||||||
ollama.WithHTTPClient(http.DefaultClient),
|
ollama.WithHTTPClient(http.DefaultClient),
|
||||||
ollama.WithServerURL(getUrl(c)),
|
ollama.WithServerURL(getUrl(c)),
|
||||||
ollama.WithKeepAlive("1h"),
|
ollama.WithKeepAlive("-1s"),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal(err)
|
logger.Fatal(err)
|
||||||
|
|
|
@ -2,8 +2,11 @@ package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg/mapstructure"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_task(t *testing.T) {
|
func Test_task(t *testing.T) {
|
||||||
|
@ -12,3 +15,23 @@ func Test_task(t *testing.T) {
|
||||||
err := json.Unmarshal([]byte(config), &c)
|
err := json.Unmarshal([]byte(config), &c)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type configData struct {
|
||||||
|
Param map[string]interface{} `json:"param"`
|
||||||
|
Do map[string]interface{} `json:"do"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_task2(t *testing.T) {
|
||||||
|
var (
|
||||||
|
c l_request.Request
|
||||||
|
config configData
|
||||||
|
)
|
||||||
|
|
||||||
|
configJson := `{"param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}, "do": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
|
||||||
|
err := json.Unmarshal([]byte(configJson), &config)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
mapstructure.Decode(config.Do, &c)
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue