Compare commits
No commits in common. "0ce09611defb9207db32d51b01f8d6aa15508f15" and "addcf2d24d4ce60fa254c9f458f8cb56bccfe86d" have entirely different histories.
0ce09611de
...
addcf2d24d
|
@ -2,24 +2,18 @@ package biz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/data/constant"
|
|
||||||
errors "ai_scheduler/internal/data/error"
|
errors "ai_scheduler/internal/data/error"
|
||||||
"ai_scheduler/internal/data/impl"
|
"ai_scheduler/internal/data/impl"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg"
|
"ai_scheduler/internal/pkg"
|
||||||
"ai_scheduler/internal/pkg/mapstructure"
|
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
"ai_scheduler/internal/tools"
|
"ai_scheduler/internal/tools"
|
||||||
"ai_scheduler/tmpl/dataTemp"
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"log"
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
|
||||||
"github.com/gofiber/fiber/v2/log"
|
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
"github.com/tmc/langchaingo/llms"
|
"github.com/tmc/langchaingo/llms"
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
|
@ -47,7 +41,6 @@ func NewAiRouterBiz(
|
||||||
hisImpl *impl.ChatImpl,
|
hisImpl *impl.ChatImpl,
|
||||||
conf *config.Config,
|
conf *config.Config,
|
||||||
utilAgent *utils_ollama.UtilOllama,
|
utilAgent *utils_ollama.UtilOllama,
|
||||||
|
|
||||||
) entitys.RouterService {
|
) entitys.RouterService {
|
||||||
return &AiRouterService{
|
return &AiRouterService{
|
||||||
//aiClient: aiClient,
|
//aiClient: aiClient,
|
||||||
|
@ -98,27 +91,26 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
|
|
||||||
//toolDefinitions := r.registerTools(task)
|
toolDefinitions := r.registerTools(task)
|
||||||
//prompt := r.getPrompt(sysInfo, history, req.Text)
|
//prompt := r.getPrompt(sysInfo, history, req.Text)
|
||||||
|
|
||||||
//意图预测
|
//意图预测
|
||||||
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
//msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
|
||||||
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
|
||||||
// llms.WithTools(toolDefinitions),
|
// llms.WithTools(toolDefinitions),
|
||||||
//llms.WithToolChoice("tool_name"),
|
// //llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
|
||||||
|
// llms.WithJSONMode(),
|
||||||
|
//)
|
||||||
|
prompt := r.getPromptLLM(sysInfo, history, req.Text)
|
||||||
|
msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
||||||
|
llms.WithTools(toolDefinitions),
|
||||||
|
llms.WithToolChoice("tool_name"),
|
||||||
llms.WithJSONMode(),
|
llms.WithJSONMode(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
log.Info(match)
|
|
||||||
var matchJson entitys.Match
|
c.WriteMessage(1, []byte(msg.Choices[0].Content))
|
||||||
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
|
|
||||||
if err != nil {
|
|
||||||
return errors.SystemError
|
|
||||||
}
|
|
||||||
return r.handleMatch(c, &matchJson, task)
|
|
||||||
//c.WriteMessage(1, []byte(msg.Choices[0].Content))
|
|
||||||
// 构建消息
|
// 构建消息
|
||||||
//messages := []entitys.Message{
|
//messages := []entitys.Message{
|
||||||
// {
|
// {
|
||||||
|
@ -201,76 +193,6 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
|
|
||||||
defer func() {
|
|
||||||
c.WriteMessage(1, []byte("EOF"))
|
|
||||||
}()
|
|
||||||
if !matchJson.IsMatch {
|
|
||||||
c.WriteMessage(1, []byte(matchJson.Reasoning))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var pointTask *model.AiTask
|
|
||||||
for _, task := range tasks {
|
|
||||||
if task.Index == matchJson.Index {
|
|
||||||
pointTask = &task
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if pointTask == nil || pointTask.Index == "other" {
|
|
||||||
return r.handleOtherTask(c, matchJson)
|
|
||||||
}
|
|
||||||
var res []byte
|
|
||||||
switch pointTask.Type {
|
|
||||||
case constant.TaskTypeApi:
|
|
||||||
res, err = r.handleApiTask(c, matchJson, pointTask)
|
|
||||||
default:
|
|
||||||
return r.handleOtherTask(c, matchJson)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
|
|
||||||
|
|
||||||
c.WriteMessage(1, []byte(matchJson.Reasoning))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (resByte []byte, err error) {
|
|
||||||
var (
|
|
||||||
request l_request.Request
|
|
||||||
auth = c.Headers("X-Authorization", "")
|
|
||||||
requestParam map[string]interface{}
|
|
||||||
)
|
|
||||||
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
|
|
||||||
for k, v := range requestParam {
|
|
||||||
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
|
|
||||||
}
|
|
||||||
var configData entitys.ConfigData
|
|
||||||
err = json.Unmarshal([]byte(task.Config), &configData)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = mapstructure.Decode(configData.Do, &request)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(request.Url) == 0 {
|
|
||||||
err = errors.NewBusinessErr("00022", "api地址获取失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res, err := request.Send()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return res.Content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
|
||||||
|
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
|
@ -339,7 +261,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
|
||||||
return prompt
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent {
|
||||||
var (
|
var (
|
||||||
prompt = make([]llms.MessageContent, 0)
|
prompt = make([]llms.MessageContent, 0)
|
||||||
)
|
)
|
||||||
|
@ -353,11 +275,6 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha
|
||||||
Parts: []llms.ContentPart{
|
Parts: []llms.ContentPart{
|
||||||
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
||||||
},
|
},
|
||||||
}, llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeTool,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
|
||||||
},
|
|
||||||
}, llms.MessageContent{
|
}, llms.MessageContent{
|
||||||
Role: llms.ChatMessageTypeHuman,
|
Role: llms.ChatMessageTypeHuman,
|
||||||
Parts: []llms.ContentPart{
|
Parts: []llms.ContentPart{
|
||||||
|
|
|
@ -11,6 +11,6 @@ const (
|
||||||
type TaskType int32
|
type TaskType int32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TaskTypeApi = 1
|
TaskTypeApi ConnStatus = iota + 1
|
||||||
TaskTypeKnowle = 2
|
TaskTypeKnowle
|
||||||
)
|
)
|
||||||
|
|
|
@ -69,9 +69,9 @@ type Tool interface {
|
||||||
Execute(ctx context.Context, args json.RawMessage) (interface{}, error)
|
Execute(ctx context.Context, args json.RawMessage) (interface{}, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigData struct {
|
// AIClient AI客户端接口
|
||||||
Param map[string]interface{} `json:"param"`
|
type AIClient interface {
|
||||||
Do map[string]interface{} `json:"do"`
|
Chat(ctx context.Context, messages []Message, tools []ToolDefinition) (*ChatResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message 消息
|
// Message 消息
|
||||||
|
@ -88,13 +88,7 @@ type FuncApi struct {
|
||||||
type TaskConfig struct {
|
type TaskConfig struct {
|
||||||
Param interface{} `json:"param"`
|
Param interface{} `json:"param"`
|
||||||
}
|
}
|
||||||
type Match struct {
|
|
||||||
Confidence float64 `json:"confidence"`
|
|
||||||
Index string `json:"index"`
|
|
||||||
IsMatch bool `json:"is_match"`
|
|
||||||
Parameters string `json:"parameters"`
|
|
||||||
Reasoning string `json:"reasoning"`
|
|
||||||
}
|
|
||||||
type ChatHis struct {
|
type ChatHis struct {
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Messages []HisMessage `json:"messages"`
|
Messages []HisMessage `json:"messages"`
|
||||||
|
|
|
@ -18,7 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
||||||
ollama.WithModel(c.Ollama.Model),
|
ollama.WithModel(c.Ollama.Model),
|
||||||
ollama.WithHTTPClient(http.DefaultClient),
|
ollama.WithHTTPClient(http.DefaultClient),
|
||||||
ollama.WithServerURL(getUrl(c)),
|
ollama.WithServerURL(getUrl(c)),
|
||||||
ollama.WithKeepAlive("-1s"),
|
ollama.WithKeepAlive("1h"),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal(err)
|
logger.Fatal(err)
|
||||||
|
|
|
@ -2,11 +2,8 @@ package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg/mapstructure"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_task(t *testing.T) {
|
func Test_task(t *testing.T) {
|
||||||
|
@ -15,23 +12,3 @@ func Test_task(t *testing.T) {
|
||||||
err := json.Unmarshal([]byte(config), &c)
|
err := json.Unmarshal([]byte(config), &c)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
type configData struct {
|
|
||||||
Param map[string]interface{} `json:"param"`
|
|
||||||
Do map[string]interface{} `json:"do"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_task2(t *testing.T) {
|
|
||||||
var (
|
|
||||||
c l_request.Request
|
|
||||||
config configData
|
|
||||||
)
|
|
||||||
|
|
||||||
configJson := `{"param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}, "do": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
|
|
||||||
err := json.Unmarshal([]byte(configJson), &config)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
mapstructure.Decode(config.Do, &c)
|
|
||||||
t.Log(err)
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue