Compare commits
2 Commits
addcf2d24d
...
0ce09611de
Author | SHA1 | Date |
---|---|---|
|
0ce09611de | |
|
03f6130a6e |
|
@ -2,18 +2,24 @@ package biz
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/constant"
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/pkg/mapstructure"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"ai_scheduler/internal/tools"
|
||||
"ai_scheduler/tmpl/dataTemp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"xorm.io/builder"
|
||||
|
@ -41,6 +47,7 @@ func NewAiRouterBiz(
|
|||
hisImpl *impl.ChatImpl,
|
||||
conf *config.Config,
|
||||
utilAgent *utils_ollama.UtilOllama,
|
||||
|
||||
) entitys.RouterService {
|
||||
return &AiRouterService{
|
||||
//aiClient: aiClient,
|
||||
|
@ -91,26 +98,27 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
|||
return errors.SystemError
|
||||
}
|
||||
|
||||
toolDefinitions := r.registerTools(task)
|
||||
//toolDefinitions := r.registerTools(task)
|
||||
//prompt := r.getPrompt(sysInfo, history, req.Text)
|
||||
|
||||
//意图预测
|
||||
//msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
|
||||
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
||||
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
||||
//llms.WithTools(toolDefinitions),
|
||||
// //llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
|
||||
// llms.WithJSONMode(),
|
||||
//)
|
||||
prompt := r.getPromptLLM(sysInfo, history, req.Text)
|
||||
msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
||||
llms.WithTools(toolDefinitions),
|
||||
llms.WithToolChoice("tool_name"),
|
||||
//llms.WithToolChoice("tool_name"),
|
||||
llms.WithJSONMode(),
|
||||
)
|
||||
if err != nil {
|
||||
return errors.SystemError
|
||||
}
|
||||
|
||||
c.WriteMessage(1, []byte(msg.Choices[0].Content))
|
||||
log.Info(match)
|
||||
var matchJson entitys.Match
|
||||
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
|
||||
if err != nil {
|
||||
return errors.SystemError
|
||||
}
|
||||
return r.handleMatch(c, &matchJson, task)
|
||||
//c.WriteMessage(1, []byte(msg.Choices[0].Content))
|
||||
// 构建消息
|
||||
//messages := []entitys.Message{
|
||||
// {
|
||||
|
@ -193,6 +201,76 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
|
||||
defer func() {
|
||||
c.WriteMessage(1, []byte("EOF"))
|
||||
}()
|
||||
if !matchJson.IsMatch {
|
||||
c.WriteMessage(1, []byte(matchJson.Reasoning))
|
||||
return
|
||||
}
|
||||
var pointTask *model.AiTask
|
||||
for _, task := range tasks {
|
||||
if task.Index == matchJson.Index {
|
||||
pointTask = &task
|
||||
break
|
||||
}
|
||||
}
|
||||
if pointTask == nil || pointTask.Index == "other" {
|
||||
return r.handleOtherTask(c, matchJson)
|
||||
}
|
||||
var res []byte
|
||||
switch pointTask.Type {
|
||||
case constant.TaskTypeApi:
|
||||
res, err = r.handleApiTask(c, matchJson, pointTask)
|
||||
default:
|
||||
return r.handleOtherTask(c, matchJson)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
|
||||
|
||||
c.WriteMessage(1, []byte(matchJson.Reasoning))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (resByte []byte, err error) {
|
||||
var (
|
||||
request l_request.Request
|
||||
auth = c.Headers("X-Authorization", "")
|
||||
requestParam map[string]interface{}
|
||||
)
|
||||
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
|
||||
for k, v := range requestParam {
|
||||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
||||
}
|
||||
var configData entitys.ConfigData
|
||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = mapstructure.Decode(configData.Do, &request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(request.Url) == 0 {
|
||||
err = errors.NewBusinessErr("00022", "api地址获取失败")
|
||||
return
|
||||
}
|
||||
res, err := request.Send()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return res.Content, nil
|
||||
}
|
||||
|
||||
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||||
|
||||
cond := builder.NewCond()
|
||||
|
@ -261,7 +339,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
|
|||
return prompt
|
||||
}
|
||||
|
||||
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent {
|
||||
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
||||
var (
|
||||
prompt = make([]llms.MessageContent, 0)
|
||||
)
|
||||
|
@ -275,6 +353,11 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha
|
|||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeTool,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
||||
},
|
||||
}, llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
Parts: []llms.ContentPart{
|
||||
|
|
|
@ -11,6 +11,6 @@ const (
|
|||
type TaskType int32
|
||||
|
||||
const (
|
||||
TaskTypeApi ConnStatus = iota + 1
|
||||
TaskTypeKnowle
|
||||
TaskTypeApi = 1
|
||||
TaskTypeKnowle = 2
|
||||
)
|
||||
|
|
|
@ -69,9 +69,9 @@ type Tool interface {
|
|||
Execute(ctx context.Context, args json.RawMessage) (interface{}, error)
|
||||
}
|
||||
|
||||
// AIClient AI客户端接口
|
||||
type AIClient interface {
|
||||
Chat(ctx context.Context, messages []Message, tools []ToolDefinition) (*ChatResponse, error)
|
||||
type ConfigData struct {
|
||||
Param map[string]interface{} `json:"param"`
|
||||
Do map[string]interface{} `json:"do"`
|
||||
}
|
||||
|
||||
// Message 消息
|
||||
|
@ -88,7 +88,13 @@ type FuncApi struct {
|
|||
type TaskConfig struct {
|
||||
Param interface{} `json:"param"`
|
||||
}
|
||||
|
||||
type Match struct {
|
||||
Confidence float64 `json:"confidence"`
|
||||
Index string `json:"index"`
|
||||
IsMatch bool `json:"is_match"`
|
||||
Parameters string `json:"parameters"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
type ChatHis struct {
|
||||
SessionId string `json:"session_id"`
|
||||
Messages []HisMessage `json:"messages"`
|
||||
|
|
|
@ -18,7 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
|||
ollama.WithModel(c.Ollama.Model),
|
||||
ollama.WithHTTPClient(http.DefaultClient),
|
||||
ollama.WithServerURL(getUrl(c)),
|
||||
ollama.WithKeepAlive("1h"),
|
||||
ollama.WithKeepAlive("-1s"),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Fatal(err)
|
||||
|
|
|
@ -2,8 +2,11 @@ package test
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/mapstructure"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||
)
|
||||
|
||||
func Test_task(t *testing.T) {
|
||||
|
@ -12,3 +15,23 @@ func Test_task(t *testing.T) {
|
|||
err := json.Unmarshal([]byte(config), &c)
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
type configData struct {
|
||||
Param map[string]interface{} `json:"param"`
|
||||
Do map[string]interface{} `json:"do"`
|
||||
}
|
||||
|
||||
func Test_task2(t *testing.T) {
|
||||
var (
|
||||
c l_request.Request
|
||||
config configData
|
||||
)
|
||||
|
||||
configJson := `{"param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}, "do": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
|
||||
err := json.Unmarshal([]byte(configJson), &config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mapstructure.Decode(config.Do, &c)
|
||||
t.Log(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue